Skip to content

Commit 69f79b9

Browse files
authored
Arm backend: Use reshape instead of view before edge (#15269)
The view operator can't handle non-contigious strides, such as the result of an expand. These are normalized after to_edge, but in the transform_for_annotation_pipeline we shouldn't use views for that reason. Reshape is the equivalent operator that can handle such strides. Signed-off-by: Erik Lundell <[email protected]>
1 parent c66078c commit 69f79b9

File tree

7 files changed

+32
-10
lines changed

7 files changed

+32
-10
lines changed

backends/arm/_passes/decompose_embedding_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class DecomposeEmbeddingPass(ArmPass):
4242
def get_decomposition(self, op):
4343
if op in self.aten_ops:
4444
return (
45-
torch.ops.aten.view_copy.default,
45+
torch.ops.aten.reshape.default,
4646
torch.ops.aten.index_select.default,
4747
)
4848

backends/arm/_passes/decompose_groupnorm_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_group_norm_decomposition(op) -> tuple:
3939
torch.ops.aten.add.Tensor,
4040
torch.ops.aten.rsqrt.default,
4141
torch.ops.aten.mul.Tensor,
42-
torch.ops.aten.view_copy.default,
42+
torch.ops.aten.reshape.default,
4343
)
4444
raise RuntimeError(f"Can't get group_norm composition for op {op}")
4545

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_layer_norm_decomposition(op) -> tuple:
3939
torch.ops.aten.add.Tensor,
4040
torch.ops.aten.rsqrt.default,
4141
torch.ops.aten.mul.Tensor,
42-
torch.ops.aten.view_copy.default,
42+
torch.ops.aten.reshape.default,
4343
)
4444
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
4545

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_view(op):
4646
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
4747
return exir_ops.edge.aten.view_copy.default
4848
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
49-
return torch.ops.aten.view_copy.default
49+
return torch.ops.aten.reshape.default
5050
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5151

5252

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _get_sum_decomp(op):
1919
exir_ops.edge.aten.sum.dim_IntList,
2020
)
2121
case torch.ops.aten.sum.dim_IntList:
22-
return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList)
22+
return (torch.ops.aten.reshape.default, torch.ops.aten.sum.dim_IntList)
2323
case _:
2424
raise RuntimeError("Unvalid op in DecomposeSumPass")
2525

backends/arm/test/ops/test_embedding.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@ def forward(self, weights: torch.Tensor, indices: torch.Tensor):
2727
return torch.embedding(weights, indices)
2828

2929

30-
input_params = Tuple[torch.Tensor, torch.Tensor, torch.dtype]
30+
class ExpandEmbedding(Embedding):
31+
example_inputs = (torch.randn(10, 3), torch.tensor([[1, 2, 3]], dtype=torch.int32))
32+
33+
def forward(self, weights: torch.Tensor, indices: torch.Tensor):
34+
return torch.embedding(weights, indices.expand(2, 3))
35+
36+
37+
input_params = Tuple[torch.Tensor, torch.Tensor]
3138

3239

33-
test_input: dict[input_params] = {
40+
test_input: dict[str, input_params] = {
3441
"test_1": (
3542
torch.randn(10, 3),
3643
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32),
@@ -89,6 +96,21 @@ def test_embedding_tosa_INT(test_input: input_params):
8996
pipeline.run()
9097

9198

99+
def test_expand_embedding_tosa_INT():
100+
op = ExpandEmbedding()
101+
pipeline = TosaPipelineINT(
102+
op,
103+
ExpandEmbedding.example_inputs,
104+
ExpandEmbedding.aten_op,
105+
ExpandEmbedding.exir_op,
106+
use_to_edge_transform_and_lower=True,
107+
)
108+
pipeline.pop_stage("check.aten")
109+
pipeline.pop_stage("check_count.exir")
110+
111+
pipeline.run()
112+
113+
92114
@pytest.mark.skip("reason=MLETORCH-1274 Improve data type checks during partitioning")
93115
@common.parametrize("test_input", test_input)
94116
@common.SkipIfNoModelConverter

backends/arm/test/passes/test_decompose_meandim_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class MeanDim(torch.nn.Module):
2828
}
2929

3030
ops_not_after_pass = u55_ops_not_after_pass = [
31-
"torch.ops.aten.view_copy.default",
31+
"torch.ops.aten.reshape.default",
3232
"torch.ops.aten.avg_pool2d.default",
3333
"torch.ops.aten.mean.dim",
3434
]
@@ -52,7 +52,7 @@ class MeanDimTensor(torch.nn.Module):
5252
"torch.ops.aten.sum.dim_IntList": 2,
5353
"torch.ops.aten.mul.Tensor": 1,
5454
"torch.ops.aten.avg_pool2d.default": 1,
55-
"torch.ops.aten.view_copy.default": 1,
55+
"torch.ops.aten.reshape.default": 1,
5656
}
5757

5858
ops_not_after_pass = [
@@ -62,7 +62,7 @@ class MeanDimTensor(torch.nn.Module):
6262
u55_ops_after_pass = {
6363
"torch.ops.aten.sum.dim_IntList": 2,
6464
"torch.ops.aten.mul.Tensor": 1,
65-
"torch.ops.aten.view_copy.default": 1,
65+
"torch.ops.aten.reshape.default": 1,
6666
}
6767

6868
u55_ops_not_after_pass = [

0 commit comments

Comments
 (0)