@@ -2300,6 +2300,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23002300 return result
23012301
23022302
2303+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
2304+ class ReplaceMulTensorWithMulAndFullOpsPass (ExportPass ):
2305+ """
2306+ Extracts a single value argument of mul op to a separate full op.
2307+ """
2308+
2309+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
2310+ for mul_node in graph_module .graph .find_nodes (
2311+ op = "call_function" , target = torch .ops .aten .mul .Tensor
2312+ ):
2313+ x_arg , const_arg = mul_node .args
2314+
2315+ # Swap arguments if the order is wrong
2316+ if isinstance (const_arg , torch .fx .Node ):
2317+ x_arg , const_arg = const_arg , x_arg
2318+
2319+ # Skip if the const_arg is not a scalar
2320+ if not isinstance (const_arg , (float , int )) or not isinstance (
2321+ x_arg , torch .fx .Node
2322+ ):
2323+ continue
2324+
2325+ # Cast the const_arg to the dtype of the x_arg
2326+ full_arg = self .resolve_full_arg (x_arg , const_arg )
2327+
2328+ # Extract an argument to a separate full op.
2329+ with graph_module .graph .inserting_before (mul_node ):
2330+ full_tensor = graph_module .graph .call_function (
2331+ exir_ops .edge .aten .full .default , args = ([1 ], full_arg )
2332+ )
2333+ new_mul_node = graph_module .graph .call_function (
2334+ torch .ops .aten .mul .Tensor , args = (x_arg , full_tensor )
2335+ )
2336+ # Replace the old mul with a newly created mul.
2337+ mul_node .replace_all_uses_with (new_mul_node )
2338+ graph_module .graph .erase_node (mul_node )
2339+ return super ().call (graph_module )
2340+
2341+ def resolve_full_arg (self , x_arg , const_arg ):
2342+ if x_arg .meta ["val" ].dtype == torch .float32 and isinstance (const_arg , int ):
2343+ const_arg = float (const_arg )
2344+ if x_arg .meta ["val" ].dtype == torch .int32 and isinstance (const_arg , float ):
2345+ const_arg = int (const_arg )
2346+ return const_arg
2347+
2348+
23032349# This class encapsulates all the functions that replace/switch one op in the
23042350# graph with another.
23052351class CadenceReplaceOpsInGraph :
0 commit comments