Skip to content

Commit 1e6e960

Browse files
author
root
committed
回退到batch=1
1 parent 5987d32 commit 1e6e960

File tree

3 files changed

+16
-26
lines changed

3 files changed

+16
-26
lines changed

agentlightning/runner.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,9 @@ async def run_async(self) -> bool:
240240
# Pass the task input, not the whole task object
241241
result = await rollout_method(task.input, task.rollout_id, resources_update.resources)
242242
#降低最大rollout
243-
if len(result) > 5:
243+
if len(result) > 40:
244244
import random
245-
result = random.sample(result,5)
245+
result = random.sample(result,40)
246246
rollout_obj = self._to_rollout_object(result, task.rollout_id)
247247
end_time = time.time()
248248
logger.info(
@@ -254,11 +254,16 @@ async def run_async(self) -> bool:
254254
logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
255255
MAX_TRY = MAX_TRY - 1
256256
finally:
257-
try:
258-
self.agent.on_rollout_end(task, rollout_obj, self, self.tracer)
259-
except Exception:
260-
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
261-
await self.client.post_rollout_async(rollout_obj)
257+
if rollout_obj.triplets:
258+
try:
259+
self.agent.on_rollout_end(task, rollout_obj, self, self.tracer)
260+
except Exception:
261+
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
262+
await self.client.post_rollout_async(rollout_obj)
263+
else:
264+
print("Warning: error occured ,empty triplets")
265+
if MAX_TRY == 0:
266+
raise Exception("rollout_obj.triplets is EMPTY")
262267
return True
263268

264269
async def iter_async(self) -> int:

agentlightning/verl/daemon.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -458,23 +458,8 @@ def get_train_data_batch(self, max_prompt_length, max_response_length, device):
458458
print(reward_list)
459459
n_transition = len(input_ids_list)
460460
print("***************************************",n_transition)
461-
462-
# # 直接扔掉多余的 transitions,限制最大数量(会报错)
463-
# MAX_TRANSITIONS = 96
464-
# if n_transition > MAX_TRANSITIONS:
465-
# # 确保所有列表长度一致
466-
# input_ids_list = input_ids_list[:MAX_TRANSITIONS]
467-
# input_attention_mask_list = input_attention_mask_list[:MAX_TRANSITIONS]
468-
# response_ids_list = response_ids_list[:MAX_TRANSITIONS]
469-
# response_attention_mask_list = response_attention_mask_list[:MAX_TRANSITIONS]
470-
# reward_list = reward_list[:MAX_TRANSITIONS]
471-
# data_id_list = data_id_list[:MAX_TRANSITIONS]
472-
# rollout_id_list = rollout_id_list[:MAX_TRANSITIONS]
473-
# turn_index_list = turn_index_list[:MAX_TRANSITIONS]
474-
# is_drop_list = is_drop_list[:MAX_TRANSITIONS]
475-
476-
# n_transition = MAX_TRANSITIONS
477-
# print("********************MAX_TRANSITIONS*******************",n_transition)
461+
if n_transition == 0:
462+
raise Exception("Empty transitions !!!!!!!")
478463
batch_input_ids = torch.LongTensor(input_ids_list).to(device)
479464
input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
480465
batch_response_ids = torch.LongTensor(response_ids_list).to(device)

examples/werewolf/train.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ python -m agentlightning.verl \
1717
data.val_files=${DATA_DIR}/test.parquet \
1818
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
1919
trainer.n_gpus_per_node=${N_GPUS} \
20-
data.train_batch_size=4 \
21-
actor_rollout_ref.rollout.n=4 \
20+
data.train_batch_size=1 \
21+
actor_rollout_ref.rollout.n=1 \
2222
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
2323
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
2424
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \

0 commit comments

Comments
 (0)