|
3 | 3 | # This source code is licensed under the BSD-style license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
| 6 | +from copy import copy |
6 | 7 | from math import prod |
7 | 8 |
|
8 | 9 | import torch |
@@ -75,35 +76,47 @@ def call_operator(self, op, args, kwargs, meta): |
75 | 76 | return super().call_operator(op, args, kwargs, meta) |
76 | 77 |
|
77 | 78 | x = get_node_arg(args, 0) |
78 | | - input_shape = x.data.size() |
79 | | - output_shape = meta["val"].size() |
| 79 | + input_shape = list(x.data.shape) |
| 80 | + output_shape = list(meta["val"].shape) |
80 | 81 | dims_to_reduce = get_node_arg(args, 1) |
81 | 82 | dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] |
| 83 | + dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] |
82 | 84 |
|
83 | 85 | dtype = meta["val"].dtype |
84 | 86 | view_op = get_view(op) |
85 | 87 |
|
86 | | - if len(input_shape) > 4: |
87 | | - raise NotImplementedError( |
88 | | - f"{op} with rank > 4 is currently not supported for the TOSA backend." |
89 | | - ) |
| 88 | + # Reshape to 4D |
| 89 | + if len(input_shape) != 4: |
| 90 | + new_shape = copy(input_shape) |
| 91 | + |
| 92 | + while len(new_shape) < 4: |
| 93 | + new_shape.insert(0, 1) |
| 94 | + dims_to_reduce = [dim + 1 for dim in dims_to_reduce] |
90 | 95 |
|
91 | | - # Unsqueeze to 4D |
92 | | - if len(input_shape) < 4: |
93 | | - pad_n = 4 - len(input_shape) |
94 | | - new_shape = [1] * pad_n + list(input_shape) |
95 | | - dims_to_reduce = [dim + pad_n for dim in dims_to_reduce] |
| 96 | + while len(new_shape) > 4: |
| 97 | + i = new_shape.pop(0) |
| 98 | + new_shape[0] = new_shape[0] * i |
| 99 | + dims_to_reduce = [dim - 1 for dim in dims_to_reduce] |
96 | 100 |
|
97 | 101 | x = super().call_operator(view_op, (x, new_shape), {}, meta, True) |
98 | 102 |
|
99 | 103 | # Reduce (h,w) dims by avg pool if possible |
100 | 104 | x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta) |
101 | 105 |
|
| 106 | + # Reshape back to 5D if necessary |
| 107 | + if len(input_shape) > 4: |
| 108 | + original_dims = input_shape[0:-4] |
| 109 | + temp_shape = list(x.data.shape)[1:] |
| 110 | + temp_shape = original_dims + temp_shape |
| 111 | + dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce] |
| 112 | + |
| 113 | + x = super().call_operator(view_op, (x, temp_shape), {}, meta, True) |
| 114 | + |
102 | 115 | # Reduce remaining dims by sum |
103 | 116 | x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype) |
104 | 117 |
|
105 | 118 | # Reshape to correct output shape if necessary |
106 | | - if x.data.size() != output_shape: |
| 119 | + if list(x.data.shape) != output_shape: |
107 | 120 | x = super().call_operator(view_op, (x, output_shape), {}, meta, True) |
108 | 121 |
|
109 | 122 | return x |
|
0 commit comments