Skip to content

Commit fabf7d5

Browse files
committed
Wire up Gemma 3 270M for GRPO training
- Switch default model from gemma-3n-e2b-it to gemma-3-270m-it (0.3B params) - Dynamically load ALL forms from abide.forms (140+ forms) - Reduce batch sizes and sequence lengths for smaller model - Use 0.7 GPU memory utilization for vLLM - Add flashinfer-python for faster sampling
1 parent 879d870 commit fabf7d5

File tree

7 files changed

+265
-103
lines changed

7 files changed

+265
-103
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,5 @@ Thumbs.db
4646

4747
# uv
4848
.uv/
49+
logs/
50+
results/

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ classifiers = [
2323
]
2424

2525
dependencies = [
26-
"pronouncing>=0.2.0", # CMU Pronouncing Dictionary wrapper
27-
"setuptools", # Required by pronouncing for pkg_resources
28-
"numpy>=1.24.0", # Numerical operations
29-
"jellyfish>=1.0.0", # Phonetic encodings (Soundex, Metaphone, etc.)
30-
"pyphen>=0.14.0", # Hyphenation for syllable fallback
26+
"pronouncing>=0.2.0", # CMU Pronouncing Dictionary wrapper
27+
"setuptools", # Required by pronouncing for pkg_resources
28+
"numpy>=1.24.0", # Numerical operations
29+
"jellyfish>=1.0.0", # Phonetic encodings (Soundex, Metaphone, etc.)
30+
"pyphen>=0.14.0", # Hyphenation for syllable fallback
31+
"flashinfer-python>=0.5.3",
3132
]
3233

3334
[project.optional-dependencies]

scripts/eval_local.py

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,8 @@
2626
# Add src to path
2727
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
2828

29-
# Forms to evaluate (same as training)
30-
EVAL_FORMS = [
31-
# Hard forms
32-
"StaircasePoem",
33-
"VowelBudgetPoem",
34-
"PrecisionVerse",
35-
"ExactWordPoem",
36-
"CharacterBudgetPoem",
37-
# Mathematical forms
38-
"FibonacciVerse",
39-
"TriangularVerse",
40-
"PiKu",
41-
# Novel forms
42-
"HourglassVerse",
43-
"PrimeVerse",
44-
"GoldenRatio",
45-
]
29+
# Forms to evaluate - dynamically loaded, but can override with --forms
30+
EVAL_FORMS = None # Will load all forms dynamically
4631

4732
TOPICS = [
4833
"the passage of time",
@@ -53,34 +38,79 @@
5338
]
5439

5540

56-
def get_forms(form_names: list[str]) -> dict[str, object]:
57-
"""Load form instances."""
58-
from abide.forms import hard, mathematical, novel
59-
60-
form_configs = {
61-
# Hard forms
62-
"StaircasePoem": (hard, "StaircasePoem", {"num_words": 7}),
63-
"VowelBudgetPoem": (hard, "VowelBudgetPoem", {"vowel_count": 30}),
64-
"PrecisionVerse": (hard, "PrecisionVerse", {"chars_per_line": 25}),
65-
"ExactWordPoem": (hard, "ExactWordPoem", {"word_count": 20}),
66-
"CharacterBudgetPoem": (hard, "CharacterBudgetPoem", {"character": "e", "count": 10}),
67-
# Mathematical forms
68-
"FibonacciVerse": (mathematical, "FibonacciVerse", {"num_lines": 5}),
69-
"TriangularVerse": (mathematical, "TriangularVerse", {"num_lines": 4}),
70-
"PiKu": (mathematical, "PiKu", {"num_lines": 5}),
71-
# Novel forms
72-
"HourglassVerse": (novel, "HourglassVerse", {}),
73-
"PrimeVerse": (novel, "PrimeVerse", {}),
74-
"GoldenRatio": (novel, "GoldenRatio", {}),
75-
}
76-
77-
forms = {}
78-
for name in form_names:
79-
if name in form_configs:
80-
module, cls_name, kwargs = form_configs[name]
81-
form_class = getattr(module, cls_name)
82-
forms[name] = form_class(**kwargs)
83-
return forms
41+
def get_forms(form_names: list[str] | None = None) -> dict[str, object]:
42+
"""Load form instances. If form_names is None, load ALL forms."""
43+
import abide.forms as forms_module
44+
45+
all_forms = {}
46+
names_to_load = form_names if form_names else forms_module.__all__
47+
48+
for name in names_to_load:
49+
try:
50+
form_class = getattr(forms_module, name)
51+
# Try to instantiate with no args first
52+
try:
53+
all_forms[name] = form_class()
54+
except TypeError:
55+
# Some forms need specific params - use sensible defaults
56+
if name == "StaircasePoem" or name == "DescendingStaircasePoem":
57+
all_forms[name] = form_class(num_words=7)
58+
elif name == "VowelBudgetPoem":
59+
all_forms[name] = form_class(vowel_count=30)
60+
elif name == "PrecisionVerse":
61+
all_forms[name] = form_class(chars_per_line=25)
62+
elif name == "ExactWordPoem":
63+
all_forms[name] = form_class(word_count=20)
64+
elif name == "CharacterBudgetPoem":
65+
all_forms[name] = form_class(character="e", count=10)
66+
elif name == "TotalCharacterPoem":
67+
all_forms[name] = form_class(total_chars=100)
68+
elif name == "FibonacciVerse":
69+
all_forms[name] = form_class(num_lines=5)
70+
elif name == "TriangularVerse":
71+
all_forms[name] = form_class(num_lines=4)
72+
elif name == "PiKu":
73+
all_forms[name] = form_class(num_lines=5)
74+
elif name == "PrecisionHaiku":
75+
all_forms[name] = form_class(chars_per_line=17)
76+
elif name == "ArithmeticVerse":
77+
all_forms[name] = form_class(start=2, diff=2, num_lines=5)
78+
elif name == "PositionalPoem":
79+
all_forms[name] = form_class(positions=[1, 2, 3])
80+
elif name == "IsolatedCouplet":
81+
all_forms[name] = form_class(position=3)
82+
elif name == "AlternatingIsolation":
83+
all_forms[name] = form_class(num_lines=6)
84+
elif name == "DoubleAcrosticPoem":
85+
all_forms[name] = form_class(word="POETRY")
86+
elif name == "CombinedChallenge":
87+
all_forms[name] = form_class(num_lines=4)
88+
elif name == "Lipogram":
89+
all_forms[name] = form_class(forbidden="e")
90+
elif name == "Univocalic":
91+
all_forms[name] = form_class(vowel="a")
92+
elif name == "Mesostic":
93+
all_forms[name] = form_class(spine="POEM")
94+
elif name == "Anaphora":
95+
all_forms[name] = form_class(phrase="I am", num_lines=4)
96+
elif name == "ModularVerse":
97+
all_forms[name] = form_class(modulus=3, num_lines=6)
98+
elif name == "CoprimeVerse":
99+
all_forms[name] = form_class(base=6, num_lines=4)
100+
elif name == "SquareStanzas":
101+
all_forms[name] = form_class(size=4)
102+
elif name == "SelfReferential":
103+
all_forms[name] = form_class(num_lines=4)
104+
elif name == "GoldenRatioVerse":
105+
all_forms[name] = form_class(num_lines=6)
106+
elif name == "PythagoreanTercet":
107+
all_forms[name] = form_class(scale=2)
108+
else:
109+
continue
110+
except Exception:
111+
continue
112+
113+
return all_forms
84114

85115

86116
def evaluate_form(
@@ -161,8 +191,8 @@ def main() -> int:
161191
print(f"Failed to connect to vLLM at {args.url}: {e}")
162192
return 1
163193

164-
# Load forms
165-
form_names = args.forms.split(",") if args.forms else EVAL_FORMS
194+
# Load forms (None means all)
195+
form_names = args.forms.split(",") if args.forms else None
166196
forms = get_forms(form_names)
167197
print(f"\nEvaluating {len(forms)} forms with {args.samples} samples per topic")
168198
print(f"Forms: {', '.join(forms.keys())}")

scripts/prompt_generator.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -139,28 +139,80 @@
139139

140140

141141
def get_forms() -> dict[str, object]:
142-
"""Load all training forms with their configurations."""
142+
"""Load ALL training forms from abide.forms."""
143143
# Add src to path
144144
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
145145

146-
from abide.forms import hard, mathematical, novel
147-
148-
return {
149-
# Hard forms
150-
"StaircasePoem": hard.StaircasePoem(num_words=7),
151-
"VowelBudgetPoem": hard.VowelBudgetPoem(vowel_count=30),
152-
"PrecisionVerse": hard.PrecisionVerse(chars_per_line=25),
153-
"ExactWordPoem": hard.ExactWordPoem(word_count=20),
154-
"CharacterBudgetPoem": hard.CharacterBudgetPoem(character="e", count=10),
155-
# Mathematical forms
156-
"FibonacciVerse": mathematical.FibonacciVerse(num_lines=5),
157-
"TriangularVerse": mathematical.TriangularVerse(num_lines=4),
158-
"PiKu": mathematical.PiKu(num_lines=5),
159-
# Novel forms
160-
"HourglassVerse": novel.HourglassVerse(),
161-
"PrimeVerse": novel.PrimeVerse(),
162-
"GoldenRatio": novel.GoldenRatio(),
163-
}
146+
import abide.forms as forms_module
147+
148+
all_forms = {}
149+
for name in forms_module.__all__:
150+
try:
151+
form_class = getattr(forms_module, name)
152+
# Try to instantiate with no args first
153+
try:
154+
all_forms[name] = form_class()
155+
except TypeError:
156+
# Some forms need specific params - use sensible defaults
157+
if name == "StaircasePoem" or name == "DescendingStaircasePoem":
158+
all_forms[name] = form_class(num_words=7)
159+
elif name == "VowelBudgetPoem":
160+
all_forms[name] = form_class(vowel_count=30)
161+
elif name == "PrecisionVerse":
162+
all_forms[name] = form_class(chars_per_line=25)
163+
elif name == "ExactWordPoem":
164+
all_forms[name] = form_class(word_count=20)
165+
elif name == "CharacterBudgetPoem":
166+
all_forms[name] = form_class(character="e", count=10)
167+
elif name == "TotalCharacterPoem":
168+
all_forms[name] = form_class(total_chars=100)
169+
elif name == "FibonacciVerse":
170+
all_forms[name] = form_class(num_lines=5)
171+
elif name == "TriangularVerse":
172+
all_forms[name] = form_class(num_lines=4)
173+
elif name == "PiKu":
174+
all_forms[name] = form_class(num_lines=5)
175+
elif name == "PrecisionHaiku":
176+
all_forms[name] = form_class(chars_per_line=17)
177+
elif name == "ArithmeticVerse":
178+
all_forms[name] = form_class(start=2, diff=2, num_lines=5)
179+
elif name == "PositionalPoem":
180+
all_forms[name] = form_class(positions=[1, 2, 3])
181+
elif name == "IsolatedCouplet":
182+
all_forms[name] = form_class(position=3)
183+
elif name == "AlternatingIsolation":
184+
all_forms[name] = form_class(num_lines=6)
185+
elif name == "DoubleAcrosticPoem":
186+
all_forms[name] = form_class(word="POETRY")
187+
elif name == "CombinedChallenge":
188+
all_forms[name] = form_class(num_lines=4)
189+
elif name == "Lipogram":
190+
all_forms[name] = form_class(forbidden="e")
191+
elif name == "Univocalic":
192+
all_forms[name] = form_class(vowel="a")
193+
elif name == "Mesostic":
194+
all_forms[name] = form_class(spine="POEM")
195+
elif name == "Anaphora":
196+
all_forms[name] = form_class(phrase="I am", num_lines=4)
197+
elif name == "ModularVerse":
198+
all_forms[name] = form_class(modulus=3, num_lines=6)
199+
elif name == "CoprimeVerse":
200+
all_forms[name] = form_class(base=6, num_lines=4)
201+
elif name == "SquareStanzas":
202+
all_forms[name] = form_class(size=4)
203+
elif name == "SelfReferential":
204+
all_forms[name] = form_class(num_lines=4)
205+
elif name == "GoldenRatioVerse":
206+
all_forms[name] = form_class(num_lines=6)
207+
elif name == "PythagoreanTercet":
208+
all_forms[name] = form_class(scale=2)
209+
else:
210+
# Skip forms we can't instantiate
211+
continue
212+
except Exception:
213+
continue
214+
215+
return all_forms
164216

165217

166218
def generate_prompt(
@@ -224,9 +276,12 @@ def generate_dataset(
224276
dataset.append(
225277
{
226278
"prompt": [{"role": "user", "content": prompt}],
227-
"form_name": form_name,
228-
"topic": topic,
229-
"style": style,
279+
# verifiers passes 'info' dict to reward functions
280+
"info": {
281+
"form_name": form_name,
282+
"topic": topic,
283+
"style": style,
284+
},
230285
}
231286
)
232287

scripts/run_grpo.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -e
66
# (GPU order matters for NCCL - training must be on GPU 0)
77

88
# Configuration
9-
MODEL="${ABIDE_MODEL:-google/gemma-3n-e2b-it}"
9+
MODEL="${ABIDE_MODEL:-google/gemma-3-270m-it}"
1010
PORT=8000
1111
VLLM_PID=""
1212

@@ -42,9 +42,9 @@ echo "Starting vf-vllm on GPU 1..."
4242
CUDA_VISIBLE_DEVICES=1 nohup /home/darren/miniconda3/bin/vf-vllm \
4343
--model "$MODEL" \
4444
--port $PORT \
45-
--gpu-memory-utilization 0.85 \
45+
--gpu-memory-utilization 0.7 \
4646
--tensor-parallel-size 1 \
47-
--max-model-len 2048 \
47+
--max-model-len 1024 \
4848
--trust-remote-code \
4949
--disable-log-stats \
5050
--enforce-eager \
@@ -66,6 +66,7 @@ echo "vLLM is ready!"
6666
echo ""
6767
echo "Starting training on GPU 0..."
6868
export OMP_NUM_THREADS=4
69+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
6970
CUDA_VISIBLE_DEVICES=0 /home/darren/miniconda3/bin/torchrun --nproc_per_node=1 scripts/train_grpo.py
7071

7172
# Cleanup

0 commit comments

Comments
 (0)