@@ -124,49 +124,6 @@ def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
124124 return freq
125125
126126
127- # Return the output node of the graph
128- def get_output_node (graph : torch .fx .Graph ) -> torch .fx .Node :
129- assert graph is not None , "Cannot get output of an empty graph"
130- output_node = next (iter (reversed (graph .nodes )))
131- assert (
132- output_node and output_node .op == "output" and len (output_node .args ) == 1
133- ), "Failed to find output node"
134- return output_node
135-
136-
137- # Return true if the node is part of the flattened output
138- def is_node_in_flattened_output (graph : torch .fx .Graph , node : torch .fx .Node ) -> bool :
139- output_node = get_output_node (graph )
140- return node in tree_flatten (output_node .args [0 ])[0 ]
141-
142-
143- # Return the shape of the incoming node.
144- def get_shape (
145- graph_module : torch .fx .GraphModule , node : torch .fx .Node
146- ) -> Union [torch .Size , None ]:
147- """
148- Return the shape of the tensor correspnding to node. If the node has a
149- tensor spec, return the shape from the metadata. If the node is a param,
150- return it shape. Otherwise return None.
151- """
152- try :
153- # Case 1. node is a scalar
154- if isinstance (node , (float , int , bool )):
155- return torch .Size ([1 ])
156- # Case 2. node has TensorSpec metadata
157- fake_tensor = node .meta .get ("val" )
158- if fake_tensor is not None :
159- return fake_tensor .shape
160- # Case 3. node holds a param
161- if node .op == "get_attr" :
162- attr_node = getattr (graph_module , node .target )
163- return attr_node .shape
164- # Default: return None
165- return None
166- except RuntimeError :
167- return None
168-
169-
170127# Print the ops and how many times they occur multiple graph modules:
171128# from export, from to_edge, and from final. Print the available
172129# implementations for each op, and error out if the op is not supported.
0 commit comments