diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 7578c07ca53..2cce0315c12 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -46,13 +46,24 @@ def call(self, graph_module: torch.fx.GraphModule): dim = (dim + rank) % rank # Validate that split lengths cover the entire dimension - length_sum = sum(split_lengths) + dim_size = shape[dim] - if length_sum != dim_size: - raise ValueError( - f"Split sizes {split_lengths} sum to {length_sum}, " - f"but dimension {dim} has size {dim_size}" - ) + if isinstance(split_lengths, int): + if split_lengths <= 0: + raise ValueError( + f"Split size must be positive, got {split_lengths}" + ) + full_chunks, remainder = divmod(dim_size, split_lengths) + split_lengths = [split_lengths] * full_chunks + if remainder: + split_lengths.append(remainder) + else: + length_sum = sum(split_lengths) + if length_sum != dim_size: + raise ValueError( + f"Split sizes {split_lengths} sum to {length_sum}, " + f"but dimension {dim} has size {dim_size}" + ) # Convert split argument 'split_lengths' to slice arguments start and end. starts = [0] * len(split_lengths) diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index 86db2d9b0b6..ee61aa4cce6 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -55,6 +55,7 @@ exir_ops.edge.aten.log.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split_copy.Tensor, exir_ops.edge.aten.floor.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, @@ -152,6 +153,7 @@ exir_ops.edge.aten.log.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split_copy.Tensor, exir_ops.edge.aten.floor.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 349aa3e6b21..7ed4f5b06ff 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -330,6 +330,7 @@ def _match_pattern( torch.ops.aten.slice_copy.Tensor, torch.ops.aten.split.Tensor, torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_copy.Tensor, torch.ops.aten.transpose.Dimname, torch.ops.aten.transpose.int, torch.ops.aten.transpose_copy.int, diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py index 284c142a34e..6c4d87e9de2 100644 --- a/backends/arm/test/ops/test_split.py +++ b/backends/arm/test/ops/test_split.py @@ -22,7 +22,6 @@ class Split(torch.nn.Module): - test_data = { "split_1d_2_size_0_dim": lambda: (torch.rand(10), 2, 0), "split_2d_3_size_1_dim": lambda: (torch.rand(10, 10), 3, 1), @@ -60,12 +59,24 @@ def forward( return x.split(split_size=split_size_or_sections, dim=dim)[1:3] +class SplitCopy(torch.nn.Module): + aten_op = "torch.ops.aten.split_copy.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_split_copy_Tensor" + + def forward( + self, + x: torch.Tensor, + split_size: int, + dim: int, + ): + return torch.split_copy(x, split_size=split_size, dim=dim) + + @common.parametrize( "test_data", (Split.test_data | Split.test_data_list), ) def test_split_with_sizes_tosa_FP(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( Split(), test_data(), @@ -77,7 +88,6 @@ def test_split_with_sizes_tosa_FP(test_data: input_t1): @common.parametrize("test_data", Split.test_data_list) def test_split_with_sizes_tosa_FP_2(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( SplitWithSizes(), test_data(), @@ -92,7 +102,6 @@ def test_split_with_sizes_tosa_FP_2(test_data: input_t1): (Split.test_data | Split.test_data_list), ) def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( SplitSingleOut(), test_data(), @@ -107,7 +116,6 @@ def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1): (Split.test_data | Split.test_data_list), ) def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( SplitTwoOut(), test_data(), @@ -122,7 +130,6 @@ def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1): (Split.test_data | Split.test_data_list), ) def test_split_with_sizes_tosa_INT(test_data: input_t1): - pipeline = TosaPipelineINT[input_t1]( Split(), test_data(), @@ -161,7 +168,6 @@ def test_split_with_sizes_u55_INT(test_data: input_t1): ) @common.XfailIfNoCorstone320 def test_split_with_sizes_u85_INT(test_data: input_t1): - pipeline = EthosU85PipelineINT[input_t1]( Split(), test_data(), @@ -190,7 +196,6 @@ def test_split_with_sizes_vgf_FP(test_data: input_t1): @common.parametrize("test_data", Split.test_data_list) @common.SkipIfNoModelConverter def test_split_with_sizes_vgf_FP_2(test_data: input_t1): - pipeline = VgfPipeline[input_t1]( SplitWithSizes(), test_data(), @@ -207,7 +212,6 @@ def test_split_with_sizes_vgf_FP_2(test_data: input_t1): ) @common.SkipIfNoModelConverter def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1): - pipeline = VgfPipeline[input_t1]( SplitSingleOut(), test_data(), @@ -224,7 +228,6 @@ def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1): ) @common.SkipIfNoModelConverter def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1): - pipeline = VgfPipeline[input_t1]( SplitTwoOut(), test_data(), @@ -241,7 +244,6 @@ def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1): ) @common.SkipIfNoModelConverter def test_split_with_sizes_vgf_INT(test_data: input_t1): - pipeline = VgfPipeline[input_t1]( Split(), test_data(), @@ -250,3 +252,75 @@ def test_split_with_sizes_vgf_INT(test_data: input_t1): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +@common.parametrize("test_data", Split.test_data) +def test_split_tensor_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t1]( + SplitCopy(), + test_data(), + aten_op=SplitCopy.aten_op, + exir_op=SplitCopy.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", Split.test_data) +def test_split_tensor_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + SplitCopy(), + test_data(), + aten_op=SplitCopy.aten_op, + exir_op=SplitCopy.exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", Split.test_data) +def test_split_tensor_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + SplitCopy(), + test_data(), + aten_ops=SplitCopy.aten_op, + exir_ops=SplitCopy.exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", Split.test_data) +def test_split_tensor_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + SplitCopy(), + test_data(), + aten_ops=SplitCopy.aten_op, + exir_ops=SplitCopy.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", Split.test_data) +@common.SkipIfNoModelConverter +def test_split_tensor_vgf_FP(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + SplitCopy(), + test_data(), + aten_op=SplitCopy.aten_op, + exir_op=SplitCopy.exir_op, + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", Split.test_data) +@common.SkipIfNoModelConverter +def test_split_tensor_vgf_INT(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + SplitCopy(), + test_data(), + aten_op=SplitCopy.aten_op, + exir_op=SplitCopy.exir_op, + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()