Skip to content

Commit 7741e70

Browse files
Arm backend: Merge FP and INT pass pipelines (#15518)
Merges FP and INT pass pipelines into one pipeline. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 6ab8723 commit 7741e70

24 files changed

+167
-113
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6464
)
6565
matmul_targets = {
6666
exir_ops.edge.aten.bmm.default,
67+
exir_ops.edge.aten.mm.default,
6768
}
6869
for partition in matmul_partitions:
6970
quantized_input = all(

backends/arm/_passes/arm_pass_manager.py

Lines changed: 22 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -155,79 +155,26 @@ def _transform(self, graph_module: GraphModule):
155155
with TosaLoweringContext(self.tosa_spec):
156156
return self(graph_module).graph_module
157157

158-
def _tosa_INT_pipeline(
158+
def _tosa_pipeline(
159159
self, exported_program: ExportedProgram, graph_module: GraphModule
160160
) -> GraphModule:
161161
self.add_pass(AnnotateOutputDimOrderPass())
162162
self.add_pass(FuseQuantizedActivationPass())
163163
self.add_pass(RemoveGetItemPass())
164-
self.add_pass(ConvertSplitToSlicePass())
165-
self.add_pass(ConvertMmToBmmPass())
166-
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
167-
self.add_pass(ConvertFullLikeToFullPass())
168164
self.add_pass(ConvertToClampPass())
169-
self.add_pass(ConvertMinMaxPass())
170-
self.add_pass(ConvertAnyDefaultDimDimsPass())
171-
self.add_pass(MatchArgDtypePass())
172-
if self.tosa_spec.is_U55_subset:
173-
self.add_pass(CastToInt32Pass())
174-
175-
self.add_pass(CastBoolToInt8Pass())
176-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
165+
self.add_pass(DecomposeGroupNormPass())
166+
self.add_pass(DecomposeLayerNormPass())
167+
self.add_pass(DecomposeBatchNormNoStatsPass())
168+
self.add_pass(DecomposeVarPass())
169+
self.add_pass(
170+
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
171+
)
177172
self.add_pass(AnnotateDecomposedMatmulPass())
178-
self.add_pass(QuantizeOperatorArguments())
179173
self.add_pass(ConvertELUParamsPass())
174+
self.add_pass(ConvertSplitToSlicePass())
175+
self.add_pass(QuantizeOperatorArguments())
180176
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
181177
self.add_pass(FuseDuplicateUsersPass())
182-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
183-
self.add_pass(MatchArgRanksPass(exported_program))
184-
if self.tosa_spec.is_U55_subset:
185-
self.add_pass(BroadcastArgsPass())
186-
self.add_pass(DecomposeLinearPass())
187-
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
188-
self.add_pass(DecomposeAvgPool2d())
189-
self.add_pass(ComputeConstantOpsAOT(exported_program))
190-
191-
self.add_pass(DecomposeGroupedConv())
192-
193-
self.add_pass(ConvertExpandCopyToRepeatPass())
194-
self.add_pass(UnsqueezeBeforeRepeatPass())
195-
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
196-
self.add_pass(DecomposeCumsumPass(exported_program))
197-
self.add_pass(Conv1dUnsqueezePass())
198-
self.add_pass(DecomposeMaxPool2DPass())
199-
self.add_pass(SizeAdjustInputPass())
200-
self.add_pass(DecomposeSelectPass())
201-
self.add_pass(ConvertSqueezesToViewPass())
202-
203-
self.add_pass(FuseViewCopyTransform())
204-
self.add_pass(FuseConstantArgsPass(exported_program))
205-
self.add_pass(InsertTableOpsPass(exported_program))
206-
# If we have a conv2d with int16 activation split up into a convolution
207-
# and an addition, to work-around the lack of support for int48 in torch
208-
# needs to happen before RewriteConv2dPass, but after the table ops are inserted
209-
# to be able to validate that conv2d has right dtype arguments.
210-
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
211-
self.add_pass(RewriteConv2dPass(exported_program))
212-
213-
self.add_pass(RewriteMatmulPass())
214-
self.add_pass(RewriteUpsamplePass())
215-
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
216-
217-
self.add_pass(InsertRescaleInt32Pass())
218-
self.add_pass(DecomposeSumPass())
219-
self.add_pass(ToTosaMemoryFormatPass(exported_program))
220-
self.add_pass(RemoveNoopPass())
221-
self.add_pass(InsertRescalePass())
222-
223-
self.validate_constraints_mandatory()
224-
return self._transform(graph_module)
225-
226-
def _tosa_FP_pipeline(
227-
self, exported_program: ExportedProgram, graph_module: GraphModule
228-
) -> GraphModule:
229-
self.add_pass(AnnotateOutputDimOrderPass())
230-
self.add_pass(FuseDuplicateUsersPass())
231178
self.add_pass(DecomposeExpm1Pass())
232179
self.add_pass(DecomposeLogitPass())
233180
self.add_pass(DecomposeMaskedFill())
@@ -252,32 +199,20 @@ def _tosa_FP_pipeline(
252199
self.add_pass(DecomposeRemainderPass())
253200
self.add_pass(DecomposeDivTensorModePass())
254201
self.add_pass(DecomposeEmbeddingPass())
255-
self.add_pass(FuseQuantizedActivationPass())
256-
self.add_pass(RemoveGetItemPass())
257-
self.add_pass(ConvertSplitToSlicePass())
258202
self.add_pass(FuseBatchnorm2DPass(exported_program))
259203
self.add_pass(ConvertMmToBmmPass())
260204
self.add_pass(DecomposeGluPass())
261205
self.add_pass(DecomposeLinearPass())
262206
self.add_pass(DecomposeLeakyReLUPass())
263-
self.add_pass(DecomposeGroupNormPass())
264-
self.add_pass(DecomposeLayerNormPass())
265-
self.add_pass(DecomposeBatchNormNoStatsPass())
266-
self.add_pass(DecomposeVarPass())
267-
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
268207
self.add_pass(DecomposeNotEqualPass())
269208
self.add_pass(DecomposeDivPass())
270209
self.add_pass(DecomposeAddSubAlphaPass())
271210
self.add_pass(DecomposeSoftmaxPass())
272211
self.add_pass(DecomposeGeluPass())
273212
self.add_pass(ConvertFullLikeToFullPass())
274-
self.add_pass(ConvertToClampPass())
275213
self.add_pass(ConvertMinMaxPass())
276214
self.add_pass(ConvertAnyDefaultDimDimsPass())
277215
self.add_pass(MatchArgDtypePass())
278-
self.add_pass(AnnotateDecomposedMatmulPass())
279-
self.add_pass(QuantizeOperatorArguments())
280-
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
281216
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
282217
self.add_pass(MatchArgRanksPass(exported_program))
283218
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
@@ -290,22 +225,26 @@ def _tosa_FP_pipeline(
290225
self.add_pass(DecomposeGroupedConv())
291226
self.add_pass(ConvertExpandCopyToRepeatPass())
292227
self.add_pass(UnsqueezeBeforeRepeatPass())
293-
self.add_pass(DecomposeSumPass())
294228
self.add_pass(DecomposeCumsumPass(exported_program))
295229
self.add_pass(Conv1dUnsqueezePass())
296230
self.add_pass(DecomposeMaxPool2DPass())
297231
self.add_pass(SizeAdjustInputPass())
298232
self.add_pass(DecomposeSelectPass())
299233
self.add_pass(ConvertSqueezesToViewPass())
234+
self.add_pass(CastToInt32Pass())
235+
self.add_pass(BroadcastArgsPass())
300236

301237
self.add_pass(FuseViewCopyTransform())
302238
self.add_pass(FuseConstantArgsPass(exported_program))
303-
self.add_pass(RewriteConv2dPass(exported_program))
239+
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
304240
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
305-
self.add_pass(RewriteUpsamplePass())
306241
self.add_pass(InsertTableOpsPass(exported_program))
242+
self.add_pass(RewriteUpsamplePass())
243+
self.add_pass(RewriteConv2dPass(exported_program))
307244
self.add_pass(RewriteMatmulPass())
308245
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
246+
self.add_pass(InsertRescaleInt32Pass())
247+
self.add_pass(DecomposeSumPass())
309248
self.add_pass(ToTosaMemoryFormatPass(exported_program))
310249
self.add_pass(RemoveNoopPass())
311250
self.add_pass(InsertRescalePass())
@@ -317,10 +256,11 @@ def transform_to_backend_pipeline(
317256
self, exported_program: ExportedProgram, graph_module: GraphModule
318257
):
319258
"""Apply passes before transforming program to backend"""
320-
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
321-
return self._tosa_FP_pipeline(exported_program, graph_module)
322-
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
323-
return self._tosa_INT_pipeline(exported_program, graph_module)
259+
if self.tosa_spec in (
260+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
261+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
262+
):
263+
return self._tosa_pipeline(exported_program, graph_module)
324264
else:
325265
raise NotImplementedError(
326266
f"No pass pipeline implemented for {self.tosa_spec=}"

backends/arm/_passes/broadcast_args_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_node,
1212
get_first_fake_tensor,
1313
)
14+
from executorch.backends.arm.tosa.specification import get_context_spec
1415

1516
from executorch.exir.dialects._ops import ops as exir_ops
1617

@@ -34,6 +35,9 @@ class BroadcastArgsPass(ArmPass):
3435
}
3536

3637
def call(self, graph_module: GraphModule) -> PassResult:
38+
tosa_spec = get_context_spec()
39+
if not tosa_spec.is_U55_subset:
40+
return PassResult(graph_module, False)
3741
for node in graph_module.graph.nodes:
3842
if node.op != "call_function" or node.target not in self.targeted_ops:
3943
continue

backends/arm/_passes/cast_to_int32_pass.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import torch
99

1010
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
12+
from executorch.backends.arm.tosa.specification import get_context_spec
1113
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import ExportPass
14+
from executorch.exir.pass_base import ExportPass, PassResult
1315

1416

1517
class CastToInt32Pass(ArmPass):
@@ -22,6 +24,12 @@ class CastToInt32Pass(ArmPass):
2224
exir_ops.edge.aten.bitwise_right_shift.Tensor,
2325
}
2426

27+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
28+
tosa_spec = get_context_spec()
29+
if not tosa_spec.is_U55_subset:
30+
return PassResult(graph_module, False)
31+
return super().call(graph_module)
32+
2533
def call_operator(self, op, args, kwargs, meta):
2634
if op not in self.targeted_ops:
2735
return super().call_operator(op, args, kwargs, meta)

backends/arm/_passes/convert_elu_params.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from executorch.backends.arm._passes import ArmPass
1010
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.backends.arm.constants import DQ_OPS
1112
from executorch.exir.dialects._ops import ops as exir_ops
1213
from executorch.exir.pass_base import ExportPass, PassResult
1314

@@ -30,6 +31,12 @@ def call(self, graph_module: torch.fx.GraphModule):
3031
op="call_function", target=exir_ops.edge.aten.elu.default
3132
)
3233
for node in node_list:
34+
input_node = node.all_input_nodes[0]
35+
is_quantized = (
36+
input_node.op == "call_function" and input_node.target in DQ_OPS
37+
)
38+
if not is_quantized:
39+
continue
3340
with graph.inserting_after(node):
3441
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
3542
old_args = list(node.args)

backends/arm/_passes/convert_int_pow_to_mul.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ def call_operator(self, op, args, kwargs, meta):
2424
if op != exir_ops.edge.aten.pow.Tensor_Scalar:
2525
return super().call_operator(op, args, kwargs, meta)
2626

27+
is_quantized = (
28+
len(meta.data.get("input_qparams", {})) > 0
29+
and len(meta.data.get("output_qparams", {})) > 0
30+
)
31+
if is_quantized:
32+
# If quantized, node should be replace by table op
33+
return super().call_operator(op, args, kwargs, meta)
34+
2735
x = args[0]
2836
exp = args[1]
2937

backends/arm/_passes/decompose_acosh_pass.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
4141
if op is not edge_acosh_op:
4242
return super().call_operator(op, args, kwargs, meta, updated)
4343

44+
is_quantized = (
45+
len(meta.data.get("input_qparams", {})) > 0
46+
and len(meta.data.get("output_qparams", {})) > 0
47+
)
48+
if is_quantized:
49+
# If quantized, node should be replace by table op
50+
return super().call_operator(op, args, kwargs, meta, updated)
51+
4452
log_op, sqrt_op, mul_op, sub_op, add_op, add_op_scalar = (
4553
exir_ops.edge.aten.log.default,
4654
exir_ops.edge.aten.sqrt.default,

backends/arm/_passes/decompose_asin_and_acos_pass.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ def _combine_branches(
123123
def call_operator(self, op, args, kwargs, meta):
124124
if op not in (edge_asin_op + edge_acos_op):
125125
return super().call_operator(op, args, kwargs, meta)
126+
127+
is_quantized = (
128+
len(meta.data.get("input_qparams", {})) > 0
129+
and len(meta.data.get("output_qparams", {})) > 0
130+
)
131+
if is_quantized:
132+
# If quantized, node should be replace by table op
133+
return super().call_operator(op, args, kwargs, meta)
134+
126135
logging.info(
127136
f"Approximating {op}. This may introduce small numerical errors. For details, see {__file__}."
128137
)

backends/arm/_passes/decompose_asinh_pass.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def call_operator(self, op, args, kwargs, meta):
4040
if op not in edge_asinh_op:
4141
return super().call_operator(op, args, kwargs, meta)
4242

43+
is_quantized = (
44+
len(meta.data.get("input_qparams", {})) > 0
45+
and len(meta.data.get("output_qparams", {})) > 0
46+
)
47+
if is_quantized:
48+
# If quantized, node should be replace by table op
49+
return super().call_operator(op, args, kwargs, meta)
50+
4351
log_op, sqrt_op, mul_op, add_op_scalar, add_op = (
4452
exir_ops.edge.aten.log.default,
4553
exir_ops.edge.aten.sqrt.default,

backends/arm/_passes/decompose_atan_pass.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ def call_operator(self, op, args, kwargs, meta):
8080
if op is not edge_atan:
8181
return super().call_operator(op, args, kwargs, meta, updated=False)
8282

83+
is_quantized = (
84+
len(meta.data.get("input_qparams", {})) > 0
85+
and len(meta.data.get("output_qparams", {})) > 0
86+
)
87+
if is_quantized:
88+
# If quantized, node should be replace by table op
89+
return super().call_operator(op, args, kwargs, meta)
90+
8391
logging.info(
8492
f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}."
8593
)

0 commit comments

Comments
 (0)