@@ -35,7 +35,6 @@ def sample_data(self):
3535 return logprobs , ref_logprobs , advantages , padding_mask
3636
3737 @pytest .mark .timeout (10 )
38- @pytest .mark .asyncio
3938 def test_forward_basic (self , loss_fn , sample_data ):
4039 """Test basic forward pass."""
4140 logprobs , ref_logprobs , advantages , padding_mask = sample_data
@@ -48,7 +47,6 @@ def test_forward_basic(self, loss_fn, sample_data):
4847 assert not torch .isnan (loss )
4948
5049 @pytest .mark .timeout (10 )
51- @pytest .mark .asyncio
5250 def test_output_shape (self , loss_fn ):
5351 """Test output shape for different input sizes."""
5452 for batch_size in [1 , 3 , 8 ]:
@@ -62,7 +60,6 @@ def test_output_shape(self, loss_fn):
6260 assert loss .shape == torch .Size ([])
6361
6462 @pytest .mark .timeout (10 )
65- @pytest .mark .asyncio
6663 def test_gradient_flow (self , loss_fn , sample_data ):
6764 """Test that gradients flow through logprobs."""
6865 logprobs , ref_logprobs , advantages , padding_mask = sample_data
@@ -76,7 +73,6 @@ def test_gradient_flow(self, loss_fn, sample_data):
7673 assert torch .isfinite (logprobs .grad ).all ()
7774
7875 @pytest .mark .timeout (10 )
79- @pytest .mark .asyncio
8076 def test_no_gradient_to_ref_logprobs (self , loss_fn , sample_data ):
8177 """Test that gradients don't flow to reference logprobs."""
8278 logprobs , ref_logprobs , advantages , padding_mask = sample_data
@@ -89,7 +85,6 @@ def test_no_gradient_to_ref_logprobs(self, loss_fn, sample_data):
8985 assert ref_logprobs .grad is not None
9086
9187 @pytest .mark .timeout (10 )
92- @pytest .mark .asyncio
9388 def test_padding_mask_effect (self , loss_fn ):
9489 """Test that padding mask correctly ignores padded tokens."""
9590 batch_size , seq_len = 2 , 4
@@ -111,7 +106,6 @@ def test_padding_mask_effect(self, loss_fn):
111106 assert not torch .allclose (loss_full , loss_partial )
112107
113108 @pytest .mark .timeout (10 )
114- @pytest .mark .asyncio
115109 def test_beta_parameter_effect (self , sample_data ):
116110 """Test that different beta values produce different losses."""
117111 logprobs , ref_logprobs , advantages , padding_mask = sample_data
@@ -128,7 +122,6 @@ def test_beta_parameter_effect(self, sample_data):
128122 assert not torch .allclose (loss_1 , loss_2 , atol = 1e-6 )
129123
130124 @pytest .mark .timeout (10 )
131- @pytest .mark .asyncio
132125 def test_zero_advantages (self , loss_fn ):
133126 """Test behavior with zero advantages."""
134127 batch_size , seq_len = 2 , 4
@@ -144,7 +137,6 @@ def test_zero_advantages(self, loss_fn):
144137 assert torch .isfinite (loss )
145138
146139 @pytest .mark .timeout (10 )
147- @pytest .mark .asyncio
148140 def test_identical_policies (self , loss_fn ):
149141 """Test behavior when current and reference policies are identical."""
150142 batch_size , seq_len = 2 , 4
@@ -160,7 +152,6 @@ def test_identical_policies(self, loss_fn):
160152 assert torch .isfinite (loss )
161153
162154 @pytest .mark .timeout (10 )
163- @pytest .mark .asyncio
164155 def test_extreme_values (self , loss_fn ):
165156 """Test with extreme but valid values."""
166157 batch_size , seq_len = 2 , 3
@@ -179,7 +170,6 @@ def test_extreme_values(self, loss_fn):
179170 assert not torch .isnan (loss )
180171
181172 @pytest .mark .timeout (10 )
182- @pytest .mark .asyncio
183173 def test_numerical_stability (self , loss_fn ):
184174 """Test numerical stability with edge cases."""
185175 batch_size , seq_len = 1 , 2
@@ -195,7 +185,6 @@ def test_numerical_stability(self, loss_fn):
195185 assert torch .isfinite (loss )
196186
197187 @pytest .mark .timeout (10 )
198- @pytest .mark .asyncio
199188 def test_all_masked_sequence (self , loss_fn ):
200189 """Test behavior when entire sequence is masked."""
201190 batch_size , seq_len = 1 , 3
@@ -211,7 +200,6 @@ def test_all_masked_sequence(self, loss_fn):
211200 assert torch .isfinite (loss )
212201
213202 @pytest .mark .timeout (10 )
214- @pytest .mark .asyncio
215203 def test_mathematical_correctness (self , loss_fn ):
216204 """Test mathematical correctness with simpler verification."""
217205 # Test with known simple case
0 commit comments