@@ -57,7 +57,7 @@ def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr):
5757@triton .jit
5858def matmul_kernel_tma (a_desc_ptr , b_desc_ptr , c_desc_ptr , #
5959 M , N , K , BLOCK_SIZE_M : tl .constexpr , BLOCK_SIZE_N : tl .constexpr , BLOCK_SIZE_K : tl .constexpr ,
60- BYVAL_TMA : tl .constexpr ):
60+ BYVAL_TMA : tl .constexpr , dtype : tl . constexpr ):
6161 if not BYVAL_TMA :
6262 tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (a_desc_ptr )
6363 tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (b_desc_ptr )
@@ -72,11 +72,11 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
7272 offs_k = 0
7373 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
7474 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
75- a = tl ._experimental_descriptor_load (a_desc_ptr , [offs_am , offs_k ], [BLOCK_SIZE_M , BLOCK_SIZE_K ], tl . float16 )
76- b = tl ._experimental_descriptor_load (b_desc_ptr , [offs_k , offs_bn ], [BLOCK_SIZE_K , BLOCK_SIZE_N ], tl . float16 )
75+ a = tl ._experimental_descriptor_load (a_desc_ptr , [offs_am , offs_k ], [BLOCK_SIZE_M , BLOCK_SIZE_K ], dtype )
76+ b = tl ._experimental_descriptor_load (b_desc_ptr , [offs_k , offs_bn ], [BLOCK_SIZE_K , BLOCK_SIZE_N ], dtype )
7777 accumulator = tl .dot (a , b , acc = accumulator )
7878 offs_k += BLOCK_SIZE_K
79- accumulator = accumulator .to (tl . float16 )
79+ accumulator = accumulator .to (dtype )
8080 tl ._experimental_descriptor_store (c_desc_ptr , accumulator , [offs_am , offs_bn ])
8181
8282
@@ -101,7 +101,7 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm
101101 desc_c = create_tma_desc_gmem_ptr (C .data_ptr (), [M , N ], [BLOCK_M , BLOCK_N ], C .element_size ())
102102 kernel = matmul_kernel_tma [(triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N ), 1 ,
103103 1 )](desc_a , desc_b , desc_c , M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , BYVAL_TMA = byval_tma ,
104- num_warps = 8 , num_stages = num_stages )
104+ num_warps = 8 , num_stages = num_stages , dtype = tl . float16 )
105105 ref_out = torch .matmul (A .to (torch .float32 ), B .to (torch .float32 )).to (torch .float16 )
106106 torch .testing .assert_close (ref_out , C , rtol = 1e-3 , atol = 1e-3 )
107107 if BLOCK_M >= 64 and BLOCK_N >= 64 :
0 commit comments