Skip to content

Commit 11c9b1d

Browse files
authored
Skip associative_scan tests in ref eager mode (#433)
1 parent eab7179 commit 11c9b1d

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

test/test_associative_scan.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def jit_add_combine_fn(x, y):
9999

100100

101101
class TestAssociativeScan(RefEagerTestBase, TestCase):
102+
@skipIfRefEager(
103+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
104+
)
102105
def test_associative_scan_basic_addition(self):
103106
"""Test basic associative_scan functionality with prefix sum."""
104107

@@ -132,6 +135,9 @@ def test_scan_kernel(x: torch.Tensor) -> torch.Tensor:
132135
self.assertIn("param_0 + param_1", code)
133136
self.assertIn("tl.associative_scan", code)
134137

138+
@skipIfRefEager(
139+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
140+
)
135141
def test_associative_scan_maximum(self):
136142
"""Test associative_scan with maximum combine function."""
137143

@@ -164,6 +170,9 @@ def test_max_kernel(x: torch.Tensor) -> torch.Tensor:
164170
"tl.maximum" in code or "triton_helpers.maximum" in code
165171
)
166172

173+
@skipIfRefEager(
174+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
175+
)
167176
def test_associative_scan_multiplication(self):
168177
"""Test associative_scan with multiplication combine function."""
169178

@@ -194,6 +203,9 @@ def test_mul_kernel(x: torch.Tensor) -> torch.Tensor:
194203
# Verify the generated code contains multiplication
195204
self.assertIn("param_0 * param_1", code)
196205

206+
@skipIfRefEager(
207+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
208+
)
197209
def test_associative_scan_minimum(self):
198210
"""Test associative_scan with minimum combine function."""
199211

@@ -226,6 +238,9 @@ def test_min_kernel(x: torch.Tensor) -> torch.Tensor:
226238
"tl.minimum" in code or "triton_helpers.minimum" in code
227239
)
228240

241+
@skipIfRefEager(
242+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
243+
)
229244
def test_associative_scan_multiple_functions(self):
230245
"""Test using multiple different combine functions in one kernel."""
231246

@@ -262,6 +277,9 @@ def test_multi_kernel(x: torch.Tensor) -> torch.Tensor:
262277
"tl.maximum" in code or "triton_helpers.maximum" in code
263278
)
264279

280+
@skipIfRefEager(
281+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
282+
)
265283
def test_associative_scan_type_propagation(self):
266284
"""Test that associative_scan type propagation works correctly."""
267285

@@ -286,6 +304,9 @@ def test_type_kernel(x: torch.Tensor) -> torch.Tensor:
286304
# Use relaxed tolerance for large tensors due to accumulated floating-point errors
287305
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
288306

307+
@skipIfRefEager(
308+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
309+
)
289310
def test_associative_scan_different_dtypes(self):
290311
"""Test associative_scan with different data types."""
291312

@@ -320,6 +341,9 @@ def test_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
320341
expected = expected.to(result.dtype)
321342
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
322343

344+
@skipIfRefEager(
345+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
346+
)
323347
def test_associative_scan_different_sizes(self):
324348
"""Test associative_scan with different tensor sizes."""
325349

@@ -356,6 +380,9 @@ def test_size_kernel(x: torch.Tensor) -> torch.Tensor:
356380
expected = torch.cumsum(x, dim=1)
357381
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
358382

383+
@skipIfRefEager(
384+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
385+
)
359386
def test_associative_scan_reverse(self):
360387
"""Test associative_scan with reverse=True parameter."""
361388

@@ -381,6 +408,9 @@ def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
381408
# Verify reverse parameter is in generated code
382409
self.assertIn("reverse=True", code)
383410

411+
@skipIfRefEager(
412+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
413+
)
384414
def test_associative_scan_edge_cases(self):
385415
"""Test associative_scan edge cases."""
386416

@@ -406,6 +436,9 @@ def test_single_element(x: torch.Tensor) -> torch.Tensor:
406436
expected = torch.tensor([[3.0, 10.0]], device=DEVICE)
407437
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
408438

439+
@skipIfRefEager(
440+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
441+
)
409442
def test_associative_scan_large_scale(self):
410443
"""Test associative_scan with large tensors for performance validation."""
411444

@@ -431,6 +464,9 @@ def test_large_kernel(x: torch.Tensor) -> torch.Tensor:
431464
self.assertEqual(result.shape, x.shape)
432465
self.assertEqual(result.dtype, x.dtype)
433466

467+
@skipIfRefEager(
468+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
469+
)
434470
def test_associative_scan_torch_hops_mapping(self):
435471
"""Test that torch._higher_order_ops.associative_scan automatically maps to hl.associative_scan."""
436472

@@ -466,6 +502,9 @@ def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor:
466502
self.assertIn("tl.associative_scan", code)
467503
self.assertIn("param_0 + param_1", code)
468504

505+
@skipIfRefEager(
506+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
507+
)
469508
def test_associative_scan_code_generation(self):
470509
"""Test that the generated code structure is correct."""
471510

@@ -705,6 +744,9 @@ def cumulative_argmax_kernel(
705744
self.assertIn("def argmax_combine_fn_", code)
706745
self.assertIn("tl.associative_scan", code)
707746

747+
@skipIfRefEager(
748+
"torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
749+
)
708750
def test_associative_scan_in_helper_function(self):
709751
"""Test calling a function that internally uses hl.associative_scan."""
710752

@@ -766,6 +808,7 @@ def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor:
766808
self.assertIn("param_0 + param_1", code)
767809
self.assertIn("tl.associative_scan", code)
768810

811+
@skipIfRefEager("hl.cumsum is not supported by ref eager mode yet")
769812
def test_cumsum_reverse(self):
770813
"""Test cumsum with reverse=True."""
771814

@@ -847,6 +890,7 @@ def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor:
847890
self.assertIn("param_0 * param_1", code)
848891
self.assertIn("tl.associative_scan", code)
849892

893+
@skipIfRefEager("hl.cumprod is not supported by ref eager mode yet")
850894
def test_cumprod_reverse(self):
851895
"""Test cumprod with reverse=True."""
852896

@@ -870,6 +914,7 @@ def test_cumprod_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
870914
# Verify reverse parameter is used
871915
self.assertIn("reverse=True", code)
872916

917+
@skipIfRefEager("torch.cumprod is not supported by ref eager mode yet")
873918
def test_cumprod_different_dtypes(self):
874919
"""Test cumprod with different data types."""
875920

@@ -988,6 +1033,9 @@ def test_segmented_tuple_kernel(
9881033
self.assertIn("def helion_combine_tuple_fn_", code)
9891034
self.assertIn("tl.associative_scan", code)
9901035

1036+
@skipIfRefEager(
1037+
"torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
1038+
)
9911039
def test_associative_scan_argmax_tuple_format(self):
9921040
"""Test cumulative argmax using tuple format combine function."""
9931041

0 commit comments

Comments
 (0)