Skip to content

Commit bbbaf29

Browse files
author
root
committed
开启offload试试#
1 parent 192062c commit bbbaf29

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

agentlightning/verl/daemon.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -459,22 +459,22 @@ def get_train_data_batch(self, max_prompt_length, max_response_length, device):
459459
n_transition = len(input_ids_list)
460460
print("***************************************",n_transition)
461461

462-
# # 直接扔掉多余的 transitions,限制最大数量
463-
MAX_TRANSITIONS = 96
464-
if n_transition > MAX_TRANSITIONS:
465-
# 确保所有列表长度一致
466-
input_ids_list = input_ids_list[:MAX_TRANSITIONS]
467-
input_attention_mask_list = input_attention_mask_list[:MAX_TRANSITIONS]
468-
response_ids_list = response_ids_list[:MAX_TRANSITIONS]
469-
response_attention_mask_list = response_attention_mask_list[:MAX_TRANSITIONS]
470-
reward_list = reward_list[:MAX_TRANSITIONS]
471-
data_id_list = data_id_list[:MAX_TRANSITIONS]
472-
rollout_id_list = rollout_id_list[:MAX_TRANSITIONS]
473-
turn_index_list = turn_index_list[:MAX_TRANSITIONS]
474-
is_drop_list = is_drop_list[:MAX_TRANSITIONS]
462+
# # 直接扔掉多余的 transitions,限制最大数量(会报错)
463+
# MAX_TRANSITIONS = 96
464+
# if n_transition > MAX_TRANSITIONS:
465+
# # 确保所有列表长度一致
466+
# input_ids_list = input_ids_list[:MAX_TRANSITIONS]
467+
# input_attention_mask_list = input_attention_mask_list[:MAX_TRANSITIONS]
468+
# response_ids_list = response_ids_list[:MAX_TRANSITIONS]
469+
# response_attention_mask_list = response_attention_mask_list[:MAX_TRANSITIONS]
470+
# reward_list = reward_list[:MAX_TRANSITIONS]
471+
# data_id_list = data_id_list[:MAX_TRANSITIONS]
472+
# rollout_id_list = rollout_id_list[:MAX_TRANSITIONS]
473+
# turn_index_list = turn_index_list[:MAX_TRANSITIONS]
474+
# is_drop_list = is_drop_list[:MAX_TRANSITIONS]
475475

476-
n_transition = MAX_TRANSITIONS
477-
476+
# n_transition = MAX_TRANSITIONS
477+
# print("********************MAX_TRANSITIONS*******************",n_transition)
478478
batch_input_ids = torch.LongTensor(input_ids_list).to(device)
479479
input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
480480
batch_response_ids = torch.LongTensor(response_ids_list).to(device)

examples/werewolf/train.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ python -m agentlightning.verl \
2424
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
2525
actor_rollout_ref.rollout.multi_turn.format=hermes \
2626
actor_rollout_ref.model.path=${BASE_MODEL} \
27-
data.max_prompt_length=11264 \
27+
data.max_prompt_length=12288 \
2828
data.max_response_length=1024 \
2929
data.truncation='error' \
3030
trainer.val_before_train=True \
@@ -36,12 +36,12 @@ python -m agentlightning.verl \
3636
actor_rollout_ref.actor.clip_ratio_low=0.2 \
3737
actor_rollout_ref.actor.clip_ratio_high=0.3 \
3838
actor_rollout_ref.model.enable_gradient_checkpointing=True \
39-
actor_rollout_ref.actor.fsdp_config.param_offload=False \
40-
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
39+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
40+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
4141
actor_rollout_ref.rollout.name=vllm \
4242
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
4343
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
44-
actor_rollout_ref.ref.fsdp_config.param_offload=False \
44+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
4545
algorithm.use_kl_in_reward=False \
4646
trainer.default_local_dir='/root/dataDisk/checkpoints' \
4747
trainer.rollout_data_dir='/root/dataDisk/rollout' \

0 commit comments

Comments
 (0)