@@ -81,15 +81,15 @@ def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
8181def _p_matmul_ogs (
8282 Y , YPtr , stride_y_k , stride_y_z , stride_y_m , stride_y_n ,
8383 YExpectedScale , YActualScale , YChecksumScale ,
84- stride_y_mx_z , stride_y_mx_m , stride_y_mx_n ,
84+ stride_y_mx_k , stride_y_mx_z , stride_y_mx_m , stride_y_mx_n ,
8585 X , XPtr , stride_x_z , stride_x_m , stride_x_k ,
8686 XScale ,
8787 XMxScale , stride_x_mx_z , stride_x_mx_m , stride_x_mx_k ,
8888 W , WPtr , stride_w_e , stride_w_k , stride_w_n , W_TRANSPOSE : tl .constexpr ,
8989 WScale ,
90- MxScale , stride_mx_e , stride_mx_k , stride_mx_n ,
90+ WMxScale , stride_w_mx_e , stride_w_mx_k , stride_w_mx_n ,
9191 B , stride_b_e , # Bias
92- NRows , M , N , K , # shapes
92+ M , N , K , # shapes
9393 # expt data
9494 Betas , Gammas ,
9595 GatherIndx ,
@@ -133,14 +133,14 @@ def _p_matmul_ogs(
133133 if Y_TMA_MODE is not None :
134134 Y = tl .make_tensor_descriptor (YPtr , Y .shape , Y .strides [:- 1 ] + (1 ,), Y .block_shape )
135135
136- is_microscaled_format : tl .constexpr = MxScale is not None
137- tl .static_assert (not is_microscaled_format or W_TRANSPOSE , "NYI. Non-transposed mxfp4 weights" )
136+ is_w_microscaled : tl .constexpr = WMxScale is not None
137+ tl .static_assert (not is_w_microscaled or W_TRANSPOSE , "NYI. Non-transposed mxfp4 weights" )
138138 MX_PACK_DIVISOR : tl .constexpr = MXFP_BLOCK_SIZE
139- if is_microscaled_format :
139+ if is_w_microscaled :
140140 w_type : tl .constexpr = get_dtype (W )
141141 tl .static_assert (w_type == tl .uint8 or (w_type == tl .float8e4nv or w_type == tl .float8e5 ),
142- "mx_weight_ptr must be uint8" )
143- tl .static_assert (get_dtype (MxScale ) == tl .uint8 , "mx_scale_ptr must be uint8" )
142+ "mx_weight_ptr must be uint8 or fp8 " )
143+ tl .static_assert (get_dtype (WMxScale ) == tl .uint8 , "mx_scale_ptr must be uint8" )
144144 tl .static_assert (BLOCK_K % MX_PACK_DIVISOR == 0 , "BLOCK_K must be a multiple of MX_PACK_DIVISOR" )
145145 tl .static_assert (SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None , "Only Blackwell swizzling is supported for scales" )
146146
@@ -153,6 +153,13 @@ def _p_matmul_ogs(
153153 MX_SCALE_BLOCK_K : tl .constexpr = 1
154154 PACKED_BLOCK_K_W : tl .constexpr = BLOCK_K
155155 tl .static_assert (SWIZZLE_MX_SCALE is None )
156+ is_x_microscaled : tl .constexpr = XMxScale is not None
157+ if is_x_microscaled :
158+ x_type : tl .constexpr = get_dtype (X )
159+ tl .static_assert (x_type == tl .float8e4nv , "mx_act_ptr must be float8e4nv" )
160+ tl .static_assert (XMxScale .dtype .element_ty == tl .uint8 , "mx_scale_ptr must be uint8" )
161+ tl .static_assert (BLOCK_K % MX_PACK_DIVISOR == 0 , "BLOCK_K must be a multiple of MX_PACK_DIVISOR" )
162+ is_out_microscaled : tl .constexpr = stride_y_mx_z is not None
156163
157164 if ExptOffsSum is not None :
158165 # Determine how much padding there is on the expert data. This allows us to
@@ -214,7 +221,7 @@ def _p_matmul_ogs(
214221 THREADS_PER_BLOCK : tl .constexpr = tl .extra .cuda .num_threads ()
215222 local_absmax = tl .full ([THREADS_PER_BLOCK ], 0.0 , tl .uint32 )
216223
217- DISALLOW_ACC_MULTI_BUFFER : tl .constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
224+ DISALLOW_ACC_MULTI_BUFFER : tl .constexpr = is_w_microscaled and BLOCK_M * BLOCK_N >= 128 * 256
218225
219226 for tile_id in tl .range (tl .program_id (0 ), num_tiles , NUM_SMS , flatten = True , disallow_acc_multi_buffer = DISALLOW_ACC_MULTI_BUFFER , warp_specialize = True ):
220227 expt_id , start_z , start_m , eM , off_m , off_n , pid_k = _load_tile_attrs (
@@ -241,25 +248,42 @@ def _p_matmul_ogs(
241248 else :
242249 offs_x_m = tl .load (GatherIndx + start_m .to (index_type ) + offs_m ,
243250 mask = mask_m , other = - N_EXPTS_ACT ) // N_EXPTS_ACT
244- elif X_TMA_MODE is None :
245- tl .static_assert (HAS_GATHER )
251+ elif X_TMA_MODE is None or is_x_microscaled :
246252 offs_m = off_m + tl .arange (0 , BLOCK_M )
247253 if M is not None :
248254 offs_m = tl .max_contiguous (tl .multiple_of (offs_m % M , BLOCK_M ), BLOCK_M )
249255 else :
250256 offs_m = tl .max_contiguous (tl .multiple_of (offs_m % eM , BLOCK_M ), BLOCK_M )
251257 # no needs to bounds-check here because `offs_m` wraps around M dim
252- offs_m = tl .load (GatherIndx + start_m .to (index_type ) + offs_m ) // N_EXPTS_ACT
258+ if GatherIndx is not None :
259+ tl .static_assert (HAS_GATHER )
260+ offs_m = tl .load (GatherIndx + start_m .to (index_type ) + offs_m ) // N_EXPTS_ACT
253261 offs_x_m = offs_m .to (index_type )[:, None ] * stride_x_m
254262
255263
264+ if is_x_microscaled :
265+ XMxScalePtrs = XMxScale + start_z .to (index_type ) * stride_x_mx_z
266+ if GatherIndx is None :
267+ XMxScalePtrs += start_m * stride_x_mx_m
268+ offs_k_scale = MX_SCALE_BLOCK_K * pid_k + tl .arange (0 , MX_SCALE_BLOCK_K )
269+ XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m ).to (index_type )[:, None ] * stride_x_mx_m
270+ XMxScalePtrs += offs_k_scale .to (index_type )[None , :] * stride_x_mx_k
271+ else :
272+ XMxScalePtrs = None
273+
256274 acc = tl .zeros ((BLOCK_N , BLOCK_M ) if SWAP_XW else (BLOCK_M , BLOCK_N ), dtype = tl .float32 )
257275 for ki in tl .range (k_tiles , disallow_acc_multi_buffer = DISALLOW_ACC_MULTI_BUFFER ):
258276 off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
259277 off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
260278 off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
261279
262280 # --- load x ---
281+ if is_x_microscaled :
282+ if EVEN_K :
283+ mask_k_scale = tl .full ([MX_SCALE_BLOCK_K ], True , dtype = tl .int1 )
284+ else :
285+ mask_k_scale = offs_k_scale < tl .cdiv (K , MX_PACK_DIVISOR )
286+
263287 if USE_GATHER_TMA :
264288 x = X .gather (offs_x_m , off_k )
265289 elif X_TMA_MODE == "dense" :
@@ -288,28 +312,33 @@ def _p_matmul_ogs(
288312 w = tl .reshape (W .load ([expt_id , off_k_w , off_n ]), W .block_shape [1 :])
289313
290314 # --- load w_scale ---
291- if is_microscaled_format :
315+ if is_w_microscaled :
292316 x_format : tl .constexpr = get_scaled_dot_format_string (x .dtype )
293- mx_format : tl .constexpr = get_scaled_dot_format_string (w .dtype )
294- if x_format == "fp16" or x_format == "bf16" :
317+ w_format : tl .constexpr = get_scaled_dot_format_string (w .dtype )
318+
319+ if is_x_microscaled :
320+ x_scales = tl .load (XMxScalePtrs , mask = mask_k_scale [None , :], other = 0.0 )
321+ elif x_format == "fp16" or x_format == "bf16" :
295322 x_scales : tl .constexpr = None
296323 else :
297324 x_scales = tl .full ((BLOCK_M , BLOCK_K // MX_PACK_DIVISOR ), 127 , dtype = tl .uint8 )
298325 if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" :
299326 flattened_expt_n_idx = expt_id * ((N + 127 ) // 128 ) + (off_n // 128 )
300- w_scales = MxScale .load ([0 , flattened_expt_n_idx , pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K ), 0 , 0 ])
327+ w_scales = WMxScale .load ([0 , flattened_expt_n_idx , pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K ), 0 , 0 ])
301328 w_scales = w_scales .reshape ((w_scales .shape [1 ], w_scales .shape [2 ] * w_scales .shape [- 2 ] * w_scales .shape [- 1 ]))
302329 w_scales = unswizzle_mx_scale_bw (w_scales )
303330 else :
304- w_scales = MxScale .load ([expt_id , off_k_mx , off_n ])
331+ w_scales = WMxScale .load ([expt_id , off_k_mx , off_n ])
305332 w_scales = tl .reshape (w_scales , * w_scales .shape [1 :]).T
306333
307334 # --- update accumulator ---
308- if is_microscaled_format :
335+ if is_w_microscaled :
309336 if SWAP_XW :
310- acc = tl .dot_scaled (w .T , w_scales , mx_format , x .T , x_scales , x_format , acc = acc , fast_math = True )
337+ acc = tl .dot_scaled (w .T , w_scales , w_format , x .T , x_scales , x_format , acc = acc , fast_math = True )
311338 else :
312- acc = tl .dot_scaled (x , x_scales , x_format , w , w_scales , mx_format , acc = acc , fast_math = True )
339+ acc = tl .dot_scaled (x , x_scales , x_format , w , w_scales , w_format , acc = acc , fast_math = True )
340+ if is_x_microscaled :
341+ XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K ) * stride_x_mx_k
313342 else :
314343 if SWAP_XW :
315344 acc = tl .dot (w .T , x .T , acc , max_num_imprecise_acc = MAX_NUM_IMPRECISE_ACC , allow_tf32 = ALLOW_TF32 )
@@ -392,6 +421,10 @@ def _p_matmul_ogs(
392421 tl .static_assert (EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR )
393422 tl .static_assert (len (accs ) == SUBTILE_FACTOR )
394423
424+ if is_out_microscaled :
425+ MX_SCALE_BLOCK_N : tl .constexpr = OUT_BLOCK_N // MXFP_BLOCK_SIZE
426+ N_MX_BLOCK : tl .constexpr = tl .cdiv (N , MXFP_BLOCK_SIZE )
427+
395428 for a_i in tl .static_range (len (accs )):
396429 acc_tile = accs [a_i ]
397430 acc_tile *= x_scale * w_scale
@@ -414,20 +447,47 @@ def _p_matmul_ogs(
414447
415448 if MASK_ACC :
416449 out = tl .where (mask_m [:, None ], out , 0.0 )
417- # Flexpoint
418- out_view = tl .reshape (out , [out .numel // THREADS_PER_BLOCK , THREADS_PER_BLOCK ], can_reorder = True )
419- local_absmax = tl .maximum (local_absmax , nan_propagating_absmax_reduce (out_view , axis = 0 ))
420- out = float_to_flex (
421- out , YExpectedScale ,
422- None , # ActualScale: local absmax is tracked and updated after the loop
423- YChecksumScale ,
424- None , # mask: out is manually masked to 0
425- YPtr , FLEXPOINT_SATURATE_INF
426- )
427- if EPILOGUE_FN is not None :
428- out = EPILOGUE_FN (out , * epilogue_fn_args , target_dtype = YPtr .dtype .element_ty , pid = len (accs )* tile_id1 + a_i )
429-
430450 out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
451+ if is_out_microscaled :
452+ tl .static_assert (EPILOGUE_FN is not None )
453+ offs_y_n = out_off_n + tl .arange (0 , OUT_BLOCK_N )
454+ mask_n = offs_y_n < yN
455+ out , out_scale = EPILOGUE_FN (out , mask_m [:, None ] & mask_n [None , :], * epilogue_fn_args )
456+ tl .static_assert (BLOCK_N % MX_SCALE_BLOCK_N == 0 , "" )
457+ offs_y_n_scale = off_n1 // ACTIVATION_REDUCTION_N // MXFP_BLOCK_SIZE + a_i * MX_SCALE_BLOCK_N + tl .arange (0 , MX_SCALE_BLOCK_N )
458+ mask_n_scale = offs_y_n_scale < N_MX_BLOCK
459+ offs_y_mx_k = 0
460+ if USE_SCATTER_TMA :
461+ # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
462+ # there shouldn't be any other negative values.
463+ offs_y_mx_z = 0
464+ offs_y_mx_m = (offs_y_m .to (tl .uint32 , bitcast = True ) & 0x7FFFFFFF ).to (tl .int32 , bitcast = True )
465+ elif Y_TMA_MODE == "dense" :
466+ offs_y_mx_z = pid_k * batch_size + start_z1
467+ offs_y_mx_m = off_m1 + tl .arange (0 , BLOCK_M )
468+ elif Y_TMA_MODE == "ragged" :
469+ offs_y_mx_z = pid_k
470+ offs_y_mx_m = start_m1 + off_m1 + tl .arange (0 , BLOCK_M )
471+ else :
472+ tl .static_assert (Y_TMA_MODE is None )
473+ offs_y_mx_k = pid_k1
474+ offs_y_mx_z = start_z1
475+ YActualScalePtrs = YActualScale + offs_y_mx_k .to (index_type ) * stride_y_mx_k + offs_y_mx_z .to (index_type ) * stride_y_mx_z + offs_y_mx_m .to (index_type )[:, None ] * stride_y_mx_m + offs_y_n_scale .to (index_type )[None , :] * stride_y_mx_n
476+ tl .store (YActualScalePtrs , out_scale , mask = mask_m [:, None ] & mask_n_scale [None , :])
477+ else :
478+ # Flexpoint
479+ out_view = tl .reshape (out , [out .numel // THREADS_PER_BLOCK , THREADS_PER_BLOCK ], can_reorder = True )
480+ local_absmax = tl .maximum (local_absmax , nan_propagating_absmax_reduce (out_view , axis = 0 ))
481+ out = float_to_flex (
482+ out , YExpectedScale ,
483+ None , # ActualScale: local absmax is tracked and updated after the loop
484+ YChecksumScale ,
485+ None , # mask: out is manually masked to 0
486+ YPtr , FLEXPOINT_SATURATE_INF
487+ )
488+ if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8 :
489+ out = EPILOGUE_FN (out , * epilogue_fn_args , target_dtype = YPtr .dtype .element_ty , pid = len (accs )* tile_id1 + a_i )
490+
431491 out = out .to (YPtr .dtype .element_ty )
432492 if USE_SCATTER_TMA :
433493 # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
@@ -452,7 +512,7 @@ def _p_matmul_ogs(
452512
453513
454514 # Update the flexpoint scales
455- if YActualScale is not None :
515+ if YActualScale is not None and not is_out_microscaled :
456516 tl .atomic_max (YActualScale , compute_scale (local_absmax .to (tl .float32 , bitcast = True ), YPtr ), sem = "relaxed" )
457517
458518
0 commit comments