diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 456bcbb1a9b..efae6923311 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from copy import copy from math import prod import torch @@ -75,35 +76,47 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) x = get_node_arg(args, 0) - input_shape = x.data.size() - output_shape = meta["val"].size() + input_shape = list(x.data.shape) + output_shape = list(meta["val"].shape) dims_to_reduce = get_node_arg(args, 1) dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] + dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] dtype = meta["val"].dtype view_op = get_view(op) - if len(input_shape) > 4: - raise NotImplementedError( - f"{op} with rank > 4 is currently not supported for the TOSA backend." - ) + # Reshape to 4D + if len(input_shape) != 4: + new_shape = copy(input_shape) + + while len(new_shape) < 4: + new_shape.insert(0, 1) + dims_to_reduce = [dim + 1 for dim in dims_to_reduce] - # Unsqueeze to 4D - if len(input_shape) < 4: - pad_n = 4 - len(input_shape) - new_shape = [1] * pad_n + list(input_shape) - dims_to_reduce = [dim + pad_n for dim in dims_to_reduce] + while len(new_shape) > 4: + i = new_shape.pop(0) + new_shape[0] = new_shape[0] * i + dims_to_reduce = [dim - 1 for dim in dims_to_reduce] x = super().call_operator(view_op, (x, new_shape), {}, meta, True) # Reduce (h,w) dims by avg pool if possible x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta) + # Reshape back to 5D if necessary + if len(input_shape) > 4: + original_dims = input_shape[0:-4] + temp_shape = list(x.data.shape)[1:] + temp_shape = original_dims + temp_shape + dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce] + + x = super().call_operator(view_op, (x, temp_shape), {}, meta, True) + # Reduce remaining dims by sum x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype) # Reshape to correct output shape if necessary - if x.data.size() != output_shape: + if list(x.data.shape) != output_shape: x = super().call_operator(view_op, (x, output_shape), {}, meta, True) return x diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 0ee6e3c64f3..fef11d365ea 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -195,6 +195,21 @@ class MeanDim(torch.nn.Module): (-4, -3, -2, -1), False, ), + "rank5_01234": lambda: ( + torch.rand(1, 1, 7, 3, 2), + (-5, -4, -3, -2, -1), + False, + ), + "rank5_234": lambda: ( + torch.rand(1, 1, 7, 3, 2), + (-3, -2, -1), + False, + ), + "rank5_12": lambda: ( + torch.rand(1, 1, 7, 3, 2), + (1, 2), + False, + ), "u55_avg_pool_not_supported": lambda: ( torch.rand(1, 1, 1, 257), (0, 1, 2, 3), @@ -236,7 +251,14 @@ def test_mean_dim_tosa_BI(test_data): pipeline.run() -@common.parametrize("test_data", MeanDim.test_data_suite) +xfails = { + "rank5_01234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)", + "rank5_234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)", + "rank5_12": "Rank 5 graph input currently not supported in EthosUBackend", +} + + +@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False) @common.XfailIfNoCorstone300 def test_mean_dim_u55_BI(test_data): test_data, dim, keep_dim = test_data() @@ -256,7 +278,7 @@ def test_mean_dim_u55_BI(test_data): pipeline.run() -@common.parametrize("test_data", MeanDim.test_data_suite) +@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False) @common.XfailIfNoCorstone320 def test_mean_dim_u85_BI(test_data): test_data, dim, keep_dim = test_data()