Skip to content

Commit 0fd6345

Browse files
author
pytorchbot
committed
2025-12-02 nightly release (1deab13)
1 parent 0d6dd4b commit 0fd6345

File tree

13 files changed

+853
-84
lines changed

13 files changed

+853
-84
lines changed

.meta/mast/client_bootstrap.sh

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,5 @@ fi
4141

4242
cd "$WORKSPACE_DIR/forge"
4343

44-
export WANDB_MODE=offline
45-
export HF_HUB_OFFLINE=1
46-
export MONARCH_HOST_MESH_V1_REMOVE_ME_BEFORE_RELEASE=1
47-
export TORCHSTORE_RDMA_ENABLED=1
48-
export HF_HOME=/mnt/wsfuse/teamforge/hf
49-
5044
# Execute the client training script with all passed arguments
5145
exec python -X faulthandler .meta/mast/main.py "$@"

apps/grpo/main.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
import torch.nn.functional as F
1717
import torchstore as ts
18+
import yaml
1819
from datasets import load_dataset
1920
from forge.actors._torchstore_utils import (
2021
get_dcp_whole_state_dict_key,
@@ -26,18 +27,21 @@
2627
from forge.actors.trainer import TitanTrainer
2728
from forge.controller.actor import ForgeActor
2829
from forge.controller.provisioner import init_provisioner, shutdown
29-
from forge.data.rewards import MathReward, ThinkingReward
30+
from forge.data.rewards import LanguageReward, MathReward, ThinkingReward
3031
from forge.data_models.completion import Completion
3132
from forge.observability.metric_actors import get_or_create_metric_logger
3233
from forge.observability.metrics import record_metric, Reduce
3334
from forge.observability.perf_tracker import Tracer
3435
from forge.types import LauncherConfig, ProvisionerConfig
3536
from forge.util.config import parse
37+
from forge.util.logging import get_logger
3638
from forge.util.ops import compute_logprobs
3739
from monarch.actor import endpoint
38-
from omegaconf import DictConfig
40+
from omegaconf import DictConfig, OmegaConf
3941
from vllm.transformers_utils.tokenizer import get_tokenizer
4042

43+
logger = get_logger("INFO")
44+
4145

4246
@dataclass
4347
class Episode:
@@ -46,10 +50,13 @@ class Episode:
4650
request_len: int
4751
response_len: int
4852
target: Any | None = None
53+
request: str | None = None
54+
response: str | None = None
4955
# Processed data
5056
completion: Completion | None = None
5157
ref_logprobs: torch.Tensor | None = None
5258
reward: float | None = None
59+
reward_breakdown: dict[str, float] | None = None
5360
advantage: float | None = None
5461

5562
@property
@@ -72,6 +79,32 @@ def response_tensor(self) -> torch.Tensor:
7279
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
7380
return tensor
7481

82+
def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
83+
"""Convert episode to dict, optionally excluding specified fields."""
84+
result = {
85+
"episode_id": self.episode_id,
86+
"policy_version": self.policy_version,
87+
"prompt": self.request,
88+
"response": self.response,
89+
"target": str(self.target),
90+
"reward": self.reward,
91+
"advantage": self.advantage,
92+
"request_len": self.request_len,
93+
"response_len": self.response_len,
94+
"pad_id": self.pad_id,
95+
"ref_logprobs": self.ref_logprobs,
96+
"completion": self.completion,
97+
}
98+
99+
if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
100+
result.update(self.reward_breakdown)
101+
102+
if exclude:
103+
for key in exclude:
104+
result.pop(key, None)
105+
106+
return result
107+
75108

76109
# Represents the group (G) of episodes in GRPO
77110
Group = list[Episode]
@@ -129,7 +162,7 @@ def simple_grpo_loss(
129162
ref_logprobs: torch.Tensor,
130163
advantages: torch.Tensor,
131164
padding_mask: torch.Tensor,
132-
beta: float = 0.1,
165+
beta: float = 1e-6,
133166
) -> torch.Tensor:
134167
logprobs: torch.Tensor = compute_logprobs(logits, response)
135168
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
@@ -166,8 +199,11 @@ class RewardActor(ForgeActor):
166199
reward_functions: list[Callable]
167200

168201
@endpoint
169-
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
202+
async def evaluate_response(
203+
self, prompt: str, response: str, target: str
204+
) -> (dict[str, float], float):
170205
total_rewards = 0.0
206+
reward_breakdown = {} # reward breakdown by function
171207
for reward_fn in self.reward_functions:
172208
reward = reward_fn(prompt, response, target)
173209
total_rewards += reward
@@ -176,6 +212,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
176212
reward_fn_name = getattr(
177213
reward_fn, "__name__", reward_fn.__class__.__name__
178214
)
215+
reward_breakdown[reward_fn_name] = reward
179216
# per function reward
180217
record_metric(
181218
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
@@ -205,8 +242,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
205242
Reduce.SUM,
206243
)
207244

208-
avg_reward = total_rewards / len(self.reward_functions)
209-
return avg_reward
245+
avg_reward: float = total_rewards / len(self.reward_functions)
246+
return reward_breakdown, avg_reward
210247

211248

212249
@dataclass
@@ -237,10 +274,15 @@ async def setup(self):
237274
self._epoch = 0
238275

239276
def gsm8k_transform(sample):
240-
system_prompt = """
241-
Put all your scratchpad work between <think> and </think> tags.
242-
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
243-
"""
277+
system_prompt = """You are a helpful AI assistant that solves math problems.
278+
279+
Please show your reasoning inside <思考></思考> tags, then provide your final numerical answer inside <answer></answer> tags.
280+
281+
Example:
282+
Question: What is 12 + 5?
283+
<思考>12と5を足します。12 + 5 = 17です。</思考>
284+
<answer>17</answer>
285+
"""
244286
request: str = sample["question"]
245287
as_chat = [
246288
{"role": "system", "content": system_prompt},
@@ -320,9 +362,14 @@ async def drop_weights(version: int):
320362

321363
async def main(cfg: DictConfig):
322364
"""Main GRPO training loop with rollout and training processes."""
323-
group_size = cfg.group_size
324-
max_req_tokens = cfg.max_req_tokens
325-
max_res_tokens = cfg.max_res_tokens
365+
# Convert OmegaConf config to plain dict
366+
run_config_for_logging = OmegaConf.to_container(cfg, resolve=True)
367+
368+
# Log config
369+
logger.info("=" * 30 + " CONFIGURATION " + "=" * 30)
370+
logger.info(
371+
yaml.dump(run_config_for_logging, default_flow_style=False, sort_keys=False)
372+
)
326373

327374
# ---- Global setups ---- #
328375
provisioner = None
@@ -334,8 +381,11 @@ async def main(cfg: DictConfig):
334381
provisioner = await init_provisioner()
335382

336383
metric_logging_cfg = cfg.get("metric_logging", {})
384+
337385
mlogger = await get_or_create_metric_logger(process_name="Controller")
338-
await mlogger.init_backends.call_one(metric_logging_cfg)
386+
await mlogger.init_backends.call_one(
387+
backend_config=metric_logging_cfg, run_config=run_config_for_logging
388+
)
339389

340390
# ---- Setup services ---- #
341391

@@ -359,10 +409,24 @@ async def main(cfg: DictConfig):
359409
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
360410
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
361411
RewardActor.options(**cfg.services.reward_actor).as_service(
362-
reward_functions=[MathReward(), ThinkingReward()]
412+
reward_functions=[
413+
MathReward(),
414+
ThinkingReward(tag="思考"), # Use Japanese tag
415+
LanguageReward(
416+
target_language="ja",
417+
tag="思考",
418+
match_reward=2.0,
419+
debug=False, # set to true for verbose logging
420+
debug_sample_rate=0.1,
421+
), # Japanese language reward with debug
422+
]
363423
),
364424
)
365425

426+
group_size = cfg.group_size
427+
max_req_tokens = cfg.max_req_tokens
428+
max_res_tokens = cfg.max_res_tokens
429+
366430
# Set max_steps to the configured value, or -1 if not specified or Null
367431
max_steps = cfg.trainer.training.steps or -1
368432

@@ -413,9 +477,14 @@ async def continuous_rollouts():
413477
request_len=max_req_tokens,
414478
response_len=max_res_tokens,
415479
target=target,
480+
request=prompt,
481+
response=response.text,
416482
completion=response,
417483
)
418-
episode.reward = await reward_actor.evaluate_response.route(
484+
(
485+
episode.reward_breakdown,
486+
episode.reward,
487+
) = await reward_actor.evaluate_response.route(
419488
prompt=prompt, response=response.text, target=target
420489
)
421490
episodes.append(episode)
@@ -456,6 +525,14 @@ async def continuous_rollouts():
456525
episode.advantage = advantage
457526
await replay_buffer.add.call_one(episode)
458527

528+
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
529+
sample["score"] = sample["reward"]
530+
record_metric(
531+
"main_samples/continuous_rollouts/sample_table",
532+
sample,
533+
Reduce.SAMPLE,
534+
)
535+
459536
rollout_count += 1
460537
record_metric(
461538
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM

apps/grpo/qwen3_1_7b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
group_size: 8
66
local_batch_size: 16 # per-device batch size
77
max_req_tokens: 1024
8-
max_res_tokens: 1024
8+
max_res_tokens: 2048
99
model: "Qwen/Qwen3-1.7B"
1010
off_by_n: 1 # Off by one by default
1111

apps/grpo/qwen3_8b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
33

44
# Global configuration
5-
group_size: 8
6-
local_batch_size: 12 # per-device batch size
5+
group_size: 16
6+
local_batch_size: 4 # per-device batch size
77
max_req_tokens: 1024
8-
max_res_tokens: 1024
8+
max_res_tokens: 2048
99
model: "Qwen/Qwen3-8B"
1010
off_by_n: 1 # Off by one by default
1111

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dev = [
4747
"anyio",
4848
"pytest-asyncio",
4949
"multiprocess",
50+
"langid",
5051
]
5152
docs = [
5253
"sphinx==7.2.6",

0 commit comments

Comments
 (0)