From fad9410193c1d5e0643f9f47f5545daa6ec907b7 Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Wed, 30 Apr 2025 16:30:44 -0700 Subject: [PATCH] Implement a coversion pass from pow(E,x) to E-1 mul ops. (#10564) Summary: Update pow to mul pass by accepting exponent more than 2. Reviewed By: Vysarat Differential Revision: D73473271 --- backends/cadence/aot/replace_ops.py | 33 +++++++++----- .../aot/tests/test_replace_ops_passes.py | 45 ++++++++++++++++--- 2 files changed, 63 insertions(+), 15 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 17eff9de0eb..34a1abdf0b1 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2263,9 +2263,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplacePowWithMullPass(ExportPass): +class ReplacePowWithMulPass(ExportPass): """ - Replace the pow op with degree 2 for a mul op. + Replace the pow op for a mul op. """ def call_operator( @@ -2275,19 +2275,32 @@ def call_operator( kwargs: Dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - # TODO(eigen): Add support for other degrees. - if ( - op - not in { - exir_ops.edge.aten.pow.Scalar, + if not ( + len(args) > 1 + and isinstance(args[1], int) + and cast(int, args[1]) > 1 + and cast(int, args[1]) < 5 + and op + in { + exir_ops.edge.aten.pow.Tensor_Scalar, } - or args[0] != 2 ): return super().call_operator(op, args, kwargs, meta) + x = args[0] + exponent = cast(int, args[1]) + + if exponent > 2: + for _ in range(exponent, 2, -1): + x = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (x, args[0]), + {}, + meta, + ) return super().call_operator( exir_ops.edge.aten.mul.Tensor, - (args[1], args[1]), + (x, args[0]), {}, meta, ) @@ -2429,5 +2442,5 @@ class CadenceReplaceOpsInGraph: ReplaceWhereWithFullArgsWithWhereScalar, ReplaceGeluWithApproximateGeluPass, ReplaceSplitWithSlicePass, - ReplacePowWithMullPass, + ReplacePowWithMulPass, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index f2b78ccd800..b8ebe21832c 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -41,7 +41,7 @@ ReplaceNopTransposeOrPermuteWithViewPass, ReplacePadWithCatPass, ReplacePermuteWithTransposePass, - ReplacePowWithMullPass, + ReplacePowWithMulPass, ReplaceRepeatWithCatPass, ReplaceScalarTensorWithFullPass, ReplaceScalarWithTensorArgPass, @@ -1382,22 +1382,23 @@ def test_replace_split_with_sizes_with_slice(self): 2, ) - def test_replace_pow_with_mul(self): + @parameterized.expand([[2], [3], [4]]) + def test_replace_pow_with_mul(self, exponent: int): class Pow(torch.nn.Module): def forward(self, input): - return torch.ops.aten.pow.Scalar(2, input) + return torch.ops.aten.pow.Tensor_Scalar(input, exponent) input = torch.randn(2, 1, 64) graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module - p = ReplacePowWithMullPass() + p = ReplacePowWithMulPass() graph_after_passes = cast(PassResult, p(graph_module)).graph_module self.assertEqual( count_node( graph_after_passes, - exir_ops.edge.aten.pow.Scalar, + exir_ops.edge.aten.pow.Tensor_Scalar, ), 0, ) @@ -1407,9 +1408,43 @@ def forward(self, input): graph_after_passes, exir_ops.edge.aten.mul.Tensor, ), + exponent - 1, + ) + + @parameterized.expand( + [ + [1], + [1.5], + ] + ) + def test_replace_pow_with_mul_not_applied(self, exponent): + class Pow(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.pow.Tensor_Scalar(input, exponent) + + input = torch.randn(2, 1, 64) + + graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module + + p = ReplacePowWithMulPass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.pow.Tensor_Scalar, + ), 1, ) + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.mul.Tensor, + ), + 0, + ) + class TestReplaceIm2rowWithViewPass(unittest.TestCase): def test_no_replacement_for_conv(self):