Skip to content

Commit a6aceea

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Implement ReplaceMulTensorWithMulAndFullOpsPass. (#11577)
Summary: Pull Request resolved: #11577 Implement pass which searches for the mul.Tensor op and extracts the argument to a standalone full op. Reviewed By: zonglinpeng, hsharma35 Differential Revision: D76469624
1 parent f211904 commit a6aceea

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,6 +2300,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23002300
return result
23012301

23022302

2303+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2304+
class ReplaceMulTensorWithMulAndFullOpsPass(ExportPass):
2305+
"""
2306+
Extracts a single value argument of mul op to a separate full op.
2307+
"""
2308+
2309+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2310+
for mul_node in graph_module.graph.find_nodes(
2311+
op="call_function", target=torch.ops.aten.mul.Tensor
2312+
):
2313+
x_arg, const_arg = mul_node.args
2314+
2315+
# Swap arguments if the order is wrong
2316+
if isinstance(const_arg, torch.fx.Node):
2317+
x_arg, const_arg = const_arg, x_arg
2318+
2319+
# Skip if the const_arg is not a scalar
2320+
if not isinstance(const_arg, (float, int)) or not isinstance(
2321+
x_arg, torch.fx.Node
2322+
):
2323+
continue
2324+
2325+
# Cast the const_arg to the dtype of the x_arg
2326+
full_arg = self.resolve_full_arg(x_arg, const_arg)
2327+
2328+
# Extract an argument to a separate full op.
2329+
with graph_module.graph.inserting_before(mul_node):
2330+
full_tensor = graph_module.graph.call_function(
2331+
exir_ops.edge.aten.full.default, args=([1], full_arg)
2332+
)
2333+
new_mul_node = graph_module.graph.call_function(
2334+
torch.ops.aten.mul.Tensor, args=(x_arg, full_tensor)
2335+
)
2336+
# Replace the old mul with a newly created mul.
2337+
mul_node.replace_all_uses_with(new_mul_node)
2338+
graph_module.graph.erase_node(mul_node)
2339+
return super().call(graph_module)
2340+
2341+
def resolve_full_arg(self, x_arg, const_arg):
2342+
if x_arg.meta["val"].dtype == torch.float32 and isinstance(const_arg, int):
2343+
const_arg = float(const_arg)
2344+
if x_arg.meta["val"].dtype == torch.int32 and isinstance(const_arg, float):
2345+
const_arg = int(const_arg)
2346+
return const_arg
2347+
2348+
23032349
# This class encapsulates all the functions that replace/switch one op in the
23042350
# graph with another.
23052351
class CadenceReplaceOpsInGraph:

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
GraphBuilder,
2121
single_op_builder,
2222
)
23-
from executorch.backends.cadence.aot.pass_utils import count_node
23+
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
2424
from executorch.backends.cadence.aot.replace_ops import (
2525
ForceChannelLastForConvPass,
2626
MakeSliceAndCatDimOutermostPass,
@@ -36,6 +36,7 @@
3636
ReplaceLinearWithFullyConnectedOpPass,
3737
ReplaceMatmulWithTransposedMatmulPass,
3838
ReplaceMMWithAddMMPass,
39+
ReplaceMulTensorWithMulAndFullOpsPass,
3940
ReplaceNopTransposeOrPermuteWithViewPass,
4041
ReplacePadWithCatPass,
4142
ReplacePermuteWithTransposePass,
@@ -1870,3 +1871,30 @@ def test_empty_slice(self):
18701871
),
18711872
1,
18721873
)
1874+
1875+
@parameterized.expand(
1876+
[
1877+
("int", int(123)),
1878+
("float", float(456.0)),
1879+
],
1880+
)
1881+
@torch.no_grad()
1882+
def test_extract_mul_argument_to_full(self, _, value) -> None:
1883+
x = torch.randn(2, 1, 64)
1884+
gm = single_op_builder(
1885+
placeholders=(x,),
1886+
op=torch.ops.aten.mul.Tensor,
1887+
args=(x, value),
1888+
kwargs={},
1889+
)
1890+
p = ReplaceMulTensorWithMulAndFullOpsPass()
1891+
graph_after_passes = p.call(gm).graph_module
1892+
self.assertTrue(
1893+
op_counts_match(
1894+
graph_after_passes,
1895+
expected_op_counts={
1896+
torch.ops.aten.mul.Tensor: 1,
1897+
exir_ops.edge.aten.full.default: 1,
1898+
},
1899+
)
1900+
)

0 commit comments

Comments
 (0)