Skip to content

Commit 2d77165

Browse files
authored
automatically split out reasoning trace from dataset (axolotl-ai-cloud#2579)
* automatically split out reasoning trace from dataset * chore: lint * fix import
1 parent 63b17e3 commit 2d77165

File tree

4 files changed

+144
-0
lines changed

4 files changed

+144
-0
lines changed

src/axolotl/prompt_strategies/chat_template.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def __init__(
228228
train_on_eos: Optional[str] = None,
229229
train_on_eot: Optional[str] = None,
230230
eot_tokens: Optional[List[str]] = None,
231+
split_thinking: Optional[bool] = False,
231232
):
232233
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
233234
self.prompter: ChatTemplatePrompter = prompter
@@ -247,6 +248,7 @@ def __init__(
247248
self.eot_tokens = (
248249
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
249250
)
251+
self.split_thinking = split_thinking
250252

251253
self.images = "images"
252254

@@ -655,6 +657,22 @@ def transform_message(self, message):
655657
transformed_message["role"], transformed_message["role"]
656658
)
657659

660+
# TODO handle reasoning_content with split_thinking
661+
# if the role is assistant that we want to use reasoning_content
662+
if self.split_thinking and transformed_message["role"] == "assistant":
663+
content = transformed_message["content"]
664+
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
665+
for pair in pairs:
666+
if pair[0] in content and pair[1] in content:
667+
start_idx = content.find(pair[0])
668+
end_idx = content.find(pair[1])
669+
thinking_content = content[start_idx + len(pair[0]) : end_idx]
670+
transformed_message["reasoning_content"] = thinking_content.strip()
671+
transformed_message["content"] = content[
672+
end_idx + len(pair[1]) :
673+
].lstrip()
674+
break
675+
658676
# Determine which keys in the original message were not mapped
659677
mapped_values = set(self.prompter.message_property_mappings.values())
660678
remaining_keys = set(message) - mapped_values
@@ -689,6 +707,7 @@ def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
689707
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
690708
"train_on_eot": ds_cfg.get("train_on_eot", None),
691709
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
710+
"split_thinking": ds_cfg.get("split_thinking", False),
692711
}
693712

694713
def __call__(

src/axolotl/utils/schemas/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class SFTDataset(BaseModel):
5050
message_property_mappings: dict[str, str] | None = None
5151
message_field_training: str | None = None
5252
message_field_training_detail: str | None = None
53+
split_thinking: bool | None = None
5354
logprobs_field: str | None = None
5455
temperature: float | None = None
5556
roles_to_train: list[str] | None = None

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ def download_qwen_2_5_half_billion_model():
9090
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
9191

9292

93+
@pytest.fixture(scope="session", autouse=True)
94+
def download_qwen3_half_billion_model():
95+
# download the model
96+
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
97+
98+
9399
@pytest.fixture(scope="session", autouse=True)
94100
def download_tatsu_lab_alpaca_dataset():
95101
# download the dataset
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Tests for splitting reasoning/thinking from content into separate field
3+
"""
4+
5+
import logging
6+
7+
import pytest
8+
from datasets import Dataset
9+
from transformers import AutoTokenizer
10+
11+
from axolotl.prompt_strategies.chat_template import (
12+
load,
13+
)
14+
from axolotl.utils.dict import DictDefault
15+
16+
from tests.hf_offline_utils import enable_hf_offline
17+
18+
logging.basicConfig(level=logging.DEBUG)
19+
LOG = logging.getLogger("axolotl")
20+
21+
22+
@pytest.fixture(name="messages_w_reasoning")
23+
def messages_w_reasoning_fixture():
24+
return Dataset.from_list(
25+
[
26+
{
27+
"messages": [
28+
{
29+
"role": "user",
30+
"content": "hello",
31+
},
32+
{
33+
"role": "assistant",
34+
"content": "<think>lorem</think>\nwelcome",
35+
},
36+
]
37+
}
38+
]
39+
)
40+
41+
42+
@pytest.fixture(name="qwen3_tokenizer")
43+
@enable_hf_offline
44+
def qwen3_tokenizer_fixture(
45+
download_qwen3_half_billion_model,
46+
): # pylint: disable=unused-argument
47+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
48+
49+
return tokenizer
50+
51+
52+
class TestSplitThinking:
53+
"""
54+
test class to make sure datasets with reasoning content conforms to the chat_template strategy
55+
"""
56+
57+
def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):
58+
# pylint: disable=duplicate-code
59+
strategy = load(
60+
qwen3_tokenizer,
61+
DictDefault(
62+
{
63+
"train_on_inputs": False,
64+
"sequence_len": 512,
65+
}
66+
),
67+
DictDefault(
68+
{
69+
"chat_template": "qwen3",
70+
"message_field_role": "role",
71+
"message_field_content": "content",
72+
"message_property_mappings": {
73+
"role": "role",
74+
"content": "content",
75+
},
76+
"roles": {
77+
"user": ["user"],
78+
"assistant": ["assistant"],
79+
"system": ["system"],
80+
},
81+
"field_messages": "messages",
82+
"split_thinking": True,
83+
}
84+
),
85+
)
86+
transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0])
87+
assert transformed_prompt[0]["role"] == "user"
88+
assert transformed_prompt[1]["role"] == "assistant"
89+
assert transformed_prompt[1]["reasoning_content"] == "lorem"
90+
assert transformed_prompt[1]["content"] == "welcome"
91+
92+
res = strategy.tokenize_prompt(messages_w_reasoning[0])
93+
input_ids = res["input_ids"]
94+
# fmt: off
95+
expected_input_ids = [
96+
151644, # im_start
97+
872, # user
98+
198, # \n
99+
14990, # hello
100+
151645, # im_end
101+
198, # \n
102+
151644, # im_start
103+
77091, # assistant
104+
198, # \n
105+
151667, # think
106+
198, # \n
107+
385, 1826, # lorem
108+
198, # \n
109+
151668, # /think
110+
271, # \n
111+
34084, # welcome
112+
151645, # im_end
113+
198, # \n
114+
]
115+
# fmt: on
116+
assert (
117+
input_ids == expected_input_ids
118+
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"

0 commit comments

Comments
 (0)