Skip to content

Commit a82790a

Browse files
jyork03awni
andauthored
Fixed/improved behavior of the mask_prompt feature. (#584)
* Fixed/improved behavior of the mask_prompt feature. Without setting add_generation_prompt to True, the model/assistant turn header can be included, which forces loss to be calculated over more than just the model's output that we care about. Introduced _apply_chat_template_safe to centralize defensive calls to apply_chat_template to account for some environemnts that don't support tools (added defensive measures for add_generation_prompt too just in case). * nits --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 2aa31f9 commit a82790a

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

mlx_lm/tuner/datasets.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __len__(self):
3939
class ChatDataset:
4040
"""
4141
A dataset for chat data in the format of {"messages": [...]}
42-
https://platform.openai.com/docs/guides/fine-tuning/example-format
42+
https://platform.openai.com/docs/guides/supervised-fine-tuning#formatting-your-data
4343
"""
4444

4545
def __init__(
@@ -59,8 +59,14 @@ def process(self, d):
5959
tools = d.get("tools", None)
6060
tokens = self.tokenizer.apply_chat_template(messages, tools=tools)
6161
if self.mask_prompt:
62-
messages = messages[:-1]
63-
offset = len(self.tokenizer.apply_chat_template(messages, tools=tools))
62+
add_generation_prompt = messages[-1].get("role") == "assistant"
63+
offset = len(
64+
self.tokenizer.apply_chat_template(
65+
messages[:-1],
66+
tools=tools,
67+
add_generation_prompt=add_generation_prompt,
68+
)
69+
)
6470
return (tokens, offset)
6571
else:
6672
return (tokens, 0)
@@ -94,16 +100,16 @@ def __init__(
94100
self.tokenizer = tokenizer
95101

96102
def process(self, d):
97-
tokens = self.tokenizer.apply_chat_template(
98-
[
99-
{"role": "user", "content": d[self.prompt_key]},
100-
{"role": "assistant", "content": d[self.completion_key]},
101-
],
102-
)
103+
tools = d.get("tools", None)
104+
messages = [
105+
{"role": "user", "content": d[self.prompt_key]},
106+
{"role": "assistant", "content": d[self.completion_key]},
107+
]
108+
tokens = _apply_chat_template_safe(self.tokenizer, messages, tools=tools)
103109
if self.mask_prompt:
104110
offset = len(
105111
self.tokenizer.apply_chat_template(
106-
[{"role": "user", "content": d[self.prompt_key]}]
112+
messages[0], tools=tools, add_generation_prompt=True
107113
)
108114
)
109115
return (tokens, offset)

0 commit comments

Comments
 (0)