Skip to content

Commit d6d3be4

Browse files
authored
Update the example of human in the loop (#247)
1 parent 6cb746d commit d6d3be4

File tree

8 files changed

+164
-5
lines changed

8 files changed

+164
-5
lines changed

docs/sphinx_doc/source/tutorial/example_data_functionalities.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ Trinity-RFT uses a unified config file to manage all config items. For the data
212212
In this example, assume that you need to select the chosen and rejected responses for DPO method. So you can set these config items like the following example:
213213

214214
```yaml
215+
# using task pipeline to decide the chosen and rejected from human preference
215216
data_processor:
216217
# task pipeline related
217218
task_pipeline:
@@ -239,7 +240,7 @@ data_processor:
239240
chosen_key: "chosen" # Chosen field
240241
rejected_key: "rejected" # Rejected field
241242
inputs: # the output will be set to the explorer input automatically
242-
- /PATH/TO/DATA/FILE/TO/BE/ANNOTATED
243+
- 'examples/dpo_human_in_the_loop/demo-data.jsonl'
243244
target_fields: ["prompt"]
244245
service:
245246
data_juicer:
@@ -252,6 +253,8 @@ The difference is that we use the data-juicer OP `human_preference_annotation_ma
252253

253254
You can set more config items for this OP (e.g. notification when annotation is finished). For more details, please refer to this [doc](https://github.com/modelscope/data-juicer/tree/main/configs/annotation).
254255

256+
All config items in the `data_processor` section can be found [here](trinity_configs.md). A prepared config file for this example can be found in [the config file](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_human_in_the_loop/dpo.yaml).
257+
255258
### Start Running
256259

257260
When you start running with the RFT config, the data processor will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# DPO with Human in the Loop
2+
3+
This example shows the usage of DPO with human in the loop on a simple example dataset.
4+
5+
For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_data_functionalities.md#example-human-in-the-loop).
6+
7+
The config files are located in [`dpo.yaml`](dpo.yaml) and [`train_dpo.yaml`](train_dpo.yaml). The example dataset is located in [`demo_data.jsonl`](demo-data.jsonl).
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{"prompt": "What is the capital of France?", "answer1": "Paris", "answer2": "Lyon"}
2+
{"prompt": "Which planet is known as the Red Planet?", "answer1": "Mars", "answer2": "Venus"}
3+
{"prompt": "What is the chemical symbol for gold?", "answer1": "Au", "answer2": "Ag"}
4+
{"prompt": "Who wrote 'Romeo and Juliet'?", "answer1": "William Shakespeare", "answer2": "Christopher Marlowe"}
5+
{"prompt": "What is the largest mammal on Earth?", "answer1": "Blue Whale", "answer2": "African Elephant"}
6+
{"prompt": "In which year did World War II end?", "answer1": "1945", "answer2": "1944"}
7+
{"prompt": "What is the square root of 64?", "answer1": "8", "answer2": "6"}
8+
{"prompt": "Who painted the Mona Lisa?", "answer1": "Leonardo da Vinci", "answer2": "Michelangelo"}
9+
{"prompt": "What is the main component of the Sun?", "answer1": "Hydrogen", "answer2": "Helium"}
10+
{"prompt": "Which programming language was created by Guido van Rossum?", "answer1": "Python", "answer2": "Java"}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
project: "dpo_example"
2+
name: "trinity_dpo"
3+
mode: train
4+
5+
# using task pipeline to decide the chosen and rejected from human preference
6+
data_processor:
7+
# task pipeline related
8+
task_pipeline:
9+
num_process: 1
10+
operators:
11+
- name: "human_preference_annotation_mapper"
12+
args:
13+
# general annotation project settings
14+
project_name_prefix: "Human_Preference_Annotation_Demo"
15+
wait_for_annotations: true # Whether to wait for annotations to complete
16+
timeout: 3600 # Maximum time to wait for annotations in seconds (1 hour)
17+
poll_interval: 10 # Time between annotation status checks in seconds
18+
max_tasks_per_batch: 10 # Maximum number of tasks in a single batch
19+
notification_config:
20+
enabled: false
21+
22+
# label studio connection settings
23+
api_url: "http://localhost:7070" # Default Label Studio URL
24+
api_key: "YOUR_API_KEY" # Your API key for label studuio authentication, which can be set when starting the label-studio service
25+
26+
# human preference annotation settings
27+
prompt_key: "prompt" # Prompt field
28+
answer1_key: "answer1" # First answer option
29+
answer2_key: "answer2" # Second answer option
30+
chosen_key: "chosen" # Chosen field
31+
rejected_key: "rejected" # Rejected field
32+
inputs: # the output will be set to the explorer input automatically
33+
- 'examples/dpo_human_in_the_loop/demo-data.jsonl'
34+
target_fields: ["prompt"]
35+
service:
36+
data_juicer:
37+
auto_start: true
38+
39+
algorithm:
40+
algorithm_type: dpo
41+
kl_loss_fn: k1
42+
kl_loss_fn_args:
43+
kl_coef: 0.1
44+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
45+
model:
46+
model_path: /PATH/TO/MODEL
47+
max_response_tokens: 1024
48+
max_model_len: 1536
49+
cluster:
50+
node_num: 1
51+
gpu_per_node: 8
52+
buffer:
53+
total_epochs: 2
54+
train_batch_size: 64
55+
trainer_input:
56+
experience_buffer:
57+
name: dpo_buffer
58+
storage_type: file
59+
enable_progress_bar: True
60+
path: ./outputs/human_annotation_output/ # the result data after human preference annotation are stored here
61+
format:
62+
prompt_type: plaintext # plaintext/messages
63+
prompt_key: prompt
64+
chosen_key: chosen
65+
rejected_key: rejected
66+
synchronizer:
67+
sync_method: 'checkpoint'
68+
sync_interval: 30
69+
sync_timeout: 1200
70+
trainer:
71+
trainer_type: 'verl'
72+
trainer_config_path: 'examples/dpo_human_in_the_loop/train_dpo.yaml'
73+
save_interval: 30
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
actor_rollout_ref:
2+
hybrid_engine: True
3+
model:
4+
external_lib: null
5+
override_config: { }
6+
enable_gradient_checkpointing: True
7+
use_remove_padding: False
8+
actor:
9+
strategy: fsdp # This is for backward-compatibility
10+
ppo_micro_batch_size_per_gpu: 2
11+
use_dynamic_bsz: False
12+
ppo_max_token_len_per_gpu: 16384
13+
grad_clip: 1.0
14+
ppo_epochs: 1
15+
shuffle: False
16+
ulysses_sequence_parallel_size: 1 # sp size
17+
optim:
18+
lr: 5e-7
19+
lr_warmup_steps_ratio: 0.03 # the total steps will be injected during runtime
20+
min_lr_ratio: 0.1 # only useful for warmup with cosine
21+
warmup_style: cosine # select from constant/cosine
22+
total_training_steps: 783
23+
betas: [0.9, 0.95]
24+
fsdp_config:
25+
wrap_policy:
26+
# transformer_layer_cls_to_wrap: None
27+
min_num_params: 0
28+
param_offload: False
29+
optimizer_offload: False
30+
fsdp_size: -1
31+
ref:
32+
fsdp_config:
33+
param_offload: False
34+
wrap_policy:
35+
# transformer_layer_cls_to_wrap: None
36+
min_num_params: 0
37+
# log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu
38+
log_prob_micro_batch_size_per_gpu: 2
39+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
40+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
41+
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
42+
43+
trainer:
44+
balance_batch: False
45+
total_training_steps: 783
46+
# auto: find the last ckpt to resume. If can't find, start from scratch
47+
resume_mode: auto # or auto or resume_path if
48+
default_hdfs_dir: null
49+
remove_previous_ckpt_in_save: False
50+
del_local_ckpt_after_load: False
51+
val_before_train: False

trinity/buffer/pipelines/task_pipeline.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55

66

77
def check_and_run_task_pipeline(config: Config) -> Dict:
8-
if not (config.mode == "explore" or config.mode == "both"):
9-
# task pipeline is only available when using Explorer
10-
return {}
118
if config.data_processor.task_pipeline is None:
129
return {}
1310

trinity/common/config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,21 @@ def _check_buffer(self) -> None: # noqa: C901
688688
)
689689
if self.data_processor.task_pipeline is not None:
690690
if self.data_processor.task_pipeline.output is None:
691-
self.data_processor.task_pipeline.output = self.buffer.explorer_input.taskset
691+
if self.buffer.explorer_input.taskset.path is not None:
692+
self.data_processor.task_pipeline.output = self.buffer.explorer_input.taskset
693+
elif (
694+
self.buffer.trainer_input.experience_buffer.schema_type in {"dpo", "sft"}
695+
and self.buffer.trainer_input.experience_buffer.path is not None
696+
):
697+
self.data_processor.task_pipeline.output = (
698+
self.buffer.trainer_input.experience_buffer
699+
)
700+
else:
701+
raise ValueError(
702+
"`data_processor.task_pipeline.output` is required when both "
703+
"`buffer.explorer_input.taskset.path` and `buffer.trainer_input.experience_buffer.path` are "
704+
"None"
705+
)
692706
if self.data_processor.task_pipeline.output.path and os.path.exists(
693707
self.data_processor.task_pipeline.output.path
694708
):

trinity/service/data_juicer/server/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def _parse_task_pipeline_config(config: DJConfig) -> Namespace:
114114

115115

116116
def group_scores(dataset: Dataset) -> Dataset:
117+
if Fields.stats not in dataset.features:
118+
return dataset
117119
# for perplexity, normalize them with the max value.
118120
stats_min_max = {}
119121
for stats in dataset.features[Fields.stats]:
@@ -165,6 +167,8 @@ def compute_priority_scores(
165167

166168
from data_juicer.utils.constant import Fields
167169

170+
if Fields.stats not in sample:
171+
return sample
168172
stats = sample[Fields.stats]
169173
if isinstance(stats, list):
170174
stats = stats[0]

0 commit comments

Comments
 (0)