@@ -117,26 +117,28 @@ def test_mix_policy_loss(self):
117117 self .assertTrue (torch .allclose (torch .tensor (metrics ["loss" ]), mix_loss ))
118118
119119 def test_ppo_policy_loss_with_sequence_masking (self ):
120- """Test PPO policy loss with sequence masking enabled"""
121120 policy_loss_fn_cls = POLICY_LOSS_FN .get ("ppo" )
122121 policy_loss_fn_args = policy_loss_fn_cls .default_args ()
123122 policy_loss_fn_args ["enable_sequence_masking" ] = True
124123 policy_loss_fn_args ["delta" ] = 0.1
125124 policy_loss_fn = policy_loss_fn_cls (** policy_loss_fn_args )
126125 loss , metrics = policy_loss_fn (log_prob = self .logprob , ** self .input_data .batch )
127-
128- # Test that sequence masking metrics are present
129- self .assertIn ("seq_mask/masked_tokens" , metrics )
130- self .assertIn ("seq_mask/mean_sequence_kl" , metrics )
131-
132- # Test that masked_tokens is between 0 and 1
133- self .assertGreaterEqual (metrics ["seq_mask/masked_tokens" ], 0.0 )
134- self .assertLessEqual (metrics ["seq_mask/masked_tokens" ], 1.0 )
135-
136- # Test that loss is different from non-masked version (if masking occurred)
137- policy_loss_fn_no_mask = policy_loss_fn_cls (** policy_loss_fn_cls .default_args ())
138- loss_no_mask , _ = policy_loss_fn_no_mask (log_prob = self .logprob , ** self .input_data .batch )
139-
140- # Loss should be different if tokens were masked
141- if metrics ["seq_mask/masked_tokens" ] > 0 :
142- self .assertFalse (torch .allclose (loss , loss_no_mask ))
126+ ppo_loss_masked = torch .tensor (0.22175675630569458 )
127+ pg_clipfrac = torch .tensor (0.3541666567325592 )
128+ ppo_kl = torch .tensor (- 0.21663446724414825 )
129+ pg_clipfrac_lower = torch .tensor (0.0625 )
130+ masked_tokens = torch .tensor (0.16666666666631944 )
131+ mean_sequence_kl = torch .tensor (- 0.21027061343193054 )
132+ self .assertTrue (torch .allclose (loss , ppo_loss_masked ))
133+ self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_clipfrac" ]), pg_clipfrac ))
134+ self .assertTrue (torch .allclose (torch .tensor (metrics ["ppo_kl" ]), ppo_kl ))
135+ self .assertTrue (torch .allclose (torch .tensor (metrics ["pg_loss" ]), ppo_loss_masked ))
136+ self .assertTrue (
137+ torch .allclose (torch .tensor (metrics ["pg_clipfrac_lower" ]), pg_clipfrac_lower )
138+ )
139+ self .assertTrue (
140+ torch .allclose (torch .tensor (metrics ["seq_mask/masked_tokens" ]), masked_tokens )
141+ )
142+ self .assertTrue (
143+ torch .allclose (torch .tensor (metrics ["seq_mask/mean_sequence_kl" ]), mean_sequence_kl )
144+ )
0 commit comments