@@ -99,9 +99,6 @@ def jit_add_combine_fn(x, y):
99
99
100
100
101
101
class TestAssociativeScan (RefEagerTestBase , TestCase ):
102
- @skipIfRefEager (
103
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
104
- )
105
102
def test_associative_scan_basic_addition (self ):
106
103
"""Test basic associative_scan functionality with prefix sum."""
107
104
@@ -135,9 +132,6 @@ def test_scan_kernel(x: torch.Tensor) -> torch.Tensor:
135
132
self .assertIn ("param_0 + param_1" , code )
136
133
self .assertIn ("tl.associative_scan" , code )
137
134
138
- @skipIfRefEager (
139
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
140
- )
141
135
def test_associative_scan_maximum (self ):
142
136
"""Test associative_scan with maximum combine function."""
143
137
@@ -170,9 +164,6 @@ def test_max_kernel(x: torch.Tensor) -> torch.Tensor:
170
164
"tl.maximum" in code or "triton_helpers.maximum" in code
171
165
)
172
166
173
- @skipIfRefEager (
174
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
175
- )
176
167
def test_associative_scan_multiplication (self ):
177
168
"""Test associative_scan with multiplication combine function."""
178
169
@@ -203,9 +194,6 @@ def test_mul_kernel(x: torch.Tensor) -> torch.Tensor:
203
194
# Verify the generated code contains multiplication
204
195
self .assertIn ("param_0 * param_1" , code )
205
196
206
- @skipIfRefEager (
207
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
208
- )
209
197
def test_associative_scan_minimum (self ):
210
198
"""Test associative_scan with minimum combine function."""
211
199
@@ -238,9 +226,6 @@ def test_min_kernel(x: torch.Tensor) -> torch.Tensor:
238
226
"tl.minimum" in code or "triton_helpers.minimum" in code
239
227
)
240
228
241
- @skipIfRefEager (
242
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
243
- )
244
229
def test_associative_scan_multiple_functions (self ):
245
230
"""Test using multiple different combine functions in one kernel."""
246
231
@@ -277,9 +262,6 @@ def test_multi_kernel(x: torch.Tensor) -> torch.Tensor:
277
262
"tl.maximum" in code or "triton_helpers.maximum" in code
278
263
)
279
264
280
- @skipIfRefEager (
281
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
282
- )
283
265
def test_associative_scan_type_propagation (self ):
284
266
"""Test that associative_scan type propagation works correctly."""
285
267
@@ -304,9 +286,6 @@ def test_type_kernel(x: torch.Tensor) -> torch.Tensor:
304
286
# Use relaxed tolerance for large tensors due to accumulated floating-point errors
305
287
torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
306
288
307
- @skipIfRefEager (
308
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
309
- )
310
289
def test_associative_scan_different_dtypes (self ):
311
290
"""Test associative_scan with different data types."""
312
291
@@ -341,9 +320,6 @@ def test_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
341
320
expected = expected .to (result .dtype )
342
321
torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
343
322
344
- @skipIfRefEager (
345
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
346
- )
347
323
def test_associative_scan_different_sizes (self ):
348
324
"""Test associative_scan with different tensor sizes."""
349
325
@@ -380,9 +356,6 @@ def test_size_kernel(x: torch.Tensor) -> torch.Tensor:
380
356
expected = torch .cumsum (x , dim = 1 )
381
357
torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
382
358
383
- @skipIfRefEager (
384
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
385
- )
386
359
def test_associative_scan_reverse (self ):
387
360
"""Test associative_scan with reverse=True parameter."""
388
361
@@ -408,9 +381,6 @@ def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
408
381
# Verify reverse parameter is in generated code
409
382
self .assertIn ("reverse=True" , code )
410
383
411
- @skipIfRefEager (
412
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
413
- )
414
384
def test_associative_scan_edge_cases (self ):
415
385
"""Test associative_scan edge cases."""
416
386
@@ -436,9 +406,6 @@ def test_single_element(x: torch.Tensor) -> torch.Tensor:
436
406
expected = torch .tensor ([[3.0 , 10.0 ]], device = DEVICE )
437
407
torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
438
408
439
- @skipIfRefEager (
440
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
441
- )
442
409
def test_associative_scan_large_scale (self ):
443
410
"""Test associative_scan with large tensors for performance validation."""
444
411
@@ -464,9 +431,6 @@ def test_large_kernel(x: torch.Tensor) -> torch.Tensor:
464
431
self .assertEqual (result .shape , x .shape )
465
432
self .assertEqual (result .dtype , x .dtype )
466
433
467
- @skipIfRefEager (
468
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
469
- )
470
434
def test_associative_scan_torch_hops_mapping (self ):
471
435
"""Test that torch._higher_order_ops.associative_scan automatically maps to hl.associative_scan."""
472
436
@@ -502,9 +466,6 @@ def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor:
502
466
self .assertIn ("tl.associative_scan" , code )
503
467
self .assertIn ("param_0 + param_1" , code )
504
468
505
- @skipIfRefEager (
506
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
507
- )
508
469
def test_associative_scan_code_generation (self ):
509
470
"""Test that the generated code structure is correct."""
510
471
@@ -744,9 +705,6 @@ def cumulative_argmax_kernel(
744
705
self .assertIn ("def argmax_combine_fn_" , code )
745
706
self .assertIn ("tl.associative_scan" , code )
746
707
747
- @skipIfRefEager (
748
- "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
749
- )
750
708
def test_associative_scan_in_helper_function (self ):
751
709
"""Test calling a function that internally uses hl.associative_scan."""
752
710
@@ -808,7 +766,6 @@ def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor:
808
766
self .assertIn ("param_0 + param_1" , code )
809
767
self .assertIn ("tl.associative_scan" , code )
810
768
811
- @skipIfRefEager ("hl.cumsum is not supported by ref eager mode yet" )
812
769
def test_cumsum_reverse (self ):
813
770
"""Test cumsum with reverse=True."""
814
771
@@ -890,7 +847,6 @@ def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor:
890
847
self .assertIn ("param_0 * param_1" , code )
891
848
self .assertIn ("tl.associative_scan" , code )
892
849
893
- @skipIfRefEager ("hl.cumprod is not supported by ref eager mode yet" )
894
850
def test_cumprod_reverse (self ):
895
851
"""Test cumprod with reverse=True."""
896
852
@@ -914,7 +870,6 @@ def test_cumprod_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
914
870
# Verify reverse parameter is used
915
871
self .assertIn ("reverse=True" , code )
916
872
917
- @skipIfRefEager ("torch.cumprod is not supported by ref eager mode yet" )
918
873
def test_cumprod_different_dtypes (self ):
919
874
"""Test cumprod with different data types."""
920
875
@@ -1033,9 +988,6 @@ def test_segmented_tuple_kernel(
1033
988
self .assertIn ("def helion_combine_tuple_fn_" , code )
1034
989
self .assertIn ("tl.associative_scan" , code )
1035
990
1036
- @skipIfRefEager (
1037
- "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
1038
- )
1039
991
def test_associative_scan_argmax_tuple_format (self ):
1040
992
"""Test cumulative argmax using tuple format combine function."""
1041
993
0 commit comments