Skip to content

Commit baeb002

Browse files
Handle other reasoning trace dataset formats (axolotl-ai-cloud#2591)
* Handle other reasoning trace dataset formats * rename var to improve readability * chore: refactor with comments --------- Co-authored-by: NanoCode012 <[email protected]>
1 parent 2413688 commit baeb002

File tree

2 files changed

+98
-42
lines changed

2 files changed

+98
-42
lines changed

src/axolotl/prompt_strategies/chat_template.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
message_property_mappings = {
4343
"role": "role",
4444
"content": "content",
45+
"reasoning_content": "reasoning_content",
4546
}
4647

4748
if roles:
@@ -661,16 +662,46 @@ def transform_message(self, message):
661662
# if the role is assistant that we want to use reasoning_content
662663
if self.split_thinking and transformed_message["role"] == "assistant":
663664
content = transformed_message["content"]
664-
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
665-
for pair in pairs:
666-
if pair[0] in content and pair[1] in content:
667-
start_idx = content.find(pair[0])
668-
end_idx = content.find(pair[1])
669-
thinking_content = content[start_idx + len(pair[0]) : end_idx]
665+
thinking_pairs = [
666+
("<think>", "</think>"),
667+
("<reasoning>", "</reasoning>"),
668+
("<|begin_of_thought|>", "<|end_of_thought|>"),
669+
]
670+
content_pairs = [("<|begin_of_solution|>", "<|end_of_solution|>")]
671+
for tpair in thinking_pairs:
672+
# check if the thinking pair is in the content
673+
if tpair[0] in content and tpair[1] in content:
674+
# find the start and end index of the thinking pair
675+
t_start_idx = content.find(tpair[0])
676+
t_end_idx = content.find(tpair[1])
677+
678+
# get the thinking content
679+
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
670680
transformed_message["reasoning_content"] = thinking_content.strip()
671-
transformed_message["content"] = content[
672-
end_idx + len(pair[1]) :
673-
].lstrip()
681+
682+
# take remainder of the content
683+
# strip whitespace from beginning of the remainder (thinking tokens)
684+
remainder = content[t_end_idx + len(tpair[1]) :].lstrip()
685+
686+
# check if the content pair is in the remainder
687+
cpair_found = False
688+
for cpair in content_pairs:
689+
if cpair[0] in remainder and cpair[1] in remainder:
690+
# find the start and end index of the content pair
691+
c_start_idx = remainder.find(cpair[0])
692+
c_end_idx = remainder.find(cpair[1])
693+
694+
# get the content content
695+
content_content = remainder[
696+
c_start_idx + len(cpair[0]) : c_end_idx
697+
]
698+
transformed_message["content"] = content_content.strip()
699+
cpair_found = True
700+
break
701+
702+
# else, the content is the remainder
703+
if not cpair_found:
704+
transformed_message["content"] = remainder
674705
break
675706

676707
# Determine which keys in the original message were not mapped

tests/prompt_strategies/test_chat_templates_thinking.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,31 @@ def messages_w_reasoning_fixture():
3434
"content": "<think>lorem</think>\nwelcome",
3535
},
3636
]
37-
}
37+
},
38+
{
39+
"messages": [
40+
{
41+
"role": "user",
42+
"content": "hello",
43+
},
44+
{
45+
"role": "assistant",
46+
"content": "<|begin_of_thought|>lorem<|end_of_thought|>\n<|begin_of_solution|>welcome\n<|end_of_solution|>",
47+
},
48+
]
49+
},
50+
{
51+
"messages": [
52+
{
53+
"role": "user",
54+
"content": "hello",
55+
},
56+
{
57+
"role": "assistant",
58+
"content": "<reasoning>lorem</reasoning>\nwelcome",
59+
},
60+
]
61+
},
3862
]
3963
)
4064

@@ -83,36 +107,37 @@ def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):
83107
}
84108
),
85109
)
86-
transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0])
87-
assert transformed_prompt[0]["role"] == "user"
88-
assert transformed_prompt[1]["role"] == "assistant"
89-
assert transformed_prompt[1]["reasoning_content"] == "lorem"
90-
assert transformed_prompt[1]["content"] == "welcome"
110+
for conversation in messages_w_reasoning:
111+
transformed_prompt = strategy.get_conversation_thread(conversation)
112+
assert transformed_prompt[0]["role"] == "user"
113+
assert transformed_prompt[1]["role"] == "assistant"
114+
assert transformed_prompt[1]["reasoning_content"] == "lorem"
115+
assert transformed_prompt[1]["content"] == "welcome"
91116

92-
res = strategy.tokenize_prompt(messages_w_reasoning[0])
93-
input_ids = res["input_ids"]
94-
# fmt: off
95-
expected_input_ids = [
96-
151644, # im_start
97-
872, # user
98-
198, # \n
99-
14990, # hello
100-
151645, # im_end
101-
198, # \n
102-
151644, # im_start
103-
77091, # assistant
104-
198, # \n
105-
151667, # think
106-
198, # \n
107-
385, 1826, # lorem
108-
198, # \n
109-
151668, # /think
110-
271, # \n
111-
34084, # welcome
112-
151645, # im_end
113-
198, # \n
114-
]
115-
# fmt: on
116-
assert (
117-
input_ids == expected_input_ids
118-
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
117+
res = strategy.tokenize_prompt(conversation)
118+
input_ids = res["input_ids"]
119+
# fmt: off
120+
expected_input_ids = [
121+
151644, # im_start
122+
872, # user
123+
198, # \n
124+
14990, # hello
125+
151645, # im_end
126+
198, # \n
127+
151644, # im_start
128+
77091, # assistant
129+
198, # \n
130+
151667, # think
131+
198, # \n
132+
385, 1826, # lorem
133+
198, # \n
134+
151668, # /think
135+
271, # \n
136+
34084, # welcome
137+
151645, # im_end
138+
198, # \n
139+
]
140+
# fmt: on
141+
assert (
142+
input_ids == expected_input_ids
143+
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"

0 commit comments

Comments
 (0)