diff --git a/apps/sft/main.py b/apps/sft/main.py index edda0b49d..93ba05eed 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -154,6 +154,15 @@ def setup_data(self): generation_config_path=os.path.join( self.job_config.model.hf_assets_path, "generation_config.json" ), + chat_template_path=( + path + if os.path.exists( + path := os.path.join( + self.job_config.model.hf_assets_path, "chat_template.jinja" + ) + ) + else None + ), ) dataset = sft_iterable_dataset( diff --git a/src/forge/data/tokenizer.py b/src/forge/data/tokenizer.py index 65407e131..9b0f4cf64 100644 --- a/src/forge/data/tokenizer.py +++ b/src/forge/data/tokenizer.py @@ -215,8 +215,8 @@ class HuggingFaceModelTokenizer(ModelTokenizer): Args: tokenizer_json_path (str): Path to tokenizer.json file tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None - generation_config_path (str | None): Path to generation_config.json file. - Default: None + generation_config_path (str | None): Path to generation_config.json file. Default: None + chat_template_path (str | None): Path to chat_template.jinja file. Default: None truncation_type (str): type of truncation to apply, either "left" or "right". Default is "right". """ @@ -227,6 +227,7 @@ def __init__( *, tokenizer_config_json_path: str | None = None, generation_config_path: str | None = None, + chat_template_path: str | None = None, truncation_type: str = "right", ): self.base_tokenizer = HuggingFaceBaseTokenizer( @@ -245,7 +246,13 @@ def __init__( # It is used sometimes in HF chat_templates _env.globals["raise_exception"] = self._raise_helper - self.template = _env.from_string(config["chat_template"]) + + if chat_template_path: + with open(chat_template_path, "r") as f: + self.template = _env.from_string(f.read()) + else: + self.template = _env.from_string(config["chat_template"]) + self.truncation_type = truncation_type self.special_tokens_mapping = {}