Skip to content

Commit f337875

Browse files
feat: Allow chat template to be specified via a path in data config. (#615)
Signed-off-by: yashasvi <[email protected]>
1 parent d9ee35f commit f337875

File tree

5 files changed

+175
-8
lines changed

5 files changed

+175
-8
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,20 @@ If the dataset size is known to the user, `max_steps` can be calculated as the t
284284
285285
### How users can specify the chat template
286286
287+
There are multiple ways to specify chat_template in `data_config.yaml`,
288+
users could either specify path to `chat_template.jinja` file or update the chat_template directly.
289+
287290
In the `data_config.yaml` file:
288291
289292
**✅ USE:**
290293
294+
```yaml
295+
dataprocessor:
296+
chat_template_path: "path/to/chat_template.jinja"
297+
```
298+
299+
**✅ USE:**
300+
291301
```yaml
292302
dataprocessor:
293303
chat_template: "my single line chat template"

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@
5858
PREDEFINED_DATA_CONFIGS,
5959
"granite_3_1b_chat_template.txt",
6060
)
61+
CHAT_TEMPLATE_JINJA = os.path.join(
62+
PREDEFINED_DATA_CONFIGS,
63+
"chat_template.jinja",
64+
)
6165
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT = os.path.join(
6266
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml"
6367
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
{%- if messages[0]['role'] == 'system' %}
2+
{%- set system_message = messages[0]['content'] %}
3+
{%- set loop_messages = messages[1:] %}
4+
{%- else %}
5+
{%- set system_message = "Knowledge Cutoff Date: April 2024.\nToday's Date: " + strftime_now('%B %d, %Y') + ".\nYou are Granite, developed by IBM." %}
6+
{%- if tools and documents %}
7+
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.\n\nWrite the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
8+
{%- elif tools %}
9+
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
10+
{%- elif documents %}
11+
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
12+
{%- else %}
13+
{%- set system_message = system_message + " You are a helpful AI assistant." %}
14+
{%- endif %}
15+
{%- if 'citations' in controls and documents %}
16+
{%- set system_message = system_message + '\n\nIn your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
17+
{%- endif %}
18+
{%- if 'hallucinations' in controls and documents %}
19+
{%- set system_message = system_message + '\n\nFinally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
20+
{%- endif %}
21+
{%- set loop_messages = messages %}
22+
{%- endif %}
23+
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>\n' }}
24+
{%- if tools %}
25+
{{- '<|start_of_role|>tools<|end_of_role|>' }}
26+
{{- tools | tojson(indent=4) }}
27+
{{- '<|end_of_text|>\n' }}
28+
{%- endif %}
29+
{%- if documents %}
30+
{{- '<|start_of_role|>documents<|end_of_role|>' }}
31+
{%- for document in documents %}
32+
{{- 'Document ' + loop.index0 | string + '\n' }}
33+
{{- document['text'] }}
34+
{%- if not loop.last %}
35+
{{- '\n\n'}}
36+
{%- endif%}
37+
{%- endfor %}
38+
{{- '<|end_of_text|>\n' }}
39+
{%- endif %}
40+
{%- for message in loop_messages %}
41+
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}
42+
{%- if loop.last and add_generation_prompt %}
43+
{{- '<|start_of_role|>assistant' }}
44+
{%- if controls %}
45+
{{- ' ' + controls | tojson()}}
46+
{%- endif %}
47+
{{- '<|end_of_role|>' }}
48+
{%- endif %}
49+
{%- endfor %}

tests/test_sft_trainer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from scripts.run_inference import TunedCausalLM
3939
from tests.artifacts.language_models import MAYKEYE_TINY_LLAMA_CACHED
4040
from tests.artifacts.predefined_data_configs import (
41+
CHAT_TEMPLATE_JINJA,
4142
DATA_CONFIG_DUPLICATE_COLUMNS,
4243
DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE,
4344
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
@@ -1513,6 +1514,42 @@ def test_data_config_chat_template_as_base64():
15131514
data_config = load_and_validate_data_config(data_config_path)
15141515

15151516

1517+
def test_data_config_chat_template_path():
1518+
base_cfg = DATA_CONFIG_MULTITURN_DATA_YAML
1519+
chat_template_path = CHAT_TEMPLATE_JINJA
1520+
1521+
with open(chat_template_path, "r", encoding="utf-8") as f:
1522+
expected_template = f.read()
1523+
1524+
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".yaml") as tmp_cfg:
1525+
with open(base_cfg, "r", encoding="utf-8") as f:
1526+
cfg = yaml.safe_load(f)
1527+
1528+
dp = cfg.get("dataprocessor", {}) or {}
1529+
dp.pop("chat_template", None)
1530+
dp["chat_template_path"] = chat_template_path
1531+
cfg["dataprocessor"] = dp
1532+
1533+
for d in cfg.get("datasets", []):
1534+
d["data_paths"] = [TWITTER_COMPLAINTS_DATA_JSON]
1535+
1536+
yaml.safe_dump(cfg, tmp_cfg)
1537+
mod_cfg_path = tmp_cfg.name
1538+
1539+
data_config = load_and_validate_data_config(mod_cfg_path)
1540+
1541+
assert (
1542+
data_config.dataprocessor.chat_template == expected_template
1543+
), "chat_template should equal the contents of CHAT_TEMPLATE_JINJA"
1544+
assert data_config.dataprocessor.chat_template_path is not None
1545+
assert os.path.isabs(
1546+
data_config.dataprocessor.chat_template_path
1547+
), "stored chat_template_path should be absolute"
1548+
assert os.path.exists(
1549+
data_config.dataprocessor.chat_template_path
1550+
), "resolved chat_template_path should exist"
1551+
1552+
15161553
@pytest.mark.parametrize(
15171554
"data_args",
15181555
[

tuning/data/data_config.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ class DataSetConfig:
4545
class DataPreProcessorConfig:
4646
type: Optional[str] = "default"
4747
sampling_stopping_strategy: Optional[str] = "all_exhausted"
48-
# Default seed is not none to ensure reproducability
49-
seed: Optional[float] = 42
48+
# Default seed is not none to ensure reproducibility
49+
seed: Optional[int] = 42
5050
streaming: Optional[bool] = False
5151
chat_template: Optional[str] = None
52+
chat_template_path: Optional[str] = None
5253

5354

5455
@dataclass
@@ -148,28 +149,94 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf
148149
streaming = kwargs["streaming"]
149150
assert isinstance(streaming, bool), f"streaming: {streaming} should be a bool"
150151
c.streaming = streaming
151-
if "chat_template" in kwargs:
152+
153+
is_chat_template_present = (
154+
"chat_template" in kwargs and kwargs["chat_template"] is not None
155+
)
156+
is_chat_template_path_present = (
157+
"chat_template_path" in kwargs and kwargs["chat_template_path"] is not None
158+
)
159+
is_chat_template_b64_present = (
160+
"chat_template_base64" in kwargs and kwargs["chat_template_base64"] is not None
161+
)
162+
163+
if (
164+
sum(
165+
[
166+
is_chat_template_present,
167+
is_chat_template_path_present,
168+
is_chat_template_b64_present,
169+
]
170+
)
171+
> 1
172+
):
173+
raise ValueError(
174+
"Only one of 'chat_template', 'chat_template_path', or 'chat_template_base64' "
175+
"may be specified in dataprocessor config."
176+
)
177+
178+
if is_chat_template_present:
152179
chat_template = kwargs["chat_template"]
153180
assert isinstance(chat_template, str), "chat_template should be a string"
154181
c.chat_template = chat_template
155-
elif "chat_template_base64" in kwargs:
182+
c.chat_template_path = None
183+
return c
184+
185+
if is_chat_template_path_present:
186+
chat_template_path = kwargs["chat_template_path"]
187+
assert isinstance(
188+
chat_template_path, str
189+
), "chat_template_path should be a string path"
190+
# Expand ~ and environment variables, then absolutize
191+
expanded = os.path.expanduser(os.path.expandvars(chat_template_path))
192+
abs_path = os.path.abspath(expanded)
193+
194+
if not os.path.isabs(chat_template_path) and os.path.exists(abs_path):
195+
logger.warning(
196+
" Provided chat_template_path %s is not absolute, changing it to %s",
197+
chat_template_path,
198+
abs_path,
199+
)
200+
chat_template_path = abs_path
201+
202+
if not os.path.exists(abs_path):
203+
raise ValueError(
204+
f"chat_template_path does not exist: {chat_template_path} (resolved to {abs_path})"
205+
)
206+
if not os.path.isfile(abs_path):
207+
raise ValueError(
208+
f"chat_template_path is not a file: {chat_template_path} (resolved to {abs_path})"
209+
)
210+
try:
211+
with open(abs_path, "r", encoding="utf-8") as f:
212+
c.chat_template = f.read()
213+
c.chat_template_path = abs_path
214+
except Exception as e:
215+
raise ValueError(
216+
f"Failed to read chat_template_path: {chat_template_path} (resolved to {abs_path})."
217+
) from e
218+
return c
219+
220+
if is_chat_template_b64_present:
156221
chat_template_base64 = kwargs["chat_template_base64"]
157222
assert isinstance(
158223
chat_template_base64, str
159224
), "chat_template_base64 should be a string"
160225
logger.warning(
161226
"You are using the 'chat_template_base64' field. "
162-
+ "Please use the 'chat_template' field instead for better readability."
227+
"Please prefer 'chat_template' or 'chat_template_path' for better readability."
163228
)
164229
try:
165230
chat_template_bytes = b64decode(chat_template_base64)
166231
chat_template = chat_template_bytes.decode("utf-8")
167232
c.chat_template = chat_template
233+
c.chat_template_path = None
168234
except Exception as e:
169235
raise ValueError(
170-
"You passed the 'chat_template_base64' field which failed during decoding."
171-
+ "Please check it or use a decoded chat template with the 'chat_template' field."
236+
"You passed the 'chat_template_base64' field which failed during decoding. "
237+
"Please check it or use 'chat_template' or 'chat_template_path' instead."
172238
) from e
239+
173240
return c
174241

175242

@@ -199,7 +266,7 @@ def load_and_validate_data_config(data_config_file: str) -> DataConfig:
199266

200267
if dataprocessor is None:
201268
logging.info(
202-
"`dataprocessor` filed is absent from data config. Using default dataprocessor"
269+
"`dataprocessor` field is absent from data config. Using default dataprocessor"
203270
)
204271
dataprocessor = DataPreProcessorConfig()
205272
logging.info("Default datapreprocessor is %s", str(dataprocessor))

0 commit comments

Comments
 (0)