Skip to content

Commit dcbc769

Browse files
authored
Refactor workflow (#40)
1 parent 11a95be commit dcbc769

File tree

33 files changed

+463
-247
lines changed

33 files changed

+463
-247
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ buffer:
105105
format:
106106
prompt_key: 'question'
107107
response_key: 'answer'
108+
rollout_args:
109+
repeat_times: 1
110+
temperature: 1.0
111+
logprobs: 0
108112
eval_tasksets: []
109113
default_workflow_type: 'math_workflow'
110114
default_reward_fn_type: 'countdown_reward'
@@ -123,6 +127,9 @@ buffer:
123127
- `buffer.explorer_input.taskset.path`: The path to the taskset.
124128
- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`.
125129
- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`.
130+
- `buffer.explorer_input.taskset.rollout_args.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`.
131+
- `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`.
132+
- `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`.
126133
- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default.
127134
- `buffer.explorer_input.default_workflow_type`: The default workflow type for `taskset` and `eval_tasksets`.
128135
- `buffer.explorer_input.default_reward_fn_type`: The default reward function type for `taskset` and `eval_tasksets`.
@@ -145,10 +152,7 @@ explorer:
145152
enable_prefix_caching: false
146153
enforce_eager: true
147154
dtype: bfloat16
148-
temperature: 1.0
149155
seed: 42
150-
logprobs: 0
151-
repeat_times: 5
152156
use_ray: false
153157
backend: 'nccl'
154158
max_pending_requests: 32
@@ -162,10 +166,7 @@ explorer:
162166
- `explorer.enable_prefix_caching`: Whether to enable prefix caching. Default is `False`.
163167
- `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`.
164168
- `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`.
165-
- `explorer.temperature`: The temperature used in vLLM. Default is `1.0`.
166169
- `explorer.seed`: The seed used in vLLM. Default is `42`.
167-
- `explorer.logprobs`: The logprobs used in vLLM. Default is `0`.
168-
- `explorer.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `5`.
169170
- `explorer.use_ray`: Whether to use Ray. Default is `False`.
170171
- `explorer.backend`: The backend used in vLLM. Default is `nccl`.
171172
- `explorer.max_pending_requests`: The maximum number of pending requests. Default is `32`.

examples/async_gsm8k/explorer.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ buffer:
2323
format:
2424
prompt_key: 'question'
2525
response_key: 'answer'
26+
rollout_args:
27+
repeat_times: 8
28+
temperature: 1.0
29+
logprobs: 0
2630
default_workflow_type: 'math_workflow'
2731
trainer_input:
2832
experience_buffer:
@@ -37,10 +41,7 @@ explorer:
3741
enable_prefix_caching: false
3842
enforce_eager: true
3943
dtype: bfloat16
40-
temperature: 1.0
4144
seed: 42
42-
logprobs: 0
43-
repeat_times: 8
4445
use_ray: false
4546
backend: 'nccl'
4647
max_pending_requests: 32

examples/async_gsm8k/trainer.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ buffer:
2222
format:
2323
prompt_key: 'question'
2424
response_key: 'answer'
25+
rollout_args:
26+
repeat_times: 8
27+
temperature: 1.0
28+
logprobs: 0
2529
default_workflow_type: 'math_workflow'
2630
trainer_input:
2731
experience_buffer:
@@ -36,10 +40,7 @@ explorer:
3640
enable_prefix_caching: false
3741
enforce_eager: true
3842
dtype: bfloat16
39-
temperature: 1.0
4043
seed: 42
41-
logprobs: 0
42-
repeat_times: 8
4344
use_ray: false
4445
backend: 'nccl'
4546
max_pending_requests: 32

examples/dpo_humanlike/dpo.yaml

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,6 @@ buffer:
2323
prompt_key: prompt
2424
chosen_key: chosen
2525
rejected_key: rejected
26-
explorer:
27-
engine_type: vllm_async
28-
engine_num: 0
29-
runner_num: 32
30-
tensor_parallel_size: 1
31-
enable_prefix_caching: false
32-
enforce_eager: true
33-
dtype: bfloat16
34-
temperature: 1.0
35-
seed: 42
36-
logprobs: 0
37-
repeat_times: 1 # NOTE
38-
use_ray: false
39-
backend: 'nccl'
40-
max_pending_requests: 32
41-
max_waiting_steps: 4
4226
synchronizer:
4327
sync_method: 'checkpoint'
4428
sync_interval: 30

examples/grpo_alfworld/alfworld.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ buffer:
1919
path: 'scripts/data_prepare/alfworld_data'
2020
format:
2121
prompt_key: 'game_file'
22+
rollout_args:
23+
repeat_times: 8
24+
temperature: 1.0
25+
logprobs: 0
2226
default_workflow_type: 'alfworld_workflow'
2327
trainer_input:
2428
experience_buffer:
@@ -33,10 +37,7 @@ explorer:
3337
enable_prefix_caching: false
3438
enforce_eager: true
3539
dtype: bfloat16
36-
temperature: 1.0
3740
seed: 42
38-
logprobs: 0
39-
repeat_times: 8
4041
use_ray: false
4142
backend: 'nccl'
4243
max_pending_requests: 32

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ buffer:
3636
format:
3737
prompt_key: 'question'
3838
response_key: 'answer'
39+
rollout_args:
40+
repeat_times: 8
41+
temperature: 1.0
42+
logprobs: 0
3943
eval_tasksets:
4044
- name: gsm8k-eval
4145
storage_type: file
@@ -65,10 +69,7 @@ explorer:
6569
enable_prefix_caching: false
6670
enforce_eager: true
6771
dtype: bfloat16
68-
temperature: 1.0
6972
seed: 42
70-
logprobs: 0
71-
repeat_times: 8
7273
use_ray: false
7374
backend: 'nccl'
7475
max_pending_requests: 32

examples/grpo_math/math.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ buffer:
2121
format:
2222
prompt_key: 'question'
2323
response_key: 'gt_answer'
24+
rollout_args:
25+
repeat_times: 8
26+
temperature: 1.0
27+
logprobs: 0
2428
default_workflow_type: 'math_workflow'
2529
trainer_input:
2630
experience_buffer:
@@ -35,10 +39,7 @@ explorer:
3539
enable_prefix_caching: false
3640
enforce_eager: true
3741
dtype: bfloat16
38-
temperature: 1.0
3942
seed: 42
40-
logprobs: 0
41-
repeat_times: 8
4243
use_ray: false
4344
backend: 'nccl'
4445
max_pending_requests: 32

examples/grpo_sciworld/sciworld.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ buffer:
1919
path: 'scripts/data_prepare/sciworld_data'
2020
format:
2121
prompt_key: 'game_file'
22+
rollout_args:
23+
repeat_times: 8
24+
temperature: 1.0
25+
logprobs: 0
2226
default_workflow_type: 'sciworld_workflow'
2327
trainer_input:
2428
experience_buffer:
@@ -33,10 +37,7 @@ explorer:
3337
enable_prefix_caching: false
3438
enforce_eager: true
3539
dtype: bfloat16
36-
temperature: 1.0
3740
seed: 42
38-
logprobs: 0
39-
repeat_times: 8
4041
use_ray: false
4142
backend: 'nccl'
4243
max_pending_requests: 32

examples/grpo_webshop/webshop.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ buffer:
1919
path: 'scripts/data_prepare/webshop_data'
2020
format:
2121
prompt_key: 'task_id'
22+
rollout_args:
23+
repeat_times: 8
24+
temperature: 1.0
25+
logprobs: 0
2226
default_workflow_type: 'webshop_workflow'
2327
trainer_input:
2428
experience_buffer:
@@ -33,10 +37,7 @@ explorer:
3337
enable_prefix_caching: false
3438
enforce_eager: true
3539
dtype: bfloat16
36-
temperature: 1.0
3740
seed: 42
38-
logprobs: 0
39-
repeat_times: 8
4041
use_ray: false
4142
backend: 'nccl'
4243
max_pending_requests: 32

examples/opmd_gsm8k/opmd_gsm8k.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ buffer:
2020
format:
2121
prompt_key: 'question'
2222
response_key: 'answer'
23+
rollout_args:
24+
repeat_times: 8
25+
temperature: 1.0
26+
logprobs: 0
2327
default_workflow_type: 'math_workflow'
2428
trainer_input:
2529
experience_buffer:
@@ -34,10 +38,7 @@ explorer:
3438
enable_prefix_caching: false
3539
enforce_eager: true
3640
dtype: bfloat16
37-
temperature: 1.0
3841
seed: 42
39-
logprobs: 0
40-
repeat_times: 8
4142
use_ray: false
4243
backend: 'nccl'
4344
max_pending_requests: 32

0 commit comments

Comments
 (0)