Skip to content
Merged
26 changes: 11 additions & 15 deletions src/nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nncf.quantization.advanced_parameters import AggregatorType
from nncf.quantization.range_estimator import StatisticsType
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType

InplaceInsertionFNType = TypeVar("InplaceInsertionFNType")
AggregationAxes = tuple[int, ...]
Expand Down Expand Up @@ -427,8 +428,9 @@ class ShapeReducer(TensorReducerBase):
def __init__(self, inplace: bool = False):
super().__init__(inplace=inplace)

def _reduce_out_of_place(self, x: list[TensorType]) -> list[tuple[int, ...]]:
return [x[0].shape]
def _reduce_out_of_place(self, x: list[TensorType]) -> list[TensorType]:
# Return as tensor for consistency, because in-place reducer returns a tensor
return [fns.tensor(x[0].shape, backend=x[0].backend, dtype=TensorDataType.int32, device=x[0].device)]

def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
return None
Expand Down Expand Up @@ -555,25 +557,19 @@ def __hash__(self) -> int:


class NoopAggregator(AggregatorBase):
def __init__(self, num_samples: Optional[int]):
def __init__(self, num_samples: Optional[int], return_first: bool = False):
if return_first and num_samples is not None and num_samples > 1:
msg = "NoopAggregator with return_first=True should not have num_samples > 1"
raise nncf.InternalError(msg)
num_samples = 1 if num_samples is None and return_first else num_samples
super().__init__(None, num_samples=num_samples)
self._return_first = return_first

def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container.append(x)

def _aggregate_impl(self):
return self._container


class ShapeAggregator(AggregatorBase):
def __init__(self):
super().__init__(None, num_samples=1)

def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container = x

def _aggregate_impl(self):
return self._container.shape
return self._container[0] if self._return_first else self._container


class OnlineAggregatorBase(AggregatorBase, ABC):
Expand Down
19 changes: 11 additions & 8 deletions src/nncf/experimental/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import nncf
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
from nncf.tensor import functions as fns


Expand Down Expand Up @@ -99,31 +100,34 @@ def __eq__(self, other: TensorStatistic):
return False


@dataclass
@dataclass(init=False)
class MeanTensorStatistic(TensorStatistic):
MEAN_STAT: ClassVar[str] = "mean_values"
SHAPE_STAT: ClassVar[str] = "shape"

mean_values: Tensor
shape: tuple[int, ...]

def __init__(self, mean_values: Tensor, shape: Tensor) -> None:
self.mean_values = mean_values
self.shape = tuple(shape.data.tolist())

def __eq__(self, other: TensorStatistic):
if isinstance(other, MeanTensorStatistic):
return self.shape == other.shape and fns.allclose(self.mean_values, other.mean_values)
return False

def _get_serialized_data(self) -> dict[str, Tensor]:
backend = self.mean_values.backend
dtype = self.mean_values.dtype
device = self.mean_values.device
return {
self.MEAN_STAT: self.mean_values,
self.SHAPE_STAT: fns.tensor(self.shape, backend=backend, dtype=dtype, device=device),
self.SHAPE_STAT: fns.tensor(self.shape, backend=backend, dtype=TensorDataType.int32, device=device),
}

def load_data(self, loaded_data: dict[str, Tensor]) -> None:
self.mean_values = loaded_data[self.MEAN_STAT]
self.shape_values = tuple(loaded_data[self.SHAPE_STAT].tolist())
self.shape = tuple(loaded_data[self.SHAPE_STAT].data.tolist())


@dataclass
Expand Down Expand Up @@ -261,14 +265,13 @@ def __eq__(self, other: Any) -> bool:

def _get_serialized_data(self) -> dict[str, Tensor]:
backend = self.mean_values[0].backend
dtype = self.mean_values[0].dtype
device = self.mean_values[0].device
return {
self.MEAN_STAT: fns.stack(self.mean_values),
self.SHAPE_STAT: fns.tensor(
[[dim.data for dim in shape] for shape in self.shape_values],
self.shape_values,
backend=backend,
dtype=dtype,
dtype=TensorDataType.int32,
device=device,
),
}
Expand All @@ -283,5 +286,5 @@ def from_config(cls, config: dict[str, Any]) -> TensorStatistic:
if cls.MEAN_STAT in config and config[cls.MEAN_STAT] is not None:
mean_values = [fns.squeeze(it) for it in config[cls.MEAN_STAT]]
if cls.SHAPE_STAT in config and config[cls.SHAPE_STAT] is not None:
shape_values = [tuple(it) for it in config[cls.SHAPE_STAT]]
shape_values = [tuple(it.data.tolist()) for it in config[cls.SHAPE_STAT]]
return cls(mean_values=mean_values, shape_values=shape_values)
8 changes: 4 additions & 4 deletions src/nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from nncf.experimental.common.tensor_statistics.collectors import MeanPerChReducer
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.experimental.common.tensor_statistics.statistics import RawTensorStatistic
Expand All @@ -40,19 +40,19 @@ def get_mean_statistic_collector(
reducer = BatchMeanReducer(inplace)
else:
reducer = MeanPerChReducer(channel_axis=channel_axis, inplace=inplace)
raw_reducer = RawReducer()
shape_reducer = ShapeReducer(inplace=inplace)

kwargs = {
"num_samples": num_samples,
"window_size": window_size,
}

aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator()
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
return collector


Expand Down
7 changes: 3 additions & 4 deletions src/nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
Expand Down Expand Up @@ -120,18 +119,18 @@ def get_mean_statistic_collector(
reducer = OVBatchMeanReducer(inplace)
else:
reducer = OVMeanPerChanelReducer(channel_axis=channel_axis, inplace=inplace)
raw_reducer = RawReducer()
shape_reducer = OVShapeReducer(inplace=inplace)

kwargs = {
"num_samples": num_samples,
"window_size": window_size,
}
aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator()
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
return collector


Expand Down
8 changes: 4 additions & 4 deletions src/nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.experimental.common.tensor_statistics.statistics import MedianMADTensorStatistic
Expand Down Expand Up @@ -306,18 +306,18 @@ def get_mean_statistic_collector(
reducer = BatchMeanReducer()
else:
reducer = MeanPerChReducer(channel_axis=channel_axis)
raw_reducer = RawReducer()
shape_reducer = ShapeReducer()

kwargs = {
"num_samples": num_samples,
"window_size": window_size,
}
aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator()
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
return collector


Expand Down
14 changes: 7 additions & 7 deletions tests/common/experimental/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.tensor import functions as fns

Expand Down Expand Up @@ -297,15 +296,16 @@ def test_noop_aggregator(self):
for val in aggregated:
assert fns.allclose(val, self.get_nncf_tensor(input_))

def test_shape_aggregator(self):
aggregator = ShapeAggregator()
def test_noop_aggregator_return_first(self):
aggregator = NoopAggregator(None, return_first=True)

ref_shape = (1, 3, 5, 7, 9)
input_ = np.empty(ref_shape)
for _ in range(3):
aggregator.register_reduced_input(self.get_nncf_tensor(input_))
input_ = np.arange(np.prod(ref_shape)).reshape(ref_shape)
aggregator.register_reduced_input(self.get_nncf_tensor(input_))

assert aggregator._collected_samples == 1
assert ref_shape == aggregator.aggregate()
aggregated = aggregator.aggregate()
assert fns.allclose(aggregated, self.get_nncf_tensor(input_))

@pytest.mark.parametrize(
"offline_aggregators_test_desc",
Expand Down
4 changes: 2 additions & 2 deletions tests/common/experimental/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,10 @@ def test_mean_max_stat_building(self):
tensor_collector.register_statistic_branch(
MeanTensorStatistic.SHAPE_STAT, DummyTensorReducer("B"), DummyTensorAggregator()
)
tensor_collector.register_input_for_all_reducers(Tensor(np.array(1)))
tensor_collector.register_input_for_all_reducers(Tensor(np.array([1])))
statistic = tensor_collector.get_statistics()
assert isinstance(statistic, MeanTensorStatistic)
assert statistic.mean_values == statistic.shape == Tensor(np.array(1))
assert statistic.mean_values == statistic.shape == Tensor(np.array([1]))

def test_median_mad_stat_building(self):
class DummyMADPercentileAggregator(DummyTensorAggregator):
Expand Down