@@ -2300,6 +2300,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2300
2300
return result
2301
2301
2302
2302
2303
+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
2304
+ class ReplaceMulTensorWithMulAndFullOpsPass (ExportPass ):
2305
+ """
2306
+ Extracts a single value argument of mul op to a separate full op.
2307
+ """
2308
+
2309
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
2310
+ for mul_node in graph_module .graph .find_nodes (
2311
+ op = "call_function" , target = torch .ops .aten .mul .Tensor
2312
+ ):
2313
+ x_arg , const_arg = mul_node .args
2314
+
2315
+ # Swap arguments if the order is wrong
2316
+ if isinstance (const_arg , torch .fx .Node ):
2317
+ x_arg , const_arg = const_arg , x_arg
2318
+
2319
+ # Skip if the const_arg is not a scalar
2320
+ if not isinstance (const_arg , (float , int )) or not isinstance (
2321
+ x_arg , torch .fx .Node
2322
+ ):
2323
+ continue
2324
+
2325
+ # Cast the const_arg to the dtype of the x_arg
2326
+ full_arg = self .resolve_full_arg (x_arg , const_arg )
2327
+
2328
+ # Extract an argument to a separate full op.
2329
+ with graph_module .graph .inserting_before (mul_node ):
2330
+ full_tensor = graph_module .graph .call_function (
2331
+ exir_ops .edge .aten .full .default , args = ([1 ], full_arg )
2332
+ )
2333
+ new_mul_node = graph_module .graph .call_function (
2334
+ torch .ops .aten .mul .Tensor , args = (x_arg , full_tensor )
2335
+ )
2336
+ # Replace the old mul with a newly created mul.
2337
+ mul_node .replace_all_uses_with (new_mul_node )
2338
+ graph_module .graph .erase_node (mul_node )
2339
+ return super ().call (graph_module )
2340
+
2341
+ def resolve_full_arg (self , x_arg , const_arg ):
2342
+ if x_arg .meta ["val" ].dtype == torch .float32 and isinstance (const_arg , int ):
2343
+ const_arg = float (const_arg )
2344
+ if x_arg .meta ["val" ].dtype == torch .int32 and isinstance (const_arg , float ):
2345
+ const_arg = int (const_arg )
2346
+ return const_arg
2347
+
2348
+
2303
2349
# This class encapsulates all the functions that replace/switch one op in the
2304
2350
# graph with another.
2305
2351
class CadenceReplaceOpsInGraph :
0 commit comments