Skip to content

Commit 87565ec

Browse files
lhlwinglian
andauthored
Add chat_template.argilla_chat support for DPO datasets (#3202)
* Add chat_template.argilla_chat support for DPO datasets Creates a new chat_template.argilla_chat prompt strategy for handling DPO datasets where chosen/rejected fields contain full conversations (messages + final response), following the pattern of chatml.argilla_chat and llama3.argilla_chat. - Add argilla_chat() function to chat_template.py - Add chat_template.argilla_chat to RLHF documentation - Add test coverage for argilla_chat with multiple tokenizers Dataset format: { "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } * Fix chat_template.argilla_chat return value contract and add docstring - Return (transform_fn, dataset_kwargs) tuple instead of bare transform_fn - Add remove_columns specification for field_chosen and field_rejected - Add comprehensive docstring with Args/Returns sections - Update tests to unpack tuple return value Addresses PR feedback to maintain consistency with chat_template.default() and properly specify columns to remove after dataset transformation. * Update tests/prompt_strategies/test_dpo_chat_templates.py Co-authored-by: Wing Lian <[email protected]> --------- Co-authored-by: Wing Lian <[email protected]>
1 parent 93ba573 commit 87565ec

File tree

3 files changed

+212
-1
lines changed

3 files changed

+212
-1
lines changed

docs/rlhf.qmd

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format:
219219
}
220220
```
221221

222+
#### chat_template.argilla_chat
223+
224+
```json
225+
{
226+
"chosen": [
227+
{"role": "user", "content": "..."},
228+
{"role": "assistant", "content": "..."}
229+
],
230+
"rejected": [
231+
{"role": "user", "content": "..."},
232+
{"role": "assistant", "content": "..."}
233+
]
234+
}
235+
```
236+
222237
#### chat_template.default
223238

224239
```yaml

src/axolotl/prompt_strategies/dpo/chat_template.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,123 @@ def transform_fn(sample, tokenizer=None):
120120
return result
121121

122122
return transform_fn, {"remove_columns": [field_messages]}
123+
124+
125+
def argilla_chat(cfg, dataset_idx=0, **kwargs):
126+
"""
127+
DPO chat template strategy for argilla-style datasets.
128+
129+
For argilla-style datasets where chosen/rejected contain full conversations
130+
instead of single response messages. Extracts the conversation history from
131+
the chosen field and formats both chosen/rejected responses using the
132+
configured chat template.
133+
134+
Args:
135+
cfg: Configuration object containing chat_template and dataset settings
136+
dataset_idx: Index of the dataset in the config (default: 0)
137+
**kwargs: Additional keyword arguments (unused)
138+
139+
Returns:
140+
tuple: (transform_fn, dataset_kwargs) where:
141+
- transform_fn: Function to transform dataset samples
142+
- dataset_kwargs: Dict with 'remove_columns' specifying columns to drop
143+
144+
Dataset format:
145+
{
146+
"chosen": [
147+
{"role": "user", "content": "..."},
148+
{"role": "assistant", "content": "..."}
149+
],
150+
"rejected": [
151+
{"role": "user", "content": "..."},
152+
{"role": "assistant", "content": "..."}
153+
]
154+
}
155+
"""
156+
ds_cfg = cfg["datasets"][dataset_idx]
157+
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
158+
159+
chat_template_choice, chat_template_jinja = extract_chat_template_args(
160+
cfg=cfg, ds_cfg=ds_cfg
161+
)
162+
field_chosen = ds_cfg.get("field_chosen", "chosen")
163+
field_rejected = ds_cfg.get("field_rejected", "rejected")
164+
message_property_mappings = ds_cfg.get(
165+
"message_property_mappings",
166+
{
167+
"role": "role",
168+
"content": "content",
169+
},
170+
)
171+
role_map_inv = ds_cfg.get(
172+
"roles",
173+
{
174+
"user": ["user"],
175+
"assistant": ["assistant"],
176+
"system": ["system"],
177+
},
178+
)
179+
role_map = {}
180+
for target, sources in role_map_inv.items():
181+
for source in sources:
182+
role_map[source] = target
183+
184+
def transform_fn(sample, tokenizer=None):
185+
chat_template_string = get_chat_template(
186+
user_choice=chat_template_choice,
187+
jinja_template=chat_template_jinja,
188+
tokenizer=tokenizer,
189+
)
190+
191+
chosen_raw = sample[field_chosen]
192+
rejected_raw = sample[field_rejected]
193+
194+
# Extract messages (all but last) and responses (last message)
195+
chosen_messages = [
196+
{
197+
"role": role_map[m[message_property_mappings["role"]]],
198+
"content": m[message_property_mappings["content"]],
199+
}
200+
for m in chosen_raw[:-1]
201+
]
202+
chosen_response = {
203+
"role": role_map[chosen_raw[-1][message_property_mappings["role"]]],
204+
"content": chosen_raw[-1][message_property_mappings["content"]],
205+
}
206+
207+
rejected_response = {
208+
"role": role_map[rejected_raw[-1][message_property_mappings["role"]]],
209+
"content": rejected_raw[-1][message_property_mappings["content"]],
210+
}
211+
212+
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
213+
214+
result = {}
215+
result["prompt"] = tokenizer.apply_chat_template(
216+
chosen_messages,
217+
add_generation_prompt=True,
218+
chat_template=chat_template_string,
219+
tokenize=False,
220+
)
221+
222+
result["chosen"] = tokenizer.apply_chat_template(
223+
[dummy_user_message, chosen_response],
224+
add_generation_prompt=False,
225+
chat_template=chat_template_string,
226+
tokenize=False,
227+
)
228+
chosen_strip_index = result["chosen"].find(chosen_response["content"])
229+
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
230+
231+
result["rejected"] = tokenizer.apply_chat_template(
232+
[dummy_user_message, rejected_response],
233+
add_generation_prompt=False,
234+
chat_template=chat_template_string,
235+
tokenize=False,
236+
)
237+
rejected_strip_index = result["rejected"].find(rejected_response["content"])
238+
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
239+
240+
return result
241+
242+
return transform_fn, {"remove_columns": [field_chosen, field_rejected]}

tests/prompt_strategies/test_dpo_chat_templates.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from datasets import Dataset
99
from transformers import AutoTokenizer
1010

11-
from axolotl.prompt_strategies.dpo.chat_template import default
11+
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
1212
from axolotl.utils.dict import DictDefault
1313

1414
from tests.hf_offline_utils import enable_hf_offline
@@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset():
7878
)
7979

8080

81+
@pytest.fixture(name="argilla_chat_dataset")
82+
def fixture_argilla_chat_dataset():
83+
return Dataset.from_list(
84+
[
85+
{
86+
"chosen": [
87+
{
88+
"role": "user",
89+
"content": "hello",
90+
},
91+
{
92+
"role": "assistant",
93+
"content": "goodbye",
94+
},
95+
],
96+
"rejected": [
97+
{
98+
"role": "user",
99+
"content": "hello",
100+
},
101+
{
102+
"role": "assistant",
103+
"content": "party on",
104+
},
105+
],
106+
}
107+
]
108+
)
109+
110+
81111
@pytest.fixture(name="phi3_tokenizer")
82112
@enable_hf_offline
83113
def fixture_phi3_tokenizer():
@@ -216,5 +246,51 @@ def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
216246
assert result["rejected"] == "party on<end_of_turn>"
217247

218248

249+
class TestArgillaChatDPOChatTemplate:
250+
"""
251+
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
252+
"""
253+
254+
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
255+
transform_fn, _ = argilla_chat(
256+
DictDefault(
257+
{
258+
"chat_template": "llama3",
259+
"datasets": [
260+
{
261+
"type": "chat_template.argilla_chat",
262+
}
263+
],
264+
}
265+
)
266+
)
267+
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
268+
assert result["prompt"] == (
269+
"<|begin_of_text|>"
270+
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
271+
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
272+
)
273+
assert result["chosen"] == "goodbye<|eot_id|>"
274+
assert result["rejected"] == "party on<|eot_id|>"
275+
276+
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
277+
transform_fn, _ = argilla_chat(
278+
DictDefault(
279+
{
280+
"chat_template": "tokenizer_default",
281+
"datasets": [
282+
{
283+
"type": "chat_template.argilla_chat",
284+
}
285+
],
286+
}
287+
)
288+
)
289+
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
290+
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
291+
assert result["chosen"] == "goodbye<|end|>"
292+
assert result["rejected"] == "party on<|end|>"
293+
294+
219295
if __name__ == "__main__":
220296
unittest.main()

0 commit comments

Comments
 (0)