Skip to content

Commit a4c3cd7

Browse files
committed
Revert "Arm backend: Merge passes that replace scalars (pytorch#15298)"
This reverts commit de56c81.
1 parent ddfa961 commit a4c3cd7

12 files changed

+37
-67
lines changed

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@
8888
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
8989
from .remove_noop_pass import RemoveNoopPass # noqa
9090
from .replace_scalar_with_tensor_pass import ( # noqa
91-
ReplaceScalarWithTensorByProfilePass,
91+
ReplaceScalarWithTensorArgPassTOSABI,
92+
ReplaceScalarWithTensorArgPassTOSAMI,
9293
)
9394
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
9495
from .rewrite_matmul import RewriteMatmulPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@
8787
QuantizeOperatorArguments,
8888
RemoveNoopPass,
8989
ReplaceInfValues,
90-
ReplaceScalarWithTensorByProfilePass,
90+
ReplaceScalarWithTensorArgPassTOSABI,
91+
ReplaceScalarWithTensorArgPassTOSAMI,
9192
RetraceFoldedDtypesPass,
9293
RewriteConv2dPass,
9394
RewriteMatmulPass,
@@ -171,7 +172,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
171172
self.add_pass(CastToInt32Pass())
172173

173174
self.add_pass(CastBoolToInt8Pass())
174-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
175+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
175176
self.add_pass(AnnotateDecomposedMatmulPass())
176177
self.add_pass(QuantizeOperatorArguments())
177178
self.add_pass(ConvertELUParamsPass())
@@ -241,7 +242,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
241242
self.add_pass(DecomposeSinhPass())
242243
self.add_pass(DecomposeSignPass())
243244
self.add_pass(DecomposeDivTensorModePass())
244-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
245+
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
245246
self.add_pass(DecomposeEmbeddingPass())
246247
self.add_pass(FuseQuantizedActivationPass())
247248
self.add_pass(RemoveGetItemPass())
@@ -334,7 +335,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
334335
self.add_pass(DecomposeAddmmPass())
335336
self.add_pass(DecomposeDivTensorModePass())
336337
self.add_pass(DecomposeAddSubAlphaPass())
337-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
338+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
338339
self.add_pass(ScalarsToAttributePass())
339340
self.add_pass(DecomposeGroupNormPass())
340341
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_acosh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorByProfilePass,
15+
ReplaceScalarWithTensorArgPassTOSAMI,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -32,7 +32,7 @@ class DecomposeAcoshPass(ArmPass):
3232
DecomposeSqrtPass,
3333
InsertTableOpsPass,
3434
MatchArgRanksPass,
35-
ReplaceScalarWithTensorByProfilePass,
35+
ReplaceScalarWithTensorArgPassTOSAMI,
3636
MatchArgDtypePass,
3737
}
3838

backends/arm/_passes/decompose_asin_and_acos_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
2020
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
2121
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
22-
ReplaceScalarWithTensorByProfilePass,
22+
ReplaceScalarWithTensorArgPassTOSAMI,
2323
)
2424
from executorch.exir.dialects._ops import ops as exir_ops
2525
from executorch.exir.pass_base import ExportPass
@@ -71,7 +71,7 @@ class DecomposeAsinAndAcosPass(ArmPass):
7171
ConvertFullLikeToFullPass,
7272
MatchArgRanksPass,
7373
MatchArgDtypePass,
74-
ReplaceScalarWithTensorByProfilePass,
74+
ReplaceScalarWithTensorArgPassTOSAMI,
7575
}
7676

7777
def _build_polynomial(

backends/arm/_passes/decompose_asinh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorByProfilePass,
15+
ReplaceScalarWithTensorArgPassTOSAMI,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -32,7 +32,7 @@ class DecomposeAsinhPass(ArmPass):
3232
DecomposeSqrtPass,
3333
InsertTableOpsPass,
3434
MatchArgRanksPass,
35-
ReplaceScalarWithTensorByProfilePass,
35+
ReplaceScalarWithTensorArgPassTOSAMI,
3636
MatchArgDtypePass,
3737
}
3838

backends/arm/_passes/decompose_atan_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorByProfilePass,
15+
ReplaceScalarWithTensorArgPassTOSAMI,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -47,7 +47,7 @@ class DecomposeAtanPass(ArmPass):
4747
InsertTableOpsPass,
4848
MatchArgRanksPass,
4949
MatchArgDtypePass,
50-
ReplaceScalarWithTensorByProfilePass,
50+
ReplaceScalarWithTensorArgPassTOSAMI,
5151
}
5252

5353
def _rational_approximation(self, z, ops, meta):

backends/arm/_passes/decompose_atanh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1111
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1212
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
13-
ReplaceScalarWithTensorByProfilePass,
13+
ReplaceScalarWithTensorArgPassTOSAMI,
1414
)
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass
@@ -43,7 +43,7 @@ class DecomposeAtanhPass(ArmPass):
4343
InsertTableOpsPass,
4444
MatchArgRanksPass,
4545
MatchArgDtypePass,
46-
ReplaceScalarWithTensorByProfilePass,
46+
ReplaceScalarWithTensorArgPassTOSAMI,
4747
}
4848

4949
def call_operator(self, op, args, kwargs, meta):

backends/arm/_passes/decompose_cosh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1111
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1212
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
13-
ReplaceScalarWithTensorByProfilePass,
13+
ReplaceScalarWithTensorArgPassTOSAMI,
1414
)
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass
@@ -31,7 +31,7 @@ class DecomposeCoshPass(ArmPass):
3131
_passes_required_after: Set[Type[ExportPass]] = {
3232
InsertTableOpsPass,
3333
MatchArgRanksPass,
34-
ReplaceScalarWithTensorByProfilePass,
34+
ReplaceScalarWithTensorArgPassTOSAMI,
3535
MatchArgDtypePass,
3636
}
3737

backends/arm/_passes/decompose_expm1_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorByProfilePass,
15+
ReplaceScalarWithTensorArgPassTOSAMI,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -83,7 +83,7 @@ class DecomposeExpm1Pass(ArmPass):
8383
ConvertIntPowToMuls,
8484
InsertTableOpsPass,
8585
DecomposeDivPass,
86-
ReplaceScalarWithTensorByProfilePass,
86+
ReplaceScalarWithTensorArgPassTOSAMI,
8787
MatchArgDtypePass,
8888
MatchArgRanksPass,
8989
}

backends/arm/_passes/decompose_logit_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorByProfilePass,
15+
ReplaceScalarWithTensorArgPassTOSAMI,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -73,7 +73,7 @@ class DecomposeLogitPass(ArmPass):
7373
InsertTableOpsPass,
7474
MatchArgRanksPass,
7575
MatchArgDtypePass,
76-
ReplaceScalarWithTensorByProfilePass,
76+
ReplaceScalarWithTensorArgPassTOSAMI,
7777
}
7878

7979
def call_operator(self, op, args, kwargs, meta):

0 commit comments

Comments
 (0)