Skip to content

Commit c69bac4

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Merge passes that replace scalars
Besides having obsolete names, ReplaceScalarWithTensorPassTOSAMI and ReplaceScalarWithTensorPassTOSABI was causing difficulties for defining chronological pass dependencies with the `_passes_required_after` attribute. This is because, with the current design, there is no way to distinguish which profile is referred to when defining `_passes_required_after`; a pass simply declares its chronological dependencies globally. Solve this by merging the two pass classes together into one called `ReplaceScalarWithTensorByProfilePass`. This means that a pass should include this new pass in `_passes_required_after` no matter which TOSA profile its working towards. Signed-off-by: Martin Lindström <[email protected]> Change-Id: I00c0e3658426fb5c1da996eeee89093a395630db
1 parent 5689785 commit c69bac4

12 files changed

+67
-37
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@
8888
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
8989
from .remove_noop_pass import RemoveNoopPass # noqa
9090
from .replace_scalar_with_tensor_pass import ( # noqa
91-
ReplaceScalarWithTensorArgPassTOSABI,
92-
ReplaceScalarWithTensorArgPassTOSAMI,
91+
ReplaceScalarWithTensorByProfilePass,
9392
)
9493
from .rewrite_upsample import RewriteUpsamplePass # noqa
9594
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@
8888
QuantizeOperatorArguments,
8989
RemoveNoopPass,
9090
ReplaceInfValues,
91-
ReplaceScalarWithTensorArgPassTOSABI,
92-
ReplaceScalarWithTensorArgPassTOSAMI,
91+
ReplaceScalarWithTensorByProfilePass,
9392
RetraceFoldedDtypesPass,
9493
RewriteUpsamplePass,
9594
ScalarsToAttributePass,
@@ -171,7 +170,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
171170
self.add_pass(CastToInt32Pass())
172171

173172
self.add_pass(CastBoolToInt8Pass())
174-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
173+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
175174
self.add_pass(AnnotateDecomposedMatmulPass())
176175
self.add_pass(QuantizeOperatorArguments())
177176
self.add_pass(ConvertELUParamsPass())
@@ -239,7 +238,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
239238
self.add_pass(DecomposeSinhPass())
240239
self.add_pass(DecomposeSignPass())
241240
self.add_pass(DecomposeDivTensorModePass())
242-
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
241+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
243242
self.add_pass(DecomposeEmbeddingPass())
244243
self.add_pass(FuseQuantizedActivationPass())
245244
self.add_pass(RemoveGetItemPass())
@@ -329,7 +328,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
329328
self.add_pass(DecomposeSignPass())
330329
self.add_pass(DecomposeAddmmPass())
331330
self.add_pass(DecomposeDivTensorModePass())
332-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
331+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
333332
self.add_pass(ScalarsToAttributePass())
334333
self.add_pass(DecomposeGroupNormPass())
335334
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_acosh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1414
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1515
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
16-
ReplaceScalarWithTensorArgPassTOSAMI,
16+
ReplaceScalarWithTensorByProfilePass,
1717
)
1818
from executorch.exir.dialects._ops import ops as exir_ops
1919
from executorch.exir.pass_base import ExportPass
@@ -33,7 +33,7 @@ class DecomposeAcoshPass(ArmPass):
3333
DecomposeSqrtPass,
3434
InsertTableOpsPass,
3535
MatchArgRanksPass,
36-
ReplaceScalarWithTensorArgPassTOSAMI,
36+
ReplaceScalarWithTensorByProfilePass,
3737
MatchArgDtypePass,
3838
}
3939

backends/arm/_passes/decompose_asin_and_acos_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
2121
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
2222
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
23-
ReplaceScalarWithTensorArgPassTOSAMI,
23+
ReplaceScalarWithTensorByProfilePass,
2424
)
2525
from executorch.exir.dialects._ops import ops as exir_ops
2626
from executorch.exir.pass_base import ExportPass
@@ -72,7 +72,7 @@ class DecomposeAsinAndAcosPass(ArmPass):
7272
ConvertFullLikeToFullPass,
7373
MatchArgRanksPass,
7474
MatchArgDtypePass,
75-
ReplaceScalarWithTensorArgPassTOSAMI,
75+
ReplaceScalarWithTensorByProfilePass,
7676
}
7777

7878
def _build_polynomial(

backends/arm/_passes/decompose_asinh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1515
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1616
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
17-
ReplaceScalarWithTensorArgPassTOSAMI,
17+
ReplaceScalarWithTensorByProfilePass,
1818
)
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.pass_base import ExportPass
@@ -34,7 +34,7 @@ class DecomposeAsinhPass(ArmPass):
3434
DecomposeSqrtPass,
3535
InsertTableOpsPass,
3636
MatchArgRanksPass,
37-
ReplaceScalarWithTensorArgPassTOSAMI,
37+
ReplaceScalarWithTensorByProfilePass,
3838
MatchArgDtypePass,
3939
}
4040

backends/arm/_passes/decompose_atan_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -47,7 +47,7 @@ class DecomposeAtanPass(ArmPass):
4747
InsertTableOpsPass,
4848
MatchArgRanksPass,
4949
MatchArgDtypePass,
50-
ReplaceScalarWithTensorArgPassTOSAMI,
50+
ReplaceScalarWithTensorByProfilePass,
5151
}
5252

5353
def _rational_approximation(self, z, ops, meta):

backends/arm/_passes/decompose_atanh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1111
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1212
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
13-
ReplaceScalarWithTensorArgPassTOSAMI,
13+
ReplaceScalarWithTensorByProfilePass,
1414
)
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass
@@ -43,7 +43,7 @@ class DecomposeAtanhPass(ArmPass):
4343
InsertTableOpsPass,
4444
MatchArgRanksPass,
4545
MatchArgDtypePass,
46-
ReplaceScalarWithTensorArgPassTOSAMI,
46+
ReplaceScalarWithTensorByProfilePass,
4747
}
4848

4949
def call_operator(self, op, args, kwargs, meta):

backends/arm/_passes/decompose_cosh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1111
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1212
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
13-
ReplaceScalarWithTensorArgPassTOSAMI,
13+
ReplaceScalarWithTensorByProfilePass,
1414
)
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass
@@ -31,7 +31,7 @@ class DecomposeCoshPass(ArmPass):
3131
_passes_required_after: Set[Type[ExportPass]] = {
3232
InsertTableOpsPass,
3333
MatchArgRanksPass,
34-
ReplaceScalarWithTensorArgPassTOSAMI,
34+
ReplaceScalarWithTensorByProfilePass,
3535
MatchArgDtypePass,
3636
}
3737

backends/arm/_passes/decompose_expm1_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -83,7 +83,7 @@ class DecomposeExpm1Pass(ArmPass):
8383
ConvertIntPowToMuls,
8484
InsertTableOpsPass,
8585
DecomposeDivPass,
86-
ReplaceScalarWithTensorArgPassTOSAMI,
86+
ReplaceScalarWithTensorByProfilePass,
8787
MatchArgDtypePass,
8888
MatchArgRanksPass,
8989
}

backends/arm/_passes/decompose_logit_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -73,7 +73,7 @@ class DecomposeLogitPass(ArmPass):
7373
InsertTableOpsPass,
7474
MatchArgRanksPass,
7575
MatchArgDtypePass,
76-
ReplaceScalarWithTensorArgPassTOSAMI,
76+
ReplaceScalarWithTensorByProfilePass,
7777
}
7878

7979
def call_operator(self, op, args, kwargs, meta):

0 commit comments

Comments
 (0)