Skip to content

Commit 902e53f

Browse files
committed
Arm Backend: Add support for split_copy.default
Signed-off-by: Agrima Khare <[email protected]> Change-Id: I9320188aceb5778de9649f414bf9fc7fc062b407
1 parent ce6e2cf commit 902e53f

File tree

4 files changed

+105
-17
lines changed

4 files changed

+105
-17
lines changed

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,24 @@ def call(self, graph_module: torch.fx.GraphModule):
4646
dim = (dim + rank) % rank
4747

4848
# Validate that split lengths cover the entire dimension
49-
length_sum = sum(split_lengths)
49+
5050
dim_size = shape[dim]
51-
if length_sum != dim_size:
52-
raise ValueError(
53-
f"Split sizes {split_lengths} sum to {length_sum}, "
54-
f"but dimension {dim} has size {dim_size}"
55-
)
51+
if isinstance(split_lengths, int):
52+
if split_lengths <= 0:
53+
raise ValueError(
54+
f"Split size must be positive, got {split_lengths}"
55+
)
56+
full_chunks, remainder = divmod(dim_size, split_lengths)
57+
split_lengths = [split_lengths] * full_chunks
58+
if remainder:
59+
split_lengths.append(remainder)
60+
else:
61+
length_sum = sum(split_lengths)
62+
if length_sum != dim_size:
63+
raise ValueError(
64+
f"Split sizes {split_lengths} sum to {length_sum}, "
65+
f"but dimension {dim} has size {dim_size}"
66+
)
5667

5768
# Convert split argument 'split_lengths' to slice arguments start and end.
5869
starts = [0] * len(split_lengths)

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
exir_ops.edge.aten.log.default,
5757
exir_ops.edge.aten.linear.default,
5858
exir_ops.edge.aten.split_with_sizes_copy.default,
59+
exir_ops.edge.aten.split_copy.Tensor,
5960
exir_ops.edge.aten.floor.default,
6061
exir_ops.edge.aten.full.default,
6162
exir_ops.edge.aten.full_like.default,
@@ -172,6 +173,7 @@
172173
exir_ops.edge.aten.log.default,
173174
exir_ops.edge.aten.linear.default,
174175
exir_ops.edge.aten.split_with_sizes_copy.default,
176+
exir_ops.edge.aten.split_copy.Tensor,
175177
exir_ops.edge.aten.floor.default,
176178
exir_ops.edge.aten.full.default,
177179
exir_ops.edge.aten.full_like.default,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def _match_pattern(
323323
torch.ops.aten.slice_copy.Tensor,
324324
torch.ops.aten.split.Tensor,
325325
torch.ops.aten.split_with_sizes.default,
326+
torch.ops.aten.split_copy.Tensor,
326327
torch.ops.aten.transpose.Dimname,
327328
torch.ops.aten.transpose.int,
328329
torch.ops.aten.transpose_copy.int,

backends/arm/test/ops/test_split.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323

2424
class Split(torch.nn.Module):
25-
2625
test_data = {
2726
"split_1d_2_size_0_dim": lambda: (torch.rand(10), 2, 0),
2827
"split_2d_3_size_1_dim": lambda: (torch.rand(10, 10), 3, 1),
@@ -60,12 +59,24 @@ def forward(
6059
return x.split(split_size=split_size_or_sections, dim=dim)[1:3]
6160

6261

62+
class SplitCopy(torch.nn.Module):
63+
aten_op = "torch.ops.aten.split_copy.Tensor"
64+
exir_op = "executorch_exir_dialects_edge__ops_aten_split_copy_Tensor"
65+
66+
def forward(
67+
self,
68+
x: torch.Tensor,
69+
split_size: int,
70+
dim: int,
71+
):
72+
return torch.split_copy(x, split_size=split_size, dim=dim)
73+
74+
6375
@common.parametrize(
6476
"test_data",
6577
(Split.test_data | Split.test_data_list),
6678
)
6779
def test_split_with_sizes_tosa_FP(test_data: input_t1):
68-
6980
pipeline = TosaPipelineFP[input_t1](
7081
Split(),
7182
test_data(),
@@ -77,7 +88,6 @@ def test_split_with_sizes_tosa_FP(test_data: input_t1):
7788

7889
@common.parametrize("test_data", Split.test_data_list)
7990
def test_split_with_sizes_tosa_FP_2(test_data: input_t1):
80-
8191
pipeline = TosaPipelineFP[input_t1](
8292
SplitWithSizes(),
8393
test_data(),
@@ -92,7 +102,6 @@ def test_split_with_sizes_tosa_FP_2(test_data: input_t1):
92102
(Split.test_data | Split.test_data_list),
93103
)
94104
def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1):
95-
96105
pipeline = TosaPipelineFP[input_t1](
97106
SplitSingleOut(),
98107
test_data(),
@@ -107,7 +116,6 @@ def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1):
107116
(Split.test_data | Split.test_data_list),
108117
)
109118
def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1):
110-
111119
pipeline = TosaPipelineFP[input_t1](
112120
SplitTwoOut(),
113121
test_data(),
@@ -122,7 +130,6 @@ def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1):
122130
(Split.test_data | Split.test_data_list),
123131
)
124132
def test_split_with_sizes_tosa_INT(test_data: input_t1):
125-
126133
pipeline = TosaPipelineINT[input_t1](
127134
Split(),
128135
test_data(),
@@ -152,7 +159,6 @@ def test_split_with_sizes_u55_INT(test_data: input_t1):
152159
(Split.test_data | Split.test_data_list),
153160
)
154161
def test_split_with_sizes_u85_INT(test_data: input_t1):
155-
156162
pipeline = EthosU85PipelineINT[input_t1](
157163
Split(),
158164
test_data(),
@@ -182,7 +188,6 @@ def test_split_with_sizes_vgf_FP(test_data: input_t1):
182188
@common.parametrize("test_data", Split.test_data_list)
183189
@common.SkipIfNoModelConverter
184190
def test_split_with_sizes_vgf_FP_2(test_data: input_t1):
185-
186191
pipeline = VgfPipeline[input_t1](
187192
SplitWithSizes(),
188193
test_data(),
@@ -199,7 +204,6 @@ def test_split_with_sizes_vgf_FP_2(test_data: input_t1):
199204
)
200205
@common.SkipIfNoModelConverter
201206
def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1):
202-
203207
pipeline = VgfPipeline[input_t1](
204208
SplitSingleOut(),
205209
test_data(),
@@ -216,7 +220,6 @@ def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1):
216220
)
217221
@common.SkipIfNoModelConverter
218222
def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1):
219-
220223
pipeline = VgfPipeline[input_t1](
221224
SplitTwoOut(),
222225
test_data(),
@@ -233,7 +236,6 @@ def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1):
233236
)
234237
@common.SkipIfNoModelConverter
235238
def test_split_with_sizes_vgf_INT(test_data: input_t1):
236-
237239
pipeline = VgfPipeline[input_t1](
238240
Split(),
239241
test_data(),
@@ -242,3 +244,75 @@ def test_split_with_sizes_vgf_INT(test_data: input_t1):
242244
tosa_version="TOSA-1.0+INT",
243245
)
244246
pipeline.run()
247+
248+
249+
@common.parametrize("test_data", Split.test_data)
250+
def test_split_tensor_tosa_FP(test_data: Tuple):
251+
pipeline = TosaPipelineFP[input_t1](
252+
SplitCopy(),
253+
test_data(),
254+
aten_op=SplitCopy.aten_op,
255+
exir_op=SplitCopy.exir_op,
256+
)
257+
pipeline.run()
258+
259+
260+
@common.parametrize("test_data", Split.test_data)
261+
def test_split_tensor_tosa_INT(test_data: Tuple):
262+
pipeline = TosaPipelineINT[input_t1](
263+
SplitCopy(),
264+
test_data(),
265+
aten_op=SplitCopy.aten_op,
266+
exir_op=SplitCopy.exir_op,
267+
)
268+
pipeline.run()
269+
270+
271+
@common.XfailIfNoCorstone300
272+
@common.parametrize("test_data", Split.test_data)
273+
def test_split_tensor_u55_INT(test_data: Tuple):
274+
pipeline = EthosU55PipelineINT[input_t1](
275+
SplitCopy(),
276+
test_data(),
277+
aten_ops=SplitCopy.aten_op,
278+
exir_ops=SplitCopy.exir_op,
279+
)
280+
pipeline.run()
281+
282+
283+
@common.XfailIfNoCorstone320
284+
@common.parametrize("test_data", Split.test_data)
285+
def test_split_tensor_u85_INT(test_data: Tuple):
286+
pipeline = EthosU85PipelineINT[input_t1](
287+
SplitCopy(),
288+
test_data(),
289+
aten_ops=SplitCopy.aten_op,
290+
exir_ops=SplitCopy.exir_op,
291+
)
292+
pipeline.run()
293+
294+
295+
@common.parametrize("test_data", Split.test_data)
296+
@common.SkipIfNoModelConverter
297+
def test_split_tensor_vgf_FP(test_data: Tuple):
298+
pipeline = VgfPipeline[input_t1](
299+
SplitCopy(),
300+
test_data(),
301+
aten_op=SplitCopy.aten_op,
302+
exir_op=SplitCopy.exir_op,
303+
tosa_version="TOSA-1.0+FP",
304+
)
305+
pipeline.run()
306+
307+
308+
@common.parametrize("test_data", Split.test_data)
309+
@common.SkipIfNoModelConverter
310+
def test_split_tensor_vgf_INT(test_data: Tuple):
311+
pipeline = VgfPipeline[input_t1](
312+
SplitCopy(),
313+
test_data(),
314+
aten_op=SplitCopy.aten_op,
315+
exir_op=SplitCopy.exir_op,
316+
tosa_version="TOSA-1.0+INT",
317+
)
318+
pipeline.run()

0 commit comments

Comments
 (0)