Skip to content

Commit 24f9c81

Browse files
Nikita Savelyevdaniil-lyakhov
andauthored
Don't use RawReducer for activation shape collection in Fast Bias Correction (#3642)
### Changes As in the title. ### Reason for changes This PR reduces memory footprint when applying Fast Bias Correction algorithm: collecting raw activations is not required to obtain their shapes. Avoiding using raw reducers allows to save some memory otherwise allocated for the activations. Example quantization run on vision encoder from `OpenGVLab/InternVL2-1B` with 4 calibration data samples: | Before | After | |-|-| | <img width="1000" height="600" alt="system_memory_usage_from-zero" src="https://github.com/user-attachments/assets/73354e2f-db21-48a9-8c8a-b5a80426b41a" /> | <img width="1000" height="600" alt="system_memory_usage_from-zero" src="https://github.com/user-attachments/assets/a5343cd6-7bef-413a-8688-427b901194b7" /> | Since there is no need to allocate so much memory, statistics collection time also improves. ### Related tickets 172800 ### Tests Existing tests cover the new changes. - NNCF/job/manual/job/post_training_quantization/730 - NNCF/job/manual/job/post_training_quantization_performance/119/ --------- Co-authored-by: dlyakhov <[email protected]>
1 parent c00a0eb commit 24f9c81

File tree

13 files changed

+89
-45
lines changed

13 files changed

+89
-45
lines changed

src/nncf/experimental/common/tensor_statistics/collectors.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,9 @@ class ShapeReducer(TensorReducerBase):
433433
def __init__(self, inplace: bool = False):
434434
super().__init__(inplace=inplace)
435435

436-
def _reduce_out_of_place(self, x: list[TensorType]) -> list[tuple[int, ...]]:
437-
return [x[0].shape]
436+
def _reduce_out_of_place(self, x: list[TensorType]) -> list[TensorType]:
437+
# Return as tensor for consistency, because in-place reducer returns a tensor
438+
return [fns.tensor(x[0].shape, backend=x[0].backend, dtype=TensorDataType.int32, device=x[0].device)]
438439

439440
def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
440441
return None
@@ -561,25 +562,24 @@ def __hash__(self) -> int:
561562

562563

563564
class NoopAggregator(AggregatorBase):
564-
def __init__(self, num_samples: Optional[int]):
565-
super().__init__(None, num_samples=num_samples)
565+
def __init__(self, num_samples: Optional[int], return_first: bool = False):
566+
"""
567+
Creates an aggregator that only accumulates data without any additional processing.
568+
:param num_samples: The number of samples to collect. If None, all samples are collected.
569+
:param return_first: If True, the first collected sample is returned on aggregate call.
570+
If False, all collected samples are returned as a list.
571+
"""
572+
if return_first and num_samples is not None and num_samples != 1:
573+
msg = "NoopAggregator with return_first=True should not have num_samples > 1"
574+
raise nncf.InternalError(msg)
575+
super().__init__(None, num_samples=1 if return_first else num_samples)
576+
self._return_first = return_first
566577

567578
def _register_reduced_input_impl(self, x: TensorType) -> None:
568579
self._container.append(x)
569580

570581
def _aggregate_impl(self):
571-
return self._container
572-
573-
574-
class ShapeAggregator(AggregatorBase):
575-
def __init__(self):
576-
super().__init__(None, num_samples=1)
577-
578-
def _register_reduced_input_impl(self, x: TensorType) -> None:
579-
self._container = x
580-
581-
def _aggregate_impl(self):
582-
return self._container.shape
582+
return self._container[0] if self._return_first else self._container
583583

584584

585585
class OnlineAggregatorBase(AggregatorBase, ABC):

src/nncf/experimental/common/tensor_statistics/statistics.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import nncf
2020
from nncf.tensor import Tensor
21+
from nncf.tensor import TensorDataType
2122
from nncf.tensor import functions as fns
2223

2324

@@ -108,31 +109,34 @@ def __eq__(self, other: TensorStatistic):
108109
return False
109110

110111

111-
@dataclass
112+
@dataclass(init=False)
112113
class MeanTensorStatistic(TensorStatistic):
113114
MEAN_STAT: ClassVar[str] = "mean_values"
114115
SHAPE_STAT: ClassVar[str] = "shape"
115116

116117
mean_values: Tensor
117118
shape: tuple[int, ...]
118119

120+
def __init__(self, mean_values: Tensor, shape: Tensor) -> None:
121+
self.mean_values = mean_values
122+
self.shape = tuple(shape.tolist())
123+
119124
def __eq__(self, other: TensorStatistic):
120125
if isinstance(other, MeanTensorStatistic):
121126
return self.shape == other.shape and fns.allclose(self.mean_values, other.mean_values)
122127
return False
123128

124129
def _get_serialized_data(self) -> dict[str, Tensor]:
125130
backend = self.mean_values.backend
126-
dtype = self.mean_values.dtype
127131
device = self.mean_values.device
128132
return {
129133
self.MEAN_STAT: self.mean_values,
130-
self.SHAPE_STAT: fns.tensor(self.shape, backend=backend, dtype=dtype, device=device),
134+
self.SHAPE_STAT: fns.tensor(self.shape, backend=backend, dtype=TensorDataType.int32, device=device),
131135
}
132136

133137
def load_data(self, loaded_data: dict[str, Tensor]) -> None:
134138
self.mean_values = loaded_data[self.MEAN_STAT]
135-
self.shape_values = tuple(loaded_data[self.SHAPE_STAT].tolist())
139+
self.shape = tuple(loaded_data[self.SHAPE_STAT].tolist())
136140

137141

138142
@dataclass
@@ -270,14 +274,13 @@ def __eq__(self, other: Any) -> bool:
270274

271275
def _get_serialized_data(self) -> dict[str, Tensor]:
272276
backend = self.mean_values[0].backend
273-
dtype = self.mean_values[0].dtype
274277
device = self.mean_values[0].device
275278
return {
276279
self.MEAN_STAT: fns.stack(self.mean_values),
277280
self.SHAPE_STAT: fns.tensor(
278-
[[dim.data for dim in shape] for shape in self.shape_values],
281+
self.shape_values,
279282
backend=backend,
280-
dtype=dtype,
283+
dtype=TensorDataType.int32,
281284
device=device,
282285
),
283286
}
@@ -292,5 +295,5 @@ def from_config(cls, config: dict[str, Any]) -> TensorStatistic:
292295
if cls.MEAN_STAT in config and config[cls.MEAN_STAT] is not None:
293296
mean_values = [fns.squeeze(it) for it in config[cls.MEAN_STAT]]
294297
if cls.SHAPE_STAT in config and config[cls.SHAPE_STAT] is not None:
295-
shape_values = [tuple(it) for it in config[cls.SHAPE_STAT]]
298+
shape_values = [tuple(it.tolist()) for it in config[cls.SHAPE_STAT]]
296299
return cls(mean_values=mean_values, shape_values=shape_values)

src/nncf/onnx/statistics/collectors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from nncf.experimental.common.tensor_statistics.collectors import MeanPerChReducer
1717
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
1818
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
19-
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
19+
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
2020
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
2121
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
2222
from nncf.experimental.common.tensor_statistics.statistics import RawTensorStatistic
@@ -40,19 +40,19 @@ def get_mean_statistic_collector(
4040
reducer = BatchMeanReducer(inplace)
4141
else:
4242
reducer = MeanPerChReducer(channel_axis=channel_axis, inplace=inplace)
43-
raw_reducer = RawReducer()
43+
shape_reducer = ShapeReducer(inplace=inplace)
4444

4545
kwargs = {
4646
"num_samples": num_samples,
4747
"window_size": window_size,
4848
}
4949

5050
aggregate_mean = MeanAggregator(**kwargs)
51-
aggregate_shape = ShapeAggregator()
51+
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)
5252

5353
collector = TensorCollector(MeanTensorStatistic)
5454
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
55-
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
55+
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
5656
return collector
5757

5858

src/nncf/openvino/statistics/collectors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
2727
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
2828
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
29-
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
3029
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
3130
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
3231
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
@@ -120,18 +119,18 @@ def get_mean_statistic_collector(
120119
reducer = OVBatchMeanReducer(inplace)
121120
else:
122121
reducer = OVMeanPerChanelReducer(channel_axis=channel_axis, inplace=inplace)
123-
raw_reducer = RawReducer()
122+
shape_reducer = OVShapeReducer(inplace=inplace)
124123

125124
kwargs = {
126125
"num_samples": num_samples,
127126
"window_size": window_size,
128127
}
129128
aggregate_mean = MeanAggregator(**kwargs)
130-
aggregate_shape = ShapeAggregator()
129+
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)
131130

132131
collector = TensorCollector(MeanTensorStatistic)
133132
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
134-
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
133+
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
135134
return collector
136135

137136

src/nncf/tensor/functions/numeric.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,3 +922,13 @@ def as_numpy_tensor(a: Tensor) -> Tensor:
922922
:param a: Tensor to change backend for.
923923
:return: Tensor in numpy backend.
924924
"""
925+
926+
927+
@tensor_dispatcher
928+
def tolist(a: Tensor) -> Any:
929+
"""
930+
Returns the tensor as a nested list.
931+
For scalars, a standard Python number is returned, just like with item().
932+
933+
:return: The tensor as a nested list.
934+
"""

src/nncf/tensor/functions/numpy_numeric.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,8 @@ def tensor(
495495
validate_device(device)
496496
np_dtype = convert_to_numpy_dtype(dtype)
497497
return np.array(data, dtype=np_dtype)
498+
499+
500+
@numeric.tolist.register
501+
def _(a: T_NUMPY) -> Any:
502+
return a.tolist()

src/nncf/tensor/functions/tf_numeric.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,8 @@ def tensor(
572572
@numeric.as_numpy_tensor.register
573573
def _(a: tf.Tensor) -> npt.NDArray[Any]:
574574
return a.numpy()
575+
576+
577+
@numeric.tolist.register
578+
def _(a: tf.Tensor) -> Any:
579+
return a.numpy().tolist()

src/nncf/tensor/functions/torch_numeric.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,3 +542,8 @@ def tensor(
542542
@numeric.as_numpy_tensor.register
543543
def _(a: torch.Tensor) -> NDArray[Any]:
544544
return a.cpu().detach().numpy()
545+
546+
547+
@numeric.tolist.register
548+
def _(a: torch.Tensor) -> Any:
549+
return a.tolist()

src/nncf/tensor/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ def clone(self) -> Tensor:
205205
def as_numpy_tensor(self) -> Tensor:
206206
return cast(Tensor, _call_function("as_numpy_tensor", self))
207207

208+
def tolist(self) -> Any:
209+
return _call_function("tolist", self)
210+
208211
def as_openvino_tensor(self) -> Tensor:
209212
x = self
210213
if x.backend == TensorBackend.numpy:

src/nncf/torch/tensor_statistics/collectors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
2929
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
3030
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
31-
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
31+
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
3232
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
3333
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
3434
from nncf.experimental.common.tensor_statistics.statistics import MedianMADTensorStatistic
@@ -306,18 +306,18 @@ def get_mean_statistic_collector(
306306
reducer = BatchMeanReducer()
307307
else:
308308
reducer = MeanPerChReducer(channel_axis=channel_axis)
309-
raw_reducer = RawReducer()
309+
shape_reducer = ShapeReducer()
310310

311311
kwargs = {
312312
"num_samples": num_samples,
313313
"window_size": window_size,
314314
}
315315
aggregate_mean = MeanAggregator(**kwargs)
316-
aggregate_shape = ShapeAggregator()
316+
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)
317317

318318
collector = TensorCollector(MeanTensorStatistic)
319319
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
320-
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
320+
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
321321
return collector
322322

323323

0 commit comments

Comments
 (0)