Skip to content

Commit dfd53e1

Browse files
nbasylyuki-97
authored andcommitted
remove redudant repeat_batch key
Signed-off-by: Shih-Yang Liu <shihyangl@nvidia.com>
1 parent 91bd7c6 commit dfd53e1

File tree

2 files changed

+7
-16
lines changed

2 files changed

+7
-16
lines changed

nemo_rl/algorithms/advantage_estimator.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def compute_advantage(
9797
"""Compute GDPO advantages.
9898
9999
Args:
100-
prompt_ids: Unused; for interface consistency.
100+
prompt_ids: Tensor identifying which prompt each sample belongs to (for per-prompt baselines).
101101
rewards: Unused; for interface consistency.
102-
repeated_batch: Batch containing _input_ids_for_baseline and reward1, reward2, ... keys.
102+
repeated_batch: Batch containing reward1, reward2, ... keys.
103103
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
104104
**kwargs: Additional arguments (unused).
105105
@@ -113,20 +113,17 @@ def compute_advantage(
113113
f"This batch has {len(reward_component_keys)} component(s). "
114114
"Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config."
115115
)
116-
current_input_ids = repeated_batch["_input_ids_for_baseline"]
117-
valid = torch.ones_like(
118-
repeated_batch[reward_component_keys[0]]
119-
)
116+
valid = torch.ones_like(repeated_batch[reward_component_keys[0]])
120117
leave_one_out = self.use_leave_one_out_baseline
121-
assert current_input_ids.shape[0] == valid.shape[0], (
122-
"_input_ids_for_baseline must match reward batch size after dynamic_sampling; "
123-
f"got {current_input_ids.shape[0]} vs {valid.shape[0]}"
118+
assert prompt_ids.shape[0] == valid.shape[0], (
119+
"prompt_ids must match reward batch size; "
120+
f"got {prompt_ids.shape[0]} vs {valid.shape[0]}"
124121
)
125122
advantage_parts = []
126123
for key in reward_component_keys:
127124
r = repeated_batch[key]
128125
base, std_k = calculate_baseline_and_std_per_prompt(
129-
current_input_ids,
126+
prompt_ids,
130127
r,
131128
valid,
132129
leave_one_out_baseline=leave_one_out,

nemo_rl/algorithms/grpo.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,10 +1605,6 @@ def grpo_train(
16051605
with timer.time("reward_calculation"):
16061606
# Extract rewards from final_batch
16071607
rewards = repeated_batch["total_reward"]
1608-
# Store input_ids in batch so that after dynamic_sampling it stays aligned with
1609-
# the (possibly filtered) batch: select_indices / from_batches / slice all
1610-
# apply to this key, so per-reward baselines use the same prompts as reward components.
1611-
repeated_batch["_input_ids_for_baseline"] = input_ids
16121608

16131609
print("▶ Computing advantages...", flush=True)
16141610
if master_config["grpo"].get("calculate_advantages_on_gpu"):
@@ -2744,8 +2740,6 @@ def async_grpo_train(
27442740
del prompt_batched_flat
27452741

27462742
rewards = repeated_batch["total_reward"]
2747-
# All estimators read _input_ids_for_baseline from repeated_batch
2748-
repeated_batch["_input_ids_for_baseline"] = prompt_ids_for_adv
27492743

27502744
print(
27512745
f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}"

0 commit comments

Comments
 (0)