Skip to content

Commit f520030

Browse files
eigen-kfacebook-github-bot
authored andcommitted
to be merged with main commit
Differential Revision: D73817641
1 parent c5dd476 commit f520030

File tree

2 files changed

+64
-16
lines changed

2 files changed

+64
-16
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2263,9 +2263,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
22632263

22642264

22652265
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2266-
class ReplacePowWithMullPass(ExportPass):
2266+
class ReplacePowWithMulPass(ExportPass):
22672267
"""
2268-
Replace the pow op with degree 2 for a mul op.
2268+
Replace the pow op for a mul op.
22692269
"""
22702270

22712271
def call_operator(
@@ -2275,19 +2275,29 @@ def call_operator(
22752275
kwargs: Dict[str, Argument],
22762276
meta: NodeMetadata,
22772277
) -> ProxyValue:
2278-
# TODO(eigen): Add support for other degrees.
2279-
if (
2280-
op
2281-
not in {
2282-
exir_ops.edge.aten.pow.Scalar,
2283-
}
2284-
or args[0] != 2
2278+
if not (
2279+
len(args) > 1 and
2280+
isinstance(args[1], int) and
2281+
cast(int, args[1]) > 1 and
2282+
cast(int, args[1]) < 5 and
2283+
op in { exir_ops.edge.aten.pow.Tensor_Scalar,}
22852284
):
22862285
return super().call_operator(op, args, kwargs, meta)
22872286

2287+
x = args[0]
2288+
exponent = cast(int, args[1])
2289+
2290+
if exponent > 2 :
2291+
for _ in range(exponent, 2, -1):
2292+
x = super().call_operator(
2293+
exir_ops.edge.aten.mul.Tensor,
2294+
(x, args[0]),
2295+
{},
2296+
meta,
2297+
)
22882298
return super().call_operator(
22892299
exir_ops.edge.aten.mul.Tensor,
2290-
(args[1], args[1]),
2300+
(x, args[0]),
22912301
{},
22922302
meta,
22932303
)
@@ -2429,5 +2439,5 @@ class CadenceReplaceOpsInGraph:
24292439
ReplaceWhereWithFullArgsWithWhereScalar,
24302440
ReplaceGeluWithApproximateGeluPass,
24312441
ReplaceSplitWithSlicePass,
2432-
ReplacePowWithMullPass,
2442+
ReplacePowWithMulPass,
24332443
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
ReplaceEmptyTensorsWithFullPass,
3535
ReplaceFunctionallyEquivalentOpTargets,
3636
ReplaceGeluWithApproximateGeluPass,
37+
ReplacePowWithMulPass,
3738
ReplaceIm2RowWithViewPass,
3839
ReplaceLinearWithFullyConnectedOpPass,
3940
ReplaceMatmulWithTransposedMatmulPass,
4041
ReplaceMMWithAddMMPass,
4142
ReplaceNopTransposeOrPermuteWithViewPass,
4243
ReplacePadWithCatPass,
4344
ReplacePermuteWithTransposePass,
44-
ReplacePowWithMullPass,
4545
ReplaceRepeatWithCatPass,
4646
ReplaceScalarTensorWithFullPass,
4747
ReplaceScalarWithTensorArgPass,
@@ -1382,22 +1382,27 @@ def test_replace_split_with_sizes_with_slice(self):
13821382
2,
13831383
)
13841384

1385-
def test_replace_pow_with_mul(self):
1385+
@parameterized.expand(
1386+
[
1387+
[2], [3], [4]
1388+
]
1389+
)
1390+
def test_replace_pow_with_mul(self, exponent: int):
13861391
class Pow(torch.nn.Module):
13871392
def forward(self, input):
1388-
return torch.ops.aten.pow.Scalar(2, input)
1393+
return torch.ops.aten.pow.Tensor_Scalar(input, exponent)
13891394

13901395
input = torch.randn(2, 1, 64)
13911396

13921397
graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module
13931398

1394-
p = ReplacePowWithMullPass()
1399+
p = ReplacePowWithMulPass()
13951400
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
13961401

13971402
self.assertEqual(
13981403
count_node(
13991404
graph_after_passes,
1400-
exir_ops.edge.aten.pow.Scalar,
1405+
exir_ops.edge.aten.pow.Tensor_Scalar,
14011406
),
14021407
0,
14031408
)
@@ -1407,9 +1412,42 @@ def forward(self, input):
14071412
graph_after_passes,
14081413
exir_ops.edge.aten.mul.Tensor,
14091414
),
1415+
exponent - 1,
1416+
)
1417+
1418+
@parameterized.expand(
1419+
[
1420+
[1], [1.5],
1421+
]
1422+
)
1423+
def test_replace_pow_with_mul_not_applied(self, exponent):
1424+
class Pow(torch.nn.Module):
1425+
def forward(self, input):
1426+
return torch.ops.aten.pow.Tensor_Scalar(input, exponent)
1427+
1428+
input = torch.randn(2, 1, 64)
1429+
1430+
graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module
1431+
1432+
p = ReplacePowWithMulPass()
1433+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1434+
1435+
self.assertEqual(
1436+
count_node(
1437+
graph_after_passes,
1438+
exir_ops.edge.aten.pow.Tensor_Scalar,
1439+
),
14101440
1,
14111441
)
14121442

1443+
self.assertEqual(
1444+
count_node(
1445+
graph_after_passes,
1446+
exir_ops.edge.aten.mul.Tensor,
1447+
),
1448+
0,
1449+
)
1450+
14131451

14141452
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
14151453
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)