Skip to content

Commit f10d483

Browse files
Arm backend: Do not partition view_copy
aten::view_copy is now considered no-compute by the partitioner. This prevents the case where a partition is left empty if two view_copy ops cancels each other out, leaving the partition empty. This was the case when running torch_audio_hdemucs_high_musdb(_plus), which made Vela output an error. Additionaly, some incorrectly named tests in test_sum have been renamed from view to sum JIRA: MLETORCH-1535 Change-Id: I1fe5175fd560486a064258dee5f0048836022ee1
1 parent 12d17ef commit f10d483

File tree

5 files changed

+36
-14
lines changed

5 files changed

+36
-14
lines changed

backends/arm/test/ops/test_mean_dim.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,12 @@ def forward(self, tensor: torch.Tensor, keepdim: bool):
356356
return tensor.mean()
357357

358358
test_data_suite: dict[str, Callable[[], mean_input_t]] = {
359-
"rank1": lambda: (
360-
torch.rand(
361-
1,
362-
),
359+
"rank_2": lambda: (
360+
torch.rand(1, 2),
363361
False,
364362
),
365-
"rank2": lambda: (torch.rand(5, 5), True),
366-
"rank4": lambda: (torch.rand(5, 1, 10, 1), False),
363+
"rank_2_keepdim": lambda: (torch.rand(5, 5), True),
364+
"rank_4": lambda: (torch.rand(5, 1, 10, 1), False),
367365
}
368366

369367

backends/arm/test/ops/test_sum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_sum_dim_intlist_tosa_INT(test_data: input_t1):
6666

6767
@common.parametrize("test_data", Sum.test_parameters)
6868
@common.XfailIfNoCorstone300
69-
def test_view_u55_INT_1_0(test_data: Tuple):
69+
def test_sum_u55_INT_1_0(test_data: Tuple):
7070
pipeline = EthosU55PipelineINT[input_t1](
7171
Sum(),
7272
test_data(),
@@ -78,7 +78,7 @@ def test_view_u55_INT_1_0(test_data: Tuple):
7878

7979
@common.parametrize("test_data", Sum.test_parameters)
8080
@common.XfailIfNoCorstone320
81-
def test_view_u85_INT_1_0(test_data: Tuple):
81+
def test_sum_u85_INT_1_0(test_data: Tuple):
8282
pipeline = EthosU85PipelineINT[input_t1](
8383
Sum(),
8484
test_data(),
@@ -122,7 +122,7 @@ def test_sum_dim_intlist_vgf_INT(test_data: input_t1):
122122

123123

124124
@common.parametrize("test_data", reject_inputs)
125-
def test_view_u55_INT_failure_set(test_data: Tuple):
125+
def test_sum_u55_INT_failure_set(test_data: Tuple):
126126
pipeline = EthosU55PipelineINT[input_t1](
127127
Sum(),
128128
test_data(),

backends/arm/test/ops/test_unflatten.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def __init__(self, dim: int, sizes: Tuple[int, ...]):
2929
self.sizes = sizes
3030

3131
def forward(self, x: torch.Tensor) -> torch.Tensor:
32-
return torch.unflatten(x, self.dim, self.sizes)
32+
unflatten_op = torch.unflatten(x, self.dim, self.sizes)
33+
# Because we treat a single view as a no compute operation and therefore do not partition it,
34+
# we want to provide a mul op to verify that it does indeed get partitioned when bundled with another op.
35+
return unflatten_op * unflatten_op
3336

3437
test_data: dict[str, test_data_t] = {
3538
"rand_3d_batch3": (lambda: (Unflatten(1, (-1, 2)), (torch.rand(3, 4, 4),))),

backends/arm/test/ops/test_view.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def __init__(self, new_shape):
6565
self.new_shape = new_shape
6666

6767
def forward(self, x: torch.Tensor):
68-
return x.view(self.new_shape)
68+
view_op = x.view(self.new_shape)
69+
# Because we treat a single view as a no compute operation and therefore do not partition it,
70+
# we want to provide a mul op to verify that it does indeed get partitioned when bundled with another op.
71+
return view_op * view_op
6972

7073

7174
@common.parametrize("test_data", View.needs_transpose_tests)
@@ -139,7 +142,7 @@ def test_view_u55_INT_not_delegated(test_data: Tuple):
139142
View(new_shape),
140143
(test_tensor,),
141144
{"executorch_exir_dialects_edge__ops_aten_view_copy": 1},
142-
n_expected_delegates=0,
145+
n_expected_delegates=1,
143146
quantize=True,
144147
u55_subset=True,
145148
)

backends/arm/tosa/partitioner.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,22 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
115115
return all(m == 1 for m in multiples) and not changes_rank
116116

117117

118+
def is_view_copy(node: torch.fx.node.Node) -> bool:
119+
"""Return True if node is a ``view_copy``.
120+
121+
view_copy can be regarded as a no-compute op.
122+
123+
Args:
124+
node (torch.fx.Node): FX node to inspect.
125+
126+
Returns:
127+
bool: True if the node targets ``aten.view_copy.default``; otherwise,
128+
False.
129+
130+
"""
131+
return node.target == exir_ops.edge.aten.view_copy.default
132+
133+
118134
def is_partitioned(
119135
node: torch.fx.Node,
120136
tag: str,
@@ -267,16 +283,18 @@ def _tag_module( # noqa
267283
del node.meta["delegation_tag"]
268284
break
269285

270-
is_noop_partition = all(
286+
# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
287+
is_nocompute_partition = all(
271288
is_noop_clone(node)
272289
or is_noop_alias_copy(node)
273290
or is_noop_expand(node)
274291
or is_noop_to_dim_order_copy(node)
292+
or is_view_copy(node)
275293
or node.target in Q_OPS
276294
or node.target in DQ_OPS
277295
for node in partition.nodes
278296
)
279-
if is_noop_partition:
297+
if is_nocompute_partition:
280298
reject_partition(
281299
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
282300
partition,

0 commit comments

Comments
 (0)