@@ -194,6 +194,7 @@ class Case:
194194 x_transpose : bool = False
195195 w_transpose : bool = False
196196 y_transpose : bool = False
197+ colmajor_mxfp_weight : bool = True
197198
198199
199200@pytest .mark .parametrize (
@@ -267,6 +268,7 @@ class Case:
267268 Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 2 , 1 ),
268269 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
269270 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
271+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , colmajor_mxfp_weight = False ),
270272 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 ),
271273 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
272274 Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 ),
@@ -313,7 +315,7 @@ class Case:
313315@pytest .mark .parametrize ("has_y_gammas" , [False , True ])
314316@pytest .mark .parametrize ("is_persistent" , [False , True ])
315317def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
316- n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
318+ n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
317319 x_transpose , w_transpose , y_transpose ,
318320 device , opt_flags_scope ):
319321 # TODO: remove when Triton FP8 supports proper RTNE
@@ -461,14 +463,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
461463 w_scale_layout , w_scale_layout_opts = layout .make_default_matmul_mxfp4_w_scale_layout (
462464 mx_axis = mx_axis , num_warps = 8 )
463465 # downcast to mxfp
464- w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
465- w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
466- w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
467- w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
468- w_scale_tri = wrap_torch_tensor (w_scale_tri )
469- # convert layouts
470- w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
471- w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
466+ w_tri_orig = w_tri
467+ if colmajor_mxfp_weight :
468+ w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
469+ w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
470+ w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
471+ w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
472+ w_scale_tri = wrap_torch_tensor (w_scale_tri )
473+ # convert layouts
474+ w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
475+ w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
476+ else :
477+ if is_cuda () and torch .cuda .get_device_capability ()[0 ] < 10 :
478+ pytest .skip ("transposed mxfp weight not supported with cuda capability < 10" )
479+ if block_m == 16 :
480+ pytest .skip ("PassManager::run failed from Triton compiler" )
481+ # TODO: swizzling for rowmajor
482+
483+ # A typical use case is we already quantized col-major weight,
484+ # and we want matmul with its transposed row-major weight w/o
485+ # requantization.
486+
487+ # put abs_max of each 32x32 block to diagonal so scales of transposed agree
488+ w_ndim = w_tri .ndim
489+ if w_ndim == 2 :
490+ w_tri = w_tri .unsqueeze (0 )
491+ BLOCK_SIZE = int (MXFP_BLOCK_SIZE )
492+ for e , i , j in itertools .product (range (w_tri .shape [0 ]), range (0 , w_tri .shape [1 ], BLOCK_SIZE ), range (0 , w_tri .shape [2 ], BLOCK_SIZE )):
493+ i_end = min (i + BLOCK_SIZE , w_tri .shape [1 ])
494+ j_end = min (j + BLOCK_SIZE , w_tri .shape [2 ])
495+ block = w_tri [e , i :i_end , j :j_end ]
496+ m_abs = block .abs ().max ()
497+ i_len = i_end - i
498+ j_len = j_end - j
499+ min_len = min (i_len , j_len )
500+ signs = torch .randint (0 , 2 , (max (i_len , j_len ),), device = w_tri .device ) * 2 - 1
501+ block .diagonal (dim1 = - 2 , dim2 = - 1 )[:] = signs [:min_len ] * m_abs
502+ if j_len > i_len :
503+ block [i_len - 1 , i_len :] = signs [min_len :] * m_abs
504+ elif i_len > j_len :
505+ block [j_len :, j_len - 1 ] = signs [min_len :] * m_abs
506+ if w_ndim == 2 :
507+ w_tri = w_tri .squeeze (0 )
508+
509+ # matmul with rowmajor weight expects scale is separately
510+ # constructed (not much additional memory needed).
511+ _ , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
512+ # reuse quantized value from colmajor
513+ w_tri_rowmajor , w_scale_tri_rowmajor = downcast_to_mxfp (w_tri .mT .contiguous (), weight_dtype , axis = mx_axis )
514+ w_ref = upcast_from_mxfp (w_tri_rowmajor , w_scale_tri_rowmajor , torch .bfloat16 , axis = mx_axis ).mT .contiguous ()
515+ w_tri = w_tri_rowmajor .data .mT
516+
517+ def _pad_and_block (x : torch .Tensor ) -> torch .Tensor :
518+ x = torch .nn .functional .pad (x , (0 , x .shape [- 1 ] % BLOCK_SIZE ), mode = "replicate" )
519+ return x .view (* x .shape [:- 1 ], x .shape [- 1 ] // BLOCK_SIZE , BLOCK_SIZE )
520+
521+ # check if generated scale is transpose-invariant as intended construction
522+ # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
523+ w_scale_tri_blocked = _pad_and_block (w_scale_tri )
524+ w_scale_tri_sampled = w_scale_tri_blocked [..., 0 :1 ]
525+ # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
526+ w_scale_tri_rowmajor_blocked = _pad_and_block (w_scale_tri_rowmajor )
527+ w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked [..., 0 :1 ]
528+ assert torch .equal (w_scale_tri_sampled .expand_as (w_scale_tri_blocked ), w_scale_tri_blocked )
529+ assert torch .equal (w_scale_tri_rowmajor_sampled .expand_as (w_scale_tri_rowmajor_blocked ), w_scale_tri_rowmajor_blocked )
530+ assert torch .equal (w_scale_tri_sampled .squeeze (- 1 ), w_scale_tri_rowmajor_sampled .squeeze (- 1 ).mT )
531+
472532 precision_opt .weight_scale = w_scale_tri
473533 epilogue = None
474534 if act_mxfp8 :
@@ -477,7 +537,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
477537 is_input_batched = x_tri .ndim == 3
478538 y_shape = x_tri .shape if is_input_batched else (1 ,) + x_tri .shape
479539 n_rows = y_shape [1 ] if gindx is None or mode == "batched" else gindx .dst_indx .shape [0 ]
480- y_shape = (y_shape [0 ], n_rows , w_tri .shape [- 1 ])
540+ y_shape = (y_shape [0 ], n_rows , w_tri_orig .shape [- 1 ])
481541 if sindx is None or mode == "batched" :
482542 if not is_input_batched :
483543 y_shape = (y_shape [1 ], y_shape [2 ])
0 commit comments