Skip to content

Commit 4a7281e

Browse files
author
Allen Wang
committed
Merge branch 'main' into weight_sync
2 parents 252b2b9 + fa456c7 commit 4a7281e

File tree

8 files changed

+38
-15
lines changed

8 files changed

+38
-15
lines changed

.meta/mast/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ This only applies to Meta internal users.
55

66
## Quick Start
77

8-
⚠️ Important Note: the setup script will clone the forge repository under "/data/users/$USER".
9-
108
### 1. Run the Setup Script
119

1210
The `env_setup.sh` script will automatically:

apps/grpo/main.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ class DatasetActor(ForgeActor):
212212
@endpoint
213213
def setup(self):
214214
self._tokenizer = get_tokenizer(self.model)
215+
self._epoch = 0
215216

216217
def gsm8k_transform(sample):
217218
system_prompt = """
@@ -232,12 +233,12 @@ def gsm8k_transform(sample):
232233
formatted_target = target.split("#### ")[1]
233234
return {"request": formatted_request, "target": formatted_target}
234235

235-
ds = load_dataset(
236+
self._base_dataset = load_dataset(
236237
self.path, self.revision, split=self.data_split, streaming=self.streaming
237238
)
238-
ds = ds.map(gsm8k_transform)
239-
ds = ds.shuffle()
240-
self._iterator = iter(ds)
239+
self._base_dataset = self._base_dataset.map(gsm8k_transform)
240+
self._base_dataset = self._base_dataset.shuffle()
241+
self._iterator = iter(self._base_dataset)
241242

242243
@endpoint
243244
async def sample(self) -> dict[str, str] | None:
@@ -250,10 +251,18 @@ async def sample(self) -> dict[str, str] | None:
250251
len(sample["request"]),
251252
Reduce.MEAN,
252253
)
254+
record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX)
253255

254256
return sample
255257
except StopIteration:
256-
return None
258+
# Restart iterator for next epoch with reshuffling
259+
self._epoch += 1
260+
print(
261+
f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}"
262+
)
263+
self._base_dataset.set_epoch(self._epoch)
264+
self._iterator = iter(self._base_dataset)
265+
return next(self._iterator)
257266

258267
@endpoint
259268
async def pad_token(self):

apps/grpo/qwen3_8b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# Global configuration
55
group_size: 8
6-
local_batch_size: 16 # per-device batch size
6+
local_batch_size: 12 # per-device batch size
77
max_req_tokens: 1024
88
max_res_tokens: 1024
99
model: "Qwen/Qwen3-8B"

apps/sft/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def setup_data(self):
154154
generation_config_path=os.path.join(
155155
self.job_config.model.hf_assets_path, "generation_config.json"
156156
),
157+
chat_template_path=(
158+
path
159+
if os.path.exists(
160+
path := os.path.join(
161+
self.job_config.model.hf_assets_path, "chat_template.jinja"
162+
)
163+
)
164+
else None
165+
),
157166
)
158167

159168
dataset = sft_iterable_dataset(

docs/source/getting_started.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ After installation, verify that all components are working correctly:
124124
125125
# Test basic Monarch functionality
126126
procs = this_host().spawn_procs({'gpus': 1})
127+
procs.initialized.get()
127128
print('Monarch: Process spawning works')
128129
"
129130
```

src/forge/actors/trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ async def train_step(
176176

177177
# TODO: delete item() to avoid cpu-gpu sync
178178
loss = loss.detach().item()
179-
record_metric("rl_trainer/count_training_steps", 1, Reduce.SUM)
180-
record_metric("rl_trainer/avg_grpo_loss", loss, Reduce.MEAN)
179+
record_metric("rl_trainer/avg_loss", loss, Reduce.MEAN)
181180

182181
# These are placeholder values until the loss function exposes these metrics
183182
# record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN)

src/forge/data/tokenizer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ class HuggingFaceModelTokenizer(ModelTokenizer):
215215
Args:
216216
tokenizer_json_path (str): Path to tokenizer.json file
217217
tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None
218-
generation_config_path (str | None): Path to generation_config.json file.
219-
Default: None
218+
generation_config_path (str | None): Path to generation_config.json file. Default: None
219+
chat_template_path (str | None): Path to chat_template.jinja file. Default: None
220220
truncation_type (str): type of truncation to apply, either "left" or "right".
221221
Default is "right".
222222
"""
@@ -227,6 +227,7 @@ def __init__(
227227
*,
228228
tokenizer_config_json_path: str | None = None,
229229
generation_config_path: str | None = None,
230+
chat_template_path: str | None = None,
230231
truncation_type: str = "right",
231232
):
232233
self.base_tokenizer = HuggingFaceBaseTokenizer(
@@ -245,7 +246,13 @@ def __init__(
245246

246247
# It is used sometimes in HF chat_templates
247248
_env.globals["raise_exception"] = self._raise_helper
248-
self.template = _env.from_string(config["chat_template"])
249+
250+
if chat_template_path:
251+
with open(chat_template_path, "r") as f:
252+
self.template = _env.from_string(f.read())
253+
else:
254+
self.template = _env.from_string(config["chat_template"])
255+
249256
self.truncation_type = truncation_type
250257

251258
self.special_tokens_mapping = {}

src/forge/observability/metric_actors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,12 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
437437
await backend.log_batch(reduced_metrics, global_step)
438438

439439
@endpoint
440-
def has_fetcher(self, proc_id: str) -> bool:
440+
async def has_fetcher(self, proc_id: str) -> bool:
441441
"""Check if a fetcher is registered with the given proc_id."""
442442
return proc_id in self.fetchers
443443

444444
@endpoint
445-
def get_fetcher_count(self) -> int:
445+
async def get_fetcher_count(self) -> int:
446446
return len(self.fetchers)
447447

448448
@endpoint

0 commit comments

Comments
 (0)