forked from ml-explore/mlx-lm
-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Overview
CompletionsDataset.process() crashes when mask_prompt=True because a single message dict is passed to apply_chat_template instead of a list.
Problem
When training with the completions format and mask_prompt enabled, the process method builds a two-element messages list, then passes messages[0] (a bare dict) to tokenizer.apply_chat_template(). This method expects a list of message dicts, causing a crash.
The analogous code path in ChatDataset.process() correctly passes a list slice (messages[:-1]), so this appears to be an oversight specific to the completions path.
Requirements
- The mask-prompt code path in
CompletionsDataset.process()must pass a properly-typed argument (list of message dicts) toapply_chat_template - The prompt masking offset calculation must match the behavior of
ChatDataset.process()(i.e., tokenize only the user prompt portion withadd_generation_prompt=True) - Existing non-mask-prompt behavior must remain unchanged
Acceptance Criteria
- Training with completions format and
mask_prompt=Trueno longer crashes - The prompt mask offset correctly covers only the prompt tokens (not the completion)
- Training with completions format and
mask_prompt=Falseis unaffected -
ChatDatasetwithmask_prompt=Truecontinues to work correctly
Constraints
- Fix should be minimal and not change any unrelated behavior
- Must be backwards-compatible with existing training configs
Reference
ChatDataset.process()handles the same pattern correctly and serves as the reference implementation
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels