diff --git a/README.md b/README.md index 5f7c625..eb3325d 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,19 @@ Well documented examples of running distributed training jobs on [Modal](https://modal.com). Use this repository to learn how to build distributed training jobs on Modal. +## Example of Async RL using slime on Modal + +``` +modal profile activate modal-labs +modal config set-environment clairez-dev +modal deploy slime/tests/modal_train.py # once +modal run slime/tests/modal_train.py::prepare # once +modal run slime/tests/modal_train.py::execute +``` + + # Examples - [**`benchmark/`**](/benchmark/) contains performance and reliability testing, using AWS EFA by default. @@ -39,3 +52,5 @@ Other relevant documentation in our guide: ## License The [MIT license](LICENSE). + + diff --git a/haiku/config.py b/haiku/config.py new file mode 100644 index 0000000..f06a18c --- /dev/null +++ b/haiku/config.py @@ -0,0 +1,193 @@ +"""Configuration for Qwen3-4B GRPO training on Haiku dataset.""" + +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path + + + + +_MODEL_INFO = { + "Qwen/Qwen3-30B-A3B-Instruct-2507": ("qwen3-30b-a3b-instruct", "30b"), + "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": ("qwen3-235b-a22b-instruct-fp8", "235b"), +} + + +class JudgeType(str, Enum): + STRICT = "strict" + STRICT_LEVELED = "strict-leveled" + NO_LLM = "no-llm" # only use the structure score + +class JudgeModelSize(str, Enum): + QWEN3_30B = "Qwen/Qwen3-30B-A3B-Instruct-2507" + QWEN3_235B = "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8" + + @property + def model_name(self) -> str: + return _MODEL_INFO[self.value][0] + + @property + def shorthand(self) -> str: + return _MODEL_INFO[self.value][1] + + + +ACTIVE_JUDGE_TYPE = JudgeType.NO_LLM +ACTIVE_JUDGE_MODEL_SIZE = JudgeModelSize.QWEN3_30B + + + +@dataclass +class RLConfig: + """Training config that passes raw CLI args directly to slime.""" + + model_name: str + model_id: str + + # Modal settings + n_nodes: int = 4 + gpu: str = "H100:8" + + # Wandb + wandb_project: str = "example-train-haiku" + wandb_run_name_prefix: str = "" + + # Raw CLI args passed directly to slime + slime_args: str = "" + + save_steps: int = 10 + + # Extra args that get appended (for easy overrides) + extra_args: list[str] = field(default_factory=list) + + def _clean_args(self, args: str) -> str: + """Remove comments and normalize whitespace.""" + lines = [] + for line in args.split("\n"): + if "#" in line: + line = line[: line.index("#")] + line = line.strip() + if line: + lines.append(line) + return " ".join(lines) + + def generate_train_args(self, data_path: Path) -> str: + from huggingface_hub import snapshot_download + + model_path = snapshot_download(self.model_id) + base_args = f"--hf-checkpoint {model_path} --ref-load {model_path}" + + cleaned_slime_args = self._clean_args(self.slime_args) + cleaned_slime_args = cleaned_slime_args.replace("{data_path}", str(data_path)) + + extra = " ".join(self.extra_args) if self.extra_args else "" + + return f"{base_args} {cleaned_slime_args} {extra}".strip() + + +# ── Model architecture constants ── + +QWEN3_4B_MODEL_ARGS = """ + --num-layers 36 --hidden-size 2560 --ffn-hidden-size 9728 + --num-attention-heads 32 --group-query-attention --num-query-groups 8 + --kv-channels 128 --vocab-size 151936 + --normalization RMSNorm --norm-epsilon 1e-6 --swiglu + --disable-bias-linear --qk-layernorm + --use-rotary-position-embeddings --rotary-base 1000000 +""" + +DEFAULT_TRAINING_ARGS = """ + --tensor-model-parallel-size 2 --sequence-parallel + --recompute-granularity full --recompute-method uniform --recompute-num-layers 1 + --use-dynamic-batch-size --max-tokens-per-gpu 9216 + --megatron-to-hf-mode bridge + --attention-dropout 0.0 --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 --attention-softmax-in-fp32 +""" + +DEFAULT_OPTIMIZER_ARGS = """ + --optimizer adam + --lr 1e-6 --lr-decay-style constant + --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98 +""" + +DEFAULT_GRPO_ARGS = """ + --advantage-estimator grpo + --use-kl-loss --kl-loss-coef 0.00 --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 --eps-clip-high 0.28 +""" + + +# ── Config factory ── + +def _get_judge_url(judge_type: JudgeType, judge_model_size: JudgeModelSize) -> str: + return f"https://modal-labs-joy-dev--llm-judge-{judge_model_size.shorthand}-{judge_type.value}-llmjudge.us-east.modal.direct" + + +def _get_reward_model_args_from_judge_type(judge_type: JudgeType, judge_model_size: JudgeModelSize) -> str: + if judge_type == JudgeType.STRICT or judge_type == JudgeType.STRICT_LEVELED: + return f"""--rm-type remote_rm + --rm-url {_get_judge_url(judge_type, judge_model_size)}/score""" + elif judge_type == JudgeType.NO_LLM: + return """--rm-type async_rm + --custom-rm-path llm_judges.nlp.haiku_rm""" + +def get_config(run_name: str = "qwen3-4b-haiku", judge_type = ACTIVE_JUDGE_TYPE, judge_model_size = ACTIVE_JUDGE_MODEL_SIZE) -> RLConfig: + return RLConfig( + model_name="Qwen3-4B", + model_id="Qwen/Qwen3-4B", + n_nodes=1, + gpu="H200:8", + wandb_project="example-train-haiku", + wandb_run_name_prefix=run_name, + save_steps=10, + slime_args=f""" + # Model architecture + {QWEN3_4B_MODEL_ARGS} + + # Training parallelism and optimization + {DEFAULT_TRAINING_ARGS} + + # Optimizer + {DEFAULT_OPTIMIZER_ARGS} + + # GRPO algorithm + {DEFAULT_GRPO_ARGS} + + # Data + --input-key messages --label-key label + --apply-chat-template --rollout-shuffle + --apply-chat-template-kwargs '{{"enable_thinking": false}}' + --prompt-data {{data_path}}/haiku/train.parquet + + # Custom reward model + {_get_reward_model_args_from_judge_type(judge_type, judge_model_size)} + + --num-rollout 50 + --rollout-batch-size 128 + --n-samples-per-prompt 8 + --global-batch-size 64 + + # SGLang + --rollout-num-gpus-per-engine 2 + --sglang-mem-fraction-static 0.7 + + --rollout-max-response-len 300 + + --rollout-temperature 1 + --rollout-skip-special-tokens + + # Orchestration + --actor-num-nodes 1 + --actor-num-gpus-per-node 8 + --colocate + + # Eval + --eval-prompt-data haiku {{data_path}}/haiku/test.parquet + --eval-interval 20 + --n-samples-per-eval-prompt 8 + --eval-max-response-len 300 + --eval-top-p 1 + """, + ) diff --git a/haiku/eval/README.md b/haiku/eval/README.md new file mode 100644 index 0000000..62cf7ee --- /dev/null +++ b/haiku/eval/README.md @@ -0,0 +1,9 @@ +# Host and Evaluate Haiku Models + +To host the demo playground that evaluates various finetuned haiku models, run the following commands. + +``` +modal deploy eval.serve_haiku_model + +modal deploy eval.haiku_app +``` \ No newline at end of file diff --git a/haiku/eval/__init__.py b/haiku/eval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/haiku/eval/haiku_app.py b/haiku/eval/haiku_app.py new file mode 100644 index 0000000..20ac8c3 --- /dev/null +++ b/haiku/eval/haiku_app.py @@ -0,0 +1,111 @@ +"""FastAPI backend for the Haiku Playground, deployed on Modal. + +modal deploy eval.haiku_app +""" + +from pathlib import Path + +import modal + +app = modal.App("haiku-playground") + +image = ( + modal.Image.debian_slim(python_version="3.12") + .pip_install("fastapi[standard]", "httpx", "nltk") + .run_commands( + "python -c \"import nltk; nltk.download('cmudict')\"" + ) + .add_local_dir("eval", "/root/eval") + .add_local_dir("llm_judges", "/root/llm_judges") + .add_local_file("config.py", "/root/config.py") +) + + +@app.function( + image=image, +) +@modal.asgi_app() +def serve_playground(): + from contextlib import asynccontextmanager + + import httpx + import nltk + from fastapi import FastAPI + from fastapi.responses import FileResponse + from pydantic import BaseModel + + from eval.shared import ( + MODAL_VOCABS, + MODEL_CHECKPOINTS, + build_system_prompt, + get_model_endpoint, + query_model, + ) + from llm_judges.nlp import ( + count_syllables_for_word, + score_haiku_structure, + segment_haiku_lines, + ) + + class GenerateRequest(BaseModel): + prompt: str + model_key: str = "base-model" + include_vocab: bool = True + + @asynccontextmanager + async def lifespan(app: FastAPI): + app.state.cmudict = nltk.corpus.cmudict.dict() + app.state.http_client = httpx.AsyncClient() + yield + await app.state.http_client.aclose() + + fastapi_app = FastAPI(title="Haiku Playground", lifespan=lifespan) + + @fastapi_app.post("/api/generate") + async def generate(request: GenerateRequest): + import re + + client = fastapi_app.state.http_client + cmudict = fastapi_app.state.cmudict + + endpoint = get_model_endpoint(request.model_key) + system_prompt = build_system_prompt(include_vocab=request.include_vocab) + + haiku = await query_model( + client, + endpoint, + request.prompt, + model_name=request.model_key, + system_prompt=system_prompt, + ) + + structure_score = score_haiku_structure(haiku, cmudict) + + lines = segment_haiku_lines(haiku) + syllable_counts = [] + for line in lines: + words = re.findall(r"[a-zA-Z]+", line) + count = sum(count_syllables_for_word(w, cmudict) for w in words) + syllable_counts.append(count) + + return { + "haiku": haiku, + "structure_score": structure_score, + "syllable_counts": syllable_counts, + "passed": structure_score == 1, + } + + @fastapi_app.get("/api/models") + async def get_models(): + return MODEL_CHECKPOINTS + + @fastapi_app.get("/api/vocabs") + async def get_vocabs(): + return MODAL_VOCABS + + @fastapi_app.get("/") + async def index(): + html_path = Path("/root/eval/haiku_playground.html") + return FileResponse(html_path, media_type="text/html") + + return fastapi_app diff --git a/haiku/eval/haiku_playground.html b/haiku/eval/haiku_playground.html new file mode 100644 index 0000000..a9bee26 --- /dev/null +++ b/haiku/eval/haiku_playground.html @@ -0,0 +1,441 @@ + + + + + + Haiku Playground + + + +
+
+

haiku playground

+

compare haiku generation across training checkpoints

+
+ +
+
+ + +
+
+
+ + +
+
+ bonus words: try to get the model to use: loading... +
+
+ +
+
+ + + + \ No newline at end of file diff --git a/haiku/eval/run_eval.py b/haiku/eval/run_eval.py new file mode 100644 index 0000000..91111e8 --- /dev/null +++ b/haiku/eval/run_eval.py @@ -0,0 +1,114 @@ +"""Haiku eval script — queries a served model checkpoint and scores haiku structure.""" + +import argparse +import asyncio +import json +from dataclasses import asdict, dataclass +from datetime import datetime, timezone + +import httpx +import modal +import nltk + +from eval.shared import ( + DEFAULT_CONCURRENCY, + EVAL_QUESTIONS, + MODELS, + get_model_endpoint, + query_model, +) +from llm_judges.nlp import score_haiku_structure + +EVALS_PATH = "/opt/evals" + + +@dataclass(frozen=True) +class EvalResult: + question: str + response: str + passed: bool + + +async def eval_problem( + client: httpx.AsyncClient, + semaphore: asyncio.Semaphore, + question: str, + endpoint: str, + model_key: str, + cmudict: dict, +) -> EvalResult: + response = await query_model( + client, endpoint, question, model_name=model_key, semaphore=semaphore + ) + structure_score = score_haiku_structure(response, cmudict) + print(f"Structure score: {structure_score}") + + print("=" * 70) + print(f"Question: {question}") + print(f"Response: {response}") + print("=" * 70) + + passed = structure_score >= 0.75 + print(f"Passed: {passed}") + + return EvalResult(question=question, response=response, passed=passed) + + +async def run_eval( + model_key: str = "base-model", + file_path: str | None = None, +): + if file_path is None: + file_path = f"{EVALS_PATH}/{model_key}_eval.jsonl" + + endpoint = get_model_endpoint(model_key) + cmudict = nltk.corpus.cmudict.dict() + + print(f"Model: {model_key}") + print(f"Endpoint: {endpoint}") + print(f"Loaded {len(EVAL_QUESTIONS)} questions") + print(f"Running with concurrency={DEFAULT_CONCURRENCY}\n") + print("=" * 70) + + semaphore = asyncio.Semaphore(DEFAULT_CONCURRENCY) + + async with httpx.AsyncClient() as client: + tasks = [ + eval_problem(client, semaphore, question, endpoint, model_key, cmudict) + for question in EVAL_QUESTIONS + ] + results = await asyncio.gather(*tasks) + + with open(file_path, "w") as f: + for result in results: + f.write(json.dumps(asdict(result)) + "\n") + + success_rate = sum(result.passed for result in results) / len(results) + print(f"Success rate: {success_rate}") + + # Save results to a Modal Dict + eval_dict = modal.Dict.from_name("haiku-eval-results", create_if_missing=True) + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + key = f"{model_key}/{timestamp}" + eval_dict[key] = { + "model_key": model_key, + "timestamp": timestamp, + "success_rate": success_rate, + "results": [asdict(r) for r in results], + } + print(f"Saved eval results to Modal Dict 'haiku-eval-results' with key '{key}'") + + return results + + +if __name__ == "__main__": + model_choices = list(MODELS.keys()) + parser = argparse.ArgumentParser(description="Run haiku structure eval against a served checkpoint.") + parser.add_argument( + "--model", + default="base-model", + choices=model_choices, + help=f"Model to evaluate (choices: {', '.join(model_choices)})", + ) + args = parser.parse_args() + asyncio.run(run_eval(model_key=args.model)) diff --git a/haiku/eval/serve_haiku_model.py b/haiku/eval/serve_haiku_model.py new file mode 100644 index 0000000..78218e3 --- /dev/null +++ b/haiku/eval/serve_haiku_model.py @@ -0,0 +1,215 @@ +"""Serve SLIME-trained Haiku models with vLLM. + +modal deploy eval.serve_haiku_model +""" + +from pathlib import Path +import modal +import modal.experimental + +from eval.shared import MODEL_CONFIG, ModelConfig, _to_class_name + +APP_NAME = "serve-haiku-model" + +app = modal.App(APP_NAME) + +MODELS_PATH: Path = Path("/models") + +HF_DIR = "hf" + +N_GPU = 1 +MINUTES = 60 +VLLM_PORT = 8000 + + +# Same volume used in training +checkpoints_volume: modal.Volume = modal.Volume.from_name("grpo-slime-haiku-checkpoints") +hf_cache_vol = modal.Volume.from_name("huggingface-cache") +vllm_cache_vol = modal.Volume.from_name("vllm-cache") + + +vllm_image = ( + modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12") + .entrypoint([]) + .uv_pip_install( + "vllm==0.11.2", + "huggingface-hub==0.36.0", + "flashinfer-python==0.5.2", + ) + .env({"HF_XET_HIGH_PERFORMANCE": "1"}) +) + +slime_image = ( + modal.Image.from_registry("slimerl/slime:nightly-dev-20260126a") + .run_commands( + "uv pip install --system git+https://github.com/huggingface/transformers.git@eebf856", # 4.54.1 + "uv pip install --system aiohttp", # For LLM judge reward model + """sed -i 's/AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)/AutoImageProcessor.register(config, slow_image_processor_class=image_processor, exist_ok=True)/g' /sgl-workspace/sglang/python/sglang/srt/configs/utils.py""", + # Fix rope_theta access for transformers 5.x (moved to rope_parameters dict) + r"""sed -i 's/hf_config\.rope_theta/hf_config.rope_parameters["rope_theta"]/g' /usr/local/lib/python3.12/dist-packages/megatron/bridge/models/glm/glm45_bridge.py""", + r"""sed -i 's/hf_config\.rope_theta/hf_config.rope_parameters["rope_theta"]/g' /usr/local/lib/python3.12/dist-packages/megatron/bridge/models/qwen/qwen3_bridge.py""", + ) + .add_local_dir("tools", remote_path="/root/tools", copy=True) + .entrypoint([]) +) + +def get_hf_model_path(config: ModelConfig) -> str: + return f"{MODELS_PATH / config.model_path / HF_DIR}" + +def get_megatron_checkpoint_path(config: ModelConfig) -> str: + return f"{MODELS_PATH / config.model_path / config.iters_dir}" + +@app.function( + image=slime_image, + timeout=24 * 60 * 60, + secrets=[ + modal.Secret.from_name("huggingface-secret"), + ], + volumes={MODELS_PATH.as_posix(): checkpoints_volume}, +) +async def convert_checkpoint( + model_path: str, + iter_dir: str, + origin_hf_dir: str +): + """Convert Megatron checkpoint to HuggingFace format.""" + from huggingface_hub import snapshot_download + import subprocess + + await checkpoints_volume.reload.aio() + + local_hf_dir = MODELS_PATH / origin_hf_dir + + if not local_hf_dir.exists(): + snapshot_download(repo_id=f"Qwen/{origin_hf_dir}", local_dir=local_hf_dir) + else: + print(f"Model {origin_hf_dir} already downloaded.") + + megatron_checkpoint_path = MODELS_PATH / model_path / iter_dir + output_hf_path = MODELS_PATH / model_path / HF_DIR + + subprocess.run(f"PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py --input-dir {megatron_checkpoint_path} --output-dir {output_hf_path} --origin-hf-dir {local_hf_dir}", shell=True, check=True) + + +CLS_KWARGS = dict( + image=vllm_image, + gpu=f"A10G:{N_GPU}", + scaledown_window=15 * MINUTES, + startup_timeout=10 * MINUTES, + volumes={ + "/root/.cache/huggingface": hf_cache_vol, + "/root/.cache/vllm": vllm_cache_vol, + MODELS_PATH.as_posix(): checkpoints_volume, + }, + secrets=[modal.Secret.from_name("huggingface-secret")], + experimental_options={"flash": "us-east"}, + region="us-east", + min_containers=1, +) + + +class _VLLMServerBase: + """Base class with all vLLM serving logic. Not registered with Modal directly.""" + + # Subclasses set this as a class variable + MODEL_KEY: str + + @modal.enter() + def setup(self): + import subprocess + + if self.MODEL_KEY not in MODEL_CONFIG: + raise ValueError(f"Invalid model name: {self.MODEL_KEY}") + + model_config = MODEL_CONFIG[self.MODEL_KEY] + + if model_config.is_base_model: + model_dir = self._setup_base_model(model_config) + else: + model_dir = self._setup_finetuned_model(model_config) + + cmd = [ + "vllm", + "serve", + str(model_dir), + "--served-model-name", + model_config.model_name, + "--host", + "0.0.0.0", + "--port", + str(VLLM_PORT), + "--enforce-eager", + "--tensor-parallel-size", + str(N_GPU), + "--reasoning-parser", + "qwen3", + ] + + print(" ".join(cmd)) + self._vllm_process = subprocess.Popen(" ".join(cmd), shell=True) + + self._wait_for_port(VLLM_PORT, timeout=600) + print(f"vLLM ready on port {VLLM_PORT}") + + self.flash_manager = modal.experimental.flash_forward(VLLM_PORT) + print(f"Flash endpoint ready on port {VLLM_PORT}") + + def _setup_base_model(self, model_config: ModelConfig) -> Path: + from huggingface_hub import snapshot_download + + local_hf_dir = MODELS_PATH / model_config.model_path + + if not local_hf_dir.exists(): + snapshot_download(repo_id=f"Qwen/{model_config.model_path}", local_dir=local_hf_dir) + else: + print(f"Model {model_config.model_path} already downloaded.") + + return local_hf_dir + + def _setup_finetuned_model(self, model_config: ModelConfig) -> Path: + hf_path = MODELS_PATH / model_config.model_path / HF_DIR + + if not hf_path.joinpath("config.json").exists(): + print(f"Converting checkpoint {model_config.model_path} to HuggingFace format...") + convert_checkpoint.remote( + model_path=model_config.model_path, + iter_dir=model_config.iters_dir, + origin_hf_dir=model_config.base_model_name, + ) + checkpoints_volume.reload() + print(f"Checkpoint {model_config.model_path}/{model_config.iters_dir} converted to HuggingFace format.") + + return hf_path + + def _wait_for_port(self, port: int, timeout: int = 30): + import socket + import time + + for _ in range(timeout): + try: + socket.create_connection(("localhost", port), timeout=1).close() + return + except OSError: + time.sleep(1) + raise RuntimeError(f"Server failed to start on port {port}") + + @modal.method() + def keepalive(self): + pass + + @modal.exit() + def cleanup(self): + if hasattr(self, "flash_manager"): + self.flash_manager.stop() + self.flash_manager.close() + if hasattr(self, "_vllm_process"): + self._vllm_process.terminate() + self._vllm_process.wait(timeout=10) + + +for _model_key in MODEL_CONFIG: + _cls_name = _to_class_name(_model_key) + _cls = type(_cls_name, (_VLLMServerBase,), {"MODEL_KEY": _model_key}) + _cls = modal.concurrent(target_inputs=4)(_cls) + _cls = app.cls(**CLS_KWARGS)(_cls) + globals()[_cls_name] = _cls diff --git a/haiku/eval/shared.py b/haiku/eval/shared.py new file mode 100644 index 0000000..1496b6b --- /dev/null +++ b/haiku/eval/shared.py @@ -0,0 +1,214 @@ +"""Shared constants, endpoint helpers, and query function for haiku eval and playground.""" + +import asyncio +from dataclasses import dataclass + +import httpx + +# --------------------------------------------------------------------------- +# Flash URL generation +# --------------------------------------------------------------------------- + +APP_NAME = "serve-haiku-model" +ENVIRONMENT = "joy-dev" +FLASH_REGION = "us-east" +WORKSPACE = "modal-labs" + + +def _to_class_name(model_key: str) -> str: + """Convert model key like '235b-judge-cl' to '235bJudgeClServer'.""" + return "".join(part.capitalize() for part in model_key.split("-")) + "Server" + + +def get_flash_url(model_key: str) -> str: + """Get the Flash endpoint base URL for a given model key.""" + cls_name = _to_class_name(model_key) + return f"https://{WORKSPACE}-{ENVIRONMENT}--{APP_NAME}-{cls_name.lower()}.{FLASH_REGION}.modal.direct" + + +def get_model_endpoint(model_key: str) -> str: + """Return the full chat-completions URL for a given model key.""" + return get_flash_url(model_key) + "/v1/chat/completions" + + +# --------------------------------------------------------------------------- +# Model registry (single source of truth for both serving and eval/playground) +# --------------------------------------------------------------------------- + + +@dataclass +class ModelConfig: + model_path: str + iters_dir: str + model_name: str + model_description: str + label: str + base_model_name: str = "Qwen3-4B" + is_base_model: bool = False + + @property + def badge(self) -> str: + return "base" if self.is_base_model else "trained" + + @property + def flash_url(self) -> str: + return get_flash_url(self.model_name) + + +MODEL_CONFIG = { + "base-model": ModelConfig( + model_path="Qwen3-4B", + iters_dir="", + model_name="base-model", + model_description="Qwen3-4B Base Model", + label="Base Model", + is_base_model=True, + ), + "no-llm-model": ModelConfig( + model_path="2-23-no-llm-Qwen3-4B-20260224-032404", + iters_dir="iter_0000049", + model_name="no-llm-model", + model_description="Qwen3-4B Finetuned with No LLM", + label="No LLM Judge", + ), + "30b-judge": ModelConfig( + model_path="2_24-30b-Qwen3-4B-20260224-184838", + iters_dir="iter_0000049", + model_name="30b-judge-cl", + model_description="Qwen3-4B Finetuned with 30B Judge using Curriculumn Learning", + label="30B Judge (CL)", + ), + "30b-judge-cl": ModelConfig( + model_path="2_24-30b-leveled-Qwen3-4B-20260224-180902", + iters_dir="iter_0000049", + model_name="30b-judge-cl", + model_description="Qwen3-4B Finetuned with 30B Judge using Curriculumn Learning", + label="30B Judge (CL)", + ), + "235b-judge": ModelConfig( + model_path="2_24-235b-Qwen3-4B-20260224-174605", + iters_dir="iter_0000049", + model_name="235b-judge", + model_description="Qwen3-4B Finetuned with 235B Judge", + label="235B Judge", + ), + "235b-judge-cl": ModelConfig( + model_path="2_23-235b-leveled-Qwen3-4B-20260224-172832", + iters_dir="iter_0000049", + model_name="235b-judge-cl", + model_description="Qwen3-4B Finetuned with 235B Judge using Curriculumn Learning", + label="235B Judge (CL)", + ), +} + +# Derived views for different consumers +MODELS = { + key: {"label": c.label, "badge": c.badge} + for key, c in MODEL_CONFIG.items() +} + +# For the UI — list of models with their metadata +MODEL_CHECKPOINTS = [ + {"name": config.model_name, "label": config.label, "badge": config.badge} + for config in MODEL_CONFIG.values() +] + +# Maps model_key -> flash endpoint URL +MODEL_URLS: dict[str, str] = {key: config.flash_url for key, config in MODEL_CONFIG.items()} + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +MODAL_VOCABS = [ + "modal", + "volume", + "function", + "sandbox", + "flash", + "inference", + "train", +] + +DEFAULT_CONCURRENCY = 50 + +EVAL_QUESTIONS = [ + "Write me a haiku about cat.", + "Write me a haiku about dog.", + "Write me a haiku about bird.", + "Write me a haiku about fish.", + "Write me a haiku about horse.", + "Write me a haiku about rabbit.", + "Write me a haiku about snake.", + "Write me a haiku about tiger.", + "Write me a haiku about lion.", + "Write me a haiku about Jason Mancuso.", + "Write me a haiku about Joy Liu.", + "Write me a haiku about Modal Labs.", +] + +QUICK_PROMPTS = [ + {"emoji": "🐱", "label": "cat", "prompt": "Write me a haiku about cat."}, + {"emoji": "🌊", "label": "ocean", "prompt": "Write me a haiku about the ocean."}, + {"emoji": "🌸", "label": "cherry blossoms", "prompt": "Write me a haiku about cherry blossoms."}, + {"emoji": "💻", "label": "coding", "prompt": "Write me a haiku about coding."}, + {"emoji": "☁️", "label": "Modal", "prompt": "Write me a haiku about Modal."}, + {"emoji": "⚡", "label": "serverless", "prompt": "Write me a haiku about serverless."}, +] + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def build_system_prompt(include_vocab: bool = True) -> str: + """Build the haiku poet system prompt, optionally including Modal vocabulary.""" + base = "You are a haiku poet. You will be given a prompt and you will need to write a haiku about the prompt." + if include_vocab: + vocab_str = ", ".join(MODAL_VOCABS) + return f"{base} Try to incorporate these words into the haiku if possible: {vocab_str}" + return base + + +async def query_model( + client: httpx.AsyncClient, + endpoint: str, + prompt: str, + *, + model_name: str = "base-model", + system_prompt: str | None = None, + semaphore: asyncio.Semaphore | None = None, +) -> str: + """Send a chat-completion request and return the assistant's reply text. + + Args: + model_name: The vLLM --served-model-name (matches the model key in MODEL_CONFIG). + """ + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + payload = { + "model": model_name, + "messages": messages, + "temperature": 0.0, + "top_p": 1.0, + "stream": False, + "chat_template_kwargs": {"enable_thinking": False}, + } + + async def _do_request() -> str: + try: + response = await client.post(endpoint, json=payload, timeout=60.0) + response.raise_for_status() + data = response.json() + return data["choices"][0]["message"]["content"].strip() + except Exception as e: + return f"ERROR: {e}" + + if semaphore is not None: + async with semaphore: + return await _do_request() + return await _do_request() diff --git a/haiku/llm_judges/__init__.py b/haiku/llm_judges/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/haiku/llm_judges/base.py b/haiku/llm_judges/base.py new file mode 100644 index 0000000..162ee7b --- /dev/null +++ b/haiku/llm_judges/base.py @@ -0,0 +1,155 @@ +""" +Haiku LLM judge — scores structure (syllable counting) and style (LLM evaluation). + +Structure scoring uses CMUdict for syllable counting. +Style scoring uses a local vLLM instance to evaluate relevance, poetic quality, etc. +""" + +import re + +import aiohttp + +from llm_judges.deploy import VLLM_PORT +from llm_judges.nlp import score_haiku_structure + + +MODAL_VOCABS = [ + "modal", + "volume", + "function", + "sandbox", + "flash", + "inference", + "train", +] + + +def _build_judge_prompt(prompt: str, response: str, label: str = "") -> tuple[str, int]: + """Build the LLM judge prompt. Returns (prompt_text, max_score).""" + modal_vocab_str = ", ".join(MODAL_VOCABS) + + max_score = 15 # relevance(5) + poetic(5) + modal vocab(5) + + text = f"""You are evaluating a haiku poem. + + Score the response based on the following criteria: + + Relevance (5 points total) + - 5 points: if the central theme and punchline of the haiku is "{prompt}" + - 3 points: if the response directly discusses "{prompt}" but it is not the central theme + - 2 points: if the response is relevant to the topic "{prompt}" but very plain + - 0 points: if the response is not relevant to the topic "{prompt}" + + Poetic quality (5 points total) + - 5 points: if the response makes sense, can be considered a poetic haiku, with a clear theme and punchline + - 3 point: if the response makes sense, but is not very poetic + - 1 point: if the response doesn't make sense + - 0 points: if the response is not poetic and incoherent +""" + + if label: + max_score = 20 + text += f""" + Better than the existing poem (5 points total): + Given the existing poem, score the response by comparing its quality to the existing poem: + {label} + - 5 points: if the response is better than the poem "{label}". + - 3 points: if the response is equal in quality to the poem "{label}". + - 0 points: if the response is worse than the poem "{label}". +""" + + prereq_score = max_score - 5 + text += f""" + Uses Modal vocabulary (5 points total): (modal vocab: {modal_vocab_str}) + - 5 points: if the response uses the above words in a way that is coherent and relevant to the topic "{prompt}" + - 3 points: if the response uses the above words in a way that is not relevant to the topic "{prompt}" + - 0 points: if the response does not use the above words + DO NOT GIVE ANY POINTS TO USE MODAL VOCABULARY IF THE POEM ITSELF DOES NOT ALREADY ACHIEVE A SCORE OF {prereq_score} OR HIGHER + + Add up the scores from the above criteria to get the total score. + + -- + **Topic:** {prompt} + + **Response to evaluate:** + {response} + --- + + Output ONLY a single number (0-{max_score}), nothing else.""" + + return text, max_score + + +class HaikuJudge: + """Scores haikus on structure (syllable counting) and style (LLM evaluation). + + Args: + gate_style_on_structure: If True, only evaluate style when structure + score is perfect (1.0). If False, always evaluate style. + """ + + def __init__(self, gate_style_on_structure: bool = True): + self.gate_style_on_structure = gate_style_on_structure + + async def score_style( + self, + model_name: str, + session: aiohttp.ClientSession, + prompt: str, + response: str, + label: str = "", + vllm_base_url: str = f"http://localhost:{VLLM_PORT}", + ) -> float: + """Score haiku style via LLM judge, normalized to [0, 1].""" + judge_prompt, max_score = _build_judge_prompt(prompt, response, label) + + try: + async with session.post( + f"{vllm_base_url}/v1/chat/completions", + headers={"content-type": "application/json"}, + json={ + "model": model_name, + "messages": [{"role": "user", "content": judge_prompt}], + "max_tokens": 100, + }, + ) as resp: + if resp.status != 200: + error_text = await resp.text() + print(f"vLLM error: {resp.status} - {error_text}") + return 0 + + data = await resp.json() + score_text = data["choices"][0]["message"]["content"].strip() + print(f"Scored {response} with score {score_text}") + + match = re.search(r"(\d+(?:\.\d+)?)", score_text) + if match: + score = float(match.group(1)) + return min(max(score, 0), max_score) / max_score + return 0 + except Exception as e: + print(f"Error scoring response: {e}") + return 0 + + async def score_single( + self, + model_name: str, + session: aiohttp.ClientSession, + prompt: str, + response: str, + cmudict: dict, + label: str = "", + ) -> float: + """Score a single haiku. Returns a score in [0, 2].""" + structure_score = score_haiku_structure(response, cmudict) + + style_score = 0.0 + if not self.gate_style_on_structure or structure_score >= 1.0: + style_score = await self.score_style( + model_name, session, prompt, response, label + ) + style_score = max(style_score, 0.0) + + total = structure_score + style_score + print(f"[HaikuJudge] structure={structure_score}, style={style_score}, gated={self.gate_style_on_structure}") + return total diff --git a/haiku/llm_judges/deploy.py b/haiku/llm_judges/deploy.py new file mode 100644 index 0000000..46c750a --- /dev/null +++ b/haiku/llm_judges/deploy.py @@ -0,0 +1,221 @@ +""" +LLM-as-a-Judge Reward Model for SLIME GRPO training. + +Uses CMUdict for syllable counting: https://github.com/cmusphinx/cmudict +Recommended by various packages such as `syllables` and `nltk`. +""" + +import asyncio +import threading + +from config import ACTIVE_JUDGE_MODEL_SIZE, ACTIVE_JUDGE_TYPE, JudgeModelSize, JudgeType +import modal +import modal.experimental + + +# ============================================================================= + + +# ============================================================================= +# Modal App Setup +# ============================================================================= + +app = modal.App(f"llm-judge-{ACTIVE_JUDGE_MODEL_SIZE.shorthand}-{ACTIVE_JUDGE_TYPE.value}") + +FLASH_PORT = 8000 +VLLM_PORT = 8001 + +MODEL = ACTIVE_JUDGE_MODEL_SIZE.value +MODEL_NAME = ACTIVE_JUDGE_MODEL_SIZE.model_name +N_GPU = 1 if ACTIVE_JUDGE_MODEL_SIZE == JudgeModelSize.QWEN3_30B else 4 +MINUTES = 60 + +checkpoint_volume = modal.Volume.from_name("unsloth-checkpoints") +hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) +vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) + + +def _make_judge(judge_type: JudgeType): + from llm_judges.base import HaikuJudge + + return HaikuJudge(gate_style_on_structure=(judge_type == JudgeType.STRICT_LEVELED)) + +# ============================================================================= + +image = ( + modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12") + .entrypoint([]) + .uv_pip_install( + "vllm==0.11.2", + "huggingface-hub==0.36.0", + "flashinfer-python==0.5.2", + "aiohttp>=3.9.0", + "pydantic>=2.0.0", + "fastapi[standard]>=0.115.0", + "uvicorn>=0.30.0", + "nltk>=3.8.0", + ) + .env({"HF_XET_HIGH_PERFORMANCE": "1"}) + .run_commands( + "python -c \"import nltk; nltk.download('cmudict'); nltk.download('punkt_tab')\"" + ) + .add_local_dir("llm_judges", "/root/llm_judges") +) + + +# ============================================================================= +# Modal Flash Endpoint +# ============================================================================= + + +def create_fastapi_app(judge_type: JudgeType): + from fastapi import FastAPI + from pydantic import BaseModel + import nltk + + fastapi_app = FastAPI(title="LLM Judge Reward Model", docs_url="/docs") + cmudict = nltk.corpus.cmudict.dict() + judge = _make_judge(judge_type) + + class ScoreRequest(BaseModel): + prompt: str + response: str + label: str = "" + + @fastapi_app.post("/score") + async def score(request: ScoreRequest) -> float: + max_retries = 5 + last_error = None + + for attempt in range(max_retries): + try: + return await _do_scoring(request) + + except Exception as e: + last_error = e + if attempt < max_retries - 1: + wait_time = 2**attempt + print( + f"Scoring failed (attempt {attempt + 1}): {e}, retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) + else: + raise last_error + + async def _do_scoring(request: ScoreRequest) -> float: + import aiohttp + + prompt = request.prompt + response_text = request.response + + if prompt is None or response_text is None: + return None + + async with aiohttp.ClientSession() as session: + result = await judge.score_single( + MODEL_NAME, session, prompt, response_text, cmudict, label=request.label + ) + + return float(result) + + @fastapi_app.get("/health") + def health(): + return {"status": "ok", "model": MODEL_NAME, "judge": judge_type.value} + + return fastapi_app + + +@app.cls( + image=image, + gpu=f"H100:{N_GPU}", + min_containers=3, + scaledown_window=15 * MINUTES, + startup_timeout=15 * MINUTES, + volumes={ + "/root/.cache/huggingface": hf_cache_vol, + "/root/.cache/vllm": vllm_cache_vol, + "/checkpoints": checkpoint_volume, + }, + secrets=[modal.Secret.from_name("huggingface-secret")], + experimental_options={"flash": "us-east"}, + region="us-east", +) +@modal.concurrent( # how many requests can one replica handle? tune carefully! + target_inputs=8 +) +class LLMJudge: + """Modal Flash endpoint combining vLLM + scoring logic in one container.""" + + @modal.enter() + def setup(self): + import subprocess + + import uvicorn + + # Start vLLM on VLLM_PORT (internal) + cmd = [ + "vllm", + "serve", + "--uvicorn-log-level=info", + MODEL, + "--served-model-name", + MODEL_NAME, + "--port", + str(VLLM_PORT), + "--enforce-eager", + "--tensor-parallel-size", + str(N_GPU), + "--max-model-len", + "8192", + ] + print(" ".join(cmd)) + self._vllm_process = subprocess.Popen(" ".join(cmd), shell=True) + + # Wait for vLLM to be ready + self._wait_for_port(VLLM_PORT, timeout=600) + print(f"vLLM ready on port {VLLM_PORT}") + + # Start FastAPI scoring endpoint on FLASH_PORT (exposed) + self._fastapi_app = create_fastapi_app(ACTIVE_JUDGE_TYPE) + config = uvicorn.Config( + self._fastapi_app, + host="0.0.0.0", + port=FLASH_PORT, + log_level="info", + ) + self._server = uvicorn.Server(config) + self._thread = threading.Thread(target=self._server.run, daemon=True) + self._thread.start() + + self._wait_for_port(FLASH_PORT, timeout=30) + self.flash_manager = modal.experimental.flash_forward(FLASH_PORT) + print(f"Flash endpoint ready on port {FLASH_PORT} (judge={ACTIVE_JUDGE_TYPE.value})") + + def _wait_for_port(self, port: int, timeout: int = 30): + import socket + import time + + for _ in range(timeout): + try: + socket.create_connection(("localhost", port), timeout=1).close() + return + except OSError: + time.sleep(1) + raise RuntimeError(f"Server failed to start on port {port}") + + @modal.method() + def keepalive(self): + pass + + @modal.exit() + def cleanup(self): + if hasattr(self, "flash_manager"): + self.flash_manager.stop() + self.flash_manager.close() + if hasattr(self, "_server"): + self._server.should_exit = True + if hasattr(self, "_thread"): + self._thread.join(timeout=5) + if hasattr(self, "_vllm_process"): + self._vllm_process.terminate() + self._vllm_process.wait(timeout=10) diff --git a/haiku/llm_judges/nlp.py b/haiku/llm_judges/nlp.py new file mode 100644 index 0000000..53d7e78 --- /dev/null +++ b/haiku/llm_judges/nlp.py @@ -0,0 +1,94 @@ +import re + + +_cmudict = None + + +def _get_cmudict() -> dict: + import nltk + from nltk.corpus import cmudict as nltk_cmudict + + global _cmudict + if _cmudict is None: + nltk.download("cmudict", quiet=True) + _cmudict = dict(nltk_cmudict.dict()) + return _cmudict + + +def lookup_word(word_s, cmudict: dict): + return cmudict.get(word_s, None) + + +def is_acronym(word: str) -> bool: + return word.isupper() and 2 <= len(word) <= 6 and word.isalpha() + + +def count_syllables_for_word(word, cmudict): + original_word = word + word = word.lower().strip() + + phones = lookup_word(word, cmudict) + if phones: + return len([p for p in phones[0] if p[-1].isdigit()]) + + if is_acronym(original_word): + total = 0 + for c in original_word.lower(): + if c == "w": + total += 3 # "dub-ul-you" + else: + total += 1 + return total + + count = len(re.findall(r"[aeiouy]+", word)) + if word.endswith("e") and count > 1: + count -= 1 + return max(count, 1) + + +def diff_syllables_count(text: str, target_syllables: int, cmudict: dict) -> int: + words = re.findall(r"[a-zA-Z]+", text) + total_syllables = sum(count_syllables_for_word(w, cmudict) for w in words) + return abs(total_syllables - target_syllables) + + +def segment_haiku_lines(response: str) -> list[str]: + if "/" in response: + lines = [line.strip() for line in response.split("/")] + elif ". " in response: + lines = [line.strip() for line in response.split(". ")] + else: + lines = [line.strip() for line in response.split("\n")] + return [line for line in lines if line] + + +def score_syllable_line(diff: int, allow_off_by_one: bool) -> float: + if diff == 0: + return 1.0 + if diff == 1: + return 1.0 if allow_off_by_one else 0.5 + return 0.0 + + +def score_haiku_structure(response: str, cmudict: dict, allow_off_by_one: bool = False) -> float: + """Score haiku structure (0-1): 1/4 for 3 lines + up to 1/4 per line for syllables.""" + lines = segment_haiku_lines(response) + score = 0.0 + fractional_multiplier = 0.25 + + if len(lines) == 3: + score += fractional_multiplier + + targets = [5, 7, 5] + for i, target in enumerate(targets): + if i < len(lines): + diff = diff_syllables_count(lines[i], target, cmudict) + score += score_syllable_line(diff, allow_off_by_one) * fractional_multiplier + + return score + + +async def haiku_rm(args, sample, **kwargs) -> float: + cmudict = _get_cmudict() + allow_off_by_one = getattr(args, "haiku_allow_off_by_one", False) + return score_haiku_structure(sample.response, cmudict, allow_off_by_one=allow_off_by_one) diff --git a/slime/modal_train.py b/haiku/modal_train.py similarity index 53% rename from slime/modal_train.py rename to haiku/modal_train.py index 4ecc2d3..7ac31dc 100644 --- a/slime/modal_train.py +++ b/haiku/modal_train.py @@ -1,37 +1,20 @@ """ -Unified SLIME GRPO training script for Modal. +SLIME GRPO Haiku training script for Modal. Usage: - # Sync training with Qwen 0.5B (multi-node) - modal run modal_train.py::train_multi_node --config qwen-0.5b-sync + # deploy + modal deploy llm_judges.deploy - # Async training with Qwen 4B (multi-node) - modal run modal_train.py::train_multi_node --config qwen-4b-async - - # Single node training - modal run modal_train.py::train_single_node --config qwen-0.5b-sync - - # Single node training with LoRA (using local slime repo) - USE_LOCAL_SLIME=/path/to/slime modal run modal_train.py::train_single_node --config qwen-4b-lora - - # Download model - modal run modal_train.py::download_model --config qwen-4b-sync - - # Prepare dataset + # Train model + modal run modal_train.py::download_model modal run modal_train.py::prepare_dataset + modal run modal_train.py::train_single_node --run-name my-experiment - # List available configs - modal run modal_train.py::list_available_configs + # With local slime repo for development: + USE_LOCAL_SLIME=/path/to/slime modal run modal_train.py::train_single_node Environment variables: USE_LOCAL_SLIME=/path Path to local slime repo for development - SLIME_APP_NAME=... Override Modal app name - -Available configs (main): - - qwen-4b, glm-4-7, glm-4-7-flash - -Available configs (test-configs): - - qwen-4b-lora (LoRA training test config) """ import os @@ -40,10 +23,10 @@ from typing import Optional import time +from llm_judges.base import MODAL_VOCABS import modal -import modal.experimental -from configs.base import RLConfig +from config import RLConfig, get_config, ACTIVE_JUDGE_TYPE, ACTIVE_JUDGE_MODEL_SIZE # ============================================================================= @@ -58,14 +41,17 @@ modal.Image.from_registry("slimerl/slime:nightly-dev-20260126a") .run_commands( "uv pip install --system git+https://github.com/huggingface/transformers.git@eebf856", # 4.54.1 + "uv pip install --system aiohttp", # For LLM judge reward model """sed -i 's/AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)/AutoImageProcessor.register(config, slow_image_processor_class=image_processor, exist_ok=True)/g' /sgl-workspace/sglang/python/sglang/srt/configs/utils.py""", # Fix rope_theta access for transformers 5.x (moved to rope_parameters dict) r"""sed -i 's/hf_config\.rope_theta/hf_config.rope_parameters["rope_theta"]/g' /usr/local/lib/python3.12/dist-packages/megatron/bridge/models/glm/glm45_bridge.py""", r"""sed -i 's/hf_config\.rope_theta/hf_config.rope_parameters["rope_theta"]/g' /usr/local/lib/python3.12/dist-packages/megatron/bridge/models/qwen/qwen3_bridge.py""", ) .entrypoint([]) - .add_local_python_source("configs", copy=True) - # .add_local_dir("test-configs", remote_path="/root/test-configs", copy=True) + .add_local_python_source("config", copy=True) + .add_local_dir("tools", remote_path="/root/tools", copy=True) + .add_local_dir("llm_judges", remote_path="/root/llm_judges", copy=True) + .pip_install("nltk>=3.8.0") ) # Overlay local slime code for development @@ -73,12 +59,7 @@ SLIME_DEV_PATH = "/opt/slime-dev" if LOCAL_SLIME_PATH: # Copy the entire slime repo (has pyproject.toml) and install it - image = image.add_local_dir( - LOCAL_SLIME_PATH, - remote_path=SLIME_DEV_PATH, - copy=True, - ignore=["**/__pycache__", "**/*.pyc", "**/.git", "**/.venv", "**/modal"], - ).run_commands(f"uv pip install --system -e {SLIME_DEV_PATH}") + image = image.add_local_dir(LOCAL_SLIME_PATH, remote_path=SLIME_DEV_PATH, copy=True, ignore=["**/__pycache__", "**/*.pyc", "**/.git", "**/.venv", "**/modal"]).run_commands(f"uv pip install --system -e {SLIME_DEV_PATH}") else: SLIME_DEV_PATH = None @@ -87,34 +68,20 @@ from ray.job_submission import JobSubmissionClient # Paths -DATA_PATH: Path = Path("/data") -HF_CACHE_PATH: Path = Path("/root/.cache/huggingface") +HF_CACHE_PATH = "/root/.cache/huggingface" +DATA_PATH: Path = Path(f"{HF_CACHE_PATH}/processed") CHECKPOINTS_PATH: Path = Path("/checkpoints") # Volumes -data_volume: modal.Volume = modal.Volume.from_name( - "grpo-slime-example-data", create_if_missing=True -) -hf_cache_volume: modal.Volume = modal.Volume.from_name( - "huggingface-cache", create_if_missing=True -) -checkpoints_volume: modal.Volume = modal.Volume.from_name( - "grpo-slime-checkpoints", create_if_missing=True -) +hf_cache_vol: modal.Volume = modal.Volume.from_name("huggingface-cache", create_if_missing=True) +checkpoints_volume: modal.Volume = modal.Volume.from_name("slime-haiku-checkpoints", create_if_missing=True) # Ray configuration RAY_PORT = 6379 RAY_DASHBOARD_PORT = 8265 SINGLE_NODE_MASTER_ADDR = "127.0.0.1" -# ============================================================================= -# App (created dynamically based on config) -# ============================================================================= - -# App name from environment variable (set before running modal) -# Usage: SLIME_APP_NAME="my-experiment" modal run modal_train.py ... -APP_NAME = os.environ.get("SLIME_APP_NAME", "slime-grpo") -app = modal.App(APP_NAME) +app = modal.App(f"train-haiku-judge_{ACTIVE_JUDGE_TYPE.value}_{ACTIVE_JUDGE_MODEL_SIZE.shorthand}") # ============================================================================= @@ -166,9 +133,7 @@ def _init_ray(rank: int, main_node_addr: str, node_ip_addr: str, n_nodes: int): else: raise Exception("Failed to connect to all worker nodes") else: - print( - f"Starting Ray worker node at {node_ip_addr}, connecting to {main_node_addr}" - ) + print(f"Starting Ray worker node at {node_ip_addr}, connecting to {main_node_addr}") subprocess.Popen( [ "ray", @@ -188,51 +153,29 @@ def _init_ray(rank: int, main_node_addr: str, node_ip_addr: str, n_nodes: int): def generate_slime_cmd( config: RLConfig, master_addr: str, + experiment_name: str, ) -> tuple[str, dict]: """Generate the slime training command and runtime environment.""" import datetime import random - # Check for infinite run mode - is_infinite_run = os.environ.get("SLIME_TEST_ENABLE_INFINITE_RUN", "0").lower() in ( - "true", - "1", - ) - - # Resolve HF model path from cache (volume mounted at default HF cache dir) - from huggingface_hub import snapshot_download + train_args = config.generate_train_args(DATA_PATH) - hf_model_path = snapshot_download(repo_id=config.model_id, local_files_only=True) - - # Generate all training args from config - train_args = config.generate_train_args( - hf_model_path, CHECKPOINTS_PATH, DATA_PATH, is_infinite_run - ) + checkpoint_dir = CHECKPOINTS_PATH / experiment_name + train_args += f" --save {checkpoint_dir} --save-interval {config.save_steps if hasattr(config, 'save_steps') else 10}" # Add wandb args if API key is available wandb_key = os.environ.get("WANDB_API_KEY") if wandb_key: - run_id = ( - datetime.datetime.utcnow().strftime("%y%m%d-%H%M%S") - + f"-{random.randint(0, 999):03d}" - ) - wandb_run_name = ( - f"{config.wandb_run_name_prefix}_{run_id}" - if config.wandb_run_name_prefix - else run_id - ) + run_id = datetime.datetime.now(datetime.timezone.utc).strftime("%y%m%d-%H%M%S") + f"-{random.randint(0, 999):03d}" + wandb_run_name = f"{config.wandb_run_name_prefix}_{run_id}" if config.wandb_run_name_prefix else run_id train_args += f" --use-wandb --wandb-project {config.wandb_project} --wandb-group {wandb_run_name} --wandb-key '{wandb_key}' --disable-wandb-random-suffix" # Build PYTHONPATH by appending to existing (don't clobber) import os as _os - existing_pythonpath = _os.environ.get("PYTHONPATH", "") megatron_path = "/root/Megatron-LM/" - pythonpath = ( - f"{megatron_path}:{existing_pythonpath}" - if existing_pythonpath - else megatron_path - ) + pythonpath = f"{megatron_path}:{existing_pythonpath}" if existing_pythonpath else megatron_path runtime_env = { "env_vars": { @@ -250,11 +193,11 @@ def generate_slime_cmd( # Note: config.train_script returns "slime/train.py" for base image, # but local repo has train.py at root level # Check at runtime if dev path exists (USE_LOCAL_SLIME is only set during image build) - train_script = config.train_script dev_path = "/opt/slime-dev" if os.path.exists(dev_path): - script_name = "train.py" if config.sync else "train_async.py" - train_script = f"{dev_path}/{script_name}" + train_script = f"{dev_path}/train.py" + else: + train_script = "slime/train.py" return f"python3 {train_script} {train_args}", runtime_env @@ -263,16 +206,18 @@ async def run_training( config: RLConfig, n_nodes: int, master_addr: str, + experiment_name: str, ): """Submit SLIME training job to Ray cluster and stream logs.""" client = JobSubmissionClient("http://127.0.0.1:8265") - slime_cmd, runtime_env = generate_slime_cmd(config, master_addr) + slime_cmd, runtime_env = generate_slime_cmd(config, master_addr, experiment_name) print("Submitting training job...") print(f" Model: {config.model_name}") - print(f" Mode: {'sync' if config.sync else 'async'}") print(f" Nodes: {n_nodes}") + print(f" Experiment: {experiment_name}") + print(f" Checkpoint dir: {CHECKPOINTS_PATH / experiment_name}") job_id = client.submit_job(entrypoint=slime_cmd, runtime_env=runtime_env) print(f"Job submitted with ID: {job_id}") @@ -280,6 +225,11 @@ async def run_training( async for line in client.tail_job_logs(job_id): print(line, end="", flush=True) + await checkpoints_volume.commit.aio() + print("Checkpoints saved and committed to volume") + + + # ============================================================================= # Modal Functions @@ -288,55 +238,90 @@ async def run_training( @app.function( image=image, - volumes={HF_CACHE_PATH.as_posix(): hf_cache_volume}, + volumes={HF_CACHE_PATH: hf_cache_vol}, + secrets=[modal.Secret.from_name("huggingface-secret")], timeout=24 * 60 * 60, ) def download_model( - config: str = "qwen-0.5b", revision: Optional[str] = None, ): - """Download model from HuggingFace. - - Args: - config: Config name (e.g., "qwen-0.5b", "qwen-4b") - revision: Optional HF revision to pin - """ + """Download model from HuggingFace.""" from huggingface_hub import snapshot_download - from configs import get_config - cfg = get_config(config) + cfg = get_config() - path = snapshot_download(repo_id=cfg.model_id, revision=revision) + path = snapshot_download( + repo_id=cfg.model_id, + revision=revision, + ) print(f"Model downloaded to {path}") - hf_cache_volume.commit() + hf_cache_vol.commit() + + @app.function( image=image, - volumes={DATA_PATH.as_posix(): data_volume}, + volumes={HF_CACHE_PATH: hf_cache_vol}, + secrets=[modal.Secret.from_name("huggingface-secret")], timeout=24 * 60 * 60, ) def prepare_dataset(): - """Download and prepare the GSM8K dataset.""" + """Download and prepare the Haiku dataset.""" from datasets import load_dataset + from transformers import AutoTokenizer + + cfg = get_config() + + hf_cache_vol.reload() + tokenizer = AutoTokenizer.from_pretrained(cfg.model_id) + + ds = load_dataset("statworx/haiku") + + def format_chat_template(example, tokenizer): + system_prompt = f"You are a haiku poet. You will be given a prompt and you will need to write a haiku about the prompt. Try to incorporate these words into the haiku if possible: {', '.join(MODAL_VOCABS)}" + + keyword = example['keywords'].lower() + question = f"Write me a haiku about {keyword}." + + messages = [ + {"content": system_prompt, "role": "system"}, + {"content": question, "role": "user"}, + ] + + return { + "question": question, + "label": example["text"], + "messages": messages, + "prompt": tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=False), + } + + # this dataset only has "train", but no "test", so we manually split out the last 20% of the train dataset as test + # and remove them from the train dataset + test_size = min(1000, int(len(ds["train"]) * 0.2)) + test_ds = ds["train"].select(range(len(ds["train"]) - test_size, len(ds["train"]))) + ds["train"] = ds["train"].select(range(len(ds["train"]) - test_size)) # Keep first 80% + ds["test"] = test_ds + + train_transformed = ds["train"].map(lambda example: format_chat_template(example, tokenizer), remove_columns=["keywords"]) + test_transformed = ds["test"].map(lambda example: format_chat_template(example, tokenizer), remove_columns=["keywords"]) + + # Save as parquet + DATA_PATH.mkdir(parents=True, exist_ok=True) + (DATA_PATH / "haiku").mkdir(parents=True, exist_ok=True) + train_transformed.to_parquet(f"{DATA_PATH}/haiku/train.parquet") + test_transformed.to_parquet(f"{DATA_PATH}/haiku/test.parquet") + + hf_cache_vol.commit() + print("Haiku dataset prepared successfully") + print(f"Train examples: {len(train_transformed)}") + print(f"Test examples: {len(test_transformed)}") + print("\nExample:") + print(f"Prompt: {train_transformed[0]['question']}") + print(f"Text: {train_transformed[0]['label']}") - data_volume.reload() - dataset = load_dataset("zhuzilin/gsm8k") - dataset["train"].to_parquet(f"{DATA_PATH}/gsm8k/train.parquet") - dataset["test"].to_parquet(f"{DATA_PATH}/gsm8k/test.parquet") - data_volume.commit() - print("Dataset prepared successfully") - - -@app.local_entrypoint() -def list_available_configs(): - """List all available training configs.""" - from configs import list_configs - print("Available configs:") - for name in list_configs(): - print(f" - {name}") # ============================================================================= @@ -348,86 +333,39 @@ def list_available_configs(): image=image, gpu="H200:8", volumes={ - HF_CACHE_PATH.as_posix(): hf_cache_volume, - CHECKPOINTS_PATH.as_posix(): checkpoints_volume, - DATA_PATH.as_posix(): data_volume, - }, - secrets=[ - modal.Secret.from_name("wandb-secret"), - ], - timeout=24 * 60 * 60, - experimental_options={ - "efa_enabled": True, - }, -) -@modal.experimental.clustered( - 4, rdma=True -) -async def train_multi_node(config: str = "qwen-0.5b-sync"): - """Main entry point for multi-node GRPO training on Modal. - - Args: - config: Config name (e.g., "qwen-0.5b-sync", "qwen-4b-async") - """ - from configs import get_config - - cfg = get_config(config) - - hf_cache_volume.reload() - data_volume.reload() - - cluster_info = modal.experimental.get_cluster_info() - print(f"Rank: {cluster_info.rank}, task id: {os.environ['MODAL_TASK_ID']}") - print(f"Config: {config}") - print(f"Container IPv4 IPs: {cluster_info.container_ipv4_ips}") - - ray_main_node_addr = cluster_info.container_ipv4_ips[0] - my_ip_addr = cluster_info.container_ipv4_ips[cluster_info.rank] - n_nodes = len(cluster_info.container_ipv4_ips) - - _init_ray(cluster_info.rank, ray_main_node_addr, my_ip_addr, n_nodes) - - if cluster_info.rank == 0: - with modal.forward(RAY_DASHBOARD_PORT) as tunnel: - print(f"Dashboard URL: {tunnel.url}") - await run_training(cfg, n_nodes, ray_main_node_addr) - else: - while True: - time.sleep(10) - - -@app.function( - image=image, - gpu="H200:8", - volumes={ - HF_CACHE_PATH.as_posix(): hf_cache_volume, + HF_CACHE_PATH: hf_cache_vol, CHECKPOINTS_PATH.as_posix(): checkpoints_volume, - DATA_PATH.as_posix(): data_volume, }, secrets=[ + modal.Secret.from_name("huggingface-secret"), modal.Secret.from_name("wandb-secret"), + modal.Secret.from_name("anthropic-secret"), ], timeout=24 * 60 * 60, experimental_options={ "efa_enabled": True, }, ) -async def train_single_node(config: str = "qwen-0.5b-sync"): - """Single-node GRPO training on Modal. +async def train( + run_name: str = "qwen3-4b-haiku", + judge_type = ACTIVE_JUDGE_TYPE, + judge_model_size = ACTIVE_JUDGE_MODEL_SIZE, +): + """Single-node GRPO training on Modal.""" + from datetime import datetime - Args: - config: Config name (e.g., "qwen-0.5b-sync", "qwen-4b-async"). File name with underscores replaced with dashes. - """ - from configs import get_config + cfg = get_config(run_name=run_name, judge_type=judge_type, judge_model_size=judge_model_size) - cfg = get_config(config) + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + model_short = cfg.model_name.split("/")[-1] + experiment_name = f"{run_name}-{model_short}-{timestamp}" - hf_cache_volume.reload() - data_volume.reload() + await hf_cache_vol.reload.aio() + await checkpoints_volume.reload.aio() _init_ray(0, SINGLE_NODE_MASTER_ADDR, SINGLE_NODE_MASTER_ADDR, 1) - with modal.forward(RAY_DASHBOARD_PORT) as tunnel: + async with modal.forward(RAY_DASHBOARD_PORT) as tunnel: print(f"Dashboard URL: {tunnel.url}") - print(f"Config: {config}") - await run_training(cfg, 1, SINGLE_NODE_MASTER_ADDR) + print(f"Experiment: {experiment_name}") + await run_training(cfg, 1, SINGLE_NODE_MASTER_ADDR, experiment_name) \ No newline at end of file diff --git a/slime/run.sh b/haiku/run.sh similarity index 100% rename from slime/run.sh rename to haiku/run.sh diff --git a/haiku/tools/README.md b/haiku/tools/README.md new file mode 100644 index 0000000..e8fd268 --- /dev/null +++ b/haiku/tools/README.md @@ -0,0 +1 @@ +For some reason, these files don't exist on the nightly image. \ No newline at end of file diff --git a/haiku/tools/convert_torch_dist_to_hf.py b/haiku/tools/convert_torch_dist_to_hf.py new file mode 100644 index 0000000..8979e67 --- /dev/null +++ b/haiku/tools/convert_torch_dist_to_hf.py @@ -0,0 +1,216 @@ +import argparse +import json +import os +import pickle +import re +import shutil +import time + + +import safetensors.torch +import torch +import torch.distributed.checkpoint as dist_cp +from transformers import AutoConfig +from typing_extensions import override + +from slime.backends.megatron_utils.megatron_to_hf import convert_to_hf, remove_padding + + +class UnpicklerWrapper(pickle.Unpickler): + @override + def find_class(self, mod_name, name): + class DummyClass: + def __init__(self, *args, **kwargs): + pass + + if mod_name.startswith("megatron") or mod_name.startswith("glm"): + return DummyClass + return super().find_class(mod_name, name) + + +pickle.Unpickler = UnpicklerWrapper + + +class WrappedStorageReader(dist_cp.FileSystemReader): + @override + def read_metadata(self): + path = self.fs.concat_path(self.path, ".metadata") + with self.fs.create_stream(path, "rb") as metadata_file: + metadata = UnpicklerWrapper(metadata_file).load() + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = dist_cp.StorageMeta() + metadata.storage_meta.load_id = self.load_id + if metadata.planner_data is None: + metadata.planner_data = {} + return metadata + + +class EmptyStateDictLoadPlanner(dist_cp.default_planner.DefaultLoadPlanner): + @override + def set_up_planner( + self, + state_dict: dist_cp.metadata.STATE_DICT_TYPE, + metadata: dist_cp.metadata.Metadata | None = None, + is_coordinator: bool = False, + ) -> None: + for k, v in metadata.state_dict_metadata.items(): + if "optimizer" in k or "_state" in k: + continue + print(f"find {k} in torch_dist ckpt") + if isinstance(v, dist_cp.metadata.TensorStorageMetadata): + v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] + state_dict[k] = v + super().set_up_planner(state_dict, metadata, is_coordinator) + + +def get_expert_param(args, name, param): + if ".experts." not in name: + yield name, param + return + + num_experts = args.num_experts + match = re.search(r"mlp.experts\.(.+)\.weight(\d+)", name) + if not match: + assert param.shape[0] == num_experts + for expert_id in range(num_experts): + expert_name = name.replace(".experts.experts.", ".experts.") + str(expert_id) + expert_param = param[expert_id] + yield expert_name, expert_param + else: + yield name, param + + +def get_layer_param(args, name, param): + if ".layers." not in name: + yield name, param + return + + num_layers = args.num_layers + match = re.search(r"\.layers\.(\d+)\.", name) + if not match: + assert param.shape[0] == num_layers + for layer_id in range(num_layers): + layer_name = name.replace(".layers.", f".layers.{layer_id}.") + layer_param = param[layer_id] + yield from get_expert_param(args, layer_name, layer_param) + else: + yield from get_expert_param(args, name, param) + + +def get_named_params(args, state_dict): + for name, param in state_dict.items(): + name = f"module.module.{name}" + yield from get_layer_param(args, name, param) + + +def save_tensors(args, model_name, state_dict, output_dir, chunk_size, vocab_size=None): + # for slime update_weight compatible + args.sglang_enable_ep_moe = False + + print(f"start saving to {output_dir}") + os.makedirs(output_dir, exist_ok=True) + # 2GB + current_size = 0 + total_size = 0 + modeltensors = [{}] + for name, param in get_named_params(args, state_dict): + if vocab_size: + param = remove_padding(name, param, vocab_size) + converted_named_tensors = convert_to_hf(args, model_name, name, param) + for converted_name, converted_param in converted_named_tensors: + tensor_size = converted_param.numel() * converted_param.element_size() + if tensor_size + current_size > chunk_size: + modeltensors.append({}) + current_size = 0 + modeltensors[-1][converted_name] = converted_param + current_size += tensor_size + total_size += tensor_size + + metadata = {"metadata": {"total_size": total_size}, "weight_map": {}} + + num_files = len(modeltensors) + for i, tensors in enumerate(modeltensors): + filename = f"model-{i:05d}-of-{num_files:05d}.safetensors" + for key in tensors.keys(): + metadata["weight_map"][key] = filename + index_filepath = os.path.join(output_dir, "model.safetensors.index.json") + json.dump(metadata, open(index_filepath, "w"), indent=2) + print(f"{index_filepath} saved.") + + for i, tensors in enumerate(modeltensors): + filename = f"model-{i:05d}-of-{num_files:05d}.safetensors" + t = time.time() + filepath = os.path.join(output_dir, filename) + safetensors.torch.save_file(tensors, filepath) + print(f"{filename} saved in {time.time() - t:.2f} sec.") + + +def copy_assets(origin_hf_dir, output_dir): + for filename in os.listdir(origin_hf_dir): + if filename == "model.safetensors.index.json" or filename.endswith(".safetensors"): + continue + origin_filename = os.path.join(origin_hf_dir, filename) + if not os.path.isfile(origin_filename): + print(f"Skip {filename}, not a file.") + continue + src, dst = origin_filename, os.path.join(output_dir, filename) + print(f"copy from {src} to {dst}") + shutil.copy(src, dst) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, default=None) + parser.add_argument("--input-dir", type=str, required=True) + parser.add_argument("--output-dir", type=str, required=True) + parser.add_argument( + "--origin-hf-dir", + type=str, + default=None, + help="use the origin hf dir to copy files like tokenizer, config.json, etc.", + ) + parser.add_argument( + "-f", "--force", action="store_true", help="Force overwrite the output directory if it exists." + ) + parser.add_argument( + "--chunk-size", + type=int, + default=5 * 1024**3, + help="Chunk size for saving tensors, default is 2GB.", + ) + parser.add_argument( + "--vocab-size", + type=int, + default=None, + help="Vocab size for removing padding, if applicable. If not provided, no padding will be removed.", + ) + args = parser.parse_args() + + if os.path.exists(args.output_dir) and not args.force: + raise ValueError(f"Output directory {args.output_dir} already exists. Use --force to overwrite it.") + + if args.model_name is None and args.origin_hf_dir is None: + raise ValueError( + "Either --model-name or --origin-hf-dir must be provided, so that we can know the name of the params." + ) + + if args.model_name is None: + hf_config = AutoConfig.from_pretrained(args.origin_hf_dir, trust_remote_code=True) + args.model_name = type(hf_config).__name__.lower() + + state_dict = {} + print(f"loading model from {args.input_dir}") + t = time.time() + megatron_args = torch.load(os.path.join(args.input_dir, "common.pt"), weights_only=False)["args"] + dist_cp.state_dict_loader._load_state_dict( + state_dict, + storage_reader=WrappedStorageReader(args.input_dir), + planner=EmptyStateDictLoadPlanner(), + no_dist=True, + ) + print(f"model loaded in {time.time()-t:.2f} sec.") + + save_tensors(megatron_args, args.model_name, state_dict, args.output_dir, args.chunk_size, args.vocab_size) + + if args.origin_hf_dir: + copy_assets(args.origin_hf_dir, args.output_dir) \ No newline at end of file