Skip to content

Commit 16f7f7a

Browse files
committed
Revert "Arm backend: Use reshape instead of view before edge (pytorch#15269)"
This reverts commit 69f79b9.
1 parent a4c3cd7 commit 16f7f7a

File tree

7 files changed

+10
-32
lines changed

7 files changed

+10
-32
lines changed

backends/arm/_passes/decompose_embedding_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class DecomposeEmbeddingPass(ArmPass):
4242
def get_decomposition(self, op):
4343
if op in self.aten_ops:
4444
return (
45-
torch.ops.aten.reshape.default,
45+
torch.ops.aten.view_copy.default,
4646
torch.ops.aten.index_select.default,
4747
)
4848

backends/arm/_passes/decompose_groupnorm_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_group_norm_decomposition(op) -> tuple:
3939
torch.ops.aten.add.Tensor,
4040
torch.ops.aten.rsqrt.default,
4141
torch.ops.aten.mul.Tensor,
42-
torch.ops.aten.reshape.default,
42+
torch.ops.aten.view_copy.default,
4343
)
4444
raise RuntimeError(f"Can't get group_norm composition for op {op}")
4545

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_layer_norm_decomposition(op) -> tuple:
3939
torch.ops.aten.add.Tensor,
4040
torch.ops.aten.rsqrt.default,
4141
torch.ops.aten.mul.Tensor,
42-
torch.ops.aten.reshape.default,
42+
torch.ops.aten.view_copy.default,
4343
)
4444
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
4545

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_view(op):
4646
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
4747
return exir_ops.edge.aten.view_copy.default
4848
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
49-
return torch.ops.aten.reshape.default
49+
return torch.ops.aten.view_copy.default
5050
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5151

5252

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _get_sum_decomp(op):
1919
exir_ops.edge.aten.sum.dim_IntList,
2020
)
2121
case torch.ops.aten.sum.dim_IntList:
22-
return (torch.ops.aten.reshape.default, torch.ops.aten.sum.dim_IntList)
22+
return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList)
2323
case _:
2424
raise RuntimeError("Unvalid op in DecomposeSumPass")
2525

backends/arm/test/ops/test_embedding.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,10 @@ def forward(self, weights: torch.Tensor, indices: torch.Tensor):
2727
return torch.embedding(weights, indices)
2828

2929

30-
class ExpandEmbedding(Embedding):
31-
example_inputs = (torch.randn(10, 3), torch.tensor([[1, 2, 3]], dtype=torch.int32))
32-
33-
def forward(self, weights: torch.Tensor, indices: torch.Tensor):
34-
return torch.embedding(weights, indices.expand(2, 3))
35-
36-
37-
input_params = Tuple[torch.Tensor, torch.Tensor]
30+
input_params = Tuple[torch.Tensor, torch.Tensor, torch.dtype]
3831

3932

40-
test_input: dict[str, input_params] = {
33+
test_input: dict[input_params] = {
4134
"test_1": (
4235
torch.randn(10, 3),
4336
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32),
@@ -96,21 +89,6 @@ def test_embedding_tosa_INT(test_input: input_params):
9689
pipeline.run()
9790

9891

99-
def test_expand_embedding_tosa_INT():
100-
op = ExpandEmbedding()
101-
pipeline = TosaPipelineINT(
102-
op,
103-
ExpandEmbedding.example_inputs,
104-
ExpandEmbedding.aten_op,
105-
ExpandEmbedding.exir_op,
106-
use_to_edge_transform_and_lower=True,
107-
)
108-
pipeline.pop_stage("check.aten")
109-
pipeline.pop_stage("check_count.exir")
110-
111-
pipeline.run()
112-
113-
11492
@pytest.mark.skip("reason=MLETORCH-1274 Improve data type checks during partitioning")
11593
@common.parametrize("test_input", test_input)
11694
@common.SkipIfNoModelConverter

backends/arm/test/passes/test_decompose_meandim_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class MeanDim(torch.nn.Module):
2828
}
2929

3030
ops_not_after_pass = u55_ops_not_after_pass = [
31-
"torch.ops.aten.reshape.default",
31+
"torch.ops.aten.view_copy.default",
3232
"torch.ops.aten.avg_pool2d.default",
3333
"torch.ops.aten.mean.dim",
3434
]
@@ -52,7 +52,7 @@ class MeanDimTensor(torch.nn.Module):
5252
"torch.ops.aten.sum.dim_IntList": 2,
5353
"torch.ops.aten.mul.Tensor": 1,
5454
"torch.ops.aten.avg_pool2d.default": 1,
55-
"torch.ops.aten.reshape.default": 1,
55+
"torch.ops.aten.view_copy.default": 1,
5656
}
5757

5858
ops_not_after_pass = [
@@ -62,7 +62,7 @@ class MeanDimTensor(torch.nn.Module):
6262
u55_ops_after_pass = {
6363
"torch.ops.aten.sum.dim_IntList": 2,
6464
"torch.ops.aten.mul.Tensor": 1,
65-
"torch.ops.aten.reshape.default": 1,
65+
"torch.ops.aten.view_copy.default": 1,
6666
}
6767

6868
u55_ops_not_after_pass = [

0 commit comments

Comments
 (0)