@@ -1010,25 +1010,24 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func)
10101010 reward_func_name = reward_func .__class__ .__name__
10111011
10121012 reward_func_names .append (reward_func_name )
1013- metrics_mask = ~ agg_truncated_mask if self .args .overlong_filter else torch .ones (
1014- agg_completion_mask .shape [0 ], dtype = torch .bool )
1013+
10151014 for i , reward_func_name in enumerate (reward_func_names ):
1016- mean_rewards = ( rewards_per_func [:, i ][ metrics_mask ]) .mean ().item ()
1015+ mean_rewards = rewards_per_func [:, i ].mean ().item ()
10171016 self ._metrics [mode ][f'rewards/{ reward_func_name } /mean' ].append (mean_rewards )
1018- std_rewards = ( rewards_per_func [:, i ][ metrics_mask ]) .std ().item ()
1017+ std_rewards = rewards_per_func [:, i ].std ().item ()
10191018 self ._metrics [mode ][f'rewards/{ reward_func_name } /std' ].append (std_rewards )
10201019
10211020 # Log overall reward stats
1022- grouped_rewards = rewards [ metrics_mask ] .view (- 1 , self .num_generations )
1021+ grouped_rewards = rewards .view (- 1 , self .num_generations )
10231022 self ._metrics [mode ]['reward' ].append (grouped_rewards .mean ().item ())
10241023 self ._metrics [mode ]['reward_std' ].append (grouped_rewards .std (dim = 1 ).mean ().item ())
10251024
10261025 # Log prompt and completion texts
1027- self ._textual_logs ['prompt' ].extend (m for m , mask in zip ( gather_object (messages ), metrics_mask ) if mask )
1028- self ._textual_logs ['completion' ].extend (c for c , mask in zip ( gather_object (completions ), metrics_mask ) if mask )
1026+ self ._textual_logs ['prompt' ].extend (gather_object (messages ))
1027+ self ._textual_logs ['completion' ].extend (gather_object (completions ))
10291028
10301029 for i , name in enumerate (reward_func_names ):
1031- self ._textual_logs ['rewards' ][name ].extend (rewards_per_func [:, i ][ metrics_mask ] .tolist ())
1030+ self ._textual_logs ['rewards' ][name ].extend (rewards_per_func [:, i ].tolist ())
10321031
10331032 @profiling_decorator
10341033 def compute_loss (self , model , inputs , return_outputs = False , num_items_in_batch = None ):
@@ -1077,7 +1076,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
10771076 mean_kl = (per_token_kl * completion_mask ).sum () / completions_length
10781077 metrics ['kl' ] = mean_kl
10791078
1080- is_clipped = (coef_1 < (1 - self .epsilon_low )) | (coef_1 > (1 + self .epsilon_high ))
1079+ is_clipped = ((coef_1 < 1 - self .epsilon_low ) &
1080+ (advantages .unsqueeze (1 ) < 0 )) | ((coef_1 > 1 + self .epsilon_high ) &
1081+ (advantages .unsqueeze (1 ) > 0 ))
1082+
10811083 clip_ratio = (is_clipped * completion_mask ).sum () / completions_length
10821084 metrics ['clip_ratio' ] = clip_ratio
10831085
0 commit comments