Skip to content

Commit 8c3b621

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Move the transpose matmul pass to OSS and run it earlier in the flow (#10433)
Summary: That pass is doing a lot more than it looks, and it's just easier to move it back to where it was. CPU backends will possibly see more cycles due to added permutes, but we don't care about that. All DSP backends should be more efficient on transposed matmuls. Should that not be the case in the future, we can re-evaluate. When we do the survey of passes and reorder them properly, we can think about this more. Reviewed By: hsharma35 Differential Revision: D73600069
1 parent 6b877de commit 8c3b621

File tree

2 files changed

+145
-1
lines changed

2 files changed

+145
-1
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
is_quantized_tensor,
3333
quantize_tensor_multiplier,
3434
)
35-
from executorch.backends.cadence.aot.fuse_ops import FuseCascadedViewOps
35+
from executorch.backends.cadence.aot.fuse_ops import (
36+
FuseCascadedTransposeOrPermuteOps,
37+
FuseCascadedViewOps,
38+
)
3639
from executorch.backends.cadence.aot.pass_utils import (
3740
CadencePassAttribute,
3841
register_cadence_pass,
@@ -2290,6 +2293,101 @@ def call_operator(
22902293
)
22912294

22922295

2296+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2297+
class ReplaceMatmulWithTransposedMatmulPass(ExportPass):
2298+
"""
2299+
For certain backends, we have efficient kernels for transposed matmul. We
2300+
replace AxB with AxB' for such backends.
2301+
"""
2302+
2303+
def call_operator(self, op, args, kwargs, meta):
2304+
if op != exir_ops.edge.cadence.quantized_matmul.default or args[-1] is True:
2305+
return super().call_operator(op, args, kwargs, meta)
2306+
2307+
# Get the args
2308+
if len(args) == 9:
2309+
(
2310+
X_arg,
2311+
X_zero_point,
2312+
Y_arg,
2313+
Y_zero_point,
2314+
bias,
2315+
out_multiplier,
2316+
out_shift,
2317+
out_zero_point,
2318+
transposed,
2319+
) = args
2320+
elif len(args) == 8:
2321+
(
2322+
X_arg,
2323+
X_zero_point,
2324+
Y_arg,
2325+
Y_zero_point,
2326+
bias,
2327+
out_multiplier,
2328+
out_shift,
2329+
out_zero_point,
2330+
) = args
2331+
transposed = False
2332+
else:
2333+
raise AssertionError(
2334+
f"Unexpected number of args for quantized_matmul: {len(args)}"
2335+
)
2336+
2337+
# If the matmul is already transposed, bail
2338+
if transposed:
2339+
return super().call_operator(op, args, kwargs, meta)
2340+
2341+
# Get the second tensor
2342+
Y_tensor = Y_arg.to_tensor() if isinstance(Y_arg, ProxyValue) else Y_arg
2343+
# Concretize the bias
2344+
zero_bias = super().call_operator(
2345+
exir_ops.edge.aten.full.default,
2346+
([Y_tensor.size(-1)], 0),
2347+
{"dtype": torch.int32},
2348+
meta,
2349+
)
2350+
2351+
# If the arg was a ProxyValue, insert a transpose node. Otherwise we
2352+
# can simply transpose the tensor inplace.
2353+
if isinstance(Y_arg, ProxyValue):
2354+
transpose_args = (Y_arg, -1, -2)
2355+
transpose_node = super().call_operator(
2356+
exir_ops.edge.aten.transpose_copy.int,
2357+
transpose_args,
2358+
{},
2359+
meta,
2360+
)
2361+
Y_arg_t = transpose_node
2362+
else:
2363+
Y_arg_t = Y_tensor.transpose(-1, -2)
2364+
2365+
# Construct the new args, and return the transposed matmult op
2366+
new_args = (
2367+
X_arg,
2368+
X_zero_point,
2369+
Y_arg_t,
2370+
Y_zero_point,
2371+
zero_bias,
2372+
out_multiplier,
2373+
out_shift,
2374+
out_zero_point,
2375+
True,
2376+
)
2377+
return super().call_operator(op, new_args, kwargs, meta)
2378+
2379+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2380+
result = super().call(graph_module)
2381+
# Fuse any inserted transpose node with transpose/permute nodes
2382+
# surrounding it.
2383+
result = FuseCascadedTransposeOrPermuteOps()(result.graph_module)
2384+
assert result is not None
2385+
# Replace permute with transpose.
2386+
result = ReplacePermuteWithTransposePass()(result.graph_module)
2387+
assert result is not None
2388+
return result
2389+
2390+
22932391
# This class encapsulates all the functions that replace/switch one op in the
22942392
# graph with another.
22952393
class CadenceReplaceOpsInGraph:
@@ -2317,6 +2415,7 @@ class CadenceReplaceOpsInGraph:
23172415
# This pass should be after passes that replace conv -> im2row + linear.
23182416
ReplaceIm2RowWithViewPass,
23192417
MakeSliceAndCatDimOutermostPass,
2418+
ReplaceMatmulWithTransposedMatmulPass,
23202419
ReplaceNopTransposeOrPermuteWithViewPass,
23212420
ReplaceLinearWithFullyConnectedOpPass,
23222421
ReplaceScalarTensorWithFullPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ReplaceGeluWithApproximateGeluPass,
3636
ReplaceIm2RowWithViewPass,
3737
ReplaceLinearWithFullyConnectedOpPass,
38+
ReplaceMatmulWithTransposedMatmulPass,
3839
ReplaceMMWithAddMMPass,
3940
ReplaceNopTransposeOrPermuteWithViewPass,
4041
ReplacePadWithCatPass,
@@ -85,6 +86,50 @@ def assertTargetCountsEqual(
8586
for target, expected_count in targets_and_counts:
8687
self.assertTargetCountEqual(graph_module, target, expected_count)
8788

89+
@parameterized.expand(
90+
[
91+
# Regular MM
92+
[(64, 33), (33, 128)],
93+
# Batched MM
94+
[(2, 48, 48), (2, 48, 48)],
95+
]
96+
)
97+
@torch.no_grad()
98+
def test_replace_matmul_with_transposed_matmul(
99+
self,
100+
x_shape: Tuple[int],
101+
y_shape: Tuple[int],
102+
) -> None:
103+
class MatMul(torch.nn.Module):
104+
def __init__(self) -> None:
105+
super(MatMul, self).__init__()
106+
107+
def forward(self, x, y):
108+
return torch.matmul(x, y)
109+
110+
model = MatMul()
111+
X = torch.randn(x_shape)
112+
Y = torch.randn(y_shape)
113+
p = ReplaceMatmulWithTransposedMatmulPass()
114+
inputs = (X, Y)
115+
quantized_model = quantize_pt2(model, inputs)
116+
graph_module = (
117+
export_to_edge(quantized_model, inputs).exported_program().graph_module
118+
)
119+
# pyre-fixme[16]: Optional type has no attribute `graph_module`
120+
graph_after_passes = p(graph_module).graph_module
121+
122+
self.assertEqual(
123+
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int),
124+
1,
125+
)
126+
self.assertEqual(
127+
count_node(
128+
graph_after_passes, exir_ops.edge.cadence.quantized_matmul.default
129+
),
130+
1,
131+
)
132+
88133
@parameterized.expand(
89134
[
90135
[(3, 5), (0, 0)],

0 commit comments

Comments
 (0)