72
72
import triton .language as tl
73
73
import triton .tools .experimental_descriptor
74
74
import triton .profiler as proton
75
- from triton .tools .experimental_descriptor import TmaDescKernelParam
75
+ from triton .tools .experimental_descriptor import TensorDescriptor
76
76
from triton .tools .mxfp import MXFP4Tensor , MXScaleTensor
77
77
78
78
@@ -106,7 +106,7 @@ def _matmul_launch_metadata(grid, kernel, args):
106
106
@triton .jit (launch_metadata = _matmul_launch_metadata )
107
107
def block_scaled_matmul_kernel ( #
108
108
a_desc , a_scale , #
109
- b_desc_or_tensor , b_scale , #
109
+ b_desc , b_scale , #
110
110
c_desc , #
111
111
M : tl .constexpr , N : tl .constexpr , K : tl .constexpr , #
112
112
stride_sk : tl .constexpr , stride_sb : tl .constexpr , stride_sc : tl .constexpr , stride_sd : tl .constexpr ,
@@ -120,16 +120,6 @@ def block_scaled_matmul_kernel( #
120
120
NUM_STAGES : tl .constexpr , #
121
121
USE_2D_SCALE_LOAD : tl .constexpr ): #
122
122
123
- if ELEM_PER_BYTE_A == 1 :
124
- dtype_a = tl .float8e4nv
125
- elif ELEM_PER_BYTE_A == 2 :
126
- dtype_a = tl .dtype ("uint8" )
127
-
128
- if ELEM_PER_BYTE_B == 1 :
129
- dtype_b = tl .float8e4nv
130
- elif ELEM_PER_BYTE_B == 2 :
131
- dtype_b = tl .dtype ("uint8" )
132
-
133
123
if output_type == 0 :
134
124
output_dtype = tl .float32
135
125
elif output_type == 1 :
@@ -152,23 +142,6 @@ def block_scaled_matmul_kernel( #
152
142
153
143
MIXED_PREC : tl .constexpr = ELEM_PER_BYTE_A == 1 and ELEM_PER_BYTE_B == 2
154
144
155
- if MIXED_PREC :
156
- b_desc = tl .make_tensor_descriptor (
157
- b_desc_or_tensor ,
158
- shape = [N , K // ELEM_PER_BYTE_B ],
159
- strides = [K // ELEM_PER_BYTE_B , 1 ],
160
- block_shape = [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ],
161
- )
162
- else :
163
- b_desc = b_desc_or_tensor
164
- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [b_desc ], dtype = tl .int32 ,
165
- is_pure = False , pack = 1 )
166
-
167
- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [a_desc ], dtype = tl .int32 , is_pure = False ,
168
- pack = 1 )
169
- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [c_desc ], dtype = tl .int32 , is_pure = False ,
170
- pack = 1 )
171
-
172
145
# For now it is recommended to use 2D scale loads for better performance.
173
146
# In the future we will bring additional optimizations to either allow 5D loads,
174
147
# the use of TMAs for scale factors, or both.
@@ -192,15 +165,8 @@ def block_scaled_matmul_kernel( #
192
165
193
166
accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
194
167
for k in tl .range (0 , tl .cdiv (K , BLOCK_K ), num_stages = NUM_STAGES ):
195
- a = tl ._experimental_descriptor_load (a_desc , [offs_am , offs_k_a ], [BLOCK_M , BLOCK_K // ELEM_PER_BYTE_A ],
196
- dtype_a )
197
-
198
- if MIXED_PREC :
199
- b = b_desc .load ([offs_bn , offs_k_b ])
200
- else :
201
- b = tl ._experimental_descriptor_load (b_desc , [offs_bn , offs_k_b ], [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ],
202
- dtype_b )
203
-
168
+ a = a_desc .load ([offs_am , offs_k_a ])
169
+ b = b_desc .load ([offs_bn , offs_k_b ])
204
170
scale_a = tl .load (a_scale_ptr )
205
171
scale_b = tl .load (b_scale_ptr )
206
172
if USE_2D_SCALE_LOAD :
@@ -221,10 +187,10 @@ def block_scaled_matmul_kernel( #
221
187
a_scale_ptr += (BLOCK_K // VEC_SIZE // 4 ) * stride_sb
222
188
b_scale_ptr += (BLOCK_K // VEC_SIZE // 4 ) * stride_sb
223
189
224
- tl . _experimental_descriptor_store ( c_desc , accumulator .to (output_dtype ), [ offs_am , offs_bn ] )
190
+ c_desc . store ([ offs_am , offs_bn ], accumulator .to (output_dtype ))
225
191
226
192
227
- def block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , dtype_dst , M , N , K , configs ):
193
+ def block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , dtype_dst , M , N , K , configs ):
228
194
output = torch .empty ((M , N ), dtype = dtype_dst , device = "cuda" )
229
195
if dtype_dst == torch .float32 :
230
196
dtype_dst = 0
@@ -235,11 +201,12 @@ def block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, dtype_dst, M
235
201
else :
236
202
raise ValueError (f"Unsupported dtype: { dtype_dst } " )
237
203
238
- c_desc = TmaDescKernelParam (output .data_ptr (), output .shape , [configs ["BLOCK_SIZE_M" ], configs ["BLOCK_SIZE_N" ]],
239
- output .element_size ())
204
+ BLOCK_M = configs ["BLOCK_SIZE_M" ]
205
+ BLOCK_N = configs ["BLOCK_SIZE_N" ]
206
+ c_desc = TensorDescriptor .from_tensor (output , [BLOCK_M , BLOCK_N ])
240
207
241
- grid = (triton .cdiv (M , configs [ "BLOCK_SIZE_M" ] ) * triton .cdiv (N , configs [ "BLOCK_SIZE_N" ] ), 1 )
242
- block_scaled_matmul_kernel [grid ](a_desc , a_scale , b_desc_or_tensor , b_scale , c_desc , M , N , K , a_scale .stride (0 ),
208
+ grid = (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N ), 1 )
209
+ block_scaled_matmul_kernel [grid ](a_desc , a_scale , b_desc , b_scale , c_desc , M , N , K , a_scale .stride (0 ),
243
210
a_scale .stride (1 ), a_scale .stride (2 ), a_scale .stride (3 ), dtype_dst ,
244
211
configs ["ELEM_PER_BYTE_A" ], configs ["ELEM_PER_BYTE_B" ], configs ["VEC_SIZE" ],
245
212
configs ["BLOCK_SIZE_M" ], configs ["BLOCK_SIZE_N" ], configs ["BLOCK_SIZE_K" ],
@@ -284,12 +251,17 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
284
251
285
252
b_ref = b_ref .to (torch .float32 ).T
286
253
287
- a_desc = TmaDescKernelParam ( a . data_ptr (), a . shape , [BLOCK_M , BLOCK_K // ELEM_PER_BYTE_A ], 1 )
254
+ a_desc = TensorDescriptor . from_tensor ( a , [BLOCK_M , BLOCK_K // ELEM_PER_BYTE_A ])
288
255
289
256
if block_scale_type == "mixed" :
290
- b_desc_or_tensor = b
257
+ b_desc = TensorDescriptor (
258
+ b ,
259
+ shape = [N , K // ELEM_PER_BYTE_B ],
260
+ strides = [K // ELEM_PER_BYTE_B , 1 ],
261
+ block_shape = [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ],
262
+ )
291
263
else :
292
- b_desc_or_tensor = TmaDescKernelParam ( b . data_ptr (), b . shape , [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ], 1 )
264
+ b_desc = TensorDescriptor . from_tensor ( b , [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ])
293
265
294
266
epsilon = 1e-8
295
267
a_scale = torch .rand ((M // 128 , K // VEC_SIZE // 4 , 32 , 4 , 4 ), device = device ) + epsilon
@@ -327,7 +299,7 @@ def unpack_scale(packed):
327
299
"ELEM_PER_BYTE_B" : ELEM_PER_BYTE_B ,
328
300
"VEC_SIZE" : VEC_SIZE ,
329
301
}
330
- return a_desc , a_scale , b_desc_or_tensor , b_scale , configs , reference
302
+ return a_desc , a_scale , b_desc , b_scale , configs , reference
331
303
332
304
333
305
def validate_block_scaled (M , N , K , block_scale_type = "nvfp4" ):
@@ -340,9 +312,9 @@ def alloc_fn(size: int, align: int, _):
340
312
# TMA load for mixed-precision fp4 is supported only by device TMA.
341
313
triton .set_allocator (alloc_fn )
342
314
343
- a_desc , a_scale , b_desc_or_tensor , b_scale , configs , reference = initialize_block_scaled (
344
- M , N , K , block_scale_type , compute_reference = True )
345
- output = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
315
+ a_desc , a_scale , b_desc , b_scale , configs , reference = initialize_block_scaled (M , N , K , block_scale_type ,
316
+ compute_reference = True )
317
+ output = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
346
318
torch .testing .assert_close (reference , output .to (torch .float32 ), atol = 1e-3 , rtol = 1e-3 )
347
319
print (f"✅ (pass { block_scale_type } )" )
348
320
@@ -353,19 +325,13 @@ def bench_block_scaled(K, block_scale_type="nvfp4", reps=10):
353
325
N = 8192
354
326
print (f"Problem Shape = { M } x{ N } x{ K } " )
355
327
356
- def alloc_fn (size : int , align : int , _ ):
357
- return torch .empty (size , dtype = torch .int8 , device = "cuda" )
358
-
359
- if block_scale_type == "mixed" :
360
- triton .set_allocator (alloc_fn )
361
-
362
- a_desc , a_scale , b_desc_or_tensor , b_scale , configs , _ = initialize_block_scaled (
363
- M , N , K , block_scale_type , compute_reference = False )
364
- _ = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
328
+ a_desc , a_scale , b_desc , b_scale , configs , _ = initialize_block_scaled (M , N , K , block_scale_type ,
329
+ compute_reference = False )
330
+ _ = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
365
331
366
332
proton .activate (0 )
367
333
for _ in range (reps ):
368
- _ = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
334
+ _ = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
369
335
proton .deactivate (0 )
370
336
print ("Done benchmarking" )
371
337
0 commit comments