7272import triton .language as tl
7373import triton .tools .experimental_descriptor
7474import triton .profiler as proton
75- from triton .tools .experimental_descriptor import TmaDescKernelParam
75+ from triton .tools .experimental_descriptor import TensorDescriptor
7676from triton .tools .mxfp import MXFP4Tensor , MXScaleTensor
7777
7878
@@ -106,7 +106,7 @@ def _matmul_launch_metadata(grid, kernel, args):
106106@triton .jit (launch_metadata = _matmul_launch_metadata )
107107def block_scaled_matmul_kernel ( #
108108 a_desc , a_scale , #
109- b_desc_or_tensor , b_scale , #
109+ b_desc , b_scale , #
110110 c_desc , #
111111 M : tl .constexpr , N : tl .constexpr , K : tl .constexpr , #
112112 stride_sk : tl .constexpr , stride_sb : tl .constexpr , stride_sc : tl .constexpr , stride_sd : tl .constexpr ,
@@ -120,16 +120,6 @@ def block_scaled_matmul_kernel( #
120120 NUM_STAGES : tl .constexpr , #
121121 USE_2D_SCALE_LOAD : tl .constexpr ): #
122122
123- if ELEM_PER_BYTE_A == 1 :
124- dtype_a = tl .float8e4nv
125- elif ELEM_PER_BYTE_A == 2 :
126- dtype_a = tl .dtype ("uint8" )
127-
128- if ELEM_PER_BYTE_B == 1 :
129- dtype_b = tl .float8e4nv
130- elif ELEM_PER_BYTE_B == 2 :
131- dtype_b = tl .dtype ("uint8" )
132-
133123 if output_type == 0 :
134124 output_dtype = tl .float32
135125 elif output_type == 1 :
@@ -152,23 +142,6 @@ def block_scaled_matmul_kernel( #
152142
153143 MIXED_PREC : tl .constexpr = ELEM_PER_BYTE_A == 1 and ELEM_PER_BYTE_B == 2
154144
155- if MIXED_PREC :
156- b_desc = tl .make_tensor_descriptor (
157- b_desc_or_tensor ,
158- shape = [N , K // ELEM_PER_BYTE_B ],
159- strides = [K // ELEM_PER_BYTE_B , 1 ],
160- block_shape = [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ],
161- )
162- else :
163- b_desc = b_desc_or_tensor
164- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [b_desc ], dtype = tl .int32 ,
165- is_pure = False , pack = 1 )
166-
167- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [a_desc ], dtype = tl .int32 , is_pure = False ,
168- pack = 1 )
169- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [c_desc ], dtype = tl .int32 , is_pure = False ,
170- pack = 1 )
171-
172145 # For now it is recommended to use 2D scale loads for better performance.
173146 # In the future we will bring additional optimizations to either allow 5D loads,
174147 # the use of TMAs for scale factors, or both.
@@ -192,15 +165,8 @@ def block_scaled_matmul_kernel( #
192165
193166 accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
194167 for k in tl .range (0 , tl .cdiv (K , BLOCK_K ), num_stages = NUM_STAGES ):
195- a = tl ._experimental_descriptor_load (a_desc , [offs_am , offs_k_a ], [BLOCK_M , BLOCK_K // ELEM_PER_BYTE_A ],
196- dtype_a )
197-
198- if MIXED_PREC :
199- b = b_desc .load ([offs_bn , offs_k_b ])
200- else :
201- b = tl ._experimental_descriptor_load (b_desc , [offs_bn , offs_k_b ], [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ],
202- dtype_b )
203-
168+ a = a_desc .load ([offs_am , offs_k_a ])
169+ b = b_desc .load ([offs_bn , offs_k_b ])
204170 scale_a = tl .load (a_scale_ptr )
205171 scale_b = tl .load (b_scale_ptr )
206172 if USE_2D_SCALE_LOAD :
@@ -221,10 +187,10 @@ def block_scaled_matmul_kernel( #
221187 a_scale_ptr += (BLOCK_K // VEC_SIZE // 4 ) * stride_sb
222188 b_scale_ptr += (BLOCK_K // VEC_SIZE // 4 ) * stride_sb
223189
224- tl . _experimental_descriptor_store ( c_desc , accumulator .to (output_dtype ), [ offs_am , offs_bn ] )
190+ c_desc . store ([ offs_am , offs_bn ], accumulator .to (output_dtype ))
225191
226192
227- def block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , dtype_dst , M , N , K , configs ):
193+ def block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , dtype_dst , M , N , K , configs ):
228194 output = torch .empty ((M , N ), dtype = dtype_dst , device = "cuda" )
229195 if dtype_dst == torch .float32 :
230196 dtype_dst = 0
@@ -235,11 +201,12 @@ def block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, dtype_dst, M
235201 else :
236202 raise ValueError (f"Unsupported dtype: { dtype_dst } " )
237203
238- c_desc = TmaDescKernelParam (output .data_ptr (), output .shape , [configs ["BLOCK_SIZE_M" ], configs ["BLOCK_SIZE_N" ]],
239- output .element_size ())
204+ BLOCK_M = configs ["BLOCK_SIZE_M" ]
205+ BLOCK_N = configs ["BLOCK_SIZE_N" ]
206+ c_desc = TensorDescriptor .from_tensor (output , [BLOCK_M , BLOCK_N ])
240207
241- grid = (triton .cdiv (M , configs [ "BLOCK_SIZE_M" ] ) * triton .cdiv (N , configs [ "BLOCK_SIZE_N" ] ), 1 )
242- block_scaled_matmul_kernel [grid ](a_desc , a_scale , b_desc_or_tensor , b_scale , c_desc , M , N , K , a_scale .stride (0 ),
208+ grid = (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N ), 1 )
209+ block_scaled_matmul_kernel [grid ](a_desc , a_scale , b_desc , b_scale , c_desc , M , N , K , a_scale .stride (0 ),
243210 a_scale .stride (1 ), a_scale .stride (2 ), a_scale .stride (3 ), dtype_dst ,
244211 configs ["ELEM_PER_BYTE_A" ], configs ["ELEM_PER_BYTE_B" ], configs ["VEC_SIZE" ],
245212 configs ["BLOCK_SIZE_M" ], configs ["BLOCK_SIZE_N" ], configs ["BLOCK_SIZE_K" ],
@@ -284,12 +251,17 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
284251
285252 b_ref = b_ref .to (torch .float32 ).T
286253
287- a_desc = TmaDescKernelParam ( a . data_ptr (), a . shape , [BLOCK_M , BLOCK_K // ELEM_PER_BYTE_A ], 1 )
254+ a_desc = TensorDescriptor . from_tensor ( a , [BLOCK_M , BLOCK_K // ELEM_PER_BYTE_A ])
288255
289256 if block_scale_type == "mixed" :
290- b_desc_or_tensor = b
257+ b_desc = TensorDescriptor (
258+ b ,
259+ shape = [N , K // ELEM_PER_BYTE_B ],
260+ strides = [K // ELEM_PER_BYTE_B , 1 ],
261+ block_shape = [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ],
262+ )
291263 else :
292- b_desc_or_tensor = TmaDescKernelParam ( b . data_ptr (), b . shape , [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ], 1 )
264+ b_desc = TensorDescriptor . from_tensor ( b , [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ])
293265
294266 epsilon = 1e-8
295267 a_scale = torch .rand ((M // 128 , K // VEC_SIZE // 4 , 32 , 4 , 4 ), device = device ) + epsilon
@@ -327,7 +299,7 @@ def unpack_scale(packed):
327299 "ELEM_PER_BYTE_B" : ELEM_PER_BYTE_B ,
328300 "VEC_SIZE" : VEC_SIZE ,
329301 }
330- return a_desc , a_scale , b_desc_or_tensor , b_scale , configs , reference
302+ return a_desc , a_scale , b_desc , b_scale , configs , reference
331303
332304
333305def validate_block_scaled (M , N , K , block_scale_type = "nvfp4" ):
@@ -340,9 +312,9 @@ def alloc_fn(size: int, align: int, _):
340312 # TMA load for mixed-precision fp4 is supported only by device TMA.
341313 triton .set_allocator (alloc_fn )
342314
343- a_desc , a_scale , b_desc_or_tensor , b_scale , configs , reference = initialize_block_scaled (
344- M , N , K , block_scale_type , compute_reference = True )
345- output = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
315+ a_desc , a_scale , b_desc , b_scale , configs , reference = initialize_block_scaled (M , N , K , block_scale_type ,
316+ compute_reference = True )
317+ output = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
346318 torch .testing .assert_close (reference , output .to (torch .float32 ), atol = 1e-3 , rtol = 1e-3 )
347319 print (f"✅ (pass { block_scale_type } )" )
348320
@@ -353,19 +325,13 @@ def bench_block_scaled(K, block_scale_type="nvfp4", reps=10):
353325 N = 8192
354326 print (f"Problem Shape = { M } x{ N } x{ K } " )
355327
356- def alloc_fn (size : int , align : int , _ ):
357- return torch .empty (size , dtype = torch .int8 , device = "cuda" )
358-
359- if block_scale_type == "mixed" :
360- triton .set_allocator (alloc_fn )
361-
362- a_desc , a_scale , b_desc_or_tensor , b_scale , configs , _ = initialize_block_scaled (
363- M , N , K , block_scale_type , compute_reference = False )
364- _ = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
328+ a_desc , a_scale , b_desc , b_scale , configs , _ = initialize_block_scaled (M , N , K , block_scale_type ,
329+ compute_reference = False )
330+ _ = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
365331
366332 proton .activate (0 )
367333 for _ in range (reps ):
368- _ = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
334+ _ = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
369335 proton .deactivate (0 )
370336 print ("Done benchmarking" )
371337
0 commit comments