Skip to content

Commit 04a05b5

Browse files
author
Nikita Savelyev
committed
Initial commit
1 parent a376326 commit 04a05b5

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed

src/nncf/experimental/common/tensor_statistics/statistics.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
from __future__ import annotations
1313

1414
from collections import Counter
15+
from collections import deque
1516
from dataclasses import dataclass
1617
from dataclasses import fields
1718
from typing import Any, ClassVar
1819

1920
import nncf
2021
from nncf.tensor import Tensor
22+
from nncf.tensor import TensorDataType
2123
from nncf.tensor import functions as fns
2224

2325

@@ -107,23 +109,27 @@ class MeanTensorStatistic(TensorStatistic):
107109
mean_values: Tensor
108110
shape: tuple[int, ...]
109111

112+
def __post_init__(self):
113+
if isinstance(self.shape, (deque, list)):
114+
# If NoopAggregator was used, shape can be a sequence containing a single tensor
115+
self.shape = tuple(self.shape[0].data.tolist())
116+
110117
def __eq__(self, other: TensorStatistic):
111118
if isinstance(other, MeanTensorStatistic):
112119
return self.shape == other.shape and fns.allclose(self.mean_values, other.mean_values)
113120
return False
114121

115122
def _get_serialized_data(self) -> dict[str, Tensor]:
116123
backend = self.mean_values.backend
117-
dtype = self.mean_values.dtype
118124
device = self.mean_values.device
119125
return {
120126
self.MEAN_STAT: self.mean_values,
121-
self.SHAPE_STAT: fns.tensor(self.shape, backend=backend, dtype=dtype, device=device),
127+
self.SHAPE_STAT: fns.tensor(self.shape, backend=backend, dtype=TensorDataType.int32, device=device),
122128
}
123129

124130
def load_data(self, loaded_data: dict[str, Tensor]) -> None:
125131
self.mean_values = loaded_data[self.MEAN_STAT]
126-
self.shape_values = tuple(loaded_data[self.SHAPE_STAT].tolist())
132+
self.shape = tuple(loaded_data[self.SHAPE_STAT].data.tolist())
127133

128134

129135
@dataclass
@@ -261,14 +267,13 @@ def __eq__(self, other: Any) -> bool:
261267

262268
def _get_serialized_data(self) -> dict[str, Tensor]:
263269
backend = self.mean_values[0].backend
264-
dtype = self.mean_values[0].dtype
265270
device = self.mean_values[0].device
266271
return {
267272
self.MEAN_STAT: fns.stack(self.mean_values),
268273
self.SHAPE_STAT: fns.tensor(
269-
[[dim.data for dim in shape] for shape in self.shape_values],
274+
self.shape_values,
270275
backend=backend,
271-
dtype=dtype,
276+
dtype=TensorDataType.int32,
272277
device=device,
273278
),
274279
}
@@ -283,5 +288,5 @@ def from_config(cls, config: dict[str, Any]) -> TensorStatistic:
283288
if cls.MEAN_STAT in config and config[cls.MEAN_STAT] is not None:
284289
mean_values = [fns.squeeze(it) for it in config[cls.MEAN_STAT]]
285290
if cls.SHAPE_STAT in config and config[cls.SHAPE_STAT] is not None:
286-
shape_values = [tuple(it) for it in config[cls.SHAPE_STAT]]
291+
shape_values = [tuple(it.data.tolist()) for it in config[cls.SHAPE_STAT]]
287292
return cls(mean_values=mean_values, shape_values=shape_values)

src/nncf/onnx/statistics/collectors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from nncf.experimental.common.tensor_statistics.collectors import MeanPerChReducer
1717
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
1818
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
19-
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
19+
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
2020
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
2121
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
2222
from nncf.experimental.common.tensor_statistics.statistics import RawTensorStatistic
@@ -40,19 +40,19 @@ def get_mean_statistic_collector(
4040
reducer = BatchMeanReducer(inplace)
4141
else:
4242
reducer = MeanPerChReducer(channel_axis=channel_axis, inplace=inplace)
43-
raw_reducer = RawReducer()
43+
shape_reducer = ShapeReducer(inplace=inplace)
4444

4545
kwargs = {
4646
"num_samples": num_samples,
4747
"window_size": window_size,
4848
}
4949

5050
aggregate_mean = MeanAggregator(**kwargs)
51-
aggregate_shape = ShapeAggregator()
51+
aggregate_noop = NoopAggregator(num_samples=1)
5252

5353
collector = TensorCollector(MeanTensorStatistic)
5454
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
55-
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
55+
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
5656
return collector
5757

5858

src/nncf/openvino/statistics/collectors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
2727
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
2828
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
29-
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
3029
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
3130
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
3231
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
@@ -120,18 +119,18 @@ def get_mean_statistic_collector(
120119
reducer = OVBatchMeanReducer(inplace)
121120
else:
122121
reducer = OVMeanPerChanelReducer(channel_axis=channel_axis, inplace=inplace)
123-
raw_reducer = RawReducer()
122+
shape_reducer = OVShapeReducer(inplace=inplace)
124123

125124
kwargs = {
126125
"num_samples": num_samples,
127126
"window_size": window_size,
128127
}
129128
aggregate_mean = MeanAggregator(**kwargs)
130-
aggregate_shape = ShapeAggregator()
129+
aggregate_noop = NoopAggregator(num_samples=1)
131130

132131
collector = TensorCollector(MeanTensorStatistic)
133132
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
134-
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
133+
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
135134
return collector
136135

137136

src/nncf/torch/tensor_statistics/collectors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
2929
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
3030
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
31-
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
31+
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
3232
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
3333
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
3434
from nncf.experimental.common.tensor_statistics.statistics import MedianMADTensorStatistic
@@ -306,18 +306,18 @@ def get_mean_statistic_collector(
306306
reducer = BatchMeanReducer()
307307
else:
308308
reducer = MeanPerChReducer(channel_axis=channel_axis)
309-
raw_reducer = RawReducer()
309+
shape_reducer = ShapeReducer()
310310

311311
kwargs = {
312312
"num_samples": num_samples,
313313
"window_size": window_size,
314314
}
315315
aggregate_mean = MeanAggregator(**kwargs)
316-
aggregate_shape = ShapeAggregator()
316+
aggregate_noop = NoopAggregator(num_samples=1)
317317

318318
collector = TensorCollector(MeanTensorStatistic)
319319
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
320-
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
320+
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
321321
return collector
322322

323323

0 commit comments

Comments
 (0)