@@ -219,9 +219,9 @@ def create_weights(
219219 # bias
220220 if module .bias :
221221 if w3_w1_bias_shape is None :
222- w3_w1_bias_shape = (module . expert_size_per_partition ,
223- module .intermediate_size_per_partition *
224- module .intermediate_size_expand_ratio )
222+ w3_w1_bias_shape = (
223+ module .expert_size_per_partition ,
224+ module .expand_intermediate_size_per_partition )
225225 if w2_bias_shape is None :
226226 w2_bias_shape = (module .expert_size_per_partition ,
227227 module .hidden_size )
@@ -515,8 +515,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
515515 def create_weights (self , module : torch .nn .Module ):
516516 weight_dtype = module .dtype
517517 w3_w1_weight_shape = (module .expert_size_per_partition ,
518- module .intermediate_size_per_partition *
519- module .intermediate_size_expand_ratio ,
518+ module .expand_intermediate_size_per_partition ,
520519 module .hidden_size )
521520 w2_weight_shape = (
522521 module .expert_size_per_partition ,
@@ -581,7 +580,7 @@ def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module,
581580 w3_weight_scale = w3_weight_scale [...].reshape ([])
582581 max_w3_w1_weight_scale = max (w1_weight_scale , w3_weight_scale )
583582
584- split_length = module .intermediate_size_per_partition * module . intermediate_size_expand_ratio // 2
583+ split_length = module .expand_intermediate_size_per_partition // 2
585584 w3_weight = dst_w3_w1_weight .narrow (
586585 dim = 0 , start = 0 , length = split_length ).to (dtype = module .dtype )
587586 w1_weight = dst_w3_w1_weight .narrow (
@@ -605,8 +604,7 @@ def create_weights(self, module: torch.nn.Module):
605604 weight_dtype = torch .float8_e4m3fn
606605
607606 w3_w1_weight_shape = (module .expert_size_per_partition ,
608- module .intermediate_size_per_partition *
609- module .intermediate_size_expand_ratio ,
607+ module .expand_intermediate_size_per_partition ,
610608 module .hidden_size )
611609 w2_weight_shape = (
612610 module .expert_size_per_partition ,
@@ -1655,6 +1653,38 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
16551653 Base class for NVFP4 fused MoE methods for all backends.
16561654 """
16571655
1656+ def get_weights_shapes (self , module : torch .nn .Module , weight_vec_size : int ,
1657+ block_scales_vec_size : int ):
1658+ # Divide by 16 because we use int64 to pack 16 fp4 values
1659+ w3_w1_weight_shape = (module .expert_size_per_partition ,
1660+ module .expand_intermediate_size_per_partition ,
1661+ module .hidden_size // weight_vec_size )
1662+ w2_weight_shape = (module .expert_size_per_partition , module .hidden_size ,
1663+ module .intermediate_size_per_partition //
1664+ weight_vec_size )
1665+
1666+ w3_w1_weight_scale_shape = (
1667+ module .expert_size_per_partition ,
1668+ module .expand_intermediate_size_per_partition , module .hidden_size //
1669+ module .scaling_vector_size // block_scales_vec_size )
1670+ w2_weight_scale_shape = (module .expert_size_per_partition ,
1671+ module .hidden_size ,
1672+ module .intermediate_size_per_partition //
1673+ module .scaling_vector_size //
1674+ block_scales_vec_size )
1675+
1676+ if module .bias :
1677+ w3_w1_bias_shape = (module .expert_size_per_partition ,
1678+ module .expand_intermediate_size_per_partition )
1679+ w2_bias_shape = (module .expert_size_per_partition ,
1680+ module .hidden_size )
1681+ else :
1682+ w3_w1_bias_shape = None
1683+ w2_bias_shape = None
1684+
1685+ return (w3_w1_weight_shape , w2_weight_shape , w3_w1_bias_shape ,
1686+ w2_bias_shape , w3_w1_weight_scale_shape , w2_weight_scale_shape )
1687+
16581688 def create_weights (self ,
16591689 module : torch .nn .Module ,
16601690 weight_dtype ,
@@ -1664,35 +1694,23 @@ def create_weights(self,
16641694 scaling_vector_size = 16 ):
16651695
16661696 module .scaling_vector_size = scaling_vector_size
1667- # Divide by 16 because we use int64 to pack 16 fp4 values
1668- w3_w1_weight_shape = (module .expert_size_per_partition ,
1669- module .intermediate_size_per_partition *
1670- module .intermediate_size_expand_ratio ,
1671- module .hidden_size // weight_vec_size )
1672- w2_weight_shape = (module .expert_size_per_partition , module .hidden_size ,
1673- module .intermediate_size_per_partition //
1674- weight_vec_size )
1697+
1698+ (w3_w1_weight_shape , w2_weight_shape , w3_w1_bias_shape , w2_bias_shape ,
1699+ w3_w1_weight_scale_shape ,
1700+ w2_weight_scale_shape ) = self .get_weights_shapes (
1701+ module , weight_vec_size , block_scales_vec_size )
16751702
16761703 # Divide by 4 because we use int32 to pack 4 fp8 values
16771704 # column parallel
1678- w3_w1_weight_scale = nn .Parameter (
1679- torch .ones (module .expert_size_per_partition ,
1680- module .intermediate_size_per_partition *
1681- module .intermediate_size_expand_ratio ,
1682- module .hidden_size // module .scaling_vector_size //
1683- block_scales_vec_size ,
1684- dtype = block_scales_dtype ),
1685- requires_grad = False )
1705+ w3_w1_weight_scale = nn .Parameter (torch .ones (w3_w1_weight_scale_shape ,
1706+ dtype = block_scales_dtype ),
1707+ requires_grad = False )
16861708 module .register_parameter ("w3_w1_weight_scale" , w3_w1_weight_scale )
16871709
16881710 # row parallel
1689- w2_weight_scale = nn .Parameter (
1690- torch .ones (module .expert_size_per_partition ,
1691- module .hidden_size ,
1692- module .intermediate_size_per_partition //
1693- module .scaling_vector_size // block_scales_vec_size ,
1694- dtype = block_scales_dtype ),
1695- requires_grad = False )
1711+ w2_weight_scale = nn .Parameter (torch .ones (w2_weight_scale_shape ,
1712+ dtype = block_scales_dtype ),
1713+ requires_grad = False )
16961714 module .register_parameter ("w2_weight_scale" , w2_weight_scale )
16971715
16981716 fc31_input_scale = nn .Parameter (torch .tensor (1. , dtype = torch .float32 ),
@@ -1717,8 +1735,12 @@ def create_weights(self,
17171735 # This will be initialized in load_quant_scales if pre_quant_scale exists
17181736 module .register_parameter ("fc31_act_scale" , None )
17191737
1720- super ().create_weights (module , weight_dtype , w3_w1_weight_shape ,
1721- w2_weight_shape )
1738+ super ().create_weights (module ,
1739+ weight_dtype ,
1740+ w3_w1_weight_shape = w3_w1_weight_shape ,
1741+ w2_weight_shape = w2_weight_shape ,
1742+ w3_w1_bias_shape = w3_w1_bias_shape ,
1743+ w2_bias_shape = w2_bias_shape )
17221744
17231745 self .setup_quant_scales (module )
17241746
@@ -2005,6 +2027,55 @@ def setup_quant_scales(self, module: torch.nn.Module):
20052027class NVFP4CutlassFusedMoEMethod (NVFP4FusedMoEMethod ):
20062028 weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
20072029 block_scales_dtype = FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE
2030+ NVFP4_ROW_ALIGNMENT = 128
2031+ NVFP4_COL_ALIGNMENT = 4
2032+
2033+ def get_weights_shapes (self , module : torch .nn .Module , weight_vec_size : int ,
2034+ block_scales_vec_size : int ):
2035+ """Override the base method to get aligned weights shapes for Cutlass nvfp4 alignment."""
2036+ intermediate_size_expand_aligned = (
2037+ module .expand_intermediate_size_per_partition +
2038+ self .NVFP4_ROW_ALIGNMENT -
2039+ 1 ) // self .NVFP4_ROW_ALIGNMENT * self .NVFP4_ROW_ALIGNMENT
2040+
2041+ if module .hidden_size % self .NVFP4_COL_ALIGNMENT != 0 :
2042+ raise ValueError (
2043+ f"hidden_size { module .hidden_size } must be divisible by { self .NVFP4_COL_ALIGNMENT } "
2044+ )
2045+ hidden_size_aligned = module .hidden_size
2046+
2047+ w3_w1_weight_shape = (module .expert_size_per_partition ,
2048+ intermediate_size_expand_aligned ,
2049+ hidden_size_aligned // weight_vec_size )
2050+ w2_weight_shape = (module .expert_size_per_partition ,
2051+ hidden_size_aligned ,
2052+ intermediate_size_expand_aligned //
2053+ module .intermediate_size_expand_ratio //
2054+ weight_vec_size )
2055+
2056+ w3_w1_weight_scale_shape = (module .expert_size_per_partition ,
2057+ intermediate_size_expand_aligned ,
2058+ hidden_size_aligned //
2059+ module .scaling_vector_size //
2060+ block_scales_vec_size )
2061+ w2_weight_scale_shape = (module .expert_size_per_partition ,
2062+ hidden_size_aligned ,
2063+ intermediate_size_expand_aligned //
2064+ module .intermediate_size_expand_ratio //
2065+ module .scaling_vector_size //
2066+ block_scales_vec_size )
2067+
2068+ if module .bias :
2069+ w3_w1_bias_shape = (module .expert_size_per_partition ,
2070+ intermediate_size_expand_aligned )
2071+ w2_bias_shape = (module .expert_size_per_partition ,
2072+ hidden_size_aligned )
2073+ else :
2074+ w3_w1_bias_shape = None
2075+ w2_bias_shape = None
2076+
2077+ return (w3_w1_weight_shape , w2_weight_shape , w3_w1_bias_shape ,
2078+ w2_bias_shape , w3_w1_weight_scale_shape , w2_weight_scale_shape )
20082079
20092080 def create_weights (self , module : torch .nn .Module ):
20102081 weight_vec_size = torch .iinfo (self .weight_dtype ).bits // 4
@@ -2029,21 +2100,16 @@ def load_expert_w3_w1_weight_scale_nvfp4(
20292100 module .tp_rank ,
20302101 TensorParallelMode .COLUMN ,
20312102 device = device )
2032- # Keep weights in device buffer
2033- # w3
2034- split_length = module .intermediate_size_per_partition * module .intermediate_size_expand_ratio // 2
2035- dst_w3_weight_scale = dst_w3_w1_weight_scale .narrow (dim = 0 ,
2036- start = 0 ,
2037- length = split_length )
2038- dst_w3_weight_scale .copy_ (
2039- w3_weight_scale .view (dst_w3_weight_scale .dtype ))
20402103
2041- # w1
2042- dst_w1_weight_scale = dst_w3_w1_weight_scale .narrow (dim = 0 ,
2043- start = split_length ,
2044- length = split_length )
2045- dst_w1_weight_scale .copy_ (
2046- w1_weight_scale .view (dst_w1_weight_scale .dtype ))
2104+ cast_w3_weight_scale = w3_weight_scale .view (
2105+ dst_w3_w1_weight_scale .dtype )
2106+ cast_w1_weight_scale = w1_weight_scale .view (
2107+ dst_w3_w1_weight_scale .dtype )
2108+ cast_w31_weight_scale = torch .cat (
2109+ [cast_w3_weight_scale , cast_w1_weight_scale ], dim = 0 )
2110+ cast_w31_weight_scale = self ._maybe_padding_shape (
2111+ cast_w31_weight_scale , dst_w3_w1_weight_scale )
2112+ dst_w3_w1_weight_scale .copy_ (cast_w31_weight_scale )
20472113
20482114 orig_shape = dst_w3_w1_weight_scale .shape
20492115
@@ -2065,9 +2131,12 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
20652131 module .tp_rank ,
20662132 TensorParallelMode .ROW ,
20672133 device = device )
2134+
2135+ cast_w2_weight_scale = w2_weight_scale .view (dst_w2_weight_scale .dtype )
2136+ cast_w2_weight_scale = self ._maybe_padding_shape (
2137+ cast_w2_weight_scale , dst_w2_weight_scale )
20682138 # Keep weights in device buffer
2069- dst_w2_weight_scale .copy_ (
2070- w2_weight_scale .view (dst_w2_weight_scale .dtype ))
2139+ dst_w2_weight_scale .copy_ (cast_w2_weight_scale )
20712140
20722141 orig_shape = dst_w2_weight_scale .shape
20732142
@@ -2079,6 +2148,60 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
20792148
20802149 dst_w2_weight_scale .copy_ (dst_w2_weight_scale_interleaved )
20812150
2151+ def load_expert_w3_w1_weight (self , module : torch .nn .Module ,
2152+ w1_weight : torch .Tensor ,
2153+ w3_weight : torch .Tensor ,
2154+ dst_w3_w1_weight : torch .Tensor ):
2155+ """Load and pad w1 and w3 weights for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
2156+ device = dst_w3_w1_weight .device
2157+ w1_weight_shard = load_weight_shard (w1_weight ,
2158+ module .tp_size ,
2159+ module .tp_rank ,
2160+ TensorParallelMode .COLUMN ,
2161+ device = device )
2162+ w3_weight_shard = load_weight_shard (w3_weight ,
2163+ module .tp_size ,
2164+ module .tp_rank ,
2165+ TensorParallelMode .COLUMN ,
2166+ device = device )
2167+
2168+ cast_w1_weight_shard = w1_weight_shard .view (dst_w3_w1_weight .dtype )
2169+ cast_w3_weight_shard = w3_weight_shard .view (dst_w3_w1_weight .dtype )
2170+ cast_w31_weight_shard = torch .cat (
2171+ [cast_w3_weight_shard , cast_w1_weight_shard ], dim = 0 )
2172+ cast_w31_weight_shard = self ._maybe_padding_shape (
2173+ cast_w31_weight_shard , dst_w3_w1_weight )
2174+ dst_w3_w1_weight .copy_ (cast_w31_weight_shard , non_blocking = True )
2175+
2176+ def load_expert_w2_weight (self , module : torch .nn .Module ,
2177+ w2_weight : torch .Tensor ,
2178+ dst_w2_weight : torch .Tensor ):
2179+ """Load and pad w2 weight for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
2180+ device = dst_w2_weight .device
2181+ w2_weight_shard = load_weight_shard (w2_weight ,
2182+ module .tp_size ,
2183+ module .tp_rank ,
2184+ TensorParallelMode .ROW ,
2185+ device = device )
2186+ cast_w2_weight_shard = w2_weight_shard .view (dst_w2_weight .dtype )
2187+ cast_w2_weight_shard = self ._maybe_padding_shape (
2188+ cast_w2_weight_shard , dst_w2_weight )
2189+ dst_w2_weight .copy_ (cast_w2_weight_shard , non_blocking = True )
2190+
2191+ def _maybe_padding_shape (self , source_tensor , dst_tensor ):
2192+ """Pad the source tensor to match the shape of the destination tensor."""
2193+ # In `get_weights_shapes` method, the shape of `weights` and `weight_scales` might be tuned to align with `NVFP4_ROW_ALIGNMENT`.
2194+ # Padding the `source_tensor` to match the shape of `dst_tensor` here.
2195+ assert len (source_tensor .shape ) == 2 and len (
2196+ dst_tensor .shape ) == 2 , "Only support 2D weights padding for now."
2197+ dst_row , dst_col = dst_tensor .shape
2198+ _row , _col = source_tensor .shape
2199+ if _row != dst_row or _col != dst_col :
2200+ source_tensor = torch .nn .functional .pad (
2201+ source_tensor , (0 , dst_col - _col , 0 , dst_row - _row ),
2202+ "constant" , 0 ).contiguous ()
2203+ return source_tensor
2204+
20822205
20832206class NVFP4CuteDslFusedMoEMethod (NVFP4CutlassFusedMoEMethod ):
20842207
0 commit comments