@@ -484,9 +484,6 @@ def kernel(
484
484
torch .testing .assert_close (src_result , expected_src )
485
485
torch .testing .assert_close (dst_result , expected_dst )
486
486
487
- @skipIfNormalMode (
488
- "AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
489
- )
490
487
def test_2d_full_slice (self ):
491
488
"""Test both setter from scalar and getter for [:,:]"""
492
489
@@ -537,33 +534,79 @@ def kernel(
537
534
torch .testing .assert_close (src_result , expected_src )
538
535
torch .testing .assert_close (dst_result , expected_dst )
539
536
540
- @skipIfNormalMode (
541
- "AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
542
- )
543
537
def test_1d_full_slice (self ):
544
- """Test both setter from scalar and getter for [:]"""
538
+ """Test both setter from scalar and getter for [:] with multiple scalar types """
545
539
546
- @helion .kernel (use_default_config = True )
540
+ @helion .kernel (config = { "block_size" : 128 } )
547
541
def kernel (
548
- src : torch .Tensor , dst : torch .Tensor
549
- ) -> tuple [torch .Tensor , torch .Tensor ]:
550
- N = src .shape [0 ]
551
- for _ in hl .grid (N ):
552
- dst [:] = 1.0 # Test setter with scalar
553
- src [:] = dst [:] # Test getter from dst and setter to src
554
- return src , dst
542
+ src_float : torch .Tensor ,
543
+ dst_float : torch .Tensor ,
544
+ src_int : torch .Tensor ,
545
+ dst_int : torch .Tensor ,
546
+ src_symint : torch .Tensor ,
547
+ dst_symint : torch .Tensor ,
548
+ ) -> tuple [
549
+ torch .Tensor ,
550
+ torch .Tensor ,
551
+ torch .Tensor ,
552
+ torch .Tensor ,
553
+ torch .Tensor ,
554
+ torch .Tensor ,
555
+ ]:
556
+ N = src_float .shape [0 ]
557
+ for tile in hl .tile (N ):
558
+ # Test float scalar
559
+ dst_float [:] = 1.0
560
+ src_float [:] = dst_float [:]
561
+
562
+ # Test int scalar
563
+ dst_int [:] = 99
564
+ src_int [:] = dst_int [:]
565
+
566
+ # Test SymInt scalar
567
+ dst_symint [:] = tile .block_size
568
+ src_symint [:] = dst_symint [:]
569
+
570
+ return (
571
+ src_float ,
572
+ dst_float ,
573
+ src_int ,
574
+ dst_int ,
575
+ src_symint ,
576
+ dst_symint ,
577
+ )
555
578
556
579
N = 128
557
- src = torch .zeros ([N ], device = DEVICE )
558
- dst = torch .zeros ([N ], device = DEVICE )
580
+ src_float = torch .zeros ([N ], device = DEVICE )
581
+ dst_float = torch .zeros ([N ], device = DEVICE )
582
+ src_int = torch .zeros ([N ], device = DEVICE )
583
+ dst_int = torch .zeros ([N ], device = DEVICE )
584
+ src_symint = torch .zeros ([N ], device = DEVICE )
585
+ dst_symint = torch .zeros ([N ], device = DEVICE )
586
+
587
+ results = kernel (
588
+ src_float ,
589
+ dst_float ,
590
+ src_int ,
591
+ dst_int ,
592
+ src_symint ,
593
+ dst_symint ,
594
+ )
559
595
560
- src_result , dst_result = kernel (src , dst )
596
+ # Check float results
597
+ expected_float = torch .ones ([N ], device = DEVICE )
598
+ torch .testing .assert_close (results [0 ], expected_float )
599
+ torch .testing .assert_close (results [1 ], expected_float )
561
600
562
- # Both should be ones after the kernel
563
- expected_src = torch .ones ([N ], device = DEVICE )
564
- expected_dst = torch .ones ([N ], device = DEVICE )
565
- torch .testing .assert_close (src_result , expected_src )
566
- torch .testing .assert_close (dst_result , expected_dst )
601
+ # Check int results
602
+ expected_int = torch .full ([N ], 99.0 , device = DEVICE )
603
+ torch .testing .assert_close (results [2 ], expected_int )
604
+ torch .testing .assert_close (results [3 ], expected_int )
605
+
606
+ # Check SymInt results
607
+ expected_symint = torch .full ([N ], 128.0 , device = DEVICE )
608
+ torch .testing .assert_close (results [4 ], expected_symint )
609
+ torch .testing .assert_close (results [5 ], expected_symint )
567
610
568
611
@skipIfNormalMode (
569
612
"RankMismatch: Expected ndim=1, but got ndim=0 - LHS/RHS shape mismatch in type_propagation.py"
@@ -624,9 +667,6 @@ def kernel(buf: torch.Tensor, zeros: torch.Tensor) -> torch.Tensor:
624
667
expected = torch .zeros ([N ], device = DEVICE )
625
668
torch .testing .assert_close (result , expected )
626
669
627
- @skipIfNormalMode (
628
- "AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
629
- )
630
670
def test_mixed_slice_index (self ):
631
671
"""Test both setter from scalar and getter for [i,:]"""
632
672
0 commit comments