Skip to content

Commit eebced3

Browse files
authored
Enhance SFT/DPO reader (agentscope-ai#226)
1 parent 013c2c7 commit eebced3

File tree

16 files changed

+607
-233
lines changed

16 files changed

+607
-233
lines changed

docs/sphinx_doc/source/tutorial/example_dpo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ buffer:
6868
storage_type: file
6969
path: $DATASET_PATH/human_like_dpo_dataset
7070
format:
71-
prompt_type: plaintext # plaintext/messages/chatpair
71+
prompt_type: plaintext
7272
prompt_key: prompt
7373
chosen_key: chosen
7474
rejected_key: rejected

docs/sphinx_doc/source/tutorial/example_reasoning_basic.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ buffer:
182182
storage_type: file
183183
path: <$DATASET_PATH/{sft_data}>
184184
format:
185-
prompt_type: <prompt_type> # messages/plaintext/chatpair
185+
prompt_type: <prompt_type> # messages/plaintext
186186
prompt_key: <prompt_key>
187187
response_key: <response_key>
188188
sft_warmup_steps: 10

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ synchronizer:
3737
monitor:
3838
# Monitoring configurations (e.g., WandB or TensorBoard)
3939
...
40-
data_processor:
41-
# Preprocessing data settings
42-
...
43-
4440
service:
4541
# Services to use
4642
...
47-
43+
data_processor:
44+
# Preprocessing data settings
45+
...
4846
log:
4947
# Ray actor logging
5048
...
@@ -419,6 +417,41 @@ service:
419417
- `auto_start`: Whether to automatically start the data juicer service.
420418
- `port`: The port for Data Juicer service when `auto_start` is true.
421419

420+
---
421+
422+
## DataProcessor Configuration
423+
424+
Configures the task / experience pipeline, please refer to {ref}`Data Processing <Data Processing>` section for details.
425+
426+
```yaml
427+
data_processor:
428+
task_pipeline:
429+
# task pipeline related
430+
task_pipeline:
431+
num_process: 32
432+
operators:
433+
- name: "llm_difficulty_score_filter"
434+
args:
435+
api_or_hf_model: "qwen2.5-7b-instruct"
436+
min_score: 0.0
437+
input_keys: ["question", "answer"]
438+
field_names: ["Question", "Answer"]
439+
inputs: # the output will be set to the explorer input automatically
440+
- /PATH/TO/GSM8K/DATA/FILE
441+
target_fields: ["question", "answer"]
442+
experience_pipeline:
443+
operators:
444+
- name: data_juicer
445+
args:
446+
config_path: 'examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml'
447+
- name: reward_shaping_mapper
448+
args:
449+
reward_shaping_configs:
450+
- stats_key: 'llm_quality_score'
451+
op_type: ADD
452+
weight: 1.0
453+
```
454+
422455
--
423456

424457
## Log Configuration

examples/dpo_humanlike/dpo.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ buffer:
2626
enable_progress_bar: True
2727
path: /PATH/TO/DATASET/
2828
format:
29-
prompt_type: plaintext # plaintext/messages/chatpair
29+
prompt_type: plaintext # plaintext/messages
3030
prompt_key: prompt
3131
chosen_key: chosen
3232
rejected_key: rejected

tests/buffer/formatter_test.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import unittest
2+
3+
from transformers import AutoTokenizer
4+
5+
from tests.tools import get_model_path
6+
from trinity.buffer.schema.formatter import (
7+
DPOMessagesFormatter,
8+
DPOPlaintextFormatter,
9+
SFTMessagesFormatter,
10+
SFTPlaintextFormatter,
11+
)
12+
from trinity.common.config import FormatConfig
13+
from trinity.common.constants import PromptType
14+
from trinity.common.experience import Experience
15+
16+
17+
class TestFormatter(unittest.TestCase):
18+
def setUp(self):
19+
self.tokenizer = AutoTokenizer.from_pretrained(get_model_path())
20+
21+
def test_sft_messages_formatter(self):
22+
config = FormatConfig(
23+
prompt_type=PromptType.MESSAGES,
24+
messages_key="message_list",
25+
)
26+
formatter = SFTMessagesFormatter(tokenizer=self.tokenizer, format_config=config)
27+
sample = {
28+
"message_list": [
29+
{"role": "user", "content": "Hi"},
30+
{"role": "assistant", "content": "Hello"},
31+
]
32+
}
33+
34+
exp = formatter.format(sample)
35+
self.assertIsInstance(exp, Experience)
36+
self.assertIsNotNone(exp.tokens)
37+
self.assertIsNotNone(exp.prompt_length)
38+
self.assertTrue(exp.prompt_length < len(exp.tokens))
39+
sequence = self.tokenizer.decode(exp.tokens)
40+
41+
self.assertIn("Hi", sequence)
42+
self.assertIn("Hello", sequence)
43+
44+
# test tool
45+
config = FormatConfig(
46+
prompt_type=PromptType.MESSAGES,
47+
messages_key="messages",
48+
tools_key="tools",
49+
)
50+
formatter = SFTMessagesFormatter(tokenizer=self.tokenizer, format_config=config)
51+
sample = {
52+
"messages": [
53+
{
54+
"role": "system",
55+
"content": "You are a helpful assistant with access to various tools. Use them when needed to help users.",
56+
},
57+
{"role": "user", "content": "What's the weather like in Beijing today?"},
58+
{
59+
"role": "assistant",
60+
"content": "Let me get the weather for you.",
61+
"tool_calls": [
62+
{
63+
"id": "call_abc123",
64+
"type": "function",
65+
"function": {
66+
"name": "get_weather",
67+
"arguments": '{"location": "Beijing", "unit": "celsius"}',
68+
},
69+
}
70+
],
71+
},
72+
{
73+
"role": "tool",
74+
"content": '{"temperature": 22, "condition": "sunny", "humidity": 45}',
75+
"tool_call_id": "call_abc123",
76+
},
77+
{
78+
"role": "assistant",
79+
"content": "The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
80+
},
81+
],
82+
"tools": [
83+
{
84+
"type": "function",
85+
"function": {
86+
"name": "get_weather",
87+
"description": "Get the current weather in a given location",
88+
"parameters": {
89+
"type": "object",
90+
"properties": {
91+
"location": {
92+
"type": "string",
93+
"description": "The city and state, e.g. San Francisco, CA",
94+
},
95+
"unit": {
96+
"type": "string",
97+
"enum": ["celsius", "fahrenheit"],
98+
"description": "The temperature unit",
99+
},
100+
},
101+
"required": ["location"],
102+
},
103+
},
104+
}
105+
],
106+
}
107+
exp = formatter.format(sample)
108+
self.assertIsInstance(exp, Experience)
109+
self.assertIsNotNone(exp.tokens)
110+
self.assertIsNotNone(exp.prompt_length)
111+
self.assertTrue(exp.prompt_length < len(exp.tokens))
112+
sequence = self.tokenizer.decode(exp.tokens)
113+
self.assertIn("What's the weather like in Beijing today?", sequence)
114+
self.assertIn(
115+
"The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!",
116+
sequence,
117+
)
118+
self.assertIn("get_weather", sequence)
119+
120+
def test_sft_plaintext_formatter(self):
121+
# with system prompt key
122+
config = FormatConfig(
123+
prompt_type=PromptType.PLAINTEXT,
124+
system_prompt_key="system",
125+
system_prompt="You are a programmer.", # has lower priority than system_prompt_key
126+
prompt_key="prompt",
127+
response_key="response",
128+
)
129+
formatter = SFTPlaintextFormatter(tokenizer=self.tokenizer, format_config=config)
130+
sample = {
131+
"system": "You are a helpful assistant.",
132+
"prompt": "What is 2+2?",
133+
"response": "2+2=4",
134+
}
135+
exp = formatter.format(sample)
136+
self.assertIsInstance(exp, Experience)
137+
self.assertIsNotNone(exp.tokens)
138+
self.assertIsNotNone(exp.prompt_length)
139+
self.assertTrue(exp.prompt_length < len(exp.tokens))
140+
# detokenize exp.tokens into text
141+
sequence = self.tokenizer.decode(exp.tokens)
142+
self.assertIn("You are a helpful assistant.", sequence)
143+
self.assertIn("What is 2+2?", sequence)
144+
self.assertIn("2+2=4", sequence)
145+
146+
# with system prompt
147+
config = FormatConfig(
148+
prompt_type=PromptType.PLAINTEXT,
149+
system_prompt="You are a programmer.",
150+
prompt_key="prompt",
151+
response_key="response",
152+
)
153+
formatter = SFTPlaintextFormatter(tokenizer=self.tokenizer, format_config=config)
154+
155+
exp = formatter.format(sample)
156+
self.assertIsInstance(exp, Experience)
157+
self.assertIsNotNone(exp.tokens)
158+
self.assertIsNotNone(exp.prompt_length)
159+
self.assertTrue(exp.prompt_length < len(exp.tokens))
160+
# detokenize exp.tokens into text
161+
sequence = self.tokenizer.decode(exp.tokens)
162+
self.assertIn("You are a programmer.", sequence)
163+
self.assertIn("What is 2+2?", sequence)
164+
self.assertIn("2+2=4", sequence)
165+
166+
def test_dpo_plaintext_formatter(self):
167+
config = FormatConfig(
168+
prompt_type=PromptType.PLAINTEXT,
169+
prompt_key="prompt",
170+
chosen_key="chosen",
171+
rejected_key="rejected",
172+
)
173+
formatter = DPOPlaintextFormatter(tokenizer=self.tokenizer, format_config=config)
174+
sample = {"prompt": "What is 2+2?", "chosen": "2+2=4", "rejected": "2+2=5"}
175+
exp = formatter.format(sample)
176+
self.assertIsInstance(exp, Experience)
177+
self.assertIsNotNone(exp.tokens)
178+
self.assertIsNotNone(exp.chosen)
179+
self.assertIsNotNone(exp.rejected)
180+
self.assertIsNotNone(exp.prompt_length)
181+
prompt = self.tokenizer.decode(exp.tokens)
182+
chosen = self.tokenizer.decode(exp.chosen)
183+
rejected = self.tokenizer.decode(exp.rejected)
184+
self.assertIn("What is 2+2?", prompt)
185+
self.assertIn("2+2=4", chosen)
186+
self.assertIn("2+2=5", rejected)
187+
self.assertNotIn("What is 2+2?", chosen)
188+
self.assertNotIn("What is 2+2?", rejected)
189+
self.assertNotIn("2+2=4", prompt)
190+
self.assertNotIn("2+2=5", prompt)
191+
192+
def test_dpo_messages_formatter(self):
193+
config = FormatConfig(
194+
prompt_type=PromptType.MESSAGES,
195+
messages_key="messages",
196+
chosen_key="chosen",
197+
rejected_key="rejected",
198+
)
199+
formatter = DPOMessagesFormatter(tokenizer=self.tokenizer, format_config=config)
200+
sample = {
201+
"messages": [
202+
{"role": "user", "content": "What is your name?"},
203+
],
204+
"chosen": [
205+
{"role": "assistant", "content": "My name is Assistant."},
206+
],
207+
"rejected": [{"role": "assistant", "content": "I don't have a favorite color."}],
208+
}
209+
exp = formatter.format(sample)
210+
self.assertIsInstance(exp, Experience)
211+
self.assertIsNotNone(exp.tokens)
212+
self.assertIsNotNone(exp.prompt_length)
213+
# detokenize exp.tokens into text
214+
prompt = self.tokenizer.decode(exp.tokens)
215+
chosen = self.tokenizer.decode(exp.chosen)
216+
rejected = self.tokenizer.decode(exp.rejected)
217+
self.assertIn("What is your name?", prompt)
218+
self.assertIn("My name is Assistant.", chosen)
219+
self.assertIn("I don't have a favorite color.", rejected)

tests/cli/__init__.py

Whitespace-only changes.

tests/cli/launcher_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import sys
2+
import unittest
3+
from unittest import mock
4+
5+
from tests.tools import get_template_config
6+
from trinity.cli import launcher
7+
from trinity.common.constants import (
8+
LOG_DIR_ENV_VAR,
9+
LOG_LEVEL_ENV_VAR,
10+
LOG_NODE_IP_ENV_VAR,
11+
PLUGIN_DIRS_ENV_VAR,
12+
)
13+
14+
15+
class TestLauncherMain(unittest.TestCase):
16+
def setUp(self):
17+
self._orig_argv = sys.argv.copy()
18+
19+
def tearDown(self):
20+
sys.argv = self._orig_argv
21+
22+
@mock.patch("trinity.cli.launcher.explore")
23+
@mock.patch("trinity.cli.launcher.train")
24+
@mock.patch("trinity.cli.launcher.both")
25+
@mock.patch("trinity.cli.launcher.bench")
26+
@mock.patch("trinity.cli.launcher.load_config")
27+
def test_main_run_command(self, mock_load, mock_bench, mock_both, mock_train, mock_explore):
28+
config = get_template_config()
29+
mapping = {
30+
"explore": mock_explore,
31+
"train": mock_train,
32+
"both": mock_both,
33+
"bench": mock_bench,
34+
}
35+
for mode in ["explore", "train", "both", "bench"]:
36+
config.mode = mode
37+
mock_load.return_value = config
38+
with mock.patch(
39+
"argparse.ArgumentParser.parse_args",
40+
return_value=mock.Mock(
41+
command="run", config="dummy.yaml", dlc=False, plugin_dir=None
42+
),
43+
):
44+
launcher.main()
45+
mock_load.assert_called_once_with("dummy.yaml")
46+
mapping[mode].assert_called_once_with(config)
47+
mock_load.reset_mock()
48+
mapping[mode].reset_mock()
49+
50+
@mock.patch("trinity.cli.launcher.setup_ray_cluster")
51+
@mock.patch("trinity.cli.launcher.both")
52+
@mock.patch("trinity.cli.launcher.load_config")
53+
def test_main_run_in_dlc(self, mock_load, mock_both, mock_setup):
54+
config = get_template_config()
55+
config.mode = "both"
56+
config.log.level = "WARNING"
57+
config.log.group_by_node = True
58+
mock_load.return_value = config
59+
with mock.patch(
60+
"argparse.ArgumentParser.parse_args",
61+
return_value=mock.Mock(
62+
command="run", config="dummy.yaml", dlc=True, plugin_dir="/path/to/plugins"
63+
),
64+
):
65+
launcher.main()
66+
mock_load.assert_called_once_with("dummy.yaml")
67+
mock_both.assert_called_once_with(config)
68+
mock_setup.assert_called_once_with(
69+
namespace=config.ray_namespace,
70+
envs={
71+
PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
72+
LOG_DIR_ENV_VAR: config.log.save_dir,
73+
LOG_LEVEL_ENV_VAR: "WARNING",
74+
LOG_NODE_IP_ENV_VAR: "1",
75+
},
76+
)
77+
78+
@mock.patch("trinity.cli.launcher.studio")
79+
def test_main_studio_command(self, mock_studio):
80+
with mock.patch(
81+
"argparse.ArgumentParser.parse_args",
82+
return_value=mock.Mock(command="studio", port=9999),
83+
):
84+
launcher.main()
85+
mock_studio.assert_called_once_with(9999)
86+
87+
88+
if __name__ == "__main__":
89+
unittest.main()

0 commit comments

Comments
 (0)