@@ -88,6 +88,16 @@ def update_policy(self, data: DataProto): # noqa: C901
8888
8989 mini_batches = data .split (self .config .ppo_mini_batch_size )
9090
91+ # EXPERIMENTAL: apply loss scale fix
92+ loss_agg_mode = (
93+ self .policy_loss_fn .loss_agg_mode
94+ if hasattr (self .policy_loss_fn , "loss_agg_mode" )
95+ else "token-mean"
96+ )
97+ do_fix_actor_microbatch_loss_scale = self .config .fix_actor_microbatch_loss_scale and (
98+ loss_agg_mode == "token-mean"
99+ )
100+
91101 metrics = {}
92102 for _ in range (self .config .ppo_epochs ):
93103 for batch_idx , mini_batch in enumerate (mini_batches ):
@@ -104,6 +114,12 @@ def update_policy(self, data: DataProto): # noqa: C901
104114 )
105115 micro_batches = mini_batch .split (self .config .ppo_micro_batch_size_per_gpu )
106116
117+ if do_fix_actor_microbatch_loss_scale :
118+ # calculate the total number of response tokens in the minibatch
119+ mini_batch_token_num = torch .sum (
120+ mini_batch .batch ["response_mask" ].to (get_device_id ())
121+ ).item ()
122+
107123 self .actor_optimizer .zero_grad ()
108124
109125 for micro_batch in micro_batches :
@@ -156,13 +172,19 @@ def update_policy(self, data: DataProto): # noqa: C901
156172 )
157173 policy_loss = policy_loss + kl_loss
158174
159- if self .config .use_dynamic_bsz :
160- # relative to the dynamic bsz
161- loss = policy_loss * (
162- response_mask .shape [0 ] / self .config .ppo_mini_batch_size
163- )
175+ # set loss scale for the microbatch
176+ if not do_fix_actor_microbatch_loss_scale :
177+ # original implementation of microbatch loss scale
178+ if self .config .use_dynamic_bsz :
179+ loss_scale = response_mask .shape [0 ] / self .config .ppo_mini_batch_size
180+ else :
181+ loss_scale = 1.0 / self .gradient_accumulation
164182 else :
165- loss = policy_loss / self .gradient_accumulation
183+ # EXPERIMENTAL: fix for token-mean loss aggregation
184+ # scale microbatch loss according to the number of tokens (rather than sequences)
185+ loss_scale = torch .sum (response_mask ).item () / (mini_batch_token_num + 1e-6 )
186+
187+ loss = policy_loss * loss_scale
166188 loss .backward ()
167189
168190 append_to_dict (metrics , micro_batch_metrics )
0 commit comments