Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2263,9 +2263,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplacePowWithMullPass(ExportPass):
class ReplacePowWithMulPass(ExportPass):
"""
Replace the pow op with degree 2 for a mul op.
Replace the pow op for a mul op.
"""

def call_operator(
Expand All @@ -2275,19 +2275,32 @@ def call_operator(
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
# TODO(eigen): Add support for other degrees.
if (
op
not in {
exir_ops.edge.aten.pow.Scalar,
if not (
len(args) > 1
and isinstance(args[1], int)
and cast(int, args[1]) > 1
and cast(int, args[1]) < 5
and op
in {
exir_ops.edge.aten.pow.Tensor_Scalar,
}
or args[0] != 2
):
return super().call_operator(op, args, kwargs, meta)

x = args[0]
exponent = cast(int, args[1])

if exponent > 2:
for _ in range(exponent, 2, -1):
x = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(x, args[0]),
{},
meta,
)
return super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(args[1], args[1]),
(x, args[0]),
{},
meta,
)
Expand Down Expand Up @@ -2429,5 +2442,5 @@ class CadenceReplaceOpsInGraph:
ReplaceWhereWithFullArgsWithWhereScalar,
ReplaceGeluWithApproximateGeluPass,
ReplaceSplitWithSlicePass,
ReplacePowWithMullPass,
ReplacePowWithMulPass,
]
45 changes: 40 additions & 5 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
ReplaceNopTransposeOrPermuteWithViewPass,
ReplacePadWithCatPass,
ReplacePermuteWithTransposePass,
ReplacePowWithMullPass,
ReplacePowWithMulPass,
ReplaceRepeatWithCatPass,
ReplaceScalarTensorWithFullPass,
ReplaceScalarWithTensorArgPass,
Expand Down Expand Up @@ -1382,22 +1382,23 @@ def test_replace_split_with_sizes_with_slice(self):
2,
)

def test_replace_pow_with_mul(self):
@parameterized.expand([[2], [3], [4]])
def test_replace_pow_with_mul(self, exponent: int):
class Pow(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.pow.Scalar(2, input)
return torch.ops.aten.pow.Tensor_Scalar(input, exponent)

input = torch.randn(2, 1, 64)

graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module

p = ReplacePowWithMullPass()
p = ReplacePowWithMulPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module

self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.pow.Scalar,
exir_ops.edge.aten.pow.Tensor_Scalar,
),
0,
)
Expand All @@ -1407,9 +1408,43 @@ def forward(self, input):
graph_after_passes,
exir_ops.edge.aten.mul.Tensor,
),
exponent - 1,
)

@parameterized.expand(
[
[1],
[1.5],
]
)
def test_replace_pow_with_mul_not_applied(self, exponent):
class Pow(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.pow.Tensor_Scalar(input, exponent)

input = torch.randn(2, 1, 64)

graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module

p = ReplacePowWithMulPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module

self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.pow.Tensor_Scalar,
),
1,
)

self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.mul.Tensor,
),
0,
)


class TestReplaceIm2rowWithViewPass(unittest.TestCase):
def test_no_replacement_for_conv(self):
Expand Down
Loading