@@ -97,9 +97,9 @@ def compute_advantage(
9797 """Compute GDPO advantages.
9898
9999 Args:
100- prompt_ids: Unused; for interface consistency .
100+ prompt_ids: Tensor identifying which prompt each sample belongs to ( for per-prompt baselines) .
101101 rewards: Unused; for interface consistency.
102- repeated_batch: Batch containing _input_ids_for_baseline and reward1, reward2, ... keys.
102+ repeated_batch: Batch containing reward1, reward2, ... keys.
103103 mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
104104 **kwargs: Additional arguments (unused).
105105
@@ -113,20 +113,17 @@ def compute_advantage(
113113 f"This batch has { len (reward_component_keys )} component(s). "
114114 "Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config."
115115 )
116- current_input_ids = repeated_batch ["_input_ids_for_baseline" ]
117- valid = torch .ones_like (
118- repeated_batch [reward_component_keys [0 ]]
119- )
116+ valid = torch .ones_like (repeated_batch [reward_component_keys [0 ]])
120117 leave_one_out = self .use_leave_one_out_baseline
121- assert current_input_ids .shape [0 ] == valid .shape [0 ], (
122- "_input_ids_for_baseline must match reward batch size after dynamic_sampling ; "
123- f"got { current_input_ids .shape [0 ]} vs { valid .shape [0 ]} "
118+ assert prompt_ids .shape [0 ] == valid .shape [0 ], (
119+ "prompt_ids must match reward batch size; "
120+ f"got { prompt_ids .shape [0 ]} vs { valid .shape [0 ]} "
124121 )
125122 advantage_parts = []
126123 for key in reward_component_keys :
127124 r = repeated_batch [key ]
128125 base , std_k = calculate_baseline_and_std_per_prompt (
129- current_input_ids ,
126+ prompt_ids ,
130127 r ,
131128 valid ,
132129 leave_one_out_baseline = leave_one_out ,
0 commit comments