Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from copy import copy
from math import prod

import torch
Expand Down Expand Up @@ -75,35 +76,47 @@ def call_operator(self, op, args, kwargs, meta):
return super().call_operator(op, args, kwargs, meta)

x = get_node_arg(args, 0)
input_shape = x.data.size()
output_shape = meta["val"].size()
input_shape = list(x.data.shape)
output_shape = list(meta["val"].shape)
dims_to_reduce = get_node_arg(args, 1)
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]

dtype = meta["val"].dtype
view_op = get_view(op)

if len(input_shape) > 4:
raise NotImplementedError(
f"{op} with rank > 4 is currently not supported for the TOSA backend."
)
# Reshape to 4D
if len(input_shape) != 4:
new_shape = copy(input_shape)

while len(new_shape) < 4:
new_shape.insert(0, 1)
dims_to_reduce = [dim + 1 for dim in dims_to_reduce]

# Unsqueeze to 4D
if len(input_shape) < 4:
pad_n = 4 - len(input_shape)
new_shape = [1] * pad_n + list(input_shape)
dims_to_reduce = [dim + pad_n for dim in dims_to_reduce]
while len(new_shape) > 4:
i = new_shape.pop(0)
new_shape[0] = new_shape[0] * i
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]

x = super().call_operator(view_op, (x, new_shape), {}, meta, True)

# Reduce (h,w) dims by avg pool if possible
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)

# Reshape back to 5D if necessary
if len(input_shape) > 4:
original_dims = input_shape[0:-4]
temp_shape = list(x.data.shape)[1:]
temp_shape = original_dims + temp_shape
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]

x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)

# Reduce remaining dims by sum
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)

# Reshape to correct output shape if necessary
if x.data.size() != output_shape:
if list(x.data.shape) != output_shape:
x = super().call_operator(view_op, (x, output_shape), {}, meta, True)

return x
Expand Down
26 changes: 24 additions & 2 deletions backends/arm/test/ops/test_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ class MeanDim(torch.nn.Module):
(-4, -3, -2, -1),
False,
),
"rank5_01234": lambda: (
torch.rand(1, 1, 7, 3, 2),
(-5, -4, -3, -2, -1),
False,
),
"rank5_234": lambda: (
torch.rand(1, 1, 7, 3, 2),
(-3, -2, -1),
False,
),
"rank5_12": lambda: (
torch.rand(1, 1, 7, 3, 2),
(1, 2),
False,
),
"u55_avg_pool_not_supported": lambda: (
torch.rand(1, 1, 1, 257),
(0, 1, 2, 3),
Expand Down Expand Up @@ -236,7 +251,14 @@ def test_mean_dim_tosa_BI(test_data):
pipeline.run()


@common.parametrize("test_data", MeanDim.test_data_suite)
xfails = {
"rank5_01234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
"rank5_234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
"rank5_12": "Rank 5 graph input currently not supported in EthosUBackend",
}


@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False)
@common.XfailIfNoCorstone300
def test_mean_dim_u55_BI(test_data):
test_data, dim, keep_dim = test_data()
Expand All @@ -256,7 +278,7 @@ def test_mean_dim_u55_BI(test_data):
pipeline.run()


@common.parametrize("test_data", MeanDim.test_data_suite)
@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False)
@common.XfailIfNoCorstone320
def test_mean_dim_u85_BI(test_data):
test_data, dim, keep_dim = test_data()
Expand Down
Loading