Skip to content

Commit 3d319cf

Browse files
author
Nikita Savelyev
committed
Make out of place reducer return tensor
1 parent 2664489 commit 3d319cf

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/nncf/experimental/common/tensor_statistics/collectors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from nncf.quantization.advanced_parameters import AggregatorType
2727
from nncf.quantization.range_estimator import StatisticsType
2828
from nncf.tensor import Tensor
29+
from nncf.tensor import TensorDataType
2930

3031
InplaceInsertionFNType = TypeVar("InplaceInsertionFNType")
3132
AggregationAxes = tuple[int, ...]
@@ -428,7 +429,8 @@ def __init__(self, inplace: bool = False):
428429
super().__init__(inplace=inplace)
429430

430431
def _reduce_out_of_place(self, x: list[TensorType]) -> list[tuple[int, ...]]:
431-
return [x[0].shape]
432+
# Return as tensor for consistency
433+
return [fns.tensor(x[0].shape, backend=x[0].backend, dtype=TensorDataType.int32, device=x[0].device)]
432434

433435
def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
434436
return None

src/nncf/experimental/common/tensor_statistics/statistics.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from __future__ import annotations
1313

1414
from collections import Counter
15-
from collections import deque
1615
from dataclasses import dataclass
1716
from dataclasses import fields
1817
from typing import Any, ClassVar
@@ -111,7 +110,7 @@ class MeanTensorStatistic(TensorStatistic):
111110

112111
def __post_init__(self):
113112
if isinstance(self.shape[0], Tensor):
114-
# If in-place shape reducer and Noop aggregator were used, shape is a sequence containing a single tensor
113+
# If shape reducer and Noop aggregator were used, shape is a sequence containing a single tensor
115114
self.shape = tuple(self.shape[0].data.tolist())
116115

117116
def __eq__(self, other: TensorStatistic):

0 commit comments

Comments
 (0)