Skip to content

Commit c2f2316

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Move the transpose matmul pass to OSS and run it earlier in the flow (#10433)
Summary: Pull Request resolved: #10433 That pass is doing a lot more than it looks, and it's just easier to move it back to where it was. CPU backends will possibly see more cycles due to added permutes, but we don't care about that. All DSP backends should be more efficient on transposed matmuls. Should that not be the case in the future, we can re-evaluate. When we do the survey of passes and reorder them properly, we can think about this more. Reviewed By: hsharma35 Differential Revision: D73600069
1 parent 6b877de commit c2f2316

File tree

2 files changed

+146
-1
lines changed

2 files changed

+146
-1
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
is_quantized_tensor,
3333
quantize_tensor_multiplier,
3434
)
35-
from executorch.backends.cadence.aot.fuse_ops import FuseCascadedViewOps
35+
from executorch.backends.cadence.aot.fuse_ops import (
36+
FuseCascadedTransposeOrPermuteOps,
37+
FuseCascadedViewOps,
38+
)
3639
from executorch.backends.cadence.aot.pass_utils import (
3740
CadencePassAttribute,
3841
register_cadence_pass,
@@ -2290,6 +2293,101 @@ def call_operator(
22902293
)
22912294

22922295

2296+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2297+
class ReplaceMatmulWithTransposedMatmulPass(ExportPass):
2298+
"""
2299+
For certain backends, we have efficient kernels for transposed matmul. We
2300+
replace AxB with AxB' for such backends.
2301+
"""
2302+
2303+
def call_operator(self, op, args, kwargs, meta):
2304+
if op != exir_ops.edge.cadence.quantized_matmul.default or args[-1] is True:
2305+
return super().call_operator(op, args, kwargs, meta)
2306+
2307+
# Get the args
2308+
if len(args) == 9:
2309+
(
2310+
X_arg,
2311+
X_zero_point,
2312+
Y_arg,
2313+
Y_zero_point,
2314+
bias,
2315+
out_multiplier,
2316+
out_shift,
2317+
out_zero_point,
2318+
transposed,
2319+
) = args
2320+
elif len(args) == 8:
2321+
(
2322+
X_arg,
2323+
X_zero_point,
2324+
Y_arg,
2325+
Y_zero_point,
2326+
bias,
2327+
out_multiplier,
2328+
out_shift,
2329+
out_zero_point,
2330+
) = args
2331+
transposed = False
2332+
else:
2333+
raise AssertionError(
2334+
f"Unexpected number of args for quantized_matmul: {len(args)}"
2335+
)
2336+
2337+
# If the matmul is already transposed, bail
2338+
if transposed:
2339+
return super().call_operator(op, args, kwargs, meta)
2340+
2341+
# Get the second tensor
2342+
Y_tensor = Y_arg.to_tensor() if isinstance(Y_arg, ProxyValue) else Y_arg
2343+
# Concretize the bias
2344+
zero_bias = super().call_operator(
2345+
exir_ops.edge.aten.full.default,
2346+
([Y_tensor.size(-1)], 0),
2347+
{"dtype": torch.int32},
2348+
meta,
2349+
)
2350+
2351+
# If the arg was a ProxyValue, insert a transpose node. Otherwise we
2352+
# can simply transpose the tensor inplace.
2353+
if isinstance(Y_arg, ProxyValue):
2354+
transpose_args = (Y_arg, -1, -2)
2355+
transpose_node = super().call_operator(
2356+
exir_ops.edge.aten.transpose_copy.int,
2357+
transpose_args,
2358+
{},
2359+
meta,
2360+
)
2361+
Y_arg_t = transpose_node
2362+
else:
2363+
Y_arg_t = Y_tensor.transpose(-1, -2)
2364+
2365+
# Construct the new args, and return the transposed matmult op
2366+
new_args = (
2367+
X_arg,
2368+
X_zero_point,
2369+
Y_arg_t,
2370+
Y_zero_point,
2371+
zero_bias,
2372+
out_multiplier,
2373+
out_shift,
2374+
out_zero_point,
2375+
True,
2376+
)
2377+
return super().call_operator(op, new_args, kwargs, meta)
2378+
2379+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2380+
result = super().call(graph_module)
2381+
# Fuse any inserted transpose node with transpose/permute nodes
2382+
# surrounding it.
2383+
result = FuseCascadedTransposeOrPermuteOps()(result.graph_module)
2384+
assert result is not None
2385+
# Replace permute with transpose.
2386+
result = ReplacePermuteWithTransposePass()(result.graph_module)
2387+
assert result is not None
2388+
return result
2389+
2390+
22932391
# This class encapsulates all the functions that replace/switch one op in the
22942392
# graph with another.
22952393
class CadenceReplaceOpsInGraph:
@@ -2317,6 +2415,7 @@ class CadenceReplaceOpsInGraph:
23172415
# This pass should be after passes that replace conv -> im2row + linear.
23182416
ReplaceIm2RowWithViewPass,
23192417
MakeSliceAndCatDimOutermostPass,
2418+
ReplaceMatmulWithTransposedMatmulPass,
23202419
ReplaceNopTransposeOrPermuteWithViewPass,
23212420
ReplaceLinearWithFullyConnectedOpPass,
23222421
ReplaceScalarTensorWithFullPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.cadence.aot.compiler import (
1717
export_to_edge,
1818
quantize_and_export_to_edge,
19+
quantize_pt2,
1920
)
2021
from executorch.backends.cadence.aot.graph_builder import (
2122
GraphBuilder,
@@ -35,6 +36,7 @@
3536
ReplaceGeluWithApproximateGeluPass,
3637
ReplaceIm2RowWithViewPass,
3738
ReplaceLinearWithFullyConnectedOpPass,
39+
ReplaceMatmulWithTransposedMatmulPass,
3840
ReplaceMMWithAddMMPass,
3941
ReplaceNopTransposeOrPermuteWithViewPass,
4042
ReplacePadWithCatPass,
@@ -85,6 +87,50 @@ def assertTargetCountsEqual(
8587
for target, expected_count in targets_and_counts:
8688
self.assertTargetCountEqual(graph_module, target, expected_count)
8789

90+
@parameterized.expand(
91+
[
92+
# Regular MM
93+
[(64, 33), (33, 128)],
94+
# Batched MM
95+
[(2, 48, 48), (2, 48, 48)],
96+
]
97+
)
98+
@torch.no_grad()
99+
def test_replace_matmul_with_transposed_matmul(
100+
self,
101+
x_shape: Tuple[int],
102+
y_shape: Tuple[int],
103+
) -> None:
104+
class MatMul(torch.nn.Module):
105+
def __init__(self) -> None:
106+
super(MatMul, self).__init__()
107+
108+
def forward(self, x, y):
109+
return torch.matmul(x, y)
110+
111+
model = MatMul()
112+
X = torch.randn(x_shape)
113+
Y = torch.randn(y_shape)
114+
p = ReplaceMatmulWithTransposedMatmulPass()
115+
inputs = (X, Y)
116+
quantized_model = quantize_pt2(model, inputs)
117+
graph_module = (
118+
export_to_edge(quantized_model, inputs).exported_program().graph_module
119+
)
120+
# pyre-fixme[16]: Optional type has no attribute `graph_module`
121+
graph_after_passes = p(graph_module).graph_module
122+
123+
self.assertEqual(
124+
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int),
125+
1,
126+
)
127+
self.assertEqual(
128+
count_node(
129+
graph_after_passes, exir_ops.edge.cadence.quantized_matmul.default
130+
),
131+
1,
132+
)
133+
88134
@parameterized.expand(
89135
[
90136
[(3, 5), (0, 0)],

0 commit comments

Comments
 (0)