Skip to content

Commit ca705a9

Browse files
working SSM sharding
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 00c1d17 commit ca705a9

File tree

4 files changed

+123
-173
lines changed

4 files changed

+123
-173
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...models.factory import ModelFactory
1515
from ...shim.interface import CachedSequenceInterface
1616
from ...utils.node_utils import (
17-
extract_param_names_from_node,
17+
extract_weight_nodes,
1818
get_quantization_params_from_linear_node,
1919
is_bmm_op,
2020
is_linear_op,
@@ -136,13 +136,12 @@ def _insert_quantized_linear(
136136
137137
The state_dict is also updated to contain the sharded weights.
138138
"""
139-
param_name, _ = extract_param_names_from_node(node)
140-
original_weight = gm.get_parameter(param_name[0])
141-
new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False)
142-
modname, _, attrname = param_name.rpartition(".")
139+
weight_nodes = extract_weight_nodes(node)
140+
lin_weight = weight_nodes.weights[0]
141+
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
142+
modname, _, attrname = lin_weight.node_key.rpartition(".")
143143

144-
submod = gm.get_submodule(modname)
145-
setattr(submod, attrname, new_param)
144+
setattr(lin_weight.submod, attrname, new_param)
146145

147146
# check modelopt quantizers from graph
148147
if is_quantized_graph:
@@ -168,10 +167,12 @@ def _insert_quantized_linear(
168167
)
169168
# Note: canonicalize_graph() will remove input/weight/output quantizer
170169

171-
for scale_name, scale in self.default_scales(original_weight.shape).items():
172-
submod.register_buffer(scale_name, scale)
170+
for scale_name, scale in self.default_scales(lin_weight.tensor.shape).items():
171+
lin_weight.submod.register_buffer(scale_name, scale)
173172

174-
gm._register_load_state_dict_pre_hook(partial(self.load_hook, weight_name=param_name))
173+
gm._register_load_state_dict_pre_hook(
174+
partial(self.load_hook, weight_name=lin_weight.node_key)
175+
)
175176

176177
with gm.graph.inserting_before(node):
177178
scales = {}

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 21 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
LayerSubgraph,
3939
LayerType,
4040
bfs,
41+
extract_weight_name,
4142
extract_weight_nodes,
4243
filtered_nodes,
4344
get_all_layer_subgraphs,
@@ -1272,10 +1273,6 @@ def split_fused_tensor(
12721273
fused_dims: list = fused_weight_dims,
12731274
d: int = dim,
12741275
) -> torch.Tensor:
1275-
# dim_d = t.shape[d]
1276-
# num_parts = 1
1277-
# part_size = dim_d // num_parts
1278-
# fused_dims = [part_size] * num_parts
12791276
return torch.cat(
12801277
[split_tensor(w) for w in torch.split(t, fused_dims, dim=d)],
12811278
dim=d,
@@ -1343,23 +1340,35 @@ def _shard_parameter_node(
13431340
# # Shard weight using the unified function (also updates the parameter)
13441341
# original_weight = gm.get_parameter(weight_key)
13451342
weight_nodes = extract_weight_nodes(node)
1346-
for weight_node, bias_node in weight_nodes:
1343+
for weight_node in weight_nodes.weights:
13471344
_, weight_new_shape = shard_weight_tensor(
13481345
gm=gm,
1349-
weight_tensor=weight_node.node,
1346+
weight_tensor=weight_node.tensor,
13501347
param_key=weight_node.node_key,
13511348
dim=dim,
13521349
rank=rank,
13531350
world_size=world_size,
13541351
min_local_shape=min_local_shape,
13551352
fused_weight_dims=fused_weight_dims,
13561353
)
1354+
if quantization_cb is not None:
1355+
quantization_cb(
1356+
gm=gm,
1357+
submod=weight_node.submod,
1358+
node=node,
1359+
weight_key=weight_node.node_key,
1360+
weight_new_shape=weight_new_shape,
1361+
dim=dim,
1362+
rank=rank,
1363+
world_size=world_size,
1364+
)
13571365

1358-
if bias_node is not None and dim == 0:
1366+
for bias_node in weight_nodes.biases:
1367+
if dim == 0:
13591368
# update bias for dim 0 --> we can handle it like the weight
13601369
shard_weight_tensor(
13611370
gm=gm,
1362-
weight_tensor=bias_node.node,
1371+
weight_tensor=bias_node.tensor,
13631372
param_key=bias_node.node_key,
13641373
dim=dim,
13651374
rank=rank,
@@ -1381,18 +1390,6 @@ def _shard_parameter_node(
13811390
partial(_load_hook_remove, param_key=bias_node.node_key)
13821391
)
13831392

1384-
if quantization_cb is not None:
1385-
quantization_cb(
1386-
gm=gm,
1387-
submod=weight_node.submod,
1388-
node=node,
1389-
weight_key=weight_node.node_key,
1390-
weight_new_shape=weight_new_shape,
1391-
dim=dim,
1392-
rank=rank,
1393-
world_size=world_size,
1394-
)
1395-
13961393
# # # column shard with no gather: the output is sharded
13971394
if not add_dist:
13981395
return
@@ -1423,107 +1420,6 @@ def _update_node_args(node: Node, args: tuple) -> None:
14231420
)
14241421

14251422

1426-
def _insert_sharded_moe_stacked(
1427-
gm: GraphModule,
1428-
node: Node,
1429-
rank: int,
1430-
world_size: int,
1431-
allreduce_strategy: AllReduceStrategy,
1432-
scale_names: Sequence[str] = (),
1433-
):
1434-
"""Update the torch_moe node with sliced stacked weight tensors,
1435-
sharded `selected_experts` and `final_scales(router_logics)`.
1436-
Add an all_reduce node after the moe node.
1437-
1438-
For torch_moe with stacked tensor format (single-element lists containing 3D tensors).
1439-
1440-
NOTE: allreduce_strategy is MANDATORY and must be explicitly provided.
1441-
"""
1442-
if allreduce_strategy is None:
1443-
raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}")
1444-
1445-
# Extract the stacked tensors from single-element lists
1446-
# args[3] = w1_weight (Node representing list with one 3D tensor, or direct list)
1447-
# args[4] = w2_weight (Node representing list with one 3D tensor, or direct list)
1448-
1449-
# Helper to extract tensor node from list (handles both Node and direct list)
1450-
def extract_tensor_from_list_arg(list_arg):
1451-
if isinstance(list_arg, Node) and list_arg.target is list:
1452-
# It's a list() call node - extract from its args
1453-
return list_arg.args[0][0] # args[0] is the list content, [0] is first element
1454-
elif isinstance(list_arg, (list, tuple)):
1455-
# Direct list
1456-
return list_arg[0]
1457-
else:
1458-
raise ValueError(f"Unexpected list format: {type(list_arg)}")
1459-
1460-
w3_w1_tensor_node = extract_tensor_from_list_arg(node.args[3])
1461-
w2_tensor_node = extract_tensor_from_list_arg(node.args[4])
1462-
num_experts = _get_dim0_from_arg(gm, w3_w1_tensor_node)
1463-
1464-
args = list(node.args)
1465-
1466-
# -- Handle selected_experts and final_scales sharding --
1467-
selected_experts = args[1]
1468-
final_scales = args[2]
1469-
1470-
experts_per_rank = num_experts // world_size
1471-
1472-
with gm.graph.inserting_before(node):
1473-
lower = experts_per_rank * rank
1474-
# selected_experts_local = selected_experts - low
1475-
selected_experts_local = gm.graph.create_node(
1476-
"call_function", operator.sub, args=(selected_experts, lower), kwargs={}
1477-
)
1478-
1479-
# For num_experts % world_size != 0 case,
1480-
# assign the last (num_experts % world_size) experts to the last rank
1481-
div_node = gm.graph.create_node(
1482-
"call_function", operator.floordiv, args=(selected_experts, experts_per_rank), kwargs={}
1483-
)
1484-
1485-
comp_op = torch.ge if rank == world_size - 1 else torch.eq
1486-
rank_mask = gm.graph.create_node("call_function", comp_op, args=(div_node, rank), kwargs={})
1487-
1488-
# final_scales_local = final_scales * rank_mask
1489-
final_scales_local = gm.graph.create_node(
1490-
"call_function", operator.mul, args=(final_scales, rank_mask), kwargs={}
1491-
)
1492-
1493-
# -- Transform expert weight parameters --
1494-
local_lo, local_hi = _split_range_last_remainder(num_experts, world_size, rank)
1495-
1496-
# Transform w3_w1_stacked: slice experts, swap [W1,W3]->[W3,W1], transpose (E,H,2I)->(E,2I,H)
1497-
if isinstance(w3_w1_tensor_node, Node):
1498-
_transform_bmm_moe_weight_param(
1499-
gm, w3_w1_tensor_node, local_lo, local_hi, swap_gate_up=True
1500-
)
1501-
1502-
# Transform w2_stacked: slice experts, transpose (E,I,H)->(E,H,I)
1503-
if isinstance(w2_tensor_node, Node):
1504-
_transform_bmm_moe_weight_param(gm, w2_tensor_node, local_lo, local_hi, swap_gate_up=False)
1505-
1506-
# -- Update args (keep same lists/nodes, just with transformed parameters) --
1507-
args[1] = selected_experts_local
1508-
args[2] = final_scales_local
1509-
# args[3] and args[4] stay the same - we modified the parameters in-place
1510-
1511-
ad_logger.debug(
1512-
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
1513-
)
1514-
1515-
node.args = tuple(args)
1516-
1517-
# -- add an all_reduce node --
1518-
with gm.graph.inserting_after(node):
1519-
dist_node = gm.graph.call_function(
1520-
torch.ops.auto_deploy.torch_dist_all_reduce.default,
1521-
args=(node, allreduce_strategy),
1522-
)
1523-
node.replace_all_uses_with(dist_node)
1524-
dist_node.replace_input_with(dist_node, node)
1525-
1526-
15271423
def _insert_sharded_moe(
15281424
gm: GraphModule,
15291425
node: Node,
@@ -2251,9 +2147,9 @@ def detect_sharding_from_config(
22512147

22522148
for lin_node in linear_nodes:
22532149
# use node's weight name to get the module name
2254-
module_name = extract_weight_nodes(lin_node)[0].target
2150+
weight_name = extract_weight_name(lin_node)
22552151

2256-
if any(attn_name in module_name for attn_name in attn_names):
2152+
if any(attn_name in weight_name for attn_name in attn_names):
22572153
# find the next attention node and infer the head_dim
22582154
next_attention_node, _ = bfs(
22592155
lin_node, is_any_attention_op, attr_next="users", include_root=False
@@ -2277,7 +2173,7 @@ def detect_sharding_from_config(
22772173
# Then we escape dots, and finally we replace @ with .*
22782174
pattern_string = pattern_string.replace("*", "@")
22792175
pattern_regex = re.escape(pattern_string).replace("@", ".*")
2280-
if re.match(pattern_regex, module_name):
2176+
if re.match(pattern_regex, weight_name):
22812177
# we have a match. Get the config for this layer
22822178
config = tp_plan[key]
22832179

@@ -2316,7 +2212,7 @@ def detect_sharding_from_config(
23162212
elif "local" in config:
23172213
# Check if this applies to shared experts in EP parallelism.
23182214
# If yes, apply the TP col-row shard.
2319-
if "shared" in module_name:
2215+
if "shared" in weight_name:
23202216
col_row_action = config.replace("local_", "")
23212217
if col_row_action == "colwise":
23222218
transform_container.add(

0 commit comments

Comments
 (0)