Skip to content

Commit 8af5121

Browse files
authored
[FMDL-1222][feat] Support weight and weight_scale padding for NVFP4 MoE cutlass (#9358)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent ce7a42f commit 8af5121

File tree

3 files changed

+298
-49
lines changed

3 files changed

+298
-49
lines changed

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,10 @@ def enable_alltoall(self):
719719
"""
720720
return False
721721

722+
@property
723+
def expand_intermediate_size_per_partition(self):
724+
return self.intermediate_size_per_partition * self.intermediate_size_expand_ratio
725+
722726
def reducescatter_or_allreduce(
723727
self,
724728
inputs,

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 172 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,9 @@ def create_weights(
219219
# bias
220220
if module.bias:
221221
if w3_w1_bias_shape is None:
222-
w3_w1_bias_shape = (module.expert_size_per_partition,
223-
module.intermediate_size_per_partition *
224-
module.intermediate_size_expand_ratio)
222+
w3_w1_bias_shape = (
223+
module.expert_size_per_partition,
224+
module.expand_intermediate_size_per_partition)
225225
if w2_bias_shape is None:
226226
w2_bias_shape = (module.expert_size_per_partition,
227227
module.hidden_size)
@@ -515,8 +515,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
515515
def create_weights(self, module: torch.nn.Module):
516516
weight_dtype = module.dtype
517517
w3_w1_weight_shape = (module.expert_size_per_partition,
518-
module.intermediate_size_per_partition *
519-
module.intermediate_size_expand_ratio,
518+
module.expand_intermediate_size_per_partition,
520519
module.hidden_size)
521520
w2_weight_shape = (
522521
module.expert_size_per_partition,
@@ -581,7 +580,7 @@ def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module,
581580
w3_weight_scale = w3_weight_scale[...].reshape([])
582581
max_w3_w1_weight_scale = max(w1_weight_scale, w3_weight_scale)
583582

584-
split_length = module.intermediate_size_per_partition * module.intermediate_size_expand_ratio // 2
583+
split_length = module.expand_intermediate_size_per_partition // 2
585584
w3_weight = dst_w3_w1_weight.narrow(
586585
dim=0, start=0, length=split_length).to(dtype=module.dtype)
587586
w1_weight = dst_w3_w1_weight.narrow(
@@ -605,8 +604,7 @@ def create_weights(self, module: torch.nn.Module):
605604
weight_dtype = torch.float8_e4m3fn
606605

607606
w3_w1_weight_shape = (module.expert_size_per_partition,
608-
module.intermediate_size_per_partition *
609-
module.intermediate_size_expand_ratio,
607+
module.expand_intermediate_size_per_partition,
610608
module.hidden_size)
611609
w2_weight_shape = (
612610
module.expert_size_per_partition,
@@ -1655,6 +1653,38 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
16551653
Base class for NVFP4 fused MoE methods for all backends.
16561654
"""
16571655

1656+
def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
1657+
block_scales_vec_size: int):
1658+
# Divide by 16 because we use int64 to pack 16 fp4 values
1659+
w3_w1_weight_shape = (module.expert_size_per_partition,
1660+
module.expand_intermediate_size_per_partition,
1661+
module.hidden_size // weight_vec_size)
1662+
w2_weight_shape = (module.expert_size_per_partition, module.hidden_size,
1663+
module.intermediate_size_per_partition //
1664+
weight_vec_size)
1665+
1666+
w3_w1_weight_scale_shape = (
1667+
module.expert_size_per_partition,
1668+
module.expand_intermediate_size_per_partition, module.hidden_size //
1669+
module.scaling_vector_size // block_scales_vec_size)
1670+
w2_weight_scale_shape = (module.expert_size_per_partition,
1671+
module.hidden_size,
1672+
module.intermediate_size_per_partition //
1673+
module.scaling_vector_size //
1674+
block_scales_vec_size)
1675+
1676+
if module.bias:
1677+
w3_w1_bias_shape = (module.expert_size_per_partition,
1678+
module.expand_intermediate_size_per_partition)
1679+
w2_bias_shape = (module.expert_size_per_partition,
1680+
module.hidden_size)
1681+
else:
1682+
w3_w1_bias_shape = None
1683+
w2_bias_shape = None
1684+
1685+
return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
1686+
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)
1687+
16581688
def create_weights(self,
16591689
module: torch.nn.Module,
16601690
weight_dtype,
@@ -1664,35 +1694,23 @@ def create_weights(self,
16641694
scaling_vector_size=16):
16651695

16661696
module.scaling_vector_size = scaling_vector_size
1667-
# Divide by 16 because we use int64 to pack 16 fp4 values
1668-
w3_w1_weight_shape = (module.expert_size_per_partition,
1669-
module.intermediate_size_per_partition *
1670-
module.intermediate_size_expand_ratio,
1671-
module.hidden_size // weight_vec_size)
1672-
w2_weight_shape = (module.expert_size_per_partition, module.hidden_size,
1673-
module.intermediate_size_per_partition //
1674-
weight_vec_size)
1697+
1698+
(w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape, w2_bias_shape,
1699+
w3_w1_weight_scale_shape,
1700+
w2_weight_scale_shape) = self.get_weights_shapes(
1701+
module, weight_vec_size, block_scales_vec_size)
16751702

16761703
# Divide by 4 because we use int32 to pack 4 fp8 values
16771704
# column parallel
1678-
w3_w1_weight_scale = nn.Parameter(
1679-
torch.ones(module.expert_size_per_partition,
1680-
module.intermediate_size_per_partition *
1681-
module.intermediate_size_expand_ratio,
1682-
module.hidden_size // module.scaling_vector_size //
1683-
block_scales_vec_size,
1684-
dtype=block_scales_dtype),
1685-
requires_grad=False)
1705+
w3_w1_weight_scale = nn.Parameter(torch.ones(w3_w1_weight_scale_shape,
1706+
dtype=block_scales_dtype),
1707+
requires_grad=False)
16861708
module.register_parameter("w3_w1_weight_scale", w3_w1_weight_scale)
16871709

16881710
# row parallel
1689-
w2_weight_scale = nn.Parameter(
1690-
torch.ones(module.expert_size_per_partition,
1691-
module.hidden_size,
1692-
module.intermediate_size_per_partition //
1693-
module.scaling_vector_size // block_scales_vec_size,
1694-
dtype=block_scales_dtype),
1695-
requires_grad=False)
1711+
w2_weight_scale = nn.Parameter(torch.ones(w2_weight_scale_shape,
1712+
dtype=block_scales_dtype),
1713+
requires_grad=False)
16961714
module.register_parameter("w2_weight_scale", w2_weight_scale)
16971715

16981716
fc31_input_scale = nn.Parameter(torch.tensor(1., dtype=torch.float32),
@@ -1717,8 +1735,12 @@ def create_weights(self,
17171735
# This will be initialized in load_quant_scales if pre_quant_scale exists
17181736
module.register_parameter("fc31_act_scale", None)
17191737

1720-
super().create_weights(module, weight_dtype, w3_w1_weight_shape,
1721-
w2_weight_shape)
1738+
super().create_weights(module,
1739+
weight_dtype,
1740+
w3_w1_weight_shape=w3_w1_weight_shape,
1741+
w2_weight_shape=w2_weight_shape,
1742+
w3_w1_bias_shape=w3_w1_bias_shape,
1743+
w2_bias_shape=w2_bias_shape)
17221744

17231745
self.setup_quant_scales(module)
17241746

@@ -2005,6 +2027,55 @@ def setup_quant_scales(self, module: torch.nn.Module):
20052027
class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
20062028
weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
20072029
block_scales_dtype = FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE
2030+
NVFP4_ROW_ALIGNMENT = 128
2031+
NVFP4_COL_ALIGNMENT = 4
2032+
2033+
def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
2034+
block_scales_vec_size: int):
2035+
"""Override the base method to get aligned weights shapes for Cutlass nvfp4 alignment."""
2036+
intermediate_size_expand_aligned = (
2037+
module.expand_intermediate_size_per_partition +
2038+
self.NVFP4_ROW_ALIGNMENT -
2039+
1) // self.NVFP4_ROW_ALIGNMENT * self.NVFP4_ROW_ALIGNMENT
2040+
2041+
if module.hidden_size % self.NVFP4_COL_ALIGNMENT != 0:
2042+
raise ValueError(
2043+
f"hidden_size {module.hidden_size} must be divisible by {self.NVFP4_COL_ALIGNMENT}"
2044+
)
2045+
hidden_size_aligned = module.hidden_size
2046+
2047+
w3_w1_weight_shape = (module.expert_size_per_partition,
2048+
intermediate_size_expand_aligned,
2049+
hidden_size_aligned // weight_vec_size)
2050+
w2_weight_shape = (module.expert_size_per_partition,
2051+
hidden_size_aligned,
2052+
intermediate_size_expand_aligned //
2053+
module.intermediate_size_expand_ratio //
2054+
weight_vec_size)
2055+
2056+
w3_w1_weight_scale_shape = (module.expert_size_per_partition,
2057+
intermediate_size_expand_aligned,
2058+
hidden_size_aligned //
2059+
module.scaling_vector_size //
2060+
block_scales_vec_size)
2061+
w2_weight_scale_shape = (module.expert_size_per_partition,
2062+
hidden_size_aligned,
2063+
intermediate_size_expand_aligned //
2064+
module.intermediate_size_expand_ratio //
2065+
module.scaling_vector_size //
2066+
block_scales_vec_size)
2067+
2068+
if module.bias:
2069+
w3_w1_bias_shape = (module.expert_size_per_partition,
2070+
intermediate_size_expand_aligned)
2071+
w2_bias_shape = (module.expert_size_per_partition,
2072+
hidden_size_aligned)
2073+
else:
2074+
w3_w1_bias_shape = None
2075+
w2_bias_shape = None
2076+
2077+
return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
2078+
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)
20082079

20092080
def create_weights(self, module: torch.nn.Module):
20102081
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
@@ -2029,21 +2100,16 @@ def load_expert_w3_w1_weight_scale_nvfp4(
20292100
module.tp_rank,
20302101
TensorParallelMode.COLUMN,
20312102
device=device)
2032-
# Keep weights in device buffer
2033-
# w3
2034-
split_length = module.intermediate_size_per_partition * module.intermediate_size_expand_ratio // 2
2035-
dst_w3_weight_scale = dst_w3_w1_weight_scale.narrow(dim=0,
2036-
start=0,
2037-
length=split_length)
2038-
dst_w3_weight_scale.copy_(
2039-
w3_weight_scale.view(dst_w3_weight_scale.dtype))
20402103

2041-
# w1
2042-
dst_w1_weight_scale = dst_w3_w1_weight_scale.narrow(dim=0,
2043-
start=split_length,
2044-
length=split_length)
2045-
dst_w1_weight_scale.copy_(
2046-
w1_weight_scale.view(dst_w1_weight_scale.dtype))
2104+
cast_w3_weight_scale = w3_weight_scale.view(
2105+
dst_w3_w1_weight_scale.dtype)
2106+
cast_w1_weight_scale = w1_weight_scale.view(
2107+
dst_w3_w1_weight_scale.dtype)
2108+
cast_w31_weight_scale = torch.cat(
2109+
[cast_w3_weight_scale, cast_w1_weight_scale], dim=0)
2110+
cast_w31_weight_scale = self._maybe_padding_shape(
2111+
cast_w31_weight_scale, dst_w3_w1_weight_scale)
2112+
dst_w3_w1_weight_scale.copy_(cast_w31_weight_scale)
20472113

20482114
orig_shape = dst_w3_w1_weight_scale.shape
20492115

@@ -2065,9 +2131,12 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
20652131
module.tp_rank,
20662132
TensorParallelMode.ROW,
20672133
device=device)
2134+
2135+
cast_w2_weight_scale = w2_weight_scale.view(dst_w2_weight_scale.dtype)
2136+
cast_w2_weight_scale = self._maybe_padding_shape(
2137+
cast_w2_weight_scale, dst_w2_weight_scale)
20682138
# Keep weights in device buffer
2069-
dst_w2_weight_scale.copy_(
2070-
w2_weight_scale.view(dst_w2_weight_scale.dtype))
2139+
dst_w2_weight_scale.copy_(cast_w2_weight_scale)
20712140

20722141
orig_shape = dst_w2_weight_scale.shape
20732142

@@ -2079,6 +2148,60 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
20792148

20802149
dst_w2_weight_scale.copy_(dst_w2_weight_scale_interleaved)
20812150

2151+
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
2152+
w1_weight: torch.Tensor,
2153+
w3_weight: torch.Tensor,
2154+
dst_w3_w1_weight: torch.Tensor):
2155+
"""Load and pad w1 and w3 weights for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
2156+
device = dst_w3_w1_weight.device
2157+
w1_weight_shard = load_weight_shard(w1_weight,
2158+
module.tp_size,
2159+
module.tp_rank,
2160+
TensorParallelMode.COLUMN,
2161+
device=device)
2162+
w3_weight_shard = load_weight_shard(w3_weight,
2163+
module.tp_size,
2164+
module.tp_rank,
2165+
TensorParallelMode.COLUMN,
2166+
device=device)
2167+
2168+
cast_w1_weight_shard = w1_weight_shard.view(dst_w3_w1_weight.dtype)
2169+
cast_w3_weight_shard = w3_weight_shard.view(dst_w3_w1_weight.dtype)
2170+
cast_w31_weight_shard = torch.cat(
2171+
[cast_w3_weight_shard, cast_w1_weight_shard], dim=0)
2172+
cast_w31_weight_shard = self._maybe_padding_shape(
2173+
cast_w31_weight_shard, dst_w3_w1_weight)
2174+
dst_w3_w1_weight.copy_(cast_w31_weight_shard, non_blocking=True)
2175+
2176+
def load_expert_w2_weight(self, module: torch.nn.Module,
2177+
w2_weight: torch.Tensor,
2178+
dst_w2_weight: torch.Tensor):
2179+
"""Load and pad w2 weight for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
2180+
device = dst_w2_weight.device
2181+
w2_weight_shard = load_weight_shard(w2_weight,
2182+
module.tp_size,
2183+
module.tp_rank,
2184+
TensorParallelMode.ROW,
2185+
device=device)
2186+
cast_w2_weight_shard = w2_weight_shard.view(dst_w2_weight.dtype)
2187+
cast_w2_weight_shard = self._maybe_padding_shape(
2188+
cast_w2_weight_shard, dst_w2_weight)
2189+
dst_w2_weight.copy_(cast_w2_weight_shard, non_blocking=True)
2190+
2191+
def _maybe_padding_shape(self, source_tensor, dst_tensor):
2192+
"""Pad the source tensor to match the shape of the destination tensor."""
2193+
# In `get_weights_shapes` method, the shape of `weights` and `weight_scales` might be tuned to align with `NVFP4_ROW_ALIGNMENT`.
2194+
# Padding the `source_tensor` to match the shape of `dst_tensor` here.
2195+
assert len(source_tensor.shape) == 2 and len(
2196+
dst_tensor.shape) == 2, "Only support 2D weights padding for now."
2197+
dst_row, dst_col = dst_tensor.shape
2198+
_row, _col = source_tensor.shape
2199+
if _row != dst_row or _col != dst_col:
2200+
source_tensor = torch.nn.functional.pad(
2201+
source_tensor, (0, dst_col - _col, 0, dst_row - _row),
2202+
"constant", 0).contiguous()
2203+
return source_tensor
2204+
20822205

20832206
class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod):
20842207

0 commit comments

Comments
 (0)