Skip to content

Commit 638d335

Browse files
authored
Refactor MultiTurnWorkflow class and add SciWorld ENV (#9)
1 parent baa1c3f commit 638d335

File tree

10 files changed

+444
-119
lines changed

10 files changed

+444
-119
lines changed

docs/sphinx_doc/source/tutorial/example_multi_turn.md

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ You may refer to their original environment to complete the setup.
2121
### Data Preparation
2222
Our dataset follows the format in Huggingface datasets library, so we should correspondingly convert our env dataset.
2323

24-
Just run the following command.
24+
Just check the data preparation scripts and run the following command.
2525
```bash
2626
# For ALFworld env
2727
python scripts/data_prepare/get_alfworld_data.py
@@ -53,18 +53,16 @@ We provide an easy way to allow you build your own environment pipeline by creat
5353

5454
See the `trinity/common/workflows/envs/alfworld/alfworld_workflow.py` as an example on how to construct a multi-round workflow.
5555

56-
You can interact with environment using the messages format, and call the `self.process_batch_messages` function to transform the messages and rewards into the `experience` we need, and send them to buffer.
56+
You can interact with environment using the messages format, and call the `self.process_messages_to_experience` function to transform the messages and rewards into the `experience` we need, and send them to buffer.
5757

5858
```python
59-
class AlfworldWorkflow(Workflow):
59+
class AlfworldWorkflow(MultiTurnWorkflow):
6060
"""A workflow for alfworld task."""
6161
...
6262

6363
def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
6464
print("Generating env inference samples...")
65-
all_messages = []
66-
all_rewards = []
67-
all_infos = []
65+
experience_list = []
6866
for i in range(rollout_num):
6967
observation, info = env.reset()
7068
final_reward = -0.1
@@ -80,14 +78,13 @@ class AlfworldWorkflow(Workflow):
8078
if done:
8179
final_reward = reward
8280
break
83-
all_infos.append(
84-
{"env_rounds": r, "env_done": 1 if done else 0}
81+
experience = self.process_messages_to_experience(
82+
memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
8583
)
86-
all_messages.append(memory)
87-
all_rewards.append(final_reward)
84+
experience_list.append(experience)
8885
# Close the env to save cpu memory
8986
env.close()
90-
return self.process_batch_messages(all_messages, all_rewards, all_infos=all_infos)
87+
return experience_list
9188

9289

9390
def run(self) -> List[Experience]:
@@ -102,7 +99,7 @@ class AlfworldWorkflow(Workflow):
10299
Also, remember to register your workflow:
103100
```python
104101
@WORKFLOWS.register_module("alfworld_workflow")
105-
class AlfworldWorkflow(Workflow):
102+
class AlfworldWorkflow(MultiTurnWorkflow):
106103
"""A workflow for alfworld task."""
107104
...
108105
```

scripts/config/sciworld.yaml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
data:
2+
total_epoch: 20
3+
batch_size: 4
4+
dataset_path: 'scripts/data_prepare/sciworld_data'
5+
default_workflow_type: 'sciworld_workflow'
6+
train_split: 'train'
7+
eval_split: ''
8+
format_config:
9+
prompt_key: 'game_file'
10+
model:
11+
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
12+
max_prompt_tokens: 4096
13+
max_response_tokens: 16384
14+
checkpoint_path: 'checkpoints/sciworld_RFT'
15+
cluster:
16+
node_num: 1
17+
gpu_per_node: 8
18+
buffer:
19+
max_retry_times: 3
20+
max_retry_interval: 1
21+
train_dataset:
22+
name: sciworld_buffer
23+
storage_type: queue
24+
algorithm_type: ppo
25+
path: 'sqlite:///sciworld.db'
26+
explorer:
27+
engine_type: vllm_async
28+
engine_num: 2
29+
runner_num: 32
30+
tensor_parallel_size: 2
31+
enable_prefix_caching: false
32+
enforce_eager: true
33+
dtype: bfloat16
34+
temperature: 1.0
35+
top_p: 1.0
36+
top_k: -1
37+
seed: 42
38+
logprobs: 0
39+
repeat_times: 8
40+
use_ray: false
41+
backend: 'nccl'
42+
max_pending_requests: 32
43+
max_waiting_steps: 4
44+
gpu_memory_utilization: 0.7
45+
enable_chunked_prefil: true
46+
synchronizer:
47+
sync_method: 'online'
48+
sync_iteration_interval: 8
49+
trainer:
50+
trainer_type: 'verl'
51+
algorithm_type: ppo
52+
trainer_config_path: 'scripts/config/train_sciworld.yaml'
53+
monitor:
54+
cache_root_dir: ""
55+
project: "sciworld"
56+
name: "sciworld_RFT"

scripts/config/train_sciworld.yaml

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
data:
2+
tokenizer: null
3+
train_files: train_example.parquet
4+
val_files: test_example.parquet
5+
prompt_key: prompt
6+
max_prompt_length: 4096
7+
max_response_length: 16384
8+
train_batch_size: 96
9+
val_batch_size: null
10+
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
11+
return_raw_chat: False
12+
shuffle: True
13+
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
14+
truncation: error
15+
image_key: images
16+
17+
actor_rollout_ref:
18+
hybrid_engine: True
19+
model:
20+
path: /PATH/TO/MODEL/CHECKPOINT/
21+
external_lib: null
22+
override_config: { }
23+
enable_gradient_checkpointing: True
24+
use_remove_padding: False
25+
actor:
26+
strategy: fsdp # This is for backward-compatibility
27+
ppo_mini_batch_size: 1536
28+
# ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu
29+
ppo_micro_batch_size_per_gpu: 1
30+
use_dynamic_bsz: False
31+
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
32+
grad_clip: 1.0
33+
clip_ratio: 0.2
34+
entropy_coeff: 0.001
35+
use_kl_loss: True # True for GRPO
36+
kl_loss_coef: 0.001 # for grpo
37+
kl_loss_type: low_var_kl # for grpo
38+
ppo_epochs: 1
39+
shuffle: False
40+
ulysses_sequence_parallel_size: 1 # sp size
41+
optim:
42+
lr: 1e-6
43+
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
44+
# min_lr_ratio: null # only useful for warmup with cosine
45+
warmup_style: constant # select from constant/cosine
46+
total_training_steps: -1 # must be override by program
47+
fsdp_config:
48+
wrap_policy:
49+
# transformer_layer_cls_to_wrap: None
50+
min_num_params: 0
51+
param_offload: False
52+
optimizer_offload: False
53+
fsdp_size: -1
54+
ref:
55+
fsdp_config:
56+
param_offload: False
57+
wrap_policy:
58+
# transformer_layer_cls_to_wrap: None
59+
min_num_params: 0
60+
# log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu
61+
log_prob_micro_batch_size_per_gpu: 1
62+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
63+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
64+
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
65+
rollout:
66+
name: vllm
67+
temperature: 1.0
68+
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
69+
top_p: 1
70+
use_fire_sampling: False # https://arxiv.org/abs/2410.21236
71+
prompt_length: ${data.max_prompt_length} # not use for opensource
72+
response_length: ${data.max_response_length}
73+
# for vllm rollout
74+
dtype: bfloat16 # should align with FSDP
75+
gpu_memory_utilization: 0.4
76+
ignore_eos: False
77+
enforce_eager: True
78+
free_cache_engine: True
79+
load_format: dummy_dtensor
80+
tensor_model_parallel_size: 1
81+
max_num_batched_tokens: 8192
82+
max_model_len: null
83+
max_num_seqs: 1024
84+
# log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu
85+
log_prob_micro_batch_size_per_gpu: 1
86+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
87+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
88+
disable_log_stats: True
89+
enable_chunked_prefill: True # could get higher throughput
90+
# for hf rollout
91+
do_sample: True
92+
# number of responses (i.e. num sample times)
93+
n: 8 # should be > 1 for grpo; Currently is unused parameter
94+
95+
critic:
96+
strategy: fsdp
97+
optim:
98+
lr: 1e-5
99+
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
100+
# min_lr_ratio: null # only useful for warmup with cosine
101+
warmup_style: constant # select from constant/cosine
102+
total_training_steps: -1 # must be override by program
103+
model:
104+
path: /PATH/TO/MODEL/CHECKPOINT/
105+
tokenizer_path: ${actor_rollout_ref.model.path}
106+
override_config: { }
107+
external_lib: ${actor_rollout_ref.model.external_lib}
108+
enable_gradient_checkpointing: True
109+
use_remove_padding: False
110+
fsdp_config:
111+
param_offload: False
112+
optimizer_offload: False
113+
wrap_policy:
114+
# transformer_layer_cls_to_wrap: None
115+
min_num_params: 0
116+
fsdp_size: -1
117+
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
118+
# ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu
119+
ppo_micro_batch_size_per_gpu: 1
120+
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
121+
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
122+
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
123+
ppo_max_token_len_per_gpu: 16384 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
124+
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
125+
ulysses_sequence_parallel_size: 1 # sp size
126+
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
127+
shuffle: ${actor_rollout_ref.actor.shuffle}
128+
grad_clip: 1.0
129+
cliprange_value: 0.5
130+
131+
reward_model:
132+
enable: False
133+
strategy: fsdp
134+
model:
135+
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
136+
path: ~/models/FsfairX-LLaMA3-RM-v0.1
137+
external_lib: ${actor_rollout_ref.model.external_lib}
138+
use_remove_padding: False
139+
fsdp_config:
140+
min_num_params: 0
141+
param_offload: False
142+
fsdp_size: -1
143+
# micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
144+
# micro_batch_size_per_gpu: 2 # set a number
145+
# max_length: null
146+
ulysses_sequence_parallel_size: 1 # sp size
147+
use_dynamic_bsz: ${critic.use_dynamic_bsz}
148+
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
149+
150+
custom_reward_function:
151+
path: null
152+
name: compute_score
153+
154+
algorithm:
155+
gamma: 1.0
156+
lam: 1.0
157+
adv_estimator: grpo
158+
kl_penalty: kl # how to estimate kl divergence
159+
kl_ctrl:
160+
type: fixed
161+
kl_coef: 0.001
162+
163+
trainer:
164+
balance_batch: True
165+
total_epochs: 15
166+
# total_training_steps: null
167+
project_name: sciworld
168+
experiment_name: sciworld_RFT
169+
logger: [ 'wandb' ]
170+
val_generations_to_log_to_wandb: 0
171+
nnodes: 1
172+
n_gpus_per_node: 2
173+
save_freq: 1
174+
# auto: find the last ckpt to resume. If can't find, start from scratch
175+
resume_mode: auto # or auto or resume_path if
176+
resume_from_path: False
177+
test_freq: 100
178+
critic_warmup: 0
179+
default_hdfs_dir: null
180+
remove_previous_ckpt_in_save: False
181+
del_local_ckpt_after_load: False
182+
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
183+
val_before_train: False

scripts/data_prepare/get_alfworld_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def create_dataset_files(output_dir, train_size=1024, test_size=100):
3939
# create dataset_dict
4040
dataset_dict = {"train": train_data, "test": test_data}
4141

42-
# 保存为jsonl格式
4342
for split, data in dataset_dict.items():
4443
output_file = os.path.join(output_dir, f"{split}.jsonl")
4544
with open(output_file, "w") as f:

0 commit comments

Comments
 (0)