88
99import torch
1010from pydantic import BaseModel , ConfigDict
11+ from torch import nn
1112from torch ._ops import OpOverload , OpOverloadPacket
1213from torch .fx import GraphModule , Node
1314
@@ -51,6 +52,13 @@ class LayerSubgraph(BaseModel):
5152 min_local_shape : int = 1
5253
5354
55+ class WeightNode (BaseModel ):
56+ model_config = ConfigDict (arbitrary_types_allowed = True )
57+ node : Node
58+ node_key : str
59+ submod : nn .Module
60+
61+
5462@dataclass
5563class modelopt_quant_params :
5664 input_node : torch .fx .node .Node = None
@@ -129,10 +137,12 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
129137 return input_params , weight_params , output_params
130138
131139
132- def extract_weight_node (node : Node ) -> int :
133- """Extracts the weight node from the given parametrized node"""
140+ def extract_weight_nodes (node : Node ) -> Tuple [ List [ WeightNode ], List [ WeightNode ]] :
141+ """Extracts the list of weight node and optional bias node from the given parametrized node"""
134142 gm = node .graph .owning_module
135- param_names = {name for name , _ in gm .named_parameters ()}
143+ param_names = {name for name , _ in gm .named_parameters ()}.union (
144+ {name for name , _ in gm .named_buffers ()}
145+ )
136146
137147 def find_get_attr_node (weight_node : Node ) -> Node :
138148 """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op."""
@@ -157,55 +167,68 @@ def find_get_attr_node(weight_node: Node) -> Node:
157167 return None
158168
159169 if is_op (node , torch .ops .aten .bmm ):
160- weight_node = node .args [1 ]
170+ # no bias for bmm
171+ return [WeightNode (node = node .args [1 ], node_key = node .args [1 ].target )], []
161172 # for other parametrized nodes, we need to find the weight node
162173 else :
163- weight_nodes = [
174+ all_weight_nodes = [
164175 n for n in node .args if isinstance (n , Node ) and find_get_attr_node (n ) is not None
165176 ]
166- # can be two weights (if bias weight is present)
167- weight_node = None
168- if weight_nodes :
169- weight_node = weight_nodes [0 ]
170- # for modelopt quantized graph, there will be a quantize_op
171- _ , weight_params , _ = get_quantization_params_from_linear_node (node )
172- weight_node = weight_params .input_node if weight_params else weight_node
173- assert weight_node is not None , "Expected at least one weight node in the parametrized node"
174- return find_get_attr_node (weight_node )
177+ # separate weight nodes and bias nodes
178+ weight_nodes = [n for n in all_weight_nodes if n .target .endswith ("weight" )]
179+ bias_nodes = [n for n in all_weight_nodes if n .target .endswith ("bias" )]
180+ weight_nodes = [
181+ WeightNode (
182+ node = n , node_key = n .target , submod = gm .get_submodule (n .target .rpartition ("." )[0 ])
183+ )
184+ for n in weight_nodes
185+ ]
186+ bias_nodes = [
187+ WeightNode (
188+ node = n , node_key = n .target , submod = gm .get_submodule (n .target .rpartition ("." )[0 ])
189+ )
190+ for n in bias_nodes
191+ ]
192+ return weight_nodes , bias_nodes
175193
176194
177195def num_users_of_weight_node (node : Node ) -> int :
178196 """Returns the number of users of the weight node of the given parametrized node."""
179- weight_node = extract_weight_node (node )
197+ weight_node = extract_weight_nodes (node )[ 0 ]
180198 return len (weight_node .users ) if weight_node is not None else 0
181199
182200
183- def extract_param_names_from_node (node : Node ) -> Tuple [str , Optional [str ]]:
201+ def extract_param_names_from_node (node : Node ) -> Tuple [List [ str ] , Optional [List [ str ] ]]:
184202 """Extracts the name of the parameter associated with the given parametrized node.
185203
186204 Args:
187205 node: node with weight parameters in the graph.
188206 """
189- weight_node = extract_weight_node ( node )
207+ # try:
190208
191- assert weight_node , "Cannot identify weight parameter of linear node."
209+ # except:
210+ # a = 1
192211
193- # Map arg to named parameter
194- weight_name = weight_node .target
212+ # assert weight_node, "Cannot identify weight parameter of linear node."
195213
196- # check for bias
197- if is_op (node , torch .ops .aten .bmm ):
198- bias_node = node .args [2 ] if len (node .args ) > 2 else None
199- else :
200- weight_nodes = [n for n in node .args if isinstance (n , Node ) and n .op == "get_attr" ]
201- if len (weight_nodes ) > 1 :
202- bias_node = weight_nodes [1 ]
203- else :
204- bias_node = None
205- assert bias_node is None or bias_node .op == "get_attr"
206- bias_name = bias_node .target if bias_node is not None else None
214+ # # Map arg to named parameter
215+ # weight_name = weight_node.target
216+
217+ # # check for bias
218+ # if is_op(node, torch.ops.aten.bmm):
219+ # bias_node = node.args[2] if len(node.args) > 2 else None
220+ # else:
221+ # weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"]
222+ # if len(weight_nodes) > 1:
223+ # bias_node = weight_nodes[1]
224+ # else:
225+ # bias_node = None
226+ # assert bias_node is None or bias_node.op == "get_attr"
227+ # bias_name = bias_node.target if bias_node is not None else None
207228
208- return weight_name , bias_name
229+ # return weight_name, bias_name
230+ weight_nodes , bias_nodes = extract_weight_nodes (node )
231+ return [n .node_key for n in weight_nodes ], [n .node_key for n in bias_nodes ]
209232
210233
211234def get_op_overload_packet (node : Union [OpOverloadPacket , OpOverload ]) -> OpOverloadPacket :
@@ -751,9 +774,9 @@ def get_weight_shape(
751774 if not is_any_lin_op (node ):
752775 return None
753776 if dim is None :
754- return shape (extract_weight_node (node ))
777+ return shape (extract_weight_nodes (node )[ 0 ] )
755778 else :
756- return shape (extract_weight_node (node ))[dim ]
779+ return shape (extract_weight_nodes (node )[ 0 ] )[dim ]
757780
758781
759782def get_layer_after_linear_node (
0 commit comments