diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md
index 166c491aa..8af5f9c95 100644
--- a/docs/advanced-data-preprocessing.md
+++ b/docs/advanced-data-preprocessing.md
@@ -284,10 +284,20 @@ If the dataset size is known to the user, `max_steps` can be calculated as the t
### How users can specify the chat template
+There are multiple ways to specify chat_template in `data_config.yaml`,
+users could either specify path to `chat_template.jinja` file or update the chat_template directly.
+
In the `data_config.yaml` file:
**✅ USE:**
+```yaml
+dataprocessor:
+ chat_template_path: "path/to/chat_template.jinja"
+```
+
+**✅ USE:**
+
```yaml
dataprocessor:
chat_template: "my single line chat template"
diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py
index 9261152a2..0a8d3a8e1 100644
--- a/tests/artifacts/predefined_data_configs/__init__.py
+++ b/tests/artifacts/predefined_data_configs/__init__.py
@@ -58,6 +58,10 @@
PREDEFINED_DATA_CONFIGS,
"granite_3_1b_chat_template.txt",
)
+CHAT_TEMPLATE_JINJA = os.path.join(
+ PREDEFINED_DATA_CONFIGS,
+ "chat_template.jinja",
+)
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT = os.path.join(
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml"
)
diff --git a/tests/artifacts/predefined_data_configs/chat_template.jinja b/tests/artifacts/predefined_data_configs/chat_template.jinja
new file mode 100644
index 000000000..487920c59
--- /dev/null
+++ b/tests/artifacts/predefined_data_configs/chat_template.jinja
@@ -0,0 +1,49 @@
+{%- if messages[0]['role'] == 'system' %}
+ {%- set system_message = messages[0]['content'] %}
+ {%- set loop_messages = messages[1:] %}
+{%- else %}
+ {%- set system_message = "Knowledge Cutoff Date: April 2024.\nToday's Date: " + strftime_now('%B %d, %Y') + ".\nYou are Granite, developed by IBM." %}
+ {%- if tools and documents %}
+ {%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.\n\nWrite the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
+ {%- elif tools %}
+ {%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
+ {%- elif documents %}
+ {%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
+ {%- else %}
+ {%- set system_message = system_message + " You are a helpful AI assistant." %}
+ {%- endif %}
+ {%- if 'citations' in controls and documents %}
+ {%- set system_message = system_message + '\n\nIn your response, use the symbols and to indicate when a fact comes from a document in the search result, e.g 0 for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
+ {%- endif %}
+ {%- if 'hallucinations' in controls and documents %}
+ {%- set system_message = system_message + '\n\nFinally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
+ {%- endif %}
+ {%- set loop_messages = messages %}
+{%- endif %}
+{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>\n' }}
+{%- if tools %}
+ {{- '<|start_of_role|>tools<|end_of_role|>' }}
+ {{- tools | tojson(indent=4) }}
+ {{- '<|end_of_text|>\n' }}
+{%- endif %}
+{%- if documents %}
+ {{- '<|start_of_role|>documents<|end_of_role|>' }}
+ {%- for document in documents %}
+ {{- 'Document ' + loop.index0 | string + '\n' }}
+ {{- document['text'] }}
+ {%- if not loop.last %}
+ {{- '\n\n'}}
+ {%- endif%}
+ {%- endfor %}
+ {{- '<|end_of_text|>\n' }}
+{%- endif %}
+{%- for message in loop_messages %}
+ {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}
+ {%- if loop.last and add_generation_prompt %}
+ {{- '<|start_of_role|>assistant' }}
+ {%- if controls %}
+ {{- ' ' + controls | tojson()}}
+ {%- endif %}
+ {{- '<|end_of_role|>' }}
+ {%- endif %}
+{%- endfor %}
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index b20dce180..29ae5e839 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -38,6 +38,7 @@
from scripts.run_inference import TunedCausalLM
from tests.artifacts.language_models import MAYKEYE_TINY_LLAMA_CACHED
from tests.artifacts.predefined_data_configs import (
+ CHAT_TEMPLATE_JINJA,
DATA_CONFIG_DUPLICATE_COLUMNS,
DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE,
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
@@ -1513,6 +1514,42 @@ def test_data_config_chat_template_as_base64():
data_config = load_and_validate_data_config(data_config_path)
+def test_data_config_chat_template_path():
+ base_cfg = DATA_CONFIG_MULTITURN_DATA_YAML
+ chat_template_path = CHAT_TEMPLATE_JINJA
+
+ with open(chat_template_path, "r", encoding="utf-8") as f:
+ expected_template = f.read()
+
+ with tempfile.NamedTemporaryFile("w", delete=False, suffix=".yaml") as tmp_cfg:
+ with open(base_cfg, "r", encoding="utf-8") as f:
+ cfg = yaml.safe_load(f)
+
+ dp = cfg.get("dataprocessor", {}) or {}
+ dp.pop("chat_template", None)
+ dp["chat_template_path"] = chat_template_path
+ cfg["dataprocessor"] = dp
+
+ for d in cfg.get("datasets", []):
+ d["data_paths"] = [TWITTER_COMPLAINTS_DATA_JSON]
+
+ yaml.safe_dump(cfg, tmp_cfg)
+ mod_cfg_path = tmp_cfg.name
+
+ data_config = load_and_validate_data_config(mod_cfg_path)
+
+ assert (
+ data_config.dataprocessor.chat_template == expected_template
+ ), "chat_template should equal the contents of CHAT_TEMPLATE_JINJA"
+ assert data_config.dataprocessor.chat_template_path is not None
+ assert os.path.isabs(
+ data_config.dataprocessor.chat_template_path
+ ), "stored chat_template_path should be absolute"
+ assert os.path.exists(
+ data_config.dataprocessor.chat_template_path
+ ), "resolved chat_template_path should exist"
+
+
@pytest.mark.parametrize(
"data_args",
[
diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py
index 3b577766d..b461a6933 100644
--- a/tuning/data/data_config.py
+++ b/tuning/data/data_config.py
@@ -45,10 +45,11 @@ class DataSetConfig:
class DataPreProcessorConfig:
type: Optional[str] = "default"
sampling_stopping_strategy: Optional[str] = "all_exhausted"
- # Default seed is not none to ensure reproducability
- seed: Optional[float] = 42
+ # Default seed is not none to ensure reproducibility
+ seed: Optional[int] = 42
streaming: Optional[bool] = False
chat_template: Optional[str] = None
+ chat_template_path: Optional[str] = None
@dataclass
@@ -148,28 +149,94 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf
streaming = kwargs["streaming"]
assert isinstance(streaming, bool), f"streaming: {streaming} should be a bool"
c.streaming = streaming
- if "chat_template" in kwargs:
+
+ is_chat_template_present = (
+ "chat_template" in kwargs and kwargs["chat_template"] is not None
+ )
+ is_chat_template_path_present = (
+ "chat_template_path" in kwargs and kwargs["chat_template_path"] is not None
+ )
+ is_chat_template_b64_present = (
+ "chat_template_base64" in kwargs and kwargs["chat_template_base64"] is not None
+ )
+
+ if (
+ sum(
+ [
+ is_chat_template_present,
+ is_chat_template_path_present,
+ is_chat_template_b64_present,
+ ]
+ )
+ > 1
+ ):
+ raise ValueError(
+ "Only one of 'chat_template', 'chat_template_path', or 'chat_template_base64' "
+ "may be specified in dataprocessor config."
+ )
+
+ if is_chat_template_present:
chat_template = kwargs["chat_template"]
assert isinstance(chat_template, str), "chat_template should be a string"
c.chat_template = chat_template
- elif "chat_template_base64" in kwargs:
+ c.chat_template_path = None
+ return c
+
+ if is_chat_template_path_present:
+ chat_template_path = kwargs["chat_template_path"]
+ assert isinstance(
+ chat_template_path, str
+ ), "chat_template_path should be a string path"
+ # Expand ~ and environment variables, then absolutize
+ expanded = os.path.expanduser(os.path.expandvars(chat_template_path))
+ abs_path = os.path.abspath(expanded)
+
+ if not os.path.isabs(chat_template_path) and os.path.exists(abs_path):
+ logger.warning(
+ " Provided chat_template_path %s is not absolute, changing it to %s",
+ chat_template_path,
+ abs_path,
+ )
+ chat_template_path = abs_path
+
+ if not os.path.exists(abs_path):
+ raise ValueError(
+ f"chat_template_path does not exist: {chat_template_path} (resolved to {abs_path})"
+ )
+ if not os.path.isfile(abs_path):
+ raise ValueError(
+ f"chat_template_path is not a file: {chat_template_path} (resolved to {abs_path})"
+ )
+ try:
+ with open(abs_path, "r", encoding="utf-8") as f:
+ c.chat_template = f.read()
+ c.chat_template_path = abs_path
+ except Exception as e:
+ raise ValueError(
+ f"Failed to read chat_template_path: {chat_template_path} (resolved to {abs_path})."
+ ) from e
+ return c
+
+ if is_chat_template_b64_present:
chat_template_base64 = kwargs["chat_template_base64"]
assert isinstance(
chat_template_base64, str
), "chat_template_base64 should be a string"
logger.warning(
"You are using the 'chat_template_base64' field. "
- + "Please use the 'chat_template' field instead for better readability."
+ "Please prefer 'chat_template' or 'chat_template_path' for better readability."
)
try:
chat_template_bytes = b64decode(chat_template_base64)
chat_template = chat_template_bytes.decode("utf-8")
c.chat_template = chat_template
+ c.chat_template_path = None
except Exception as e:
raise ValueError(
- "You passed the 'chat_template_base64' field which failed during decoding."
- + "Please check it or use a decoded chat template with the 'chat_template' field."
+ "You passed the 'chat_template_base64' field which failed during decoding. "
+ "Please check it or use 'chat_template' or 'chat_template_path' instead."
) from e
+
return c
@@ -199,7 +266,7 @@ def load_and_validate_data_config(data_config_file: str) -> DataConfig:
if dataprocessor is None:
logging.info(
- "`dataprocessor` filed is absent from data config. Using default dataprocessor"
+ "`dataprocessor` field is absent from data config. Using default dataprocessor"
)
dataprocessor = DataPreProcessorConfig()
logging.info("Default datapreprocessor is %s", str(dataprocessor))