@@ -197,6 +197,7 @@ class Case:
197197 x_transpose : bool = False
198198 w_transpose : bool = False
199199 y_transpose : bool = False
200+ colmajor_mxfp_weight : bool = True
200201
201202
202203@pytest .mark .parametrize (
@@ -269,6 +270,7 @@ class Case:
269270 Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 2 , 1 ),
270271 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
271272 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
273+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , colmajor_mxfp_weight = False ),
272274 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 ),
273275 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
274276 Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 ),
@@ -315,7 +317,7 @@ class Case:
315317@pytest .mark .parametrize ("has_y_gammas" , [False , True ])
316318@pytest .mark .parametrize ("is_persistent" , [False , True ])
317319def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
318- n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
320+ n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
319321 x_transpose , w_transpose , y_transpose ,
320322 device , opt_flags_scope ):
321323 # TODO: remove when Triton FP8 supports proper RTNE
@@ -463,14 +465,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
463465 w_scale_layout , w_scale_layout_opts = layout .make_default_matmul_mxfp4_w_scale_layout (
464466 mx_axis = mx_axis , num_warps = 8 )
465467 # downcast to mxfp
466- w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
467- w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
468- w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
469- w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
470- w_scale_tri = wrap_torch_tensor (w_scale_tri )
471- # convert layouts
472- w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
473- w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
468+ w_tri_orig = w_tri
469+ if colmajor_mxfp_weight :
470+ w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
471+ w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
472+ w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
473+ w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
474+ w_scale_tri = wrap_torch_tensor (w_scale_tri )
475+ # convert layouts
476+ w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
477+ w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
478+ else :
479+ if torch .cuda .get_device_capability ()[0 ] < 10 :
480+ pytest .skip ("transposed mxfp weight not supported with cuda capability < 10" )
481+ if block_m == 16 :
482+ pytest .skip ("PassManager::run failed from Triton compiler" )
483+ # TODO: swizzling for rowmajor
484+
485+ # A typical use case is we already quantized col-major weight,
486+ # and we want matmul with its transposed row-major weight w/o
487+ # requantization.
488+
489+ # put abs_max of each 32x32 block to diagonal so scales of transposed agree
490+ w_ndim = w_tri .ndim
491+ if w_ndim == 2 :
492+ w_tri = w_tri .unsqueeze (0 )
493+ BLOCK_SIZE = int (MXFP_BLOCK_SIZE )
494+ for e , i , j in itertools .product (range (w_tri .shape [0 ]), range (0 , w_tri .shape [1 ], BLOCK_SIZE ), range (0 , w_tri .shape [2 ], BLOCK_SIZE )):
495+ i_end = min (i + BLOCK_SIZE , w_tri .shape [1 ])
496+ j_end = min (j + BLOCK_SIZE , w_tri .shape [2 ])
497+ block = w_tri [e , i :i_end , j :j_end ]
498+ m_abs = block .abs ().max ()
499+ i_len = i_end - i
500+ j_len = j_end - j
501+ min_len = min (i_len , j_len )
502+ signs = torch .randint (0 , 2 , (max (i_len , j_len ),), device = w_tri .device ) * 2 - 1
503+ block .diagonal (dim1 = - 2 , dim2 = - 1 )[:] = signs [:min_len ] * m_abs
504+ if j_len > i_len :
505+ block [i_len - 1 , i_len :] = signs [min_len :] * m_abs
506+ elif i_len > j_len :
507+ block [j_len :, j_len - 1 ] = signs [min_len :] * m_abs
508+ if w_ndim == 2 :
509+ w_tri = w_tri .squeeze (0 )
510+
511+ # matmul with rowmajor weight expects scale is separately
512+ # constructed (not much additional memory needed).
513+ _ , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
514+ # reuse quantized value from colmajor
515+ w_tri_rowmajor , w_scale_tri_rowmajor = downcast_to_mxfp (w_tri .mT .contiguous (), weight_dtype , axis = mx_axis )
516+ w_ref = upcast_from_mxfp (w_tri_rowmajor , w_scale_tri_rowmajor , torch .bfloat16 , axis = mx_axis ).mT .contiguous ()
517+ w_tri = w_tri_rowmajor .data .mT
518+
519+ def _pad_and_block (x : torch .Tensor ) -> torch .Tensor :
520+ x = torch .nn .functional .pad (x , (0 , x .shape [- 1 ] % BLOCK_SIZE ), mode = "replicate" )
521+ return x .view (* x .shape [:- 1 ], x .shape [- 1 ] // BLOCK_SIZE , BLOCK_SIZE )
522+
523+ # check if generated scale is transpose-invariant as intended construction
524+ # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
525+ w_scale_tri_blocked = _pad_and_block (w_scale_tri )
526+ w_scale_tri_sampled = w_scale_tri_blocked [..., 0 :1 ]
527+ # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
528+ w_scale_tri_rowmajor_blocked = _pad_and_block (w_scale_tri_rowmajor )
529+ w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked [..., 0 :1 ]
530+ assert torch .equal (w_scale_tri_sampled .expand_as (w_scale_tri_blocked ), w_scale_tri_blocked )
531+ assert torch .equal (w_scale_tri_rowmajor_sampled .expand_as (w_scale_tri_rowmajor_blocked ), w_scale_tri_rowmajor_blocked )
532+ assert torch .equal (w_scale_tri_sampled .squeeze (- 1 ), w_scale_tri_rowmajor_sampled .squeeze (- 1 ).mT )
533+
474534 precision_opt .weight_scale = w_scale_tri
475535 epilogue = None
476536 if act_mxfp8 :
@@ -479,7 +539,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
479539 is_input_batched = x_tri .ndim == 3
480540 y_shape = x_tri .shape if is_input_batched else (1 ,) + x_tri .shape
481541 n_rows = y_shape [1 ] if gindx is None or mode == "batched" else gindx .dst_indx .shape [0 ]
482- y_shape = (y_shape [0 ], n_rows , w_tri .shape [- 1 ])
542+ y_shape = (y_shape [0 ], n_rows , w_tri_orig .shape [- 1 ])
483543 if sindx is None or mode == "batched" :
484544 if not is_input_batched :
485545 y_shape = (y_shape [1 ], y_shape [2 ])
0 commit comments