Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def setup_data(self):
generation_config_path=os.path.join(
self.job_config.model.hf_assets_path, "generation_config.json"
),
chat_template_path=(
path if os.path.exists(path := os.path.join(self.job_config.model.hf_assets_path, "chat_template.jinja")) else None
),
)

dataset = sft_iterable_dataset(
Expand Down
7 changes: 7 additions & 0 deletions src/forge/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
*,
tokenizer_config_json_path: str | None = None,
generation_config_path: str | None = None,
chat_template_path: str | None = None,
):
self.tokenizer = Tokenizer.from_file(tokenizer_json_path)
if not (tokenizer_config_json_path or generation_config_path):
Expand All @@ -51,6 +52,10 @@ def __init__(
if tokenizer_config_json_path:
with open(tokenizer_config_json_path, "rb") as f:
self.config = json.load(f)
if chat_template_path:
with open(chat_template_path, "r") as f:
# TODO: warning in the case of overwrite?
self.config["chat_template"] = f.read()
else:
self.config = None
if generation_config_path:
Expand Down Expand Up @@ -227,12 +232,14 @@ def __init__(
*,
tokenizer_config_json_path: str | None = None,
generation_config_path: str | None = None,
chat_template_path: str | None = None,
truncation_type: str = "right",
):
self.base_tokenizer = HuggingFaceBaseTokenizer(
tokenizer_json_path=tokenizer_json_path,
tokenizer_config_json_path=tokenizer_config_json_path,
generation_config_path=generation_config_path,
chat_template_path=chat_template_path,
)

# Contents of the tokenizer_config.json
Expand Down