File tree Expand file tree Collapse file tree 2 files changed +15
-1
lines changed
tensorrt_llm/_torch/auto_deploy Expand file tree Collapse file tree 2 files changed +15
-1
lines changed Original file line number Diff line number Diff line change @@ -1344,6 +1344,7 @@ def _shard_parameter_node(
13441344
13451345 # Shard weight using the unified function (also updates the parameter)
13461346 original_weight = gm .get_parameter (weight_key )
1347+
13471348 _ , weight_new_shape = shard_weight_tensor (
13481349 gm = gm ,
13491350 weight_tensor = original_weight ,
@@ -1892,6 +1893,8 @@ def _process_ssm_sharding(
18921893 if "out_proj" not in str (n )
18931894 ]
18941895 for weight_node in weight_nodes :
1896+ # if is_any_ssm_op(list(weight_node.users)[0]):
1897+ # continue
18951898 weight_key = weight_node .target
18961899 # Get the weight parameter
18971900 try :
Original file line number Diff line number Diff line change @@ -131,6 +131,8 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
131131
132132def extract_weight_node (node : Node ) -> int :
133133 """Extracts the weight node from the given parametrized node"""
134+ gm = node .graph .owning_module
135+ param_names = {name for name , _ in gm .named_parameters ()}
134136
135137 def find_get_attr_node (weight_node : Node ) -> Node :
136138 """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op."""
@@ -141,7 +143,7 @@ def find_get_attr_node(weight_node: Node) -> Node:
141143 torch .ops .aten .view .default ,
142144 }
143145
144- if weight_node .op == "get_attr" :
146+ if weight_node .op == "get_attr" and weight_node . target in param_names :
145147 return weight_node
146148
147149 # If node is not in the list of allowable ops then return None
@@ -325,6 +327,15 @@ def is_any_ssm_op(node: Node) -> bool:
325327 )
326328
327329
330+ def is_any_conv_op (node : Node ) -> bool :
331+ return is_op (
332+ node ,
333+ ops = [
334+ torch .ops .auto_deploy .torch_causal_conv1d ,
335+ ],
336+ )
337+
338+
328339def is_any_attention_op (node : Node ) -> bool :
329340 return is_op (
330341 node ,
You can’t perform that action at this time.
0 commit comments