Skip to content

Commit 08330be

Browse files
authored
fix(reasoning-fsdp): fix a small bug in fsdp inference (RLinf#775)
Signed-off-by: Louis-J <czzcy3832515@gmail.com>
1 parent 874d268 commit 08330be

File tree

3 files changed

+8
-0
lines changed

3 files changed

+8
-0
lines changed

rlinf/data/io_struct.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,9 @@ def merge_batches(
878878
if len(batches) == 1:
879879
return batches[0]
880880

881+
assert all(batch.keys() == batches[0].keys() for batch in batches[1:]), (
882+
"All batches must have the same keys"
883+
)
881884
for key in batches[0].keys():
882885
if torch.is_tensor(batches[0][key]):
883886
merged_batch[key] = torch.cat([batch[key] for batch in batches], dim=0)

rlinf/workers/actor/fsdp_actor_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ def get_dynamic_batch_as_much(
426426
last_result_len = result_len
427427
result_len = all_reduce_int(len(rollout_results))
428428

429+
cliped_results = list(rollout_results[result_len:])
430+
rollout_results = rollout_results[:result_len]
431+
429432
batches = []
430433
for rollout_result in rollout_results:
431434
batch = rollout_result.to_actor_batch(

rlinf/workers/reward/reward_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def compute_rewards(self, input_channel: Channel, output_channel: Channel):
9292
rollout_result
9393
)
9494
rollout_result = self.down_sample_batch(rollout_result)
95+
# answer is not needed in training
96+
rollout_result.answers = None
9597
output_channel.put(rollout_result, async_op=True)
9698

9799
assert recv_batch_size == self.total_batch_size_per_dp, (

0 commit comments

Comments
 (0)