Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,24 @@ def call(self, graph_module: torch.fx.GraphModule):
dim = (dim + rank) % rank

# Validate that split lengths cover the entire dimension
length_sum = sum(split_lengths)

dim_size = shape[dim]
if length_sum != dim_size:
raise ValueError(
f"Split sizes {split_lengths} sum to {length_sum}, "
f"but dimension {dim} has size {dim_size}"
)
if isinstance(split_lengths, int):
if split_lengths <= 0:
raise ValueError(
f"Split size must be positive, got {split_lengths}"
)
full_chunks, remainder = divmod(dim_size, split_lengths)
split_lengths = [split_lengths] * full_chunks
if remainder:
split_lengths.append(remainder)
else:
length_sum = sum(split_lengths)
if length_sum != dim_size:
raise ValueError(
f"Split sizes {split_lengths} sum to {length_sum}, "
f"but dimension {dim} has size {dim_size}"
)

# Convert split argument 'split_lengths' to slice arguments start and end.
starts = [0] * len(split_lengths)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
Expand Down Expand Up @@ -152,6 +153,7 @@
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def _match_pattern(
torch.ops.aten.slice_copy.Tensor,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes.default,
torch.ops.aten.split_copy.Tensor,
torch.ops.aten.transpose.Dimname,
torch.ops.aten.transpose.int,
torch.ops.aten.transpose_copy.int,
Expand Down
96 changes: 85 additions & 11 deletions backends/arm/test/ops/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@


class Split(torch.nn.Module):

test_data = {
"split_1d_2_size_0_dim": lambda: (torch.rand(10), 2, 0),
"split_2d_3_size_1_dim": lambda: (torch.rand(10, 10), 3, 1),
Expand Down Expand Up @@ -60,12 +59,24 @@ def forward(
return x.split(split_size=split_size_or_sections, dim=dim)[1:3]


class SplitCopy(torch.nn.Module):
aten_op = "torch.ops.aten.split_copy.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_split_copy_Tensor"

def forward(
self,
x: torch.Tensor,
split_size: int,
dim: int,
):
return torch.split_copy(x, split_size=split_size, dim=dim)


@common.parametrize(
"test_data",
(Split.test_data | Split.test_data_list),
)
def test_split_with_sizes_tosa_FP(test_data: input_t1):

pipeline = TosaPipelineFP[input_t1](
Split(),
test_data(),
Expand All @@ -77,7 +88,6 @@ def test_split_with_sizes_tosa_FP(test_data: input_t1):

@common.parametrize("test_data", Split.test_data_list)
def test_split_with_sizes_tosa_FP_2(test_data: input_t1):

pipeline = TosaPipelineFP[input_t1](
SplitWithSizes(),
test_data(),
Expand All @@ -92,7 +102,6 @@ def test_split_with_sizes_tosa_FP_2(test_data: input_t1):
(Split.test_data | Split.test_data_list),
)
def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1):

pipeline = TosaPipelineFP[input_t1](
SplitSingleOut(),
test_data(),
Expand All @@ -107,7 +116,6 @@ def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1):
(Split.test_data | Split.test_data_list),
)
def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1):

pipeline = TosaPipelineFP[input_t1](
SplitTwoOut(),
test_data(),
Expand All @@ -122,7 +130,6 @@ def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1):
(Split.test_data | Split.test_data_list),
)
def test_split_with_sizes_tosa_INT(test_data: input_t1):

pipeline = TosaPipelineINT[input_t1](
Split(),
test_data(),
Expand Down Expand Up @@ -161,7 +168,6 @@ def test_split_with_sizes_u55_INT(test_data: input_t1):
)
@common.XfailIfNoCorstone320
def test_split_with_sizes_u85_INT(test_data: input_t1):

pipeline = EthosU85PipelineINT[input_t1](
Split(),
test_data(),
Expand Down Expand Up @@ -190,7 +196,6 @@ def test_split_with_sizes_vgf_FP(test_data: input_t1):
@common.parametrize("test_data", Split.test_data_list)
@common.SkipIfNoModelConverter
def test_split_with_sizes_vgf_FP_2(test_data: input_t1):

pipeline = VgfPipeline[input_t1](
SplitWithSizes(),
test_data(),
Expand All @@ -207,7 +212,6 @@ def test_split_with_sizes_vgf_FP_2(test_data: input_t1):
)
@common.SkipIfNoModelConverter
def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1):

pipeline = VgfPipeline[input_t1](
SplitSingleOut(),
test_data(),
Expand All @@ -224,7 +228,6 @@ def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1):
)
@common.SkipIfNoModelConverter
def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1):

pipeline = VgfPipeline[input_t1](
SplitTwoOut(),
test_data(),
Expand All @@ -241,7 +244,6 @@ def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1):
)
@common.SkipIfNoModelConverter
def test_split_with_sizes_vgf_INT(test_data: input_t1):

pipeline = VgfPipeline[input_t1](
Split(),
test_data(),
Expand All @@ -250,3 +252,75 @@ def test_split_with_sizes_vgf_INT(test_data: input_t1):
tosa_version="TOSA-1.0+INT",
)
pipeline.run()


@common.parametrize("test_data", Split.test_data)
def test_split_tensor_tosa_FP(test_data: Tuple):
pipeline = TosaPipelineFP[input_t1](
SplitCopy(),
test_data(),
aten_op=SplitCopy.aten_op,
exir_op=SplitCopy.exir_op,
)
pipeline.run()


@common.parametrize("test_data", Split.test_data)
def test_split_tensor_tosa_INT(test_data: Tuple):
pipeline = TosaPipelineINT[input_t1](
SplitCopy(),
test_data(),
aten_op=SplitCopy.aten_op,
exir_op=SplitCopy.exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone300
@common.parametrize("test_data", Split.test_data)
def test_split_tensor_u55_INT(test_data: Tuple):
pipeline = EthosU55PipelineINT[input_t1](
SplitCopy(),
test_data(),
aten_ops=SplitCopy.aten_op,
exir_ops=SplitCopy.exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("test_data", Split.test_data)
def test_split_tensor_u85_INT(test_data: Tuple):
pipeline = EthosU85PipelineINT[input_t1](
SplitCopy(),
test_data(),
aten_ops=SplitCopy.aten_op,
exir_ops=SplitCopy.exir_op,
)
pipeline.run()


@common.parametrize("test_data", Split.test_data)
@common.SkipIfNoModelConverter
def test_split_tensor_vgf_FP(test_data: Tuple):
pipeline = VgfPipeline[input_t1](
SplitCopy(),
test_data(),
aten_op=SplitCopy.aten_op,
exir_op=SplitCopy.exir_op,
tosa_version="TOSA-1.0+FP",
)
pipeline.run()


@common.parametrize("test_data", Split.test_data)
@common.SkipIfNoModelConverter
def test_split_tensor_vgf_INT(test_data: Tuple):
pipeline = VgfPipeline[input_t1](
SplitCopy(),
test_data(),
aten_op=SplitCopy.aten_op,
exir_op=SplitCopy.exir_op,
tosa_version="TOSA-1.0+INT",
)
pipeline.run()
Loading