Skip to content

Commit 7d0bd0e

Browse files
committed
[Fix]: sft multi-turn training error
1 parent 9c7ecca commit 7d0bd0e

File tree

4 files changed

+40
-18
lines changed

4 files changed

+40
-18
lines changed

docs/SFT_GUIDE_EN.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
1+
# SFT GUIDE
2+
3+
14
```
2-
nohup ./run_qwen_05_sp2.sh 4 /data1/models/openmanus_rl/Qwen/Qwen3-3b-sft \
5+
nohup ./run_sft.sh 4 /data1/models/openmanus_rl/Qwen/Qwen3-3b-sft \
36
data.truncation=right \
47
trainer.total_training_steps=1000 \
5-
++actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
6-
++critic.model.fsdp_config.model_dtype=bfloat16 \
78
trainer.logger="['console','wandb']" \
89
trainer.project_name="OpenManus-rl" \
910
> training_run.log 2>&1 &
1011
```
1112

12-
You need to clone a new verl codebase, and use verl conda environment to run this multi-turn sft script.
13-
14-
You should copy openmanus-rl/scripts/run_sft.sh to verl/examples/sft/multiturn/
15-
then run the script
16-
1713

1814
```
19-
./run_qwen_05_sp2.sh 4 /data1/models/openmanus_rl/Qwen/Qwen3-3b-sft data.truncation=right trainer.total_training_steps=30 trainer.logger="['console','wandb']" trainer.project_name="OpenManus-rl"
20-
```
15+
./run_sft.sh 4 /data1/models/openmanus_rl/Qwen/Qwen3-3b-sft data.truncation=right trainer.total_training_steps=30 trainer.logger="['console','wandb']" trainer.project_name="OpenManus-rl"
16+
```

scripts/run_sft.sh

100644100755
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ if [ -f "$CONDA_BASE_DIR/etc/profile.d/conda.sh" ]; then
1717
else
1818
echo "Conda base profile script not found at $CONDA_BASE_DIR/etc/profile.d/conda.sh"
1919
fi
20+
export WANDB_API_KEY= # TODO: add your wandb api key here
21+
wandb login
2022

2123
nproc_per_node=$1
2224
save_path=$2
@@ -31,7 +33,7 @@ if [ "$use_all_gpu" = "true" ]; then
3133
tensor_parallel_size=8
3234
echo "Configured to use 8 GPUs: CUDA_VISIBLE_DEVICES=$visible_devices, tensor_parallel_size=$tensor_parallel_size"
3335
else
34-
visible_devices="4,5,6,7"
36+
visible_devices="0,1,2,3"
3537
tensor_parallel_size=4
3638
echo "Configured to use 4 GPUs: CUDA_VISIBLE_DEVICES=$visible_devices, tensor_parallel_size=$tensor_parallel_size"
3739
fi
@@ -41,12 +43,12 @@ fi
4143
CUDA_VISIBLE_DEVICES="$visible_devices" \
4244
torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
4345
-m verl.trainer.fsdp_sft_trainer \
44-
data.train_files=$HOME/muxin/OpenManus-RL/data/train.parquet \
45-
data.val_files=$HOME/muxin/OpenManus-RL/data/test.parquet \
46+
data.train_files=../data/train.parquet \
47+
data.val_files=../data/test.parquet \
4648
data.multiturn.enable=true \
4749
data.multiturn.messages_key=conversations \
4850
data.micro_batch_size=4 \
49-
model.partial_pretrain=/data1/models/Qwen/Qwen3-4B \
51+
model.partial_pretrain=/data1/models/Qwen/Qwen2.5-3B \ # TODO: add your model path here
5052
trainer.default_local_dir=$save_path \
5153
trainer.project_name=multiturn-sft \
5254
trainer.experiment_name=multiturn-sft-qwen-3-4b \

verl/trainer/fsdp_sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _build_model_optimizer(self):
172172
with init_context():
173173
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
174174
config=config,
175-
torch_dtype=torch.float32,
175+
torch_dtype="auto",
176176
attn_implementation='flash_attention_2',
177177
trust_remote_code=trust_remote_code)
178178
if self.config.model.get('lora_rank', 0) > 0:

verl/utils/dataset/multiturn_sft_dataset.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ def series_to_item(ls):
6464

6565
while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1:
6666
ls = ls[0]
67+
68+
# Convert numpy array to list if needed
69+
if isinstance(ls, numpy.ndarray):
70+
ls = ls.tolist()
71+
elif isinstance(ls, pandas.core.series.Series):
72+
ls = ls.tolist()
73+
74+
# If ls is a single dictionary with 'role' and 'content', wrap it in a list
75+
if isinstance(ls, dict) and 'role' in ls and 'content' in ls:
76+
ls = [ls]
77+
78+
# Verify the structure if it's a list
79+
if isinstance(ls, list):
80+
for i, item in enumerate(ls):
81+
if not isinstance(item, dict) or 'role' not in item or 'content' not in item:
82+
raise ValueError(f"Invalid message format at index {i}: {item}")
83+
6784
return ls
6885

6986
dataframes = []
@@ -75,14 +92,21 @@ def series_to_item(ls):
7592
# Extract messages list from dataframe
7693
self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist()
7794

95+
dataframes = []
96+
for parquet_file in self.parquet_files:
97+
dataframe = pd.read_parquet(parquet_file)
98+
dataframes.append(dataframe)
99+
self.dataframe = pd.concat(dataframes)
100+
101+
# Extract messages list from dataframe
102+
self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist()
103+
78104
def __len__(self):
79105
return len(self.messages)
80106

81107
def __getitem__(self, item):
82108
tokenizer = self.tokenizer
83109
messages = self.messages[item]
84-
85-
# First, get the full conversation tokens
86110
full_tokens = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=False)
87111
input_ids = full_tokens[0] # The output is already a tensor
88112
attention_mask = torch.ones_like(input_ids)
@@ -143,4 +167,4 @@ def __getitem__(self, item):
143167
"attention_mask": attention_mask,
144168
"position_ids": position_ids,
145169
"loss_mask": loss_mask,
146-
}
170+
}

0 commit comments

Comments
 (0)