Skip to content

Commit 93f9a2b

Browse files
wip
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 4a5ef84 commit 93f9a2b

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,7 @@ def _shard_parameter_node(
13441344

13451345
# Shard weight using the unified function (also updates the parameter)
13461346
original_weight = gm.get_parameter(weight_key)
1347+
13471348
_, weight_new_shape = shard_weight_tensor(
13481349
gm=gm,
13491350
weight_tensor=original_weight,
@@ -1892,6 +1893,8 @@ def _process_ssm_sharding(
18921893
if "out_proj" not in str(n)
18931894
]
18941895
for weight_node in weight_nodes:
1896+
# if is_any_ssm_op(list(weight_node.users)[0]):
1897+
# continue
18951898
weight_key = weight_node.target
18961899
# Get the weight parameter
18971900
try:

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
131131

132132
def extract_weight_node(node: Node) -> int:
133133
"""Extracts the weight node from the given parametrized node"""
134+
gm = node.graph.owning_module
135+
param_names = {name for name, _ in gm.named_parameters()}
134136

135137
def find_get_attr_node(weight_node: Node) -> Node:
136138
"""Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op."""
@@ -141,7 +143,7 @@ def find_get_attr_node(weight_node: Node) -> Node:
141143
torch.ops.aten.view.default,
142144
}
143145

144-
if weight_node.op == "get_attr":
146+
if weight_node.op == "get_attr" and weight_node.target in param_names:
145147
return weight_node
146148

147149
# If node is not in the list of allowable ops then return None
@@ -325,6 +327,15 @@ def is_any_ssm_op(node: Node) -> bool:
325327
)
326328

327329

330+
def is_any_conv_op(node: Node) -> bool:
331+
return is_op(
332+
node,
333+
ops=[
334+
torch.ops.auto_deploy.torch_causal_conv1d,
335+
],
336+
)
337+
338+
328339
def is_any_attention_op(node: Node) -> bool:
329340
return is_op(
330341
node,

0 commit comments

Comments
 (0)