@@ -197,7 +197,6 @@ class Case:
197197 x_transpose : bool = False
198198 w_transpose : bool = False
199199 y_transpose : bool = False
200- colmajor_mxfp_weight : bool = True
201200
202201
203202@pytest .mark .parametrize (
@@ -270,7 +269,6 @@ class Case:
270269 Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 2 , 1 ),
271270 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
272271 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
273- Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , colmajor_mxfp_weight = False ),
274272 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 ),
275273 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
276274 Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 ),
@@ -317,7 +315,7 @@ class Case:
317315@pytest .mark .parametrize ("has_y_gammas" , [False , True ])
318316@pytest .mark .parametrize ("is_persistent" , [False , True ])
319317def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
320- n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
318+ n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
321319 x_transpose , w_transpose , y_transpose ,
322320 device , opt_flags_scope ):
323321 # TODO: remove when Triton FP8 supports proper RTNE
@@ -465,72 +463,14 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
465463 w_scale_layout , w_scale_layout_opts = layout .make_default_matmul_mxfp4_w_scale_layout (
466464 mx_axis = mx_axis , num_warps = 8 )
467465 # downcast to mxfp
468- w_tri_orig = w_tri
469- if colmajor_mxfp_weight :
470- w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
471- w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
472- w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
473- w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
474- w_scale_tri = wrap_torch_tensor (w_scale_tri )
475- # convert layouts
476- w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
477- w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
478- else :
479- if torch .cuda .get_device_capability ()[0 ] < 10 :
480- pytest .skip ("transposed mxfp weight not supported with cuda capability < 10" )
481- if block_m == 16 :
482- pytest .skip ("PassManager::run failed from Triton compiler" )
483- # TODO: swizzling for rowmajor
484-
485- # A typical use case is we already quantized col-major weight,
486- # and we want matmul with its transposed row-major weight w/o
487- # requantization.
488-
489- # put abs_max of each 32x32 block to diagonal so scales of transposed agree
490- w_ndim = w_tri .ndim
491- if w_ndim == 2 :
492- w_tri = w_tri .unsqueeze (0 )
493- BLOCK_SIZE = int (MXFP_BLOCK_SIZE )
494- for e , i , j in itertools .product (range (w_tri .shape [0 ]), range (0 , w_tri .shape [1 ], BLOCK_SIZE ), range (0 , w_tri .shape [2 ], BLOCK_SIZE )):
495- i_end = min (i + BLOCK_SIZE , w_tri .shape [1 ])
496- j_end = min (j + BLOCK_SIZE , w_tri .shape [2 ])
497- block = w_tri [e , i :i_end , j :j_end ]
498- m_abs = block .abs ().max ()
499- i_len = i_end - i
500- j_len = j_end - j
501- min_len = min (i_len , j_len )
502- signs = torch .randint (0 , 2 , (max (i_len , j_len ),), device = w_tri .device ) * 2 - 1
503- block .diagonal (dim1 = - 2 , dim2 = - 1 )[:] = signs [:min_len ] * m_abs
504- if j_len > i_len :
505- block [i_len - 1 , i_len :] = signs [min_len :] * m_abs
506- elif i_len > j_len :
507- block [j_len :, j_len - 1 ] = signs [min_len :] * m_abs
508- if w_ndim == 2 :
509- w_tri = w_tri .squeeze (0 )
510-
511- # matmul with rowmajor weight expects scale is separately
512- # constructed (not much additional memory needed).
513- _ , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
514- # reuse quantized value from colmajor
515- w_tri_rowmajor , w_scale_tri_rowmajor = downcast_to_mxfp (w_tri .mT .contiguous (), weight_dtype , axis = mx_axis )
516- w_ref = upcast_from_mxfp (w_tri_rowmajor , w_scale_tri_rowmajor , torch .bfloat16 , axis = mx_axis ).mT .contiguous ()
517- w_tri = w_tri_rowmajor .data .mT
518-
519- def _pad_and_block (x : torch .Tensor ) -> torch .Tensor :
520- x = torch .nn .functional .pad (x , (0 , x .shape [- 1 ] % BLOCK_SIZE ), mode = "replicate" )
521- return x .view (* x .shape [:- 1 ], x .shape [- 1 ] // BLOCK_SIZE , BLOCK_SIZE )
522-
523- # check if generated scale is transpose-invariant as intended construction
524- # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
525- w_scale_tri_blocked = _pad_and_block (w_scale_tri )
526- w_scale_tri_sampled = w_scale_tri_blocked [..., 0 :1 ]
527- # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
528- w_scale_tri_rowmajor_blocked = _pad_and_block (w_scale_tri_rowmajor )
529- w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked [..., 0 :1 ]
530- assert torch .equal (w_scale_tri_sampled .expand_as (w_scale_tri_blocked ), w_scale_tri_blocked )
531- assert torch .equal (w_scale_tri_rowmajor_sampled .expand_as (w_scale_tri_rowmajor_blocked ), w_scale_tri_rowmajor_blocked )
532- assert torch .equal (w_scale_tri_sampled .squeeze (- 1 ), w_scale_tri_rowmajor_sampled .squeeze (- 1 ).mT )
533-
466+ w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
467+ w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
468+ w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
469+ w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
470+ w_scale_tri = wrap_torch_tensor (w_scale_tri )
471+ # convert layouts
472+ w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
473+ w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
534474 precision_opt .weight_scale = w_scale_tri
535475 epilogue = None
536476 if act_mxfp8 :
@@ -539,7 +479,7 @@ def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
539479 is_input_batched = x_tri .ndim == 3
540480 y_shape = x_tri .shape if is_input_batched else (1 ,) + x_tri .shape
541481 n_rows = y_shape [1 ] if gindx is None or mode == "batched" else gindx .dst_indx .shape [0 ]
542- y_shape = (y_shape [0 ], n_rows , w_tri_orig .shape [- 1 ])
482+ y_shape = (y_shape [0 ], n_rows , w_tri .shape [- 1 ])
543483 if sindx is None or mode == "batched" :
544484 if not is_input_batched :
545485 y_shape = (y_shape [1 ], y_shape [2 ])
0 commit comments