Skip to content

Commit d7613f4

Browse files
[sft] load chat_template.jinja if available (#509)
1 parent 876074d commit d7613f4

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

apps/sft/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def setup_data(self):
154154
generation_config_path=os.path.join(
155155
self.job_config.model.hf_assets_path, "generation_config.json"
156156
),
157+
chat_template_path=(
158+
path
159+
if os.path.exists(
160+
path := os.path.join(
161+
self.job_config.model.hf_assets_path, "chat_template.jinja"
162+
)
163+
)
164+
else None
165+
),
157166
)
158167

159168
dataset = sft_iterable_dataset(

src/forge/data/tokenizer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ class HuggingFaceModelTokenizer(ModelTokenizer):
215215
Args:
216216
tokenizer_json_path (str): Path to tokenizer.json file
217217
tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None
218-
generation_config_path (str | None): Path to generation_config.json file.
219-
Default: None
218+
generation_config_path (str | None): Path to generation_config.json file. Default: None
219+
chat_template_path (str | None): Path to chat_template.jinja file. Default: None
220220
truncation_type (str): type of truncation to apply, either "left" or "right".
221221
Default is "right".
222222
"""
@@ -227,6 +227,7 @@ def __init__(
227227
*,
228228
tokenizer_config_json_path: str | None = None,
229229
generation_config_path: str | None = None,
230+
chat_template_path: str | None = None,
230231
truncation_type: str = "right",
231232
):
232233
self.base_tokenizer = HuggingFaceBaseTokenizer(
@@ -245,7 +246,13 @@ def __init__(
245246

246247
# It is used sometimes in HF chat_templates
247248
_env.globals["raise_exception"] = self._raise_helper
248-
self.template = _env.from_string(config["chat_template"])
249+
250+
if chat_template_path:
251+
with open(chat_template_path, "r") as f:
252+
self.template = _env.from_string(f.read())
253+
else:
254+
self.template = _env.from_string(config["chat_template"])
255+
249256
self.truncation_type = truncation_type
250257

251258
self.special_tokens_mapping = {}

0 commit comments

Comments
 (0)