@@ -1494,8 +1494,6 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK
1494
1494
if not is_native :
1495
1495
if num_ctas != 1 :
1496
1496
pytest .skip ("Multi-CTA not supported" )
1497
- if descriptor == "host" :
1498
- pytest .skip ("NYI: Host side tensor descriptor fallback" )
1499
1497
if is_hip_cdna3 () and (kind , dtype_str , M_BLOCK , N_BLOCK ) in REDUCE_SKIP_HIP_CDNA3 :
1500
1498
pytest .skip ("Broken on rocm" )
1501
1499
@@ -1573,3 +1571,105 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
1573
1571
expect = REDUCE_OP [kind ](inp , out )
1574
1572
kernel [(grid_m , grid_n )](out_desc , out , inp , M , N , M_BLOCK , N_BLOCK , kind , num_ctas = num_ctas )
1575
1573
torch .testing .assert_close (expect , unwrap_tensor (out ), check_dtype = False )
1574
+
1575
+
1576
+ @pytest .mark .interpreter ()
1577
+ @pytest .mark .parametrize ("dtype_str" , tma_dtypes )
1578
+ @pytest .mark .parametrize ("num_ctas" , [1 , 2 ])
1579
+ @pytest .mark .parametrize ("M_BLOCK,N_BLOCK" , [(2 , 16 ), (8 , 16 ), (8 , 32 ), (8 , 128 )])
1580
+ def test_host_tensor_descriptor_load (dtype_str , num_ctas , M_BLOCK , N_BLOCK , device ):
1581
+ if num_ctas == 2 and (not is_cuda () or torch .cuda .get_device_capability (0 )[0 ] not in (9 , 10 )):
1582
+ pytest .skip ("CTAs is unsupported for these cards" )
1583
+
1584
+ @triton .jit (debug = True )
1585
+ def kernel (out_ptr , desc , M , N , M_BLOCK : tl .constexpr , N_BLOCK : tl .constexpr ):
1586
+ assert desc .shape [0 ] == M
1587
+ assert desc .shape [1 ] == N
1588
+ assert desc .strides [0 ] == N
1589
+ assert desc .strides [1 ] == 1
1590
+ assert desc .block_shape == [M_BLOCK , N_BLOCK ]
1591
+ block = desc .load ([M_BLOCK , 2 * N_BLOCK ])
1592
+ idx = tl .arange (0 , M_BLOCK )[:, None ] * N_BLOCK + tl .arange (0 , N_BLOCK )[None , :]
1593
+ tl .store (out_ptr + idx , block )
1594
+
1595
+ M , N = M_BLOCK * 3 , N_BLOCK * 4
1596
+ inp = to_triton (numpy_random ((M , N ), dtype_str ), device = device , dst_type = dtype_str )
1597
+ out = inp .new_empty ((M_BLOCK , N_BLOCK ))
1598
+
1599
+ inp_desc = TensorDescriptor (inp , shape = inp .shape , strides = inp .stride (), block_shape = [M_BLOCK , N_BLOCK ])
1600
+ kernel [(1 , )](out , inp_desc , M , N , M_BLOCK , N_BLOCK , num_ctas = num_ctas )
1601
+
1602
+ expect = unwrap_tensor (inp )[1 * M_BLOCK :2 * M_BLOCK , 2 * N_BLOCK :3 * N_BLOCK ]
1603
+ torch .testing .assert_close (expect , unwrap_tensor (out ))
1604
+
1605
+
1606
+ @triton .jit
1607
+ def matmul_kernel_host_tensor_descriptor (a_desc , b_desc , c_desc ):
1608
+ K = a_desc .shape [1 ]
1609
+ BLOCK_M : tl .constexpr = a_desc .block_shape [0 ]
1610
+ BLOCK_K : tl .constexpr = a_desc .block_shape [1 ]
1611
+ BLOCK_N : tl .constexpr = b_desc .block_shape [1 ]
1612
+
1613
+ pid_m = tl .program_id (axis = 0 )
1614
+ pid_n = tl .program_id (axis = 1 )
1615
+ offs_am = pid_m * BLOCK_M
1616
+ offs_bn = pid_n * BLOCK_N
1617
+ offs_k = 0
1618
+
1619
+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
1620
+ for k in range (0 , tl .cdiv (K , BLOCK_K )):
1621
+ a = a_desc .load ([offs_am , offs_k ])
1622
+ b = b_desc .load ([offs_k , offs_bn ])
1623
+ accumulator = tl .dot (a , b , acc = accumulator )
1624
+ offs_k += BLOCK_K
1625
+ accumulator = accumulator .to (a_desc .dtype )
1626
+ c_desc .store ([offs_am , offs_bn ], accumulator )
1627
+
1628
+
1629
+ @pytest .mark .interpreter ()
1630
+ @pytest .mark .parametrize ("num_ctas" , [1 , 2 ])
1631
+ @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K, num_stages" , [
1632
+ (128 , 128 , 16 , 1 ),
1633
+ (256 , 64 , 32 , 2 ),
1634
+ (64 , 512 , 32 , 2 ),
1635
+ (128 , 128 , 16 , 4 ),
1636
+ (64 , 128 , 32 , 4 ),
1637
+ (32 , 32 , 32 , 4 ),
1638
+ (256 , 128 , 32 , 4 ),
1639
+ ])
1640
+ def test_host_tensor_descriptor_matmul (num_stages , num_ctas , BLOCK_M , BLOCK_N , BLOCK_K , device ):
1641
+ if num_ctas == 2 and (not is_cuda () or torch .cuda .get_device_capability (0 )[0 ] not in (9 , 10 )):
1642
+ pytest .skip ("CTAs is unsupported for these cards" )
1643
+
1644
+ if is_hip () and (BLOCK_M , BLOCK_N , BLOCK_K , num_stages ) == (256 , 128 , 32 , 4 ):
1645
+ pytest .skip ("Insufficient shared memory on HIP devices" )
1646
+
1647
+ if is_interpreter ():
1648
+ M , N , K = BLOCK_M , BLOCK_N , BLOCK_K
1649
+ else :
1650
+ M , N , K = 1024 , 512 , 256
1651
+ torch .manual_seed (42 )
1652
+ A = torch .randn ((M , K ), dtype = torch .float16 , device = device )
1653
+ B = torch .randn ((K , N ), dtype = torch .float16 , device = device )
1654
+ C = torch .empty ((M , N ), dtype = torch .float16 , device = device )
1655
+ grid = (triton .cdiv (M , BLOCK_M ), triton .cdiv (N , BLOCK_N ), 1 )
1656
+
1657
+ A_desc = TensorDescriptor (A , A .shape , A .stride (), [BLOCK_M , BLOCK_K ])
1658
+ B_desc = TensorDescriptor (B , B .shape , B .stride (), [BLOCK_K , BLOCK_N ])
1659
+ C_desc = TensorDescriptor (C , C .shape , C .stride (), [BLOCK_M , BLOCK_N ])
1660
+
1661
+ kernel = matmul_kernel_host_tensor_descriptor [grid ](
1662
+ A_desc ,
1663
+ B_desc ,
1664
+ C_desc , #
1665
+ num_warps = 8 ,
1666
+ num_stages = num_stages ,
1667
+ num_ctas = num_ctas ,
1668
+ )
1669
+ ref_out = torch .matmul (A .to (torch .float32 ), B .to (torch .float32 )).to (torch .float16 )
1670
+ torch .testing .assert_close (ref_out , C , rtol = 1e-3 , atol = 1e-3 )
1671
+
1672
+ if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_cuda () and torch .cuda .get_device_capability ()[0 ] == 9 :
1673
+ # TODO: The use of stmatrix for Blackwell is currently not supported.
1674
+ # Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
1675
+ assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel .asm ["ptx" ]
0 commit comments