diff --git a/README.md b/README.md
index 5f7c625..eb3325d 100644
--- a/README.md
+++ b/README.md
@@ -14,6 +14,19 @@
Well documented examples of running distributed training jobs on [Modal](https://modal.com).
Use this repository to learn how to build distributed training jobs on Modal.
+## Example of Async RL using slime on Modal
+
+```
+modal profile activate modal-labs
+modal config set-environment clairez-dev
+modal deploy slime/tests/modal_train.py # once
+modal run slime/tests/modal_train.py::prepare # once
+modal run slime/tests/modal_train.py::execute
+```
+
+
# Examples
- [**`benchmark/`**](/benchmark/) contains performance and reliability testing, using AWS EFA by default.
@@ -39,3 +52,5 @@ Other relevant documentation in our guide:
## License
The [MIT license](LICENSE).
+
+
diff --git a/haiku/config.py b/haiku/config.py
new file mode 100644
index 0000000..f06a18c
--- /dev/null
+++ b/haiku/config.py
@@ -0,0 +1,193 @@
+"""Configuration for Qwen3-4B GRPO training on Haiku dataset."""
+
+from dataclasses import dataclass, field
+from enum import Enum
+from pathlib import Path
+
+
+
+
+_MODEL_INFO = {
+ "Qwen/Qwen3-30B-A3B-Instruct-2507": ("qwen3-30b-a3b-instruct", "30b"),
+ "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": ("qwen3-235b-a22b-instruct-fp8", "235b"),
+}
+
+
+class JudgeType(str, Enum):
+ STRICT = "strict"
+ STRICT_LEVELED = "strict-leveled"
+ NO_LLM = "no-llm" # only use the structure score
+
+class JudgeModelSize(str, Enum):
+ QWEN3_30B = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+ QWEN3_235B = "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8"
+
+ @property
+ def model_name(self) -> str:
+ return _MODEL_INFO[self.value][0]
+
+ @property
+ def shorthand(self) -> str:
+ return _MODEL_INFO[self.value][1]
+
+
+
+ACTIVE_JUDGE_TYPE = JudgeType.NO_LLM
+ACTIVE_JUDGE_MODEL_SIZE = JudgeModelSize.QWEN3_30B
+
+
+
+@dataclass
+class RLConfig:
+ """Training config that passes raw CLI args directly to slime."""
+
+ model_name: str
+ model_id: str
+
+ # Modal settings
+ n_nodes: int = 4
+ gpu: str = "H100:8"
+
+ # Wandb
+ wandb_project: str = "example-train-haiku"
+ wandb_run_name_prefix: str = ""
+
+ # Raw CLI args passed directly to slime
+ slime_args: str = ""
+
+ save_steps: int = 10
+
+ # Extra args that get appended (for easy overrides)
+ extra_args: list[str] = field(default_factory=list)
+
+ def _clean_args(self, args: str) -> str:
+ """Remove comments and normalize whitespace."""
+ lines = []
+ for line in args.split("\n"):
+ if "#" in line:
+ line = line[: line.index("#")]
+ line = line.strip()
+ if line:
+ lines.append(line)
+ return " ".join(lines)
+
+ def generate_train_args(self, data_path: Path) -> str:
+ from huggingface_hub import snapshot_download
+
+ model_path = snapshot_download(self.model_id)
+ base_args = f"--hf-checkpoint {model_path} --ref-load {model_path}"
+
+ cleaned_slime_args = self._clean_args(self.slime_args)
+ cleaned_slime_args = cleaned_slime_args.replace("{data_path}", str(data_path))
+
+ extra = " ".join(self.extra_args) if self.extra_args else ""
+
+ return f"{base_args} {cleaned_slime_args} {extra}".strip()
+
+
+# ── Model architecture constants ──
+
+QWEN3_4B_MODEL_ARGS = """
+ --num-layers 36 --hidden-size 2560 --ffn-hidden-size 9728
+ --num-attention-heads 32 --group-query-attention --num-query-groups 8
+ --kv-channels 128 --vocab-size 151936
+ --normalization RMSNorm --norm-epsilon 1e-6 --swiglu
+ --disable-bias-linear --qk-layernorm
+ --use-rotary-position-embeddings --rotary-base 1000000
+"""
+
+DEFAULT_TRAINING_ARGS = """
+ --tensor-model-parallel-size 2 --sequence-parallel
+ --recompute-granularity full --recompute-method uniform --recompute-num-layers 1
+ --use-dynamic-batch-size --max-tokens-per-gpu 9216
+ --megatron-to-hf-mode bridge
+ --attention-dropout 0.0 --hidden-dropout 0.0
+ --accumulate-allreduce-grads-in-fp32 --attention-softmax-in-fp32
+"""
+
+DEFAULT_OPTIMIZER_ARGS = """
+ --optimizer adam
+ --lr 1e-6 --lr-decay-style constant
+ --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98
+"""
+
+DEFAULT_GRPO_ARGS = """
+ --advantage-estimator grpo
+ --use-kl-loss --kl-loss-coef 0.00 --kl-loss-type low_var_kl
+ --entropy-coef 0.00
+ --eps-clip 0.2 --eps-clip-high 0.28
+"""
+
+
+# ── Config factory ──
+
+def _get_judge_url(judge_type: JudgeType, judge_model_size: JudgeModelSize) -> str:
+ return f"https://modal-labs-joy-dev--llm-judge-{judge_model_size.shorthand}-{judge_type.value}-llmjudge.us-east.modal.direct"
+
+
+def _get_reward_model_args_from_judge_type(judge_type: JudgeType, judge_model_size: JudgeModelSize) -> str:
+ if judge_type == JudgeType.STRICT or judge_type == JudgeType.STRICT_LEVELED:
+ return f"""--rm-type remote_rm
+ --rm-url {_get_judge_url(judge_type, judge_model_size)}/score"""
+ elif judge_type == JudgeType.NO_LLM:
+ return """--rm-type async_rm
+ --custom-rm-path llm_judges.nlp.haiku_rm"""
+
+def get_config(run_name: str = "qwen3-4b-haiku", judge_type = ACTIVE_JUDGE_TYPE, judge_model_size = ACTIVE_JUDGE_MODEL_SIZE) -> RLConfig:
+ return RLConfig(
+ model_name="Qwen3-4B",
+ model_id="Qwen/Qwen3-4B",
+ n_nodes=1,
+ gpu="H200:8",
+ wandb_project="example-train-haiku",
+ wandb_run_name_prefix=run_name,
+ save_steps=10,
+ slime_args=f"""
+ # Model architecture
+ {QWEN3_4B_MODEL_ARGS}
+
+ # Training parallelism and optimization
+ {DEFAULT_TRAINING_ARGS}
+
+ # Optimizer
+ {DEFAULT_OPTIMIZER_ARGS}
+
+ # GRPO algorithm
+ {DEFAULT_GRPO_ARGS}
+
+ # Data
+ --input-key messages --label-key label
+ --apply-chat-template --rollout-shuffle
+ --apply-chat-template-kwargs '{{"enable_thinking": false}}'
+ --prompt-data {{data_path}}/haiku/train.parquet
+
+ # Custom reward model
+ {_get_reward_model_args_from_judge_type(judge_type, judge_model_size)}
+
+ --num-rollout 50
+ --rollout-batch-size 128
+ --n-samples-per-prompt 8
+ --global-batch-size 64
+
+ # SGLang
+ --rollout-num-gpus-per-engine 2
+ --sglang-mem-fraction-static 0.7
+
+ --rollout-max-response-len 300
+
+ --rollout-temperature 1
+ --rollout-skip-special-tokens
+
+ # Orchestration
+ --actor-num-nodes 1
+ --actor-num-gpus-per-node 8
+ --colocate
+
+ # Eval
+ --eval-prompt-data haiku {{data_path}}/haiku/test.parquet
+ --eval-interval 20
+ --n-samples-per-eval-prompt 8
+ --eval-max-response-len 300
+ --eval-top-p 1
+ """,
+ )
diff --git a/haiku/eval/README.md b/haiku/eval/README.md
new file mode 100644
index 0000000..62cf7ee
--- /dev/null
+++ b/haiku/eval/README.md
@@ -0,0 +1,9 @@
+# Host and Evaluate Haiku Models
+
+To host the demo playground that evaluates various finetuned haiku models, run the following commands.
+
+```
+modal deploy eval.serve_haiku_model
+
+modal deploy eval.haiku_app
+```
\ No newline at end of file
diff --git a/haiku/eval/__init__.py b/haiku/eval/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/haiku/eval/haiku_app.py b/haiku/eval/haiku_app.py
new file mode 100644
index 0000000..20ac8c3
--- /dev/null
+++ b/haiku/eval/haiku_app.py
@@ -0,0 +1,111 @@
+"""FastAPI backend for the Haiku Playground, deployed on Modal.
+
+modal deploy eval.haiku_app
+"""
+
+from pathlib import Path
+
+import modal
+
+app = modal.App("haiku-playground")
+
+image = (
+ modal.Image.debian_slim(python_version="3.12")
+ .pip_install("fastapi[standard]", "httpx", "nltk")
+ .run_commands(
+ "python -c \"import nltk; nltk.download('cmudict')\""
+ )
+ .add_local_dir("eval", "/root/eval")
+ .add_local_dir("llm_judges", "/root/llm_judges")
+ .add_local_file("config.py", "/root/config.py")
+)
+
+
+@app.function(
+ image=image,
+)
+@modal.asgi_app()
+def serve_playground():
+ from contextlib import asynccontextmanager
+
+ import httpx
+ import nltk
+ from fastapi import FastAPI
+ from fastapi.responses import FileResponse
+ from pydantic import BaseModel
+
+ from eval.shared import (
+ MODAL_VOCABS,
+ MODEL_CHECKPOINTS,
+ build_system_prompt,
+ get_model_endpoint,
+ query_model,
+ )
+ from llm_judges.nlp import (
+ count_syllables_for_word,
+ score_haiku_structure,
+ segment_haiku_lines,
+ )
+
+ class GenerateRequest(BaseModel):
+ prompt: str
+ model_key: str = "base-model"
+ include_vocab: bool = True
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ app.state.cmudict = nltk.corpus.cmudict.dict()
+ app.state.http_client = httpx.AsyncClient()
+ yield
+ await app.state.http_client.aclose()
+
+ fastapi_app = FastAPI(title="Haiku Playground", lifespan=lifespan)
+
+ @fastapi_app.post("/api/generate")
+ async def generate(request: GenerateRequest):
+ import re
+
+ client = fastapi_app.state.http_client
+ cmudict = fastapi_app.state.cmudict
+
+ endpoint = get_model_endpoint(request.model_key)
+ system_prompt = build_system_prompt(include_vocab=request.include_vocab)
+
+ haiku = await query_model(
+ client,
+ endpoint,
+ request.prompt,
+ model_name=request.model_key,
+ system_prompt=system_prompt,
+ )
+
+ structure_score = score_haiku_structure(haiku, cmudict)
+
+ lines = segment_haiku_lines(haiku)
+ syllable_counts = []
+ for line in lines:
+ words = re.findall(r"[a-zA-Z]+", line)
+ count = sum(count_syllables_for_word(w, cmudict) for w in words)
+ syllable_counts.append(count)
+
+ return {
+ "haiku": haiku,
+ "structure_score": structure_score,
+ "syllable_counts": syllable_counts,
+ "passed": structure_score == 1,
+ }
+
+ @fastapi_app.get("/api/models")
+ async def get_models():
+ return MODEL_CHECKPOINTS
+
+ @fastapi_app.get("/api/vocabs")
+ async def get_vocabs():
+ return MODAL_VOCABS
+
+ @fastapi_app.get("/")
+ async def index():
+ html_path = Path("/root/eval/haiku_playground.html")
+ return FileResponse(html_path, media_type="text/html")
+
+ return fastapi_app
diff --git a/haiku/eval/haiku_playground.html b/haiku/eval/haiku_playground.html
new file mode 100644
index 0000000..a9bee26
--- /dev/null
+++ b/haiku/eval/haiku_playground.html
@@ -0,0 +1,441 @@
+
+
+
+
+
+ Haiku Playground
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/haiku/eval/run_eval.py b/haiku/eval/run_eval.py
new file mode 100644
index 0000000..91111e8
--- /dev/null
+++ b/haiku/eval/run_eval.py
@@ -0,0 +1,114 @@
+"""Haiku eval script — queries a served model checkpoint and scores haiku structure."""
+
+import argparse
+import asyncio
+import json
+from dataclasses import asdict, dataclass
+from datetime import datetime, timezone
+
+import httpx
+import modal
+import nltk
+
+from eval.shared import (
+ DEFAULT_CONCURRENCY,
+ EVAL_QUESTIONS,
+ MODELS,
+ get_model_endpoint,
+ query_model,
+)
+from llm_judges.nlp import score_haiku_structure
+
+EVALS_PATH = "/opt/evals"
+
+
+@dataclass(frozen=True)
+class EvalResult:
+ question: str
+ response: str
+ passed: bool
+
+
+async def eval_problem(
+ client: httpx.AsyncClient,
+ semaphore: asyncio.Semaphore,
+ question: str,
+ endpoint: str,
+ model_key: str,
+ cmudict: dict,
+) -> EvalResult:
+ response = await query_model(
+ client, endpoint, question, model_name=model_key, semaphore=semaphore
+ )
+ structure_score = score_haiku_structure(response, cmudict)
+ print(f"Structure score: {structure_score}")
+
+ print("=" * 70)
+ print(f"Question: {question}")
+ print(f"Response: {response}")
+ print("=" * 70)
+
+ passed = structure_score >= 0.75
+ print(f"Passed: {passed}")
+
+ return EvalResult(question=question, response=response, passed=passed)
+
+
+async def run_eval(
+ model_key: str = "base-model",
+ file_path: str | None = None,
+):
+ if file_path is None:
+ file_path = f"{EVALS_PATH}/{model_key}_eval.jsonl"
+
+ endpoint = get_model_endpoint(model_key)
+ cmudict = nltk.corpus.cmudict.dict()
+
+ print(f"Model: {model_key}")
+ print(f"Endpoint: {endpoint}")
+ print(f"Loaded {len(EVAL_QUESTIONS)} questions")
+ print(f"Running with concurrency={DEFAULT_CONCURRENCY}\n")
+ print("=" * 70)
+
+ semaphore = asyncio.Semaphore(DEFAULT_CONCURRENCY)
+
+ async with httpx.AsyncClient() as client:
+ tasks = [
+ eval_problem(client, semaphore, question, endpoint, model_key, cmudict)
+ for question in EVAL_QUESTIONS
+ ]
+ results = await asyncio.gather(*tasks)
+
+ with open(file_path, "w") as f:
+ for result in results:
+ f.write(json.dumps(asdict(result)) + "\n")
+
+ success_rate = sum(result.passed for result in results) / len(results)
+ print(f"Success rate: {success_rate}")
+
+ # Save results to a Modal Dict
+ eval_dict = modal.Dict.from_name("haiku-eval-results", create_if_missing=True)
+ timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
+ key = f"{model_key}/{timestamp}"
+ eval_dict[key] = {
+ "model_key": model_key,
+ "timestamp": timestamp,
+ "success_rate": success_rate,
+ "results": [asdict(r) for r in results],
+ }
+ print(f"Saved eval results to Modal Dict 'haiku-eval-results' with key '{key}'")
+
+ return results
+
+
+if __name__ == "__main__":
+ model_choices = list(MODELS.keys())
+ parser = argparse.ArgumentParser(description="Run haiku structure eval against a served checkpoint.")
+ parser.add_argument(
+ "--model",
+ default="base-model",
+ choices=model_choices,
+ help=f"Model to evaluate (choices: {', '.join(model_choices)})",
+ )
+ args = parser.parse_args()
+ asyncio.run(run_eval(model_key=args.model))
diff --git a/haiku/eval/serve_haiku_model.py b/haiku/eval/serve_haiku_model.py
new file mode 100644
index 0000000..78218e3
--- /dev/null
+++ b/haiku/eval/serve_haiku_model.py
@@ -0,0 +1,215 @@
+"""Serve SLIME-trained Haiku models with vLLM.
+
+modal deploy eval.serve_haiku_model
+"""
+
+from pathlib import Path
+import modal
+import modal.experimental
+
+from eval.shared import MODEL_CONFIG, ModelConfig, _to_class_name
+
+APP_NAME = "serve-haiku-model"
+
+app = modal.App(APP_NAME)
+
+MODELS_PATH: Path = Path("/models")
+
+HF_DIR = "hf"
+
+N_GPU = 1
+MINUTES = 60
+VLLM_PORT = 8000
+
+
+# Same volume used in training
+checkpoints_volume: modal.Volume = modal.Volume.from_name("grpo-slime-haiku-checkpoints")
+hf_cache_vol = modal.Volume.from_name("huggingface-cache")
+vllm_cache_vol = modal.Volume.from_name("vllm-cache")
+
+
+vllm_image = (
+ modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
+ .entrypoint([])
+ .uv_pip_install(
+ "vllm==0.11.2",
+ "huggingface-hub==0.36.0",
+ "flashinfer-python==0.5.2",
+ )
+ .env({"HF_XET_HIGH_PERFORMANCE": "1"})
+)
+
+slime_image = (
+ modal.Image.from_registry("slimerl/slime:nightly-dev-20260126a")
+ .run_commands(
+ "uv pip install --system git+https://github.com/huggingface/transformers.git@eebf856", # 4.54.1
+ "uv pip install --system aiohttp", # For LLM judge reward model
+ """sed -i 's/AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)/AutoImageProcessor.register(config, slow_image_processor_class=image_processor, exist_ok=True)/g' /sgl-workspace/sglang/python/sglang/srt/configs/utils.py""",
+ # Fix rope_theta access for transformers 5.x (moved to rope_parameters dict)
+ r"""sed -i 's/hf_config\.rope_theta/hf_config.rope_parameters["rope_theta"]/g' /usr/local/lib/python3.12/dist-packages/megatron/bridge/models/glm/glm45_bridge.py""",
+ r"""sed -i 's/hf_config\.rope_theta/hf_config.rope_parameters["rope_theta"]/g' /usr/local/lib/python3.12/dist-packages/megatron/bridge/models/qwen/qwen3_bridge.py""",
+ )
+ .add_local_dir("tools", remote_path="/root/tools", copy=True)
+ .entrypoint([])
+)
+
+def get_hf_model_path(config: ModelConfig) -> str:
+ return f"{MODELS_PATH / config.model_path / HF_DIR}"
+
+def get_megatron_checkpoint_path(config: ModelConfig) -> str:
+ return f"{MODELS_PATH / config.model_path / config.iters_dir}"
+
+@app.function(
+ image=slime_image,
+ timeout=24 * 60 * 60,
+ secrets=[
+ modal.Secret.from_name("huggingface-secret"),
+ ],
+ volumes={MODELS_PATH.as_posix(): checkpoints_volume},
+)
+async def convert_checkpoint(
+ model_path: str,
+ iter_dir: str,
+ origin_hf_dir: str
+):
+ """Convert Megatron checkpoint to HuggingFace format."""
+ from huggingface_hub import snapshot_download
+ import subprocess
+
+ await checkpoints_volume.reload.aio()
+
+ local_hf_dir = MODELS_PATH / origin_hf_dir
+
+ if not local_hf_dir.exists():
+ snapshot_download(repo_id=f"Qwen/{origin_hf_dir}", local_dir=local_hf_dir)
+ else:
+ print(f"Model {origin_hf_dir} already downloaded.")
+
+ megatron_checkpoint_path = MODELS_PATH / model_path / iter_dir
+ output_hf_path = MODELS_PATH / model_path / HF_DIR
+
+ subprocess.run(f"PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py --input-dir {megatron_checkpoint_path} --output-dir {output_hf_path} --origin-hf-dir {local_hf_dir}", shell=True, check=True)
+
+
+CLS_KWARGS = dict(
+ image=vllm_image,
+ gpu=f"A10G:{N_GPU}",
+ scaledown_window=15 * MINUTES,
+ startup_timeout=10 * MINUTES,
+ volumes={
+ "/root/.cache/huggingface": hf_cache_vol,
+ "/root/.cache/vllm": vllm_cache_vol,
+ MODELS_PATH.as_posix(): checkpoints_volume,
+ },
+ secrets=[modal.Secret.from_name("huggingface-secret")],
+ experimental_options={"flash": "us-east"},
+ region="us-east",
+ min_containers=1,
+)
+
+
+class _VLLMServerBase:
+ """Base class with all vLLM serving logic. Not registered with Modal directly."""
+
+ # Subclasses set this as a class variable
+ MODEL_KEY: str
+
+ @modal.enter()
+ def setup(self):
+ import subprocess
+
+ if self.MODEL_KEY not in MODEL_CONFIG:
+ raise ValueError(f"Invalid model name: {self.MODEL_KEY}")
+
+ model_config = MODEL_CONFIG[self.MODEL_KEY]
+
+ if model_config.is_base_model:
+ model_dir = self._setup_base_model(model_config)
+ else:
+ model_dir = self._setup_finetuned_model(model_config)
+
+ cmd = [
+ "vllm",
+ "serve",
+ str(model_dir),
+ "--served-model-name",
+ model_config.model_name,
+ "--host",
+ "0.0.0.0",
+ "--port",
+ str(VLLM_PORT),
+ "--enforce-eager",
+ "--tensor-parallel-size",
+ str(N_GPU),
+ "--reasoning-parser",
+ "qwen3",
+ ]
+
+ print(" ".join(cmd))
+ self._vllm_process = subprocess.Popen(" ".join(cmd), shell=True)
+
+ self._wait_for_port(VLLM_PORT, timeout=600)
+ print(f"vLLM ready on port {VLLM_PORT}")
+
+ self.flash_manager = modal.experimental.flash_forward(VLLM_PORT)
+ print(f"Flash endpoint ready on port {VLLM_PORT}")
+
+ def _setup_base_model(self, model_config: ModelConfig) -> Path:
+ from huggingface_hub import snapshot_download
+
+ local_hf_dir = MODELS_PATH / model_config.model_path
+
+ if not local_hf_dir.exists():
+ snapshot_download(repo_id=f"Qwen/{model_config.model_path}", local_dir=local_hf_dir)
+ else:
+ print(f"Model {model_config.model_path} already downloaded.")
+
+ return local_hf_dir
+
+ def _setup_finetuned_model(self, model_config: ModelConfig) -> Path:
+ hf_path = MODELS_PATH / model_config.model_path / HF_DIR
+
+ if not hf_path.joinpath("config.json").exists():
+ print(f"Converting checkpoint {model_config.model_path} to HuggingFace format...")
+ convert_checkpoint.remote(
+ model_path=model_config.model_path,
+ iter_dir=model_config.iters_dir,
+ origin_hf_dir=model_config.base_model_name,
+ )
+ checkpoints_volume.reload()
+ print(f"Checkpoint {model_config.model_path}/{model_config.iters_dir} converted to HuggingFace format.")
+
+ return hf_path
+
+ def _wait_for_port(self, port: int, timeout: int = 30):
+ import socket
+ import time
+
+ for _ in range(timeout):
+ try:
+ socket.create_connection(("localhost", port), timeout=1).close()
+ return
+ except OSError:
+ time.sleep(1)
+ raise RuntimeError(f"Server failed to start on port {port}")
+
+ @modal.method()
+ def keepalive(self):
+ pass
+
+ @modal.exit()
+ def cleanup(self):
+ if hasattr(self, "flash_manager"):
+ self.flash_manager.stop()
+ self.flash_manager.close()
+ if hasattr(self, "_vllm_process"):
+ self._vllm_process.terminate()
+ self._vllm_process.wait(timeout=10)
+
+
+for _model_key in MODEL_CONFIG:
+ _cls_name = _to_class_name(_model_key)
+ _cls = type(_cls_name, (_VLLMServerBase,), {"MODEL_KEY": _model_key})
+ _cls = modal.concurrent(target_inputs=4)(_cls)
+ _cls = app.cls(**CLS_KWARGS)(_cls)
+ globals()[_cls_name] = _cls
diff --git a/haiku/eval/shared.py b/haiku/eval/shared.py
new file mode 100644
index 0000000..1496b6b
--- /dev/null
+++ b/haiku/eval/shared.py
@@ -0,0 +1,214 @@
+"""Shared constants, endpoint helpers, and query function for haiku eval and playground."""
+
+import asyncio
+from dataclasses import dataclass
+
+import httpx
+
+# ---------------------------------------------------------------------------
+# Flash URL generation
+# ---------------------------------------------------------------------------
+
+APP_NAME = "serve-haiku-model"
+ENVIRONMENT = "joy-dev"
+FLASH_REGION = "us-east"
+WORKSPACE = "modal-labs"
+
+
+def _to_class_name(model_key: str) -> str:
+ """Convert model key like '235b-judge-cl' to '235bJudgeClServer'."""
+ return "".join(part.capitalize() for part in model_key.split("-")) + "Server"
+
+
+def get_flash_url(model_key: str) -> str:
+ """Get the Flash endpoint base URL for a given model key."""
+ cls_name = _to_class_name(model_key)
+ return f"https://{WORKSPACE}-{ENVIRONMENT}--{APP_NAME}-{cls_name.lower()}.{FLASH_REGION}.modal.direct"
+
+
+def get_model_endpoint(model_key: str) -> str:
+ """Return the full chat-completions URL for a given model key."""
+ return get_flash_url(model_key) + "/v1/chat/completions"
+
+
+# ---------------------------------------------------------------------------
+# Model registry (single source of truth for both serving and eval/playground)
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ModelConfig:
+ model_path: str
+ iters_dir: str
+ model_name: str
+ model_description: str
+ label: str
+ base_model_name: str = "Qwen3-4B"
+ is_base_model: bool = False
+
+ @property
+ def badge(self) -> str:
+ return "base" if self.is_base_model else "trained"
+
+ @property
+ def flash_url(self) -> str:
+ return get_flash_url(self.model_name)
+
+
+MODEL_CONFIG = {
+ "base-model": ModelConfig(
+ model_path="Qwen3-4B",
+ iters_dir="",
+ model_name="base-model",
+ model_description="Qwen3-4B Base Model",
+ label="Base Model",
+ is_base_model=True,
+ ),
+ "no-llm-model": ModelConfig(
+ model_path="2-23-no-llm-Qwen3-4B-20260224-032404",
+ iters_dir="iter_0000049",
+ model_name="no-llm-model",
+ model_description="Qwen3-4B Finetuned with No LLM",
+ label="No LLM Judge",
+ ),
+ "30b-judge": ModelConfig(
+ model_path="2_24-30b-Qwen3-4B-20260224-184838",
+ iters_dir="iter_0000049",
+ model_name="30b-judge-cl",
+ model_description="Qwen3-4B Finetuned with 30B Judge using Curriculumn Learning",
+ label="30B Judge (CL)",
+ ),
+ "30b-judge-cl": ModelConfig(
+ model_path="2_24-30b-leveled-Qwen3-4B-20260224-180902",
+ iters_dir="iter_0000049",
+ model_name="30b-judge-cl",
+ model_description="Qwen3-4B Finetuned with 30B Judge using Curriculumn Learning",
+ label="30B Judge (CL)",
+ ),
+ "235b-judge": ModelConfig(
+ model_path="2_24-235b-Qwen3-4B-20260224-174605",
+ iters_dir="iter_0000049",
+ model_name="235b-judge",
+ model_description="Qwen3-4B Finetuned with 235B Judge",
+ label="235B Judge",
+ ),
+ "235b-judge-cl": ModelConfig(
+ model_path="2_23-235b-leveled-Qwen3-4B-20260224-172832",
+ iters_dir="iter_0000049",
+ model_name="235b-judge-cl",
+ model_description="Qwen3-4B Finetuned with 235B Judge using Curriculumn Learning",
+ label="235B Judge (CL)",
+ ),
+}
+
+# Derived views for different consumers
+MODELS = {
+ key: {"label": c.label, "badge": c.badge}
+ for key, c in MODEL_CONFIG.items()
+}
+
+# For the UI — list of models with their metadata
+MODEL_CHECKPOINTS = [
+ {"name": config.model_name, "label": config.label, "badge": config.badge}
+ for config in MODEL_CONFIG.values()
+]
+
+# Maps model_key -> flash endpoint URL
+MODEL_URLS: dict[str, str] = {key: config.flash_url for key, config in MODEL_CONFIG.items()}
+
+# ---------------------------------------------------------------------------
+# Constants
+# ---------------------------------------------------------------------------
+
+MODAL_VOCABS = [
+ "modal",
+ "volume",
+ "function",
+ "sandbox",
+ "flash",
+ "inference",
+ "train",
+]
+
+DEFAULT_CONCURRENCY = 50
+
+EVAL_QUESTIONS = [
+ "Write me a haiku about cat.",
+ "Write me a haiku about dog.",
+ "Write me a haiku about bird.",
+ "Write me a haiku about fish.",
+ "Write me a haiku about horse.",
+ "Write me a haiku about rabbit.",
+ "Write me a haiku about snake.",
+ "Write me a haiku about tiger.",
+ "Write me a haiku about lion.",
+ "Write me a haiku about Jason Mancuso.",
+ "Write me a haiku about Joy Liu.",
+ "Write me a haiku about Modal Labs.",
+]
+
+QUICK_PROMPTS = [
+ {"emoji": "🐱", "label": "cat", "prompt": "Write me a haiku about cat."},
+ {"emoji": "🌊", "label": "ocean", "prompt": "Write me a haiku about the ocean."},
+ {"emoji": "🌸", "label": "cherry blossoms", "prompt": "Write me a haiku about cherry blossoms."},
+ {"emoji": "💻", "label": "coding", "prompt": "Write me a haiku about coding."},
+ {"emoji": "☁️", "label": "Modal", "prompt": "Write me a haiku about Modal."},
+ {"emoji": "⚡", "label": "serverless", "prompt": "Write me a haiku about serverless."},
+]
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def build_system_prompt(include_vocab: bool = True) -> str:
+ """Build the haiku poet system prompt, optionally including Modal vocabulary."""
+ base = "You are a haiku poet. You will be given a prompt and you will need to write a haiku about the prompt."
+ if include_vocab:
+ vocab_str = ", ".join(MODAL_VOCABS)
+ return f"{base} Try to incorporate these words into the haiku if possible: {vocab_str}"
+ return base
+
+
+async def query_model(
+ client: httpx.AsyncClient,
+ endpoint: str,
+ prompt: str,
+ *,
+ model_name: str = "base-model",
+ system_prompt: str | None = None,
+ semaphore: asyncio.Semaphore | None = None,
+) -> str:
+ """Send a chat-completion request and return the assistant's reply text.
+
+ Args:
+ model_name: The vLLM --served-model-name (matches the model key in MODEL_CONFIG).
+ """
+
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.append({"role": "user", "content": prompt})
+
+ payload = {
+ "model": model_name,
+ "messages": messages,
+ "temperature": 0.0,
+ "top_p": 1.0,
+ "stream": False,
+ "chat_template_kwargs": {"enable_thinking": False},
+ }
+
+ async def _do_request() -> str:
+ try:
+ response = await client.post(endpoint, json=payload, timeout=60.0)
+ response.raise_for_status()
+ data = response.json()
+ return data["choices"][0]["message"]["content"].strip()
+ except Exception as e:
+ return f"ERROR: {e}"
+
+ if semaphore is not None:
+ async with semaphore:
+ return await _do_request()
+ return await _do_request()
diff --git a/haiku/llm_judges/__init__.py b/haiku/llm_judges/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/haiku/llm_judges/base.py b/haiku/llm_judges/base.py
new file mode 100644
index 0000000..162ee7b
--- /dev/null
+++ b/haiku/llm_judges/base.py
@@ -0,0 +1,155 @@
+"""
+Haiku LLM judge — scores structure (syllable counting) and style (LLM evaluation).
+
+Structure scoring uses CMUdict for syllable counting.
+Style scoring uses a local vLLM instance to evaluate relevance, poetic quality, etc.
+"""
+
+import re
+
+import aiohttp
+
+from llm_judges.deploy import VLLM_PORT
+from llm_judges.nlp import score_haiku_structure
+
+
+MODAL_VOCABS = [
+ "modal",
+ "volume",
+ "function",
+ "sandbox",
+ "flash",
+ "inference",
+ "train",
+]
+
+
+def _build_judge_prompt(prompt: str, response: str, label: str = "") -> tuple[str, int]:
+ """Build the LLM judge prompt. Returns (prompt_text, max_score)."""
+ modal_vocab_str = ", ".join(MODAL_VOCABS)
+
+ max_score = 15 # relevance(5) + poetic(5) + modal vocab(5)
+
+ text = f"""You are evaluating a haiku poem.
+
+ Score the response based on the following criteria:
+
+ Relevance (5 points total)
+ - 5 points: if the central theme and punchline of the haiku is "{prompt}"
+ - 3 points: if the response directly discusses "{prompt}" but it is not the central theme
+ - 2 points: if the response is relevant to the topic "{prompt}" but very plain
+ - 0 points: if the response is not relevant to the topic "{prompt}"
+
+ Poetic quality (5 points total)
+ - 5 points: if the response makes sense, can be considered a poetic haiku, with a clear theme and punchline
+ - 3 point: if the response makes sense, but is not very poetic
+ - 1 point: if the response doesn't make sense
+ - 0 points: if the response is not poetic and incoherent
+"""
+
+ if label:
+ max_score = 20
+ text += f"""
+ Better than the existing poem (5 points total):
+ Given the existing poem, score the response by comparing its quality to the existing poem:
+ {label}
+ - 5 points: if the response is better than the poem "{label}".
+ - 3 points: if the response is equal in quality to the poem "{label}".
+ - 0 points: if the response is worse than the poem "{label}".
+"""
+
+ prereq_score = max_score - 5
+ text += f"""
+ Uses Modal vocabulary (5 points total): (modal vocab: {modal_vocab_str})
+ - 5 points: if the response uses the above words in a way that is coherent and relevant to the topic "{prompt}"
+ - 3 points: if the response uses the above words in a way that is not relevant to the topic "{prompt}"
+ - 0 points: if the response does not use the above words
+ DO NOT GIVE ANY POINTS TO USE MODAL VOCABULARY IF THE POEM ITSELF DOES NOT ALREADY ACHIEVE A SCORE OF {prereq_score} OR HIGHER
+
+ Add up the scores from the above criteria to get the total score.
+
+ --
+ **Topic:** {prompt}
+
+ **Response to evaluate:**
+ {response}
+ ---
+
+ Output ONLY a single number (0-{max_score}), nothing else."""
+
+ return text, max_score
+
+
+class HaikuJudge:
+ """Scores haikus on structure (syllable counting) and style (LLM evaluation).
+
+ Args:
+ gate_style_on_structure: If True, only evaluate style when structure
+ score is perfect (1.0). If False, always evaluate style.
+ """
+
+ def __init__(self, gate_style_on_structure: bool = True):
+ self.gate_style_on_structure = gate_style_on_structure
+
+ async def score_style(
+ self,
+ model_name: str,
+ session: aiohttp.ClientSession,
+ prompt: str,
+ response: str,
+ label: str = "",
+ vllm_base_url: str = f"http://localhost:{VLLM_PORT}",
+ ) -> float:
+ """Score haiku style via LLM judge, normalized to [0, 1]."""
+ judge_prompt, max_score = _build_judge_prompt(prompt, response, label)
+
+ try:
+ async with session.post(
+ f"{vllm_base_url}/v1/chat/completions",
+ headers={"content-type": "application/json"},
+ json={
+ "model": model_name,
+ "messages": [{"role": "user", "content": judge_prompt}],
+ "max_tokens": 100,
+ },
+ ) as resp:
+ if resp.status != 200:
+ error_text = await resp.text()
+ print(f"vLLM error: {resp.status} - {error_text}")
+ return 0
+
+ data = await resp.json()
+ score_text = data["choices"][0]["message"]["content"].strip()
+ print(f"Scored {response} with score {score_text}")
+
+ match = re.search(r"(\d+(?:\.\d+)?)", score_text)
+ if match:
+ score = float(match.group(1))
+ return min(max(score, 0), max_score) / max_score
+ return 0
+ except Exception as e:
+ print(f"Error scoring response: {e}")
+ return 0
+
+ async def score_single(
+ self,
+ model_name: str,
+ session: aiohttp.ClientSession,
+ prompt: str,
+ response: str,
+ cmudict: dict,
+ label: str = "",
+ ) -> float:
+ """Score a single haiku. Returns a score in [0, 2]."""
+ structure_score = score_haiku_structure(response, cmudict)
+
+ style_score = 0.0
+ if not self.gate_style_on_structure or structure_score >= 1.0:
+ style_score = await self.score_style(
+ model_name, session, prompt, response, label
+ )
+ style_score = max(style_score, 0.0)
+
+ total = structure_score + style_score
+ print(f"[HaikuJudge] structure={structure_score}, style={style_score}, gated={self.gate_style_on_structure}")
+ return total
diff --git a/haiku/llm_judges/deploy.py b/haiku/llm_judges/deploy.py
new file mode 100644
index 0000000..46c750a
--- /dev/null
+++ b/haiku/llm_judges/deploy.py
@@ -0,0 +1,221 @@
+"""
+LLM-as-a-Judge Reward Model for SLIME GRPO training.
+
+Uses CMUdict for syllable counting: https://github.com/cmusphinx/cmudict
+Recommended by various packages such as `syllables` and `nltk`.
+"""
+
+import asyncio
+import threading
+
+from config import ACTIVE_JUDGE_MODEL_SIZE, ACTIVE_JUDGE_TYPE, JudgeModelSize, JudgeType
+import modal
+import modal.experimental
+
+
+# =============================================================================
+
+
+# =============================================================================
+# Modal App Setup
+# =============================================================================
+
+app = modal.App(f"llm-judge-{ACTIVE_JUDGE_MODEL_SIZE.shorthand}-{ACTIVE_JUDGE_TYPE.value}")
+
+FLASH_PORT = 8000
+VLLM_PORT = 8001
+
+MODEL = ACTIVE_JUDGE_MODEL_SIZE.value
+MODEL_NAME = ACTIVE_JUDGE_MODEL_SIZE.model_name
+N_GPU = 1 if ACTIVE_JUDGE_MODEL_SIZE == JudgeModelSize.QWEN3_30B else 4
+MINUTES = 60
+
+checkpoint_volume = modal.Volume.from_name("unsloth-checkpoints")
+hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
+vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
+
+
+def _make_judge(judge_type: JudgeType):
+ from llm_judges.base import HaikuJudge
+
+ return HaikuJudge(gate_style_on_structure=(judge_type == JudgeType.STRICT_LEVELED))
+
+# =============================================================================
+
+image = (
+ modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
+ .entrypoint([])
+ .uv_pip_install(
+ "vllm==0.11.2",
+ "huggingface-hub==0.36.0",
+ "flashinfer-python==0.5.2",
+ "aiohttp>=3.9.0",
+ "pydantic>=2.0.0",
+ "fastapi[standard]>=0.115.0",
+ "uvicorn>=0.30.0",
+ "nltk>=3.8.0",
+ )
+ .env({"HF_XET_HIGH_PERFORMANCE": "1"})
+ .run_commands(
+ "python -c \"import nltk; nltk.download('cmudict'); nltk.download('punkt_tab')\""
+ )
+ .add_local_dir("llm_judges", "/root/llm_judges")
+)
+
+
+# =============================================================================
+# Modal Flash Endpoint
+# =============================================================================
+
+
+def create_fastapi_app(judge_type: JudgeType):
+ from fastapi import FastAPI
+ from pydantic import BaseModel
+ import nltk
+
+ fastapi_app = FastAPI(title="LLM Judge Reward Model", docs_url="/docs")
+ cmudict = nltk.corpus.cmudict.dict()
+ judge = _make_judge(judge_type)
+
+ class ScoreRequest(BaseModel):
+ prompt: str
+ response: str
+ label: str = ""
+
+ @fastapi_app.post("/score")
+ async def score(request: ScoreRequest) -> float:
+ max_retries = 5
+ last_error = None
+
+ for attempt in range(max_retries):
+ try:
+ return await _do_scoring(request)
+
+ except Exception as e:
+ last_error = e
+ if attempt < max_retries - 1:
+ wait_time = 2**attempt
+ print(
+ f"Scoring failed (attempt {attempt + 1}): {e}, retrying in {wait_time}s..."
+ )
+ await asyncio.sleep(wait_time)
+ else:
+ raise last_error
+
+ async def _do_scoring(request: ScoreRequest) -> float:
+ import aiohttp
+
+ prompt = request.prompt
+ response_text = request.response
+
+ if prompt is None or response_text is None:
+ return None
+
+ async with aiohttp.ClientSession() as session:
+ result = await judge.score_single(
+ MODEL_NAME, session, prompt, response_text, cmudict, label=request.label
+ )
+
+ return float(result)
+
+ @fastapi_app.get("/health")
+ def health():
+ return {"status": "ok", "model": MODEL_NAME, "judge": judge_type.value}
+
+ return fastapi_app
+
+
+@app.cls(
+ image=image,
+ gpu=f"H100:{N_GPU}",
+ min_containers=3,
+ scaledown_window=15 * MINUTES,
+ startup_timeout=15 * MINUTES,
+ volumes={
+ "/root/.cache/huggingface": hf_cache_vol,
+ "/root/.cache/vllm": vllm_cache_vol,
+ "/checkpoints": checkpoint_volume,
+ },
+ secrets=[modal.Secret.from_name("huggingface-secret")],
+ experimental_options={"flash": "us-east"},
+ region="us-east",
+)
+@modal.concurrent( # how many requests can one replica handle? tune carefully!
+ target_inputs=8
+)
+class LLMJudge:
+ """Modal Flash endpoint combining vLLM + scoring logic in one container."""
+
+ @modal.enter()
+ def setup(self):
+ import subprocess
+
+ import uvicorn
+
+ # Start vLLM on VLLM_PORT (internal)
+ cmd = [
+ "vllm",
+ "serve",
+ "--uvicorn-log-level=info",
+ MODEL,
+ "--served-model-name",
+ MODEL_NAME,
+ "--port",
+ str(VLLM_PORT),
+ "--enforce-eager",
+ "--tensor-parallel-size",
+ str(N_GPU),
+ "--max-model-len",
+ "8192",
+ ]
+ print(" ".join(cmd))
+ self._vllm_process = subprocess.Popen(" ".join(cmd), shell=True)
+
+ # Wait for vLLM to be ready
+ self._wait_for_port(VLLM_PORT, timeout=600)
+ print(f"vLLM ready on port {VLLM_PORT}")
+
+ # Start FastAPI scoring endpoint on FLASH_PORT (exposed)
+ self._fastapi_app = create_fastapi_app(ACTIVE_JUDGE_TYPE)
+ config = uvicorn.Config(
+ self._fastapi_app,
+ host="0.0.0.0",
+ port=FLASH_PORT,
+ log_level="info",
+ )
+ self._server = uvicorn.Server(config)
+ self._thread = threading.Thread(target=self._server.run, daemon=True)
+ self._thread.start()
+
+ self._wait_for_port(FLASH_PORT, timeout=30)
+ self.flash_manager = modal.experimental.flash_forward(FLASH_PORT)
+ print(f"Flash endpoint ready on port {FLASH_PORT} (judge={ACTIVE_JUDGE_TYPE.value})")
+
+ def _wait_for_port(self, port: int, timeout: int = 30):
+ import socket
+ import time
+
+ for _ in range(timeout):
+ try:
+ socket.create_connection(("localhost", port), timeout=1).close()
+ return
+ except OSError:
+ time.sleep(1)
+ raise RuntimeError(f"Server failed to start on port {port}")
+
+ @modal.method()
+ def keepalive(self):
+ pass
+
+ @modal.exit()
+ def cleanup(self):
+ if hasattr(self, "flash_manager"):
+ self.flash_manager.stop()
+ self.flash_manager.close()
+ if hasattr(self, "_server"):
+ self._server.should_exit = True
+ if hasattr(self, "_thread"):
+ self._thread.join(timeout=5)
+ if hasattr(self, "_vllm_process"):
+ self._vllm_process.terminate()
+ self._vllm_process.wait(timeout=10)
diff --git a/haiku/llm_judges/nlp.py b/haiku/llm_judges/nlp.py
new file mode 100644
index 0000000..53d7e78
--- /dev/null
+++ b/haiku/llm_judges/nlp.py
@@ -0,0 +1,94 @@
+import re
+
+
+_cmudict = None
+
+
+def _get_cmudict() -> dict:
+ import nltk
+ from nltk.corpus import cmudict as nltk_cmudict
+
+ global _cmudict
+ if _cmudict is None:
+ nltk.download("cmudict", quiet=True)
+ _cmudict = dict(nltk_cmudict.dict())
+ return _cmudict
+
+
+def lookup_word(word_s, cmudict: dict):
+ return cmudict.get(word_s, None)
+
+
+def is_acronym(word: str) -> bool:
+ return word.isupper() and 2 <= len(word) <= 6 and word.isalpha()
+
+
+def count_syllables_for_word(word, cmudict):
+ original_word = word
+ word = word.lower().strip()
+
+ phones = lookup_word(word, cmudict)
+ if phones:
+ return len([p for p in phones[0] if p[-1].isdigit()])
+
+ if is_acronym(original_word):
+ total = 0
+ for c in original_word.lower():
+ if c == "w":
+ total += 3 # "dub-ul-you"
+ else:
+ total += 1
+ return total
+
+ count = len(re.findall(r"[aeiouy]+", word))
+ if word.endswith("e") and count > 1:
+ count -= 1
+ return max(count, 1)
+
+
+def diff_syllables_count(text: str, target_syllables: int, cmudict: dict) -> int:
+ words = re.findall(r"[a-zA-Z]+", text)
+ total_syllables = sum(count_syllables_for_word(w, cmudict) for w in words)
+ return abs(total_syllables - target_syllables)
+
+
+def segment_haiku_lines(response: str) -> list[str]:
+ if "/" in response:
+ lines = [line.strip() for line in response.split("/")]
+ elif ". " in response:
+ lines = [line.strip() for line in response.split(". ")]
+ else:
+ lines = [line.strip() for line in response.split("\n")]
+ return [line for line in lines if line]
+
+
+def score_syllable_line(diff: int, allow_off_by_one: bool) -> float:
+ if diff == 0:
+ return 1.0
+ if diff == 1:
+ return 1.0 if allow_off_by_one else 0.5
+ return 0.0
+
+
+def score_haiku_structure(response: str, cmudict: dict, allow_off_by_one: bool = False) -> float:
+ """Score haiku structure (0-1): 1/4 for 3 lines + up to 1/4 per line for syllables."""
+ lines = segment_haiku_lines(response)
+ score = 0.0
+ fractional_multiplier = 0.25
+
+ if len(lines) == 3:
+ score += fractional_multiplier
+
+ targets = [5, 7, 5]
+ for i, target in enumerate(targets):
+ if i < len(lines):
+ diff = diff_syllables_count(lines[i], target, cmudict)
+ score += score_syllable_line(diff, allow_off_by_one) * fractional_multiplier
+
+ return score
+
+
+async def haiku_rm(args, sample, **kwargs) -> float:
+ cmudict = _get_cmudict()
+ allow_off_by_one = getattr(args, "haiku_allow_off_by_one", False)
+ return score_haiku_structure(sample.response, cmudict, allow_off_by_one=allow_off_by_one)
diff --git a/slime/modal_train.py b/haiku/modal_train.py
similarity index 53%
rename from slime/modal_train.py
rename to haiku/modal_train.py
index 4ecc2d3..7ac31dc 100644
--- a/slime/modal_train.py
+++ b/haiku/modal_train.py
@@ -1,37 +1,20 @@
"""
-Unified SLIME GRPO training script for Modal.
+SLIME GRPO Haiku training script for Modal.
Usage:
- # Sync training with Qwen 0.5B (multi-node)
- modal run modal_train.py::train_multi_node --config qwen-0.5b-sync
+ # deploy
+ modal deploy llm_judges.deploy
- # Async training with Qwen 4B (multi-node)
- modal run modal_train.py::train_multi_node --config qwen-4b-async
-
- # Single node training
- modal run modal_train.py::train_single_node --config qwen-0.5b-sync
-
- # Single node training with LoRA (using local slime repo)
- USE_LOCAL_SLIME=/path/to/slime modal run modal_train.py::train_single_node --config qwen-4b-lora
-
- # Download model
- modal run modal_train.py::download_model --config qwen-4b-sync
-
- # Prepare dataset
+ # Train model
+ modal run modal_train.py::download_model
modal run modal_train.py::prepare_dataset
+ modal run modal_train.py::train_single_node --run-name my-experiment
- # List available configs
- modal run modal_train.py::list_available_configs
+ # With local slime repo for development:
+ USE_LOCAL_SLIME=/path/to/slime modal run modal_train.py::train_single_node
Environment variables:
USE_LOCAL_SLIME=/path Path to local slime repo for development
- SLIME_APP_NAME=... Override Modal app name
-
-Available configs (main):
- - qwen-4b, glm-4-7, glm-4-7-flash
-
-Available configs (test-configs):
- - qwen-4b-lora (LoRA training test config)
"""
import os
@@ -40,10 +23,10 @@
from typing import Optional
import time
+from llm_judges.base import MODAL_VOCABS
import modal
-import modal.experimental
-from configs.base import RLConfig
+from config import RLConfig, get_config, ACTIVE_JUDGE_TYPE, ACTIVE_JUDGE_MODEL_SIZE
# =============================================================================
@@ -58,14 +41,17 @@
modal.Image.from_registry("slimerl/slime:nightly-dev-20260126a")
.run_commands(
"uv pip install --system git+https://github.com/huggingface/transformers.git@eebf856", # 4.54.1
+ "uv pip install --system aiohttp", # For LLM judge reward model
"""sed -i 's/AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)/AutoImageProcessor.register(config, slow_image_processor_class=image_processor, exist_ok=True)/g' /sgl-workspace/sglang/python/sglang/srt/configs/utils.py""",
# Fix rope_theta access for transformers 5.x (moved to rope_parameters dict)
r"""sed -i 's/hf_config\.rope_theta/hf_config.rope_parameters["rope_theta"]/g' /usr/local/lib/python3.12/dist-packages/megatron/bridge/models/glm/glm45_bridge.py""",
r"""sed -i 's/hf_config\.rope_theta/hf_config.rope_parameters["rope_theta"]/g' /usr/local/lib/python3.12/dist-packages/megatron/bridge/models/qwen/qwen3_bridge.py""",
)
.entrypoint([])
- .add_local_python_source("configs", copy=True)
- # .add_local_dir("test-configs", remote_path="/root/test-configs", copy=True)
+ .add_local_python_source("config", copy=True)
+ .add_local_dir("tools", remote_path="/root/tools", copy=True)
+ .add_local_dir("llm_judges", remote_path="/root/llm_judges", copy=True)
+ .pip_install("nltk>=3.8.0")
)
# Overlay local slime code for development
@@ -73,12 +59,7 @@
SLIME_DEV_PATH = "/opt/slime-dev"
if LOCAL_SLIME_PATH:
# Copy the entire slime repo (has pyproject.toml) and install it
- image = image.add_local_dir(
- LOCAL_SLIME_PATH,
- remote_path=SLIME_DEV_PATH,
- copy=True,
- ignore=["**/__pycache__", "**/*.pyc", "**/.git", "**/.venv", "**/modal"],
- ).run_commands(f"uv pip install --system -e {SLIME_DEV_PATH}")
+ image = image.add_local_dir(LOCAL_SLIME_PATH, remote_path=SLIME_DEV_PATH, copy=True, ignore=["**/__pycache__", "**/*.pyc", "**/.git", "**/.venv", "**/modal"]).run_commands(f"uv pip install --system -e {SLIME_DEV_PATH}")
else:
SLIME_DEV_PATH = None
@@ -87,34 +68,20 @@
from ray.job_submission import JobSubmissionClient
# Paths
-DATA_PATH: Path = Path("/data")
-HF_CACHE_PATH: Path = Path("/root/.cache/huggingface")
+HF_CACHE_PATH = "/root/.cache/huggingface"
+DATA_PATH: Path = Path(f"{HF_CACHE_PATH}/processed")
CHECKPOINTS_PATH: Path = Path("/checkpoints")
# Volumes
-data_volume: modal.Volume = modal.Volume.from_name(
- "grpo-slime-example-data", create_if_missing=True
-)
-hf_cache_volume: modal.Volume = modal.Volume.from_name(
- "huggingface-cache", create_if_missing=True
-)
-checkpoints_volume: modal.Volume = modal.Volume.from_name(
- "grpo-slime-checkpoints", create_if_missing=True
-)
+hf_cache_vol: modal.Volume = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
+checkpoints_volume: modal.Volume = modal.Volume.from_name("slime-haiku-checkpoints", create_if_missing=True)
# Ray configuration
RAY_PORT = 6379
RAY_DASHBOARD_PORT = 8265
SINGLE_NODE_MASTER_ADDR = "127.0.0.1"
-# =============================================================================
-# App (created dynamically based on config)
-# =============================================================================
-
-# App name from environment variable (set before running modal)
-# Usage: SLIME_APP_NAME="my-experiment" modal run modal_train.py ...
-APP_NAME = os.environ.get("SLIME_APP_NAME", "slime-grpo")
-app = modal.App(APP_NAME)
+app = modal.App(f"train-haiku-judge_{ACTIVE_JUDGE_TYPE.value}_{ACTIVE_JUDGE_MODEL_SIZE.shorthand}")
# =============================================================================
@@ -166,9 +133,7 @@ def _init_ray(rank: int, main_node_addr: str, node_ip_addr: str, n_nodes: int):
else:
raise Exception("Failed to connect to all worker nodes")
else:
- print(
- f"Starting Ray worker node at {node_ip_addr}, connecting to {main_node_addr}"
- )
+ print(f"Starting Ray worker node at {node_ip_addr}, connecting to {main_node_addr}")
subprocess.Popen(
[
"ray",
@@ -188,51 +153,29 @@ def _init_ray(rank: int, main_node_addr: str, node_ip_addr: str, n_nodes: int):
def generate_slime_cmd(
config: RLConfig,
master_addr: str,
+ experiment_name: str,
) -> tuple[str, dict]:
"""Generate the slime training command and runtime environment."""
import datetime
import random
- # Check for infinite run mode
- is_infinite_run = os.environ.get("SLIME_TEST_ENABLE_INFINITE_RUN", "0").lower() in (
- "true",
- "1",
- )
-
- # Resolve HF model path from cache (volume mounted at default HF cache dir)
- from huggingface_hub import snapshot_download
+ train_args = config.generate_train_args(DATA_PATH)
- hf_model_path = snapshot_download(repo_id=config.model_id, local_files_only=True)
-
- # Generate all training args from config
- train_args = config.generate_train_args(
- hf_model_path, CHECKPOINTS_PATH, DATA_PATH, is_infinite_run
- )
+ checkpoint_dir = CHECKPOINTS_PATH / experiment_name
+ train_args += f" --save {checkpoint_dir} --save-interval {config.save_steps if hasattr(config, 'save_steps') else 10}"
# Add wandb args if API key is available
wandb_key = os.environ.get("WANDB_API_KEY")
if wandb_key:
- run_id = (
- datetime.datetime.utcnow().strftime("%y%m%d-%H%M%S")
- + f"-{random.randint(0, 999):03d}"
- )
- wandb_run_name = (
- f"{config.wandb_run_name_prefix}_{run_id}"
- if config.wandb_run_name_prefix
- else run_id
- )
+ run_id = datetime.datetime.now(datetime.timezone.utc).strftime("%y%m%d-%H%M%S") + f"-{random.randint(0, 999):03d}"
+ wandb_run_name = f"{config.wandb_run_name_prefix}_{run_id}" if config.wandb_run_name_prefix else run_id
train_args += f" --use-wandb --wandb-project {config.wandb_project} --wandb-group {wandb_run_name} --wandb-key '{wandb_key}' --disable-wandb-random-suffix"
# Build PYTHONPATH by appending to existing (don't clobber)
import os as _os
-
existing_pythonpath = _os.environ.get("PYTHONPATH", "")
megatron_path = "/root/Megatron-LM/"
- pythonpath = (
- f"{megatron_path}:{existing_pythonpath}"
- if existing_pythonpath
- else megatron_path
- )
+ pythonpath = f"{megatron_path}:{existing_pythonpath}" if existing_pythonpath else megatron_path
runtime_env = {
"env_vars": {
@@ -250,11 +193,11 @@ def generate_slime_cmd(
# Note: config.train_script returns "slime/train.py" for base image,
# but local repo has train.py at root level
# Check at runtime if dev path exists (USE_LOCAL_SLIME is only set during image build)
- train_script = config.train_script
dev_path = "/opt/slime-dev"
if os.path.exists(dev_path):
- script_name = "train.py" if config.sync else "train_async.py"
- train_script = f"{dev_path}/{script_name}"
+ train_script = f"{dev_path}/train.py"
+ else:
+ train_script = "slime/train.py"
return f"python3 {train_script} {train_args}", runtime_env
@@ -263,16 +206,18 @@ async def run_training(
config: RLConfig,
n_nodes: int,
master_addr: str,
+ experiment_name: str,
):
"""Submit SLIME training job to Ray cluster and stream logs."""
client = JobSubmissionClient("http://127.0.0.1:8265")
- slime_cmd, runtime_env = generate_slime_cmd(config, master_addr)
+ slime_cmd, runtime_env = generate_slime_cmd(config, master_addr, experiment_name)
print("Submitting training job...")
print(f" Model: {config.model_name}")
- print(f" Mode: {'sync' if config.sync else 'async'}")
print(f" Nodes: {n_nodes}")
+ print(f" Experiment: {experiment_name}")
+ print(f" Checkpoint dir: {CHECKPOINTS_PATH / experiment_name}")
job_id = client.submit_job(entrypoint=slime_cmd, runtime_env=runtime_env)
print(f"Job submitted with ID: {job_id}")
@@ -280,6 +225,11 @@ async def run_training(
async for line in client.tail_job_logs(job_id):
print(line, end="", flush=True)
+ await checkpoints_volume.commit.aio()
+ print("Checkpoints saved and committed to volume")
+
+
+
# =============================================================================
# Modal Functions
@@ -288,55 +238,90 @@ async def run_training(
@app.function(
image=image,
- volumes={HF_CACHE_PATH.as_posix(): hf_cache_volume},
+ volumes={HF_CACHE_PATH: hf_cache_vol},
+ secrets=[modal.Secret.from_name("huggingface-secret")],
timeout=24 * 60 * 60,
)
def download_model(
- config: str = "qwen-0.5b",
revision: Optional[str] = None,
):
- """Download model from HuggingFace.
-
- Args:
- config: Config name (e.g., "qwen-0.5b", "qwen-4b")
- revision: Optional HF revision to pin
- """
+ """Download model from HuggingFace."""
from huggingface_hub import snapshot_download
- from configs import get_config
- cfg = get_config(config)
+ cfg = get_config()
- path = snapshot_download(repo_id=cfg.model_id, revision=revision)
+ path = snapshot_download(
+ repo_id=cfg.model_id,
+ revision=revision,
+ )
print(f"Model downloaded to {path}")
- hf_cache_volume.commit()
+ hf_cache_vol.commit()
+
+
@app.function(
image=image,
- volumes={DATA_PATH.as_posix(): data_volume},
+ volumes={HF_CACHE_PATH: hf_cache_vol},
+ secrets=[modal.Secret.from_name("huggingface-secret")],
timeout=24 * 60 * 60,
)
def prepare_dataset():
- """Download and prepare the GSM8K dataset."""
+ """Download and prepare the Haiku dataset."""
from datasets import load_dataset
+ from transformers import AutoTokenizer
+
+ cfg = get_config()
+
+ hf_cache_vol.reload()
+ tokenizer = AutoTokenizer.from_pretrained(cfg.model_id)
+
+ ds = load_dataset("statworx/haiku")
+
+ def format_chat_template(example, tokenizer):
+ system_prompt = f"You are a haiku poet. You will be given a prompt and you will need to write a haiku about the prompt. Try to incorporate these words into the haiku if possible: {', '.join(MODAL_VOCABS)}"
+
+ keyword = example['keywords'].lower()
+ question = f"Write me a haiku about {keyword}."
+
+ messages = [
+ {"content": system_prompt, "role": "system"},
+ {"content": question, "role": "user"},
+ ]
+
+ return {
+ "question": question,
+ "label": example["text"],
+ "messages": messages,
+ "prompt": tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=False),
+ }
+
+ # this dataset only has "train", but no "test", so we manually split out the last 20% of the train dataset as test
+ # and remove them from the train dataset
+ test_size = min(1000, int(len(ds["train"]) * 0.2))
+ test_ds = ds["train"].select(range(len(ds["train"]) - test_size, len(ds["train"])))
+ ds["train"] = ds["train"].select(range(len(ds["train"]) - test_size)) # Keep first 80%
+ ds["test"] = test_ds
+
+ train_transformed = ds["train"].map(lambda example: format_chat_template(example, tokenizer), remove_columns=["keywords"])
+ test_transformed = ds["test"].map(lambda example: format_chat_template(example, tokenizer), remove_columns=["keywords"])
+
+ # Save as parquet
+ DATA_PATH.mkdir(parents=True, exist_ok=True)
+ (DATA_PATH / "haiku").mkdir(parents=True, exist_ok=True)
+ train_transformed.to_parquet(f"{DATA_PATH}/haiku/train.parquet")
+ test_transformed.to_parquet(f"{DATA_PATH}/haiku/test.parquet")
+
+ hf_cache_vol.commit()
+ print("Haiku dataset prepared successfully")
+ print(f"Train examples: {len(train_transformed)}")
+ print(f"Test examples: {len(test_transformed)}")
+ print("\nExample:")
+ print(f"Prompt: {train_transformed[0]['question']}")
+ print(f"Text: {train_transformed[0]['label']}")
- data_volume.reload()
- dataset = load_dataset("zhuzilin/gsm8k")
- dataset["train"].to_parquet(f"{DATA_PATH}/gsm8k/train.parquet")
- dataset["test"].to_parquet(f"{DATA_PATH}/gsm8k/test.parquet")
- data_volume.commit()
- print("Dataset prepared successfully")
-
-
-@app.local_entrypoint()
-def list_available_configs():
- """List all available training configs."""
- from configs import list_configs
- print("Available configs:")
- for name in list_configs():
- print(f" - {name}")
# =============================================================================
@@ -348,86 +333,39 @@ def list_available_configs():
image=image,
gpu="H200:8",
volumes={
- HF_CACHE_PATH.as_posix(): hf_cache_volume,
- CHECKPOINTS_PATH.as_posix(): checkpoints_volume,
- DATA_PATH.as_posix(): data_volume,
- },
- secrets=[
- modal.Secret.from_name("wandb-secret"),
- ],
- timeout=24 * 60 * 60,
- experimental_options={
- "efa_enabled": True,
- },
-)
-@modal.experimental.clustered(
- 4, rdma=True
-)
-async def train_multi_node(config: str = "qwen-0.5b-sync"):
- """Main entry point for multi-node GRPO training on Modal.
-
- Args:
- config: Config name (e.g., "qwen-0.5b-sync", "qwen-4b-async")
- """
- from configs import get_config
-
- cfg = get_config(config)
-
- hf_cache_volume.reload()
- data_volume.reload()
-
- cluster_info = modal.experimental.get_cluster_info()
- print(f"Rank: {cluster_info.rank}, task id: {os.environ['MODAL_TASK_ID']}")
- print(f"Config: {config}")
- print(f"Container IPv4 IPs: {cluster_info.container_ipv4_ips}")
-
- ray_main_node_addr = cluster_info.container_ipv4_ips[0]
- my_ip_addr = cluster_info.container_ipv4_ips[cluster_info.rank]
- n_nodes = len(cluster_info.container_ipv4_ips)
-
- _init_ray(cluster_info.rank, ray_main_node_addr, my_ip_addr, n_nodes)
-
- if cluster_info.rank == 0:
- with modal.forward(RAY_DASHBOARD_PORT) as tunnel:
- print(f"Dashboard URL: {tunnel.url}")
- await run_training(cfg, n_nodes, ray_main_node_addr)
- else:
- while True:
- time.sleep(10)
-
-
-@app.function(
- image=image,
- gpu="H200:8",
- volumes={
- HF_CACHE_PATH.as_posix(): hf_cache_volume,
+ HF_CACHE_PATH: hf_cache_vol,
CHECKPOINTS_PATH.as_posix(): checkpoints_volume,
- DATA_PATH.as_posix(): data_volume,
},
secrets=[
+ modal.Secret.from_name("huggingface-secret"),
modal.Secret.from_name("wandb-secret"),
+ modal.Secret.from_name("anthropic-secret"),
],
timeout=24 * 60 * 60,
experimental_options={
"efa_enabled": True,
},
)
-async def train_single_node(config: str = "qwen-0.5b-sync"):
- """Single-node GRPO training on Modal.
+async def train(
+ run_name: str = "qwen3-4b-haiku",
+ judge_type = ACTIVE_JUDGE_TYPE,
+ judge_model_size = ACTIVE_JUDGE_MODEL_SIZE,
+):
+ """Single-node GRPO training on Modal."""
+ from datetime import datetime
- Args:
- config: Config name (e.g., "qwen-0.5b-sync", "qwen-4b-async"). File name with underscores replaced with dashes.
- """
- from configs import get_config
+ cfg = get_config(run_name=run_name, judge_type=judge_type, judge_model_size=judge_model_size)
- cfg = get_config(config)
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
+ model_short = cfg.model_name.split("/")[-1]
+ experiment_name = f"{run_name}-{model_short}-{timestamp}"
- hf_cache_volume.reload()
- data_volume.reload()
+ await hf_cache_vol.reload.aio()
+ await checkpoints_volume.reload.aio()
_init_ray(0, SINGLE_NODE_MASTER_ADDR, SINGLE_NODE_MASTER_ADDR, 1)
- with modal.forward(RAY_DASHBOARD_PORT) as tunnel:
+ async with modal.forward(RAY_DASHBOARD_PORT) as tunnel:
print(f"Dashboard URL: {tunnel.url}")
- print(f"Config: {config}")
- await run_training(cfg, 1, SINGLE_NODE_MASTER_ADDR)
+ print(f"Experiment: {experiment_name}")
+ await run_training(cfg, 1, SINGLE_NODE_MASTER_ADDR, experiment_name)
\ No newline at end of file
diff --git a/slime/run.sh b/haiku/run.sh
similarity index 100%
rename from slime/run.sh
rename to haiku/run.sh
diff --git a/haiku/tools/README.md b/haiku/tools/README.md
new file mode 100644
index 0000000..e8fd268
--- /dev/null
+++ b/haiku/tools/README.md
@@ -0,0 +1 @@
+For some reason, these files don't exist on the nightly image.
\ No newline at end of file
diff --git a/haiku/tools/convert_torch_dist_to_hf.py b/haiku/tools/convert_torch_dist_to_hf.py
new file mode 100644
index 0000000..8979e67
--- /dev/null
+++ b/haiku/tools/convert_torch_dist_to_hf.py
@@ -0,0 +1,216 @@
+import argparse
+import json
+import os
+import pickle
+import re
+import shutil
+import time
+
+
+import safetensors.torch
+import torch
+import torch.distributed.checkpoint as dist_cp
+from transformers import AutoConfig
+from typing_extensions import override
+
+from slime.backends.megatron_utils.megatron_to_hf import convert_to_hf, remove_padding
+
+
+class UnpicklerWrapper(pickle.Unpickler):
+ @override
+ def find_class(self, mod_name, name):
+ class DummyClass:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ if mod_name.startswith("megatron") or mod_name.startswith("glm"):
+ return DummyClass
+ return super().find_class(mod_name, name)
+
+
+pickle.Unpickler = UnpicklerWrapper
+
+
+class WrappedStorageReader(dist_cp.FileSystemReader):
+ @override
+ def read_metadata(self):
+ path = self.fs.concat_path(self.path, ".metadata")
+ with self.fs.create_stream(path, "rb") as metadata_file:
+ metadata = UnpicklerWrapper(metadata_file).load()
+ if getattr(metadata, "storage_meta", None) is None:
+ metadata.storage_meta = dist_cp.StorageMeta()
+ metadata.storage_meta.load_id = self.load_id
+ if metadata.planner_data is None:
+ metadata.planner_data = {}
+ return metadata
+
+
+class EmptyStateDictLoadPlanner(dist_cp.default_planner.DefaultLoadPlanner):
+ @override
+ def set_up_planner(
+ self,
+ state_dict: dist_cp.metadata.STATE_DICT_TYPE,
+ metadata: dist_cp.metadata.Metadata | None = None,
+ is_coordinator: bool = False,
+ ) -> None:
+ for k, v in metadata.state_dict_metadata.items():
+ if "optimizer" in k or "_state" in k:
+ continue
+ print(f"find {k} in torch_dist ckpt")
+ if isinstance(v, dist_cp.metadata.TensorStorageMetadata):
+ v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
+ state_dict[k] = v
+ super().set_up_planner(state_dict, metadata, is_coordinator)
+
+
+def get_expert_param(args, name, param):
+ if ".experts." not in name:
+ yield name, param
+ return
+
+ num_experts = args.num_experts
+ match = re.search(r"mlp.experts\.(.+)\.weight(\d+)", name)
+ if not match:
+ assert param.shape[0] == num_experts
+ for expert_id in range(num_experts):
+ expert_name = name.replace(".experts.experts.", ".experts.") + str(expert_id)
+ expert_param = param[expert_id]
+ yield expert_name, expert_param
+ else:
+ yield name, param
+
+
+def get_layer_param(args, name, param):
+ if ".layers." not in name:
+ yield name, param
+ return
+
+ num_layers = args.num_layers
+ match = re.search(r"\.layers\.(\d+)\.", name)
+ if not match:
+ assert param.shape[0] == num_layers
+ for layer_id in range(num_layers):
+ layer_name = name.replace(".layers.", f".layers.{layer_id}.")
+ layer_param = param[layer_id]
+ yield from get_expert_param(args, layer_name, layer_param)
+ else:
+ yield from get_expert_param(args, name, param)
+
+
+def get_named_params(args, state_dict):
+ for name, param in state_dict.items():
+ name = f"module.module.{name}"
+ yield from get_layer_param(args, name, param)
+
+
+def save_tensors(args, model_name, state_dict, output_dir, chunk_size, vocab_size=None):
+ # for slime update_weight compatible
+ args.sglang_enable_ep_moe = False
+
+ print(f"start saving to {output_dir}")
+ os.makedirs(output_dir, exist_ok=True)
+ # 2GB
+ current_size = 0
+ total_size = 0
+ modeltensors = [{}]
+ for name, param in get_named_params(args, state_dict):
+ if vocab_size:
+ param = remove_padding(name, param, vocab_size)
+ converted_named_tensors = convert_to_hf(args, model_name, name, param)
+ for converted_name, converted_param in converted_named_tensors:
+ tensor_size = converted_param.numel() * converted_param.element_size()
+ if tensor_size + current_size > chunk_size:
+ modeltensors.append({})
+ current_size = 0
+ modeltensors[-1][converted_name] = converted_param
+ current_size += tensor_size
+ total_size += tensor_size
+
+ metadata = {"metadata": {"total_size": total_size}, "weight_map": {}}
+
+ num_files = len(modeltensors)
+ for i, tensors in enumerate(modeltensors):
+ filename = f"model-{i:05d}-of-{num_files:05d}.safetensors"
+ for key in tensors.keys():
+ metadata["weight_map"][key] = filename
+ index_filepath = os.path.join(output_dir, "model.safetensors.index.json")
+ json.dump(metadata, open(index_filepath, "w"), indent=2)
+ print(f"{index_filepath} saved.")
+
+ for i, tensors in enumerate(modeltensors):
+ filename = f"model-{i:05d}-of-{num_files:05d}.safetensors"
+ t = time.time()
+ filepath = os.path.join(output_dir, filename)
+ safetensors.torch.save_file(tensors, filepath)
+ print(f"{filename} saved in {time.time() - t:.2f} sec.")
+
+
+def copy_assets(origin_hf_dir, output_dir):
+ for filename in os.listdir(origin_hf_dir):
+ if filename == "model.safetensors.index.json" or filename.endswith(".safetensors"):
+ continue
+ origin_filename = os.path.join(origin_hf_dir, filename)
+ if not os.path.isfile(origin_filename):
+ print(f"Skip {filename}, not a file.")
+ continue
+ src, dst = origin_filename, os.path.join(output_dir, filename)
+ print(f"copy from {src} to {dst}")
+ shutil.copy(src, dst)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-name", type=str, default=None)
+ parser.add_argument("--input-dir", type=str, required=True)
+ parser.add_argument("--output-dir", type=str, required=True)
+ parser.add_argument(
+ "--origin-hf-dir",
+ type=str,
+ default=None,
+ help="use the origin hf dir to copy files like tokenizer, config.json, etc.",
+ )
+ parser.add_argument(
+ "-f", "--force", action="store_true", help="Force overwrite the output directory if it exists."
+ )
+ parser.add_argument(
+ "--chunk-size",
+ type=int,
+ default=5 * 1024**3,
+ help="Chunk size for saving tensors, default is 2GB.",
+ )
+ parser.add_argument(
+ "--vocab-size",
+ type=int,
+ default=None,
+ help="Vocab size for removing padding, if applicable. If not provided, no padding will be removed.",
+ )
+ args = parser.parse_args()
+
+ if os.path.exists(args.output_dir) and not args.force:
+ raise ValueError(f"Output directory {args.output_dir} already exists. Use --force to overwrite it.")
+
+ if args.model_name is None and args.origin_hf_dir is None:
+ raise ValueError(
+ "Either --model-name or --origin-hf-dir must be provided, so that we can know the name of the params."
+ )
+
+ if args.model_name is None:
+ hf_config = AutoConfig.from_pretrained(args.origin_hf_dir, trust_remote_code=True)
+ args.model_name = type(hf_config).__name__.lower()
+
+ state_dict = {}
+ print(f"loading model from {args.input_dir}")
+ t = time.time()
+ megatron_args = torch.load(os.path.join(args.input_dir, "common.pt"), weights_only=False)["args"]
+ dist_cp.state_dict_loader._load_state_dict(
+ state_dict,
+ storage_reader=WrappedStorageReader(args.input_dir),
+ planner=EmptyStateDictLoadPlanner(),
+ no_dist=True,
+ )
+ print(f"model loaded in {time.time()-t:.2f} sec.")
+
+ save_tensors(megatron_args, args.model_name, state_dict, args.output_dir, args.chunk_size, args.vocab_size)
+
+ if args.origin_hf_dir:
+ copy_assets(args.origin_hf_dir, args.output_dir)
\ No newline at end of file