Skip to content

Commit c36eca6

Browse files
committed
🎉 LLM as a judge to write haikus
1 parent c623a76 commit c36eca6

File tree

12 files changed

+1300
-181
lines changed

12 files changed

+1300
-181
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@
1414
Well documented examples of running distributed training jobs on [Modal](https://modal.com).
1515
Use this repository to learn how to build distributed training jobs on Modal.
1616

17+
## Example of Async RL using slime on Modal
18+
19+
```
20+
modal profile activate modal-labs
21+
modal config set-environment clairez-dev
22+
modal deploy slime/tests/modal_train.py # once
23+
modal run slime/tests/modal_train.py::prepare # once
24+
modal run slime/tests/modal_train.py::execute
25+
```
26+
<!-- prepare_dataset
27+
download_model
28+
train -->
29+
1730
# Examples
1831

1932
- [**`benchmark/`**](/benchmark/) contains performance and reliability testing, using AWS EFA by default.
@@ -39,3 +52,5 @@ Other relevant documentation in our guide:
3952
## License
4053

4154
The [MIT license](LICENSE).
55+
56+

haiku/haiku_config.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""Configuration for Qwen3-4B GRPO training on Haiku dataset."""
2+
3+
from dataclasses import dataclass, field
4+
from pathlib import Path
5+
6+
7+
@dataclass
8+
class RLConfig:
9+
"""Training config that passes raw CLI args directly to slime."""
10+
11+
model_name: str
12+
model_id: str
13+
14+
# Modal settings
15+
app_name: str = "slime-grpo"
16+
n_nodes: int = 4
17+
gpu: str = "H100:8"
18+
19+
# Training mode
20+
sync: bool = True
21+
22+
# Wandb
23+
wandb_project: str = "slime-grpo"
24+
wandb_run_name_prefix: str = ""
25+
26+
# Raw CLI args passed directly to slime
27+
slime_args: str = ""
28+
29+
# Extra args that get appended (for easy overrides)
30+
extra_args: list[str] = field(default_factory=list)
31+
32+
@property
33+
def train_script(self) -> str:
34+
return "slime/train.py" if self.sync else "slime/train_async.py"
35+
36+
def _clean_args(self, args: str) -> str:
37+
"""Remove comments and normalize whitespace."""
38+
lines = []
39+
for line in args.split("\n"):
40+
if "#" in line:
41+
line = line[: line.index("#")]
42+
line = line.strip()
43+
if line:
44+
lines.append(line)
45+
return " ".join(lines)
46+
47+
def generate_train_args(self, models_path: Path, data_path: Path, is_infinite_run: bool) -> str:
48+
base_args = f"--hf-checkpoint {models_path}/{self.model_name} --ref-load {models_path}/{self.model_name}"
49+
50+
cleaned_slime_args = self._clean_args(self.slime_args)
51+
cleaned_slime_args = cleaned_slime_args.replace("{data_path}", str(data_path))
52+
cleaned_slime_args = cleaned_slime_args.replace("{models_path}", str(models_path))
53+
54+
extra = " ".join(self.extra_args) if self.extra_args else ""
55+
56+
return f"{base_args} {cleaned_slime_args} {extra}".strip()
57+
58+
59+
# ── Model architecture constants ──
60+
61+
QWEN3_4B_MODEL_ARGS = """
62+
--num-layers 36 --hidden-size 2560 --ffn-hidden-size 9728
63+
--num-attention-heads 32 --group-query-attention --num-query-groups 8
64+
--kv-channels 128 --vocab-size 151936
65+
--normalization RMSNorm --norm-epsilon 1e-6 --swiglu
66+
--disable-bias-linear --qk-layernorm
67+
--use-rotary-position-embeddings --rotary-base 1000000
68+
"""
69+
70+
DEFAULT_TRAINING_ARGS = """
71+
--tensor-model-parallel-size 2 --sequence-parallel
72+
--recompute-granularity full --recompute-method uniform --recompute-num-layers 1
73+
--use-dynamic-batch-size --max-tokens-per-gpu 9216
74+
--megatron-to-hf-mode bridge
75+
--attention-dropout 0.0 --hidden-dropout 0.0
76+
--accumulate-allreduce-grads-in-fp32 --attention-softmax-in-fp32
77+
"""
78+
79+
DEFAULT_OPTIMIZER_ARGS = """
80+
--optimizer adam
81+
--lr 1e-6 --lr-decay-style constant
82+
--weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98
83+
"""
84+
85+
DEFAULT_GRPO_ARGS = """
86+
--advantage-estimator grpo
87+
--use-kl-loss --kl-loss-coef 0.00 --kl-loss-type low_var_kl
88+
--entropy-coef 0.00
89+
--eps-clip 0.2 --eps-clip-high 0.28
90+
"""
91+
92+
93+
# ── Config factory ──
94+
95+
def get_config(run_name: str = "qwen3-4b-haiku") -> RLConfig:
96+
return RLConfig(
97+
model_name="Qwen3-4B",
98+
model_id="Qwen/Qwen3-4B",
99+
n_nodes=1,
100+
gpu="H200:8",
101+
app_name="slime-qwen3-4b-haiku",
102+
sync=True,
103+
wandb_project="slime-grpo-haiku",
104+
wandb_run_name_prefix=run_name,
105+
slime_args=f"""
106+
# Model architecture
107+
{QWEN3_4B_MODEL_ARGS}
108+
109+
# Training parallelism and optimization
110+
{DEFAULT_TRAINING_ARGS}
111+
112+
# Optimizer
113+
{DEFAULT_OPTIMIZER_ARGS}
114+
115+
# GRPO algorithm
116+
{DEFAULT_GRPO_ARGS}
117+
118+
# Data
119+
--input-key messages --label-key label
120+
--apply-chat-template --rollout-shuffle
121+
--prompt-data {{data_path}}/haiku/train.parquet
122+
123+
# Custom reward model
124+
--rm-type remote_rm
125+
--rm-url https://modal-labs-joy-dev--llm-judge-reward-model-llmjudgeflash.us-east.modal.direct/score
126+
127+
--num-rollout 50
128+
--rollout-batch-size 128
129+
--n-samples-per-prompt 8
130+
--global-batch-size 64
131+
132+
# SGLang
133+
--rollout-num-gpus-per-engine 2
134+
--sglang-mem-fraction-static 0.7
135+
136+
--rollout-max-response-len 300
137+
138+
--rollout-temperature 1
139+
--rollout-skip-special-tokens
140+
141+
# Orchestration
142+
--actor-num-nodes 1
143+
--actor-num-gpus-per-node 8
144+
--colocate
145+
146+
# Eval
147+
--eval-prompt-data haiku {{data_path}}/haiku/test.parquet
148+
--eval-interval 20
149+
--n-samples-per-eval-prompt 8
150+
--eval-max-response-len 300
151+
--eval-top-p 1
152+
""",
153+
)

haiku/llm_judges/__init__.py

Whitespace-only changes.

haiku/llm_judges/base.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""
2+
Haiku LLM judge — scores structure (syllable counting) and style (LLM evaluation).
3+
4+
Structure scoring uses CMUdict for syllable counting.
5+
Style scoring uses a local vLLM instance to evaluate relevance, poetic quality, etc.
6+
"""
7+
8+
import re
9+
10+
import aiohttp
11+
12+
from llm_judges.deploy import VLLM_PORT
13+
from llm_judges.nlp import score_haiku_structure
14+
15+
16+
MODAL_VOCABS = [
17+
"modal",
18+
"volume",
19+
"function",
20+
"sandbox",
21+
"flash",
22+
"inference",
23+
"train",
24+
]
25+
26+
27+
def _build_judge_prompt(prompt: str, response: str, label: str = "") -> tuple[str, int]:
28+
"""Build the LLM judge prompt. Returns (prompt_text, max_score)."""
29+
modal_vocab_str = ", ".join(MODAL_VOCABS)
30+
31+
max_score = 15 # relevance(5) + poetic(5) + modal vocab(5)
32+
33+
text = f"""You are evaluating a haiku poem.
34+
35+
Score the response based on the following criteria:
36+
37+
Relevance (5 points total)
38+
- 5 points: if the central theme and punchline of the haiku is "{prompt}"
39+
- 3 points: if the response directly discusses "{prompt}" but it is not the central theme
40+
- 2 points: if the response is relevant to the topic "{prompt}" but very plain
41+
- 0 points: if the response is not relevant to the topic "{prompt}"
42+
43+
Poetic quality (5 points total)
44+
- 5 points: if the response makes sense, can be considered a poetic haiku, with a clear theme and punchline
45+
- 3 point: if the response makes sense, but is not very poetic
46+
- 1 point: if the response doesn't make sense
47+
- 0 points: if the response is not poetic and incoherent
48+
"""
49+
50+
if label:
51+
max_score = 20
52+
text += f"""
53+
Better than the existing poem (5 points total):
54+
Given the existing poem, score the response by comparing its quality to the existing poem:
55+
{label}
56+
- 5 points: if the response is better than the poem "{label}".
57+
- 3 points: if the response is equal in quality to the poem "{label}".
58+
- 0 points: if the response is worse than the poem "{label}".
59+
"""
60+
61+
prereq_score = max_score - 5
62+
text += f"""
63+
Uses Modal vocabulary (5 points total): (modal vocab: {modal_vocab_str})
64+
- 5 points: if the response uses the above words in a way that is coherent and relevant to the topic "{prompt}"
65+
- 3 points: if the response uses the above words in a way that is not relevant to the topic "{prompt}"
66+
- 0 points: if the response does not use the above words
67+
DO NOT GIVE ANY POINTS TO USE MODAL VOCABULARY IF THE POEM ITSELF DOES NOT ALREADY ACHIEVE A SCORE OF {prereq_score} OR HIGHER
68+
69+
Add up the scores from the above criteria to get the total score.
70+
71+
--
72+
**Topic:** {prompt}
73+
74+
**Response to evaluate:**
75+
{response}
76+
---
77+
78+
Output ONLY a single number (0-{max_score}), nothing else."""
79+
80+
return text, max_score
81+
82+
83+
class HaikuJudge:
84+
"""Scores haikus on structure (syllable counting) and style (LLM evaluation).
85+
86+
Args:
87+
gate_style_on_structure: If True, only evaluate style when structure
88+
score is perfect (1.0). If False, always evaluate style.
89+
"""
90+
91+
def __init__(self, gate_style_on_structure: bool = True):
92+
self.gate_style_on_structure = gate_style_on_structure
93+
94+
async def score_style(
95+
self,
96+
model_name: str,
97+
session: aiohttp.ClientSession,
98+
prompt: str,
99+
response: str,
100+
label: str = "",
101+
vllm_base_url: str = f"http://localhost:{VLLM_PORT}",
102+
) -> float:
103+
"""Score haiku style via LLM judge, normalized to [0, 1]."""
104+
judge_prompt, max_score = _build_judge_prompt(prompt, response, label)
105+
106+
try:
107+
async with session.post(
108+
f"{vllm_base_url}/v1/chat/completions",
109+
headers={"content-type": "application/json"},
110+
json={
111+
"model": model_name,
112+
"messages": [{"role": "user", "content": judge_prompt}],
113+
"max_tokens": 100,
114+
},
115+
) as resp:
116+
if resp.status != 200:
117+
error_text = await resp.text()
118+
print(f"vLLM error: {resp.status} - {error_text}")
119+
return 0
120+
121+
data = await resp.json()
122+
score_text = data["choices"][0]["message"]["content"].strip()
123+
print(f"Scored {response} with score {score_text}")
124+
125+
match = re.search(r"(\d+(?:\.\d+)?)", score_text)
126+
if match:
127+
score = float(match.group(1))
128+
return min(max(score, 0), max_score) / max_score
129+
return 0
130+
except Exception as e:
131+
print(f"Error scoring response: {e}")
132+
return 0
133+
134+
async def score_single(
135+
self,
136+
model_name: str,
137+
session: aiohttp.ClientSession,
138+
prompt: str,
139+
response: str,
140+
cmudict: dict,
141+
label: str = "",
142+
) -> float:
143+
"""Score a single haiku. Returns a score in [0, 2]."""
144+
structure_score = score_haiku_structure(response, cmudict)
145+
146+
style_score = 0.0
147+
if not self.gate_style_on_structure or structure_score >= 1.0:
148+
style_score = await self.score_style(
149+
model_name, session, prompt, response, label
150+
)
151+
style_score = max(style_score, 0.0)
152+
153+
total = structure_score + style_score
154+
print(f"[HaikuJudge] structure={structure_score}, style={style_score}, gated={self.gate_style_on_structure}")
155+
return total

0 commit comments

Comments
 (0)