Skip to content

Commit 0a19c36

Browse files
authored
Re-enable associative_scan tests in ref eager mode (#443)
1 parent bf86f2a commit 0a19c36

File tree

1 file changed

+0
-48
lines changed

1 file changed

+0
-48
lines changed

test/test_associative_scan.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,6 @@ def jit_add_combine_fn(x, y):
9999

100100

101101
class TestAssociativeScan(RefEagerTestBase, TestCase):
102-
@skipIfRefEager(
103-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
104-
)
105102
def test_associative_scan_basic_addition(self):
106103
"""Test basic associative_scan functionality with prefix sum."""
107104

@@ -135,9 +132,6 @@ def test_scan_kernel(x: torch.Tensor) -> torch.Tensor:
135132
self.assertIn("param_0 + param_1", code)
136133
self.assertIn("tl.associative_scan", code)
137134

138-
@skipIfRefEager(
139-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
140-
)
141135
def test_associative_scan_maximum(self):
142136
"""Test associative_scan with maximum combine function."""
143137

@@ -170,9 +164,6 @@ def test_max_kernel(x: torch.Tensor) -> torch.Tensor:
170164
"tl.maximum" in code or "triton_helpers.maximum" in code
171165
)
172166

173-
@skipIfRefEager(
174-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
175-
)
176167
def test_associative_scan_multiplication(self):
177168
"""Test associative_scan with multiplication combine function."""
178169

@@ -203,9 +194,6 @@ def test_mul_kernel(x: torch.Tensor) -> torch.Tensor:
203194
# Verify the generated code contains multiplication
204195
self.assertIn("param_0 * param_1", code)
205196

206-
@skipIfRefEager(
207-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
208-
)
209197
def test_associative_scan_minimum(self):
210198
"""Test associative_scan with minimum combine function."""
211199

@@ -238,9 +226,6 @@ def test_min_kernel(x: torch.Tensor) -> torch.Tensor:
238226
"tl.minimum" in code or "triton_helpers.minimum" in code
239227
)
240228

241-
@skipIfRefEager(
242-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
243-
)
244229
def test_associative_scan_multiple_functions(self):
245230
"""Test using multiple different combine functions in one kernel."""
246231

@@ -277,9 +262,6 @@ def test_multi_kernel(x: torch.Tensor) -> torch.Tensor:
277262
"tl.maximum" in code or "triton_helpers.maximum" in code
278263
)
279264

280-
@skipIfRefEager(
281-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
282-
)
283265
def test_associative_scan_type_propagation(self):
284266
"""Test that associative_scan type propagation works correctly."""
285267

@@ -304,9 +286,6 @@ def test_type_kernel(x: torch.Tensor) -> torch.Tensor:
304286
# Use relaxed tolerance for large tensors due to accumulated floating-point errors
305287
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
306288

307-
@skipIfRefEager(
308-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
309-
)
310289
def test_associative_scan_different_dtypes(self):
311290
"""Test associative_scan with different data types."""
312291

@@ -341,9 +320,6 @@ def test_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
341320
expected = expected.to(result.dtype)
342321
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
343322

344-
@skipIfRefEager(
345-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
346-
)
347323
def test_associative_scan_different_sizes(self):
348324
"""Test associative_scan with different tensor sizes."""
349325

@@ -380,9 +356,6 @@ def test_size_kernel(x: torch.Tensor) -> torch.Tensor:
380356
expected = torch.cumsum(x, dim=1)
381357
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
382358

383-
@skipIfRefEager(
384-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
385-
)
386359
def test_associative_scan_reverse(self):
387360
"""Test associative_scan with reverse=True parameter."""
388361

@@ -408,9 +381,6 @@ def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
408381
# Verify reverse parameter is in generated code
409382
self.assertIn("reverse=True", code)
410383

411-
@skipIfRefEager(
412-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
413-
)
414384
def test_associative_scan_edge_cases(self):
415385
"""Test associative_scan edge cases."""
416386

@@ -436,9 +406,6 @@ def test_single_element(x: torch.Tensor) -> torch.Tensor:
436406
expected = torch.tensor([[3.0, 10.0]], device=DEVICE)
437407
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
438408

439-
@skipIfRefEager(
440-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
441-
)
442409
def test_associative_scan_large_scale(self):
443410
"""Test associative_scan with large tensors for performance validation."""
444411

@@ -464,9 +431,6 @@ def test_large_kernel(x: torch.Tensor) -> torch.Tensor:
464431
self.assertEqual(result.shape, x.shape)
465432
self.assertEqual(result.dtype, x.dtype)
466433

467-
@skipIfRefEager(
468-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
469-
)
470434
def test_associative_scan_torch_hops_mapping(self):
471435
"""Test that torch._higher_order_ops.associative_scan automatically maps to hl.associative_scan."""
472436

@@ -502,9 +466,6 @@ def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor:
502466
self.assertIn("tl.associative_scan", code)
503467
self.assertIn("param_0 + param_1", code)
504468

505-
@skipIfRefEager(
506-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
507-
)
508469
def test_associative_scan_code_generation(self):
509470
"""Test that the generated code structure is correct."""
510471

@@ -744,9 +705,6 @@ def cumulative_argmax_kernel(
744705
self.assertIn("def argmax_combine_fn_", code)
745706
self.assertIn("tl.associative_scan", code)
746707

747-
@skipIfRefEager(
748-
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
749-
)
750708
def test_associative_scan_in_helper_function(self):
751709
"""Test calling a function that internally uses hl.associative_scan."""
752710

@@ -808,7 +766,6 @@ def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor:
808766
self.assertIn("param_0 + param_1", code)
809767
self.assertIn("tl.associative_scan", code)
810768

811-
@skipIfRefEager("hl.cumsum is not supported by ref eager mode yet")
812769
def test_cumsum_reverse(self):
813770
"""Test cumsum with reverse=True."""
814771

@@ -890,7 +847,6 @@ def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor:
890847
self.assertIn("param_0 * param_1", code)
891848
self.assertIn("tl.associative_scan", code)
892849

893-
@skipIfRefEager("hl.cumprod is not supported by ref eager mode yet")
894850
def test_cumprod_reverse(self):
895851
"""Test cumprod with reverse=True."""
896852

@@ -914,7 +870,6 @@ def test_cumprod_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
914870
# Verify reverse parameter is used
915871
self.assertIn("reverse=True", code)
916872

917-
@skipIfRefEager("torch.cumprod is not supported by ref eager mode yet")
918873
def test_cumprod_different_dtypes(self):
919874
"""Test cumprod with different data types."""
920875

@@ -1033,9 +988,6 @@ def test_segmented_tuple_kernel(
1033988
self.assertIn("def helion_combine_tuple_fn_", code)
1034989
self.assertIn("tl.associative_scan", code)
1035990

1036-
@skipIfRefEager(
1037-
"torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
1038-
)
1039991
def test_associative_scan_argmax_tuple_format(self):
1040992
"""Test cumulative argmax using tuple format combine function."""
1041993

0 commit comments

Comments
 (0)