|
26 | 26 | # Add src to path |
27 | 27 | sys.path.insert(0, str(Path(__file__).parent.parent / "src")) |
28 | 28 |
|
29 | | -# Forms to evaluate (same as training) |
30 | | -EVAL_FORMS = [ |
31 | | - # Hard forms |
32 | | - "StaircasePoem", |
33 | | - "VowelBudgetPoem", |
34 | | - "PrecisionVerse", |
35 | | - "ExactWordPoem", |
36 | | - "CharacterBudgetPoem", |
37 | | - # Mathematical forms |
38 | | - "FibonacciVerse", |
39 | | - "TriangularVerse", |
40 | | - "PiKu", |
41 | | - # Novel forms |
42 | | - "HourglassVerse", |
43 | | - "PrimeVerse", |
44 | | - "GoldenRatio", |
45 | | -] |
| 29 | +# Forms to evaluate - dynamically loaded, but can override with --forms |
| 30 | +EVAL_FORMS = None # Will load all forms dynamically |
46 | 31 |
|
47 | 32 | TOPICS = [ |
48 | 33 | "the passage of time", |
|
53 | 38 | ] |
54 | 39 |
|
55 | 40 |
|
56 | | -def get_forms(form_names: list[str]) -> dict[str, object]: |
57 | | - """Load form instances.""" |
58 | | - from abide.forms import hard, mathematical, novel |
59 | | - |
60 | | - form_configs = { |
61 | | - # Hard forms |
62 | | - "StaircasePoem": (hard, "StaircasePoem", {"num_words": 7}), |
63 | | - "VowelBudgetPoem": (hard, "VowelBudgetPoem", {"vowel_count": 30}), |
64 | | - "PrecisionVerse": (hard, "PrecisionVerse", {"chars_per_line": 25}), |
65 | | - "ExactWordPoem": (hard, "ExactWordPoem", {"word_count": 20}), |
66 | | - "CharacterBudgetPoem": (hard, "CharacterBudgetPoem", {"character": "e", "count": 10}), |
67 | | - # Mathematical forms |
68 | | - "FibonacciVerse": (mathematical, "FibonacciVerse", {"num_lines": 5}), |
69 | | - "TriangularVerse": (mathematical, "TriangularVerse", {"num_lines": 4}), |
70 | | - "PiKu": (mathematical, "PiKu", {"num_lines": 5}), |
71 | | - # Novel forms |
72 | | - "HourglassVerse": (novel, "HourglassVerse", {}), |
73 | | - "PrimeVerse": (novel, "PrimeVerse", {}), |
74 | | - "GoldenRatio": (novel, "GoldenRatio", {}), |
75 | | - } |
76 | | - |
77 | | - forms = {} |
78 | | - for name in form_names: |
79 | | - if name in form_configs: |
80 | | - module, cls_name, kwargs = form_configs[name] |
81 | | - form_class = getattr(module, cls_name) |
82 | | - forms[name] = form_class(**kwargs) |
83 | | - return forms |
| 41 | +def get_forms(form_names: list[str] | None = None) -> dict[str, object]: |
| 42 | + """Load form instances. If form_names is None, load ALL forms.""" |
| 43 | + import abide.forms as forms_module |
| 44 | + |
| 45 | + all_forms = {} |
| 46 | + names_to_load = form_names if form_names else forms_module.__all__ |
| 47 | + |
| 48 | + for name in names_to_load: |
| 49 | + try: |
| 50 | + form_class = getattr(forms_module, name) |
| 51 | + # Try to instantiate with no args first |
| 52 | + try: |
| 53 | + all_forms[name] = form_class() |
| 54 | + except TypeError: |
| 55 | + # Some forms need specific params - use sensible defaults |
| 56 | + if name == "StaircasePoem" or name == "DescendingStaircasePoem": |
| 57 | + all_forms[name] = form_class(num_words=7) |
| 58 | + elif name == "VowelBudgetPoem": |
| 59 | + all_forms[name] = form_class(vowel_count=30) |
| 60 | + elif name == "PrecisionVerse": |
| 61 | + all_forms[name] = form_class(chars_per_line=25) |
| 62 | + elif name == "ExactWordPoem": |
| 63 | + all_forms[name] = form_class(word_count=20) |
| 64 | + elif name == "CharacterBudgetPoem": |
| 65 | + all_forms[name] = form_class(character="e", count=10) |
| 66 | + elif name == "TotalCharacterPoem": |
| 67 | + all_forms[name] = form_class(total_chars=100) |
| 68 | + elif name == "FibonacciVerse": |
| 69 | + all_forms[name] = form_class(num_lines=5) |
| 70 | + elif name == "TriangularVerse": |
| 71 | + all_forms[name] = form_class(num_lines=4) |
| 72 | + elif name == "PiKu": |
| 73 | + all_forms[name] = form_class(num_lines=5) |
| 74 | + elif name == "PrecisionHaiku": |
| 75 | + all_forms[name] = form_class(chars_per_line=17) |
| 76 | + elif name == "ArithmeticVerse": |
| 77 | + all_forms[name] = form_class(start=2, diff=2, num_lines=5) |
| 78 | + elif name == "PositionalPoem": |
| 79 | + all_forms[name] = form_class(positions=[1, 2, 3]) |
| 80 | + elif name == "IsolatedCouplet": |
| 81 | + all_forms[name] = form_class(position=3) |
| 82 | + elif name == "AlternatingIsolation": |
| 83 | + all_forms[name] = form_class(num_lines=6) |
| 84 | + elif name == "DoubleAcrosticPoem": |
| 85 | + all_forms[name] = form_class(word="POETRY") |
| 86 | + elif name == "CombinedChallenge": |
| 87 | + all_forms[name] = form_class(num_lines=4) |
| 88 | + elif name == "Lipogram": |
| 89 | + all_forms[name] = form_class(forbidden="e") |
| 90 | + elif name == "Univocalic": |
| 91 | + all_forms[name] = form_class(vowel="a") |
| 92 | + elif name == "Mesostic": |
| 93 | + all_forms[name] = form_class(spine="POEM") |
| 94 | + elif name == "Anaphora": |
| 95 | + all_forms[name] = form_class(phrase="I am", num_lines=4) |
| 96 | + elif name == "ModularVerse": |
| 97 | + all_forms[name] = form_class(modulus=3, num_lines=6) |
| 98 | + elif name == "CoprimeVerse": |
| 99 | + all_forms[name] = form_class(base=6, num_lines=4) |
| 100 | + elif name == "SquareStanzas": |
| 101 | + all_forms[name] = form_class(size=4) |
| 102 | + elif name == "SelfReferential": |
| 103 | + all_forms[name] = form_class(num_lines=4) |
| 104 | + elif name == "GoldenRatioVerse": |
| 105 | + all_forms[name] = form_class(num_lines=6) |
| 106 | + elif name == "PythagoreanTercet": |
| 107 | + all_forms[name] = form_class(scale=2) |
| 108 | + else: |
| 109 | + continue |
| 110 | + except Exception: |
| 111 | + continue |
| 112 | + |
| 113 | + return all_forms |
84 | 114 |
|
85 | 115 |
|
86 | 116 | def evaluate_form( |
@@ -161,8 +191,8 @@ def main() -> int: |
161 | 191 | print(f"Failed to connect to vLLM at {args.url}: {e}") |
162 | 192 | return 1 |
163 | 193 |
|
164 | | - # Load forms |
165 | | - form_names = args.forms.split(",") if args.forms else EVAL_FORMS |
| 194 | + # Load forms (None means all) |
| 195 | + form_names = args.forms.split(",") if args.forms else None |
166 | 196 | forms = get_forms(form_names) |
167 | 197 | print(f"\nEvaluating {len(forms)} forms with {args.samples} samples per topic") |
168 | 198 | print(f"Forms: {', '.join(forms.keys())}") |
|
0 commit comments