@@ -99,6 +99,9 @@ def jit_add_combine_fn(x, y):
99
99
100
100
101
101
class TestAssociativeScan (RefEagerTestBase , TestCase ):
102
+ @skipIfRefEager (
103
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
104
+ )
102
105
def test_associative_scan_basic_addition (self ):
103
106
"""Test basic associative_scan functionality with prefix sum."""
104
107
@@ -132,6 +135,9 @@ def test_scan_kernel(x: torch.Tensor) -> torch.Tensor:
132
135
self .assertIn ("param_0 + param_1" , code )
133
136
self .assertIn ("tl.associative_scan" , code )
134
137
138
+ @skipIfRefEager (
139
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
140
+ )
135
141
def test_associative_scan_maximum (self ):
136
142
"""Test associative_scan with maximum combine function."""
137
143
@@ -164,6 +170,9 @@ def test_max_kernel(x: torch.Tensor) -> torch.Tensor:
164
170
"tl.maximum" in code or "triton_helpers.maximum" in code
165
171
)
166
172
173
+ @skipIfRefEager (
174
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
175
+ )
167
176
def test_associative_scan_multiplication (self ):
168
177
"""Test associative_scan with multiplication combine function."""
169
178
@@ -194,6 +203,9 @@ def test_mul_kernel(x: torch.Tensor) -> torch.Tensor:
194
203
# Verify the generated code contains multiplication
195
204
self .assertIn ("param_0 * param_1" , code )
196
205
206
+ @skipIfRefEager (
207
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
208
+ )
197
209
def test_associative_scan_minimum (self ):
198
210
"""Test associative_scan with minimum combine function."""
199
211
@@ -226,6 +238,9 @@ def test_min_kernel(x: torch.Tensor) -> torch.Tensor:
226
238
"tl.minimum" in code or "triton_helpers.minimum" in code
227
239
)
228
240
241
+ @skipIfRefEager (
242
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
243
+ )
229
244
def test_associative_scan_multiple_functions (self ):
230
245
"""Test using multiple different combine functions in one kernel."""
231
246
@@ -262,6 +277,9 @@ def test_multi_kernel(x: torch.Tensor) -> torch.Tensor:
262
277
"tl.maximum" in code or "triton_helpers.maximum" in code
263
278
)
264
279
280
+ @skipIfRefEager (
281
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
282
+ )
265
283
def test_associative_scan_type_propagation (self ):
266
284
"""Test that associative_scan type propagation works correctly."""
267
285
@@ -286,6 +304,9 @@ def test_type_kernel(x: torch.Tensor) -> torch.Tensor:
286
304
# Use relaxed tolerance for large tensors due to accumulated floating-point errors
287
305
torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
288
306
307
+ @skipIfRefEager (
308
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
309
+ )
289
310
def test_associative_scan_different_dtypes (self ):
290
311
"""Test associative_scan with different data types."""
291
312
@@ -320,6 +341,9 @@ def test_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
320
341
expected = expected .to (result .dtype )
321
342
torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
322
343
344
+ @skipIfRefEager (
345
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
346
+ )
323
347
def test_associative_scan_different_sizes (self ):
324
348
"""Test associative_scan with different tensor sizes."""
325
349
@@ -356,6 +380,9 @@ def test_size_kernel(x: torch.Tensor) -> torch.Tensor:
356
380
expected = torch .cumsum (x , dim = 1 )
357
381
torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
358
382
383
+ @skipIfRefEager (
384
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
385
+ )
359
386
def test_associative_scan_reverse (self ):
360
387
"""Test associative_scan with reverse=True parameter."""
361
388
@@ -381,6 +408,9 @@ def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
381
408
# Verify reverse parameter is in generated code
382
409
self .assertIn ("reverse=True" , code )
383
410
411
+ @skipIfRefEager (
412
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
413
+ )
384
414
def test_associative_scan_edge_cases (self ):
385
415
"""Test associative_scan edge cases."""
386
416
@@ -406,6 +436,9 @@ def test_single_element(x: torch.Tensor) -> torch.Tensor:
406
436
expected = torch .tensor ([[3.0 , 10.0 ]], device = DEVICE )
407
437
torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
408
438
439
+ @skipIfRefEager (
440
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
441
+ )
409
442
def test_associative_scan_large_scale (self ):
410
443
"""Test associative_scan with large tensors for performance validation."""
411
444
@@ -431,6 +464,9 @@ def test_large_kernel(x: torch.Tensor) -> torch.Tensor:
431
464
self .assertEqual (result .shape , x .shape )
432
465
self .assertEqual (result .dtype , x .dtype )
433
466
467
+ @skipIfRefEager (
468
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
469
+ )
434
470
def test_associative_scan_torch_hops_mapping (self ):
435
471
"""Test that torch._higher_order_ops.associative_scan automatically maps to hl.associative_scan."""
436
472
@@ -466,6 +502,9 @@ def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor:
466
502
self .assertIn ("tl.associative_scan" , code )
467
503
self .assertIn ("param_0 + param_1" , code )
468
504
505
+ @skipIfRefEager (
506
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
507
+ )
469
508
def test_associative_scan_code_generation (self ):
470
509
"""Test that the generated code structure is correct."""
471
510
@@ -705,6 +744,9 @@ def cumulative_argmax_kernel(
705
744
self .assertIn ("def argmax_combine_fn_" , code )
706
745
self .assertIn ("tl.associative_scan" , code )
707
746
747
+ @skipIfRefEager (
748
+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
749
+ )
708
750
def test_associative_scan_in_helper_function (self ):
709
751
"""Test calling a function that internally uses hl.associative_scan."""
710
752
@@ -766,6 +808,7 @@ def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor:
766
808
self .assertIn ("param_0 + param_1" , code )
767
809
self .assertIn ("tl.associative_scan" , code )
768
810
811
+ @skipIfRefEager ("hl.cumsum is not supported by ref eager mode yet" )
769
812
def test_cumsum_reverse (self ):
770
813
"""Test cumsum with reverse=True."""
771
814
@@ -847,6 +890,7 @@ def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor:
847
890
self .assertIn ("param_0 * param_1" , code )
848
891
self .assertIn ("tl.associative_scan" , code )
849
892
893
+ @skipIfRefEager ("hl.cumprod is not supported by ref eager mode yet" )
850
894
def test_cumprod_reverse (self ):
851
895
"""Test cumprod with reverse=True."""
852
896
@@ -870,6 +914,7 @@ def test_cumprod_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
870
914
# Verify reverse parameter is used
871
915
self .assertIn ("reverse=True" , code )
872
916
917
+ @skipIfRefEager ("torch.cumprod is not supported by ref eager mode yet" )
873
918
def test_cumprod_different_dtypes (self ):
874
919
"""Test cumprod with different data types."""
875
920
@@ -988,6 +1033,9 @@ def test_segmented_tuple_kernel(
988
1033
self .assertIn ("def helion_combine_tuple_fn_" , code )
989
1034
self .assertIn ("tl.associative_scan" , code )
990
1035
1036
+ @skipIfRefEager (
1037
+ "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
1038
+ )
991
1039
def test_associative_scan_argmax_tuple_format (self ):
992
1040
"""Test cumulative argmax using tuple format combine function."""
993
1041
0 commit comments