Skip to content

Commit 7f2f080

Browse files
author
pytorchbot
committed
2025-11-04 nightly release (d7613f4)
1 parent fb9fed3 commit 7f2f080

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

apps/grpo/main.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ class DatasetActor(ForgeActor):
212212
@endpoint
213213
def setup(self):
214214
self._tokenizer = get_tokenizer(self.model)
215+
self._epoch = 0
215216

216217
def gsm8k_transform(sample):
217218
system_prompt = """
@@ -232,12 +233,12 @@ def gsm8k_transform(sample):
232233
formatted_target = target.split("#### ")[1]
233234
return {"request": formatted_request, "target": formatted_target}
234235

235-
ds = load_dataset(
236+
self._base_dataset = load_dataset(
236237
self.path, self.revision, split=self.data_split, streaming=self.streaming
237238
)
238-
ds = ds.map(gsm8k_transform)
239-
ds = ds.shuffle()
240-
self._iterator = iter(ds)
239+
self._base_dataset = self._base_dataset.map(gsm8k_transform)
240+
self._base_dataset = self._base_dataset.shuffle()
241+
self._iterator = iter(self._base_dataset)
241242

242243
@endpoint
243244
async def sample(self) -> dict[str, str] | None:
@@ -250,10 +251,18 @@ async def sample(self) -> dict[str, str] | None:
250251
len(sample["request"]),
251252
Reduce.MEAN,
252253
)
254+
record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX)
253255

254256
return sample
255257
except StopIteration:
256-
return None
258+
# Restart iterator for next epoch with reshuffling
259+
self._epoch += 1
260+
print(
261+
f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}"
262+
)
263+
self._base_dataset.set_epoch(self._epoch)
264+
self._iterator = iter(self._base_dataset)
265+
return next(self._iterator)
257266

258267
@endpoint
259268
async def pad_token(self):

apps/sft/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def setup_data(self):
154154
generation_config_path=os.path.join(
155155
self.job_config.model.hf_assets_path, "generation_config.json"
156156
),
157+
chat_template_path=(
158+
path
159+
if os.path.exists(
160+
path := os.path.join(
161+
self.job_config.model.hf_assets_path, "chat_template.jinja"
162+
)
163+
)
164+
else None
165+
),
157166
)
158167

159168
dataset = sft_iterable_dataset(

src/forge/data/tokenizer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ class HuggingFaceModelTokenizer(ModelTokenizer):
215215
Args:
216216
tokenizer_json_path (str): Path to tokenizer.json file
217217
tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None
218-
generation_config_path (str | None): Path to generation_config.json file.
219-
Default: None
218+
generation_config_path (str | None): Path to generation_config.json file. Default: None
219+
chat_template_path (str | None): Path to chat_template.jinja file. Default: None
220220
truncation_type (str): type of truncation to apply, either "left" or "right".
221221
Default is "right".
222222
"""
@@ -227,6 +227,7 @@ def __init__(
227227
*,
228228
tokenizer_config_json_path: str | None = None,
229229
generation_config_path: str | None = None,
230+
chat_template_path: str | None = None,
230231
truncation_type: str = "right",
231232
):
232233
self.base_tokenizer = HuggingFaceBaseTokenizer(
@@ -245,7 +246,13 @@ def __init__(
245246

246247
# It is used sometimes in HF chat_templates
247248
_env.globals["raise_exception"] = self._raise_helper
248-
self.template = _env.from_string(config["chat_template"])
249+
250+
if chat_template_path:
251+
with open(chat_template_path, "r") as f:
252+
self.template = _env.from_string(f.read())
253+
else:
254+
self.template = _env.from_string(config["chat_template"])
255+
249256
self.truncation_type = truncation_type
250257

251258
self.special_tokens_mapping = {}

0 commit comments

Comments
 (0)