Skip to content

Commit d27f965

Browse files
Apply comments
1 parent ebf8bcc commit d27f965

File tree

9 files changed

+167
-147
lines changed

9 files changed

+167
-147
lines changed

src/nncf/quantization/advanced_parameters.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from nncf.parameters import StrEnum
2828
from nncf.quantization.range_estimator import AggregatorType
2929
from nncf.quantization.range_estimator import RangeEstimatorParameters
30-
from nncf.quantization.range_estimator import StatisticsCollectorParameters
3130
from nncf.quantization.range_estimator import StatisticsType
3231

3332
TTensor = Any
@@ -273,12 +272,8 @@ class AdvancedQuantizationParameters:
273272
quantizer_propagation_rule: QuantizerPropagationRule = QuantizerPropagationRule.MERGE_ALL_IN_ONE
274273

275274
# Range estimator parameters
276-
activations_range_estimator_params: Union[RangeEstimatorParameters, StatisticsCollectorParameters] = field(
277-
default_factory=RangeEstimatorParameters
278-
)
279-
weights_range_estimator_params: Union[RangeEstimatorParameters, StatisticsCollectorParameters] = field(
280-
default_factory=RangeEstimatorParameters
281-
)
275+
activations_range_estimator_params: RangeEstimatorParameters = field(default_factory=RangeEstimatorParameters)
276+
weights_range_estimator_params: RangeEstimatorParameters = field(default_factory=RangeEstimatorParameters)
282277

283278
# Advanced BiasCorrection algorithm parameters
284279
bias_correction_params: AdvancedBiasCorrectionParameters = field(default_factory=AdvancedBiasCorrectionParameters)
@@ -533,19 +528,13 @@ def convert_quantization_parameters_to_dict(params: Optional[QuantizationParamet
533528
return result
534529

535530

536-
def convert_range_estimator_parameters_to_dict(
537-
params: Union[RangeEstimatorParameters, StatisticsCollectorParameters],
538-
) -> dict[str, Any]:
531+
def convert_range_estimator_parameters_to_dict(params: RangeEstimatorParameters) -> dict[str, Any]:
539532
"""
540533
Converts range estimator parameters to the dict in the legacy format
541534
542535
:param params: Range estimator parameters
543536
:return: range estimator parameters as dict in the legacy format
544537
"""
545-
if isinstance(params, StatisticsCollectorParameters):
546-
msg = "Single branch statistic collection is not supported for this backend yet."
547-
raise nncf.ParameterNotSupportedError(msg)
548-
549538
if params.min.clipping_value is not None or params.max.clipping_value is not None:
550539
msg = "clipping_value parameter is not supported in the legacy format"
551540
raise nncf.ParameterNotSupportedError(msg)

src/nncf/quantization/algorithms/min_max/algorithm.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
from nncf.quantization.range_estimator import AggregatorType
6868
from nncf.quantization.range_estimator import RangeEstimatorParameters
6969
from nncf.quantization.range_estimator import RangeEstimatorParametersSet
70-
from nncf.quantization.range_estimator import StatisticsCollectorParameters
7170
from nncf.quantization.range_estimator import StatisticsType
7271
from nncf.scopes import IgnoredScope
7372
from nncf.scopes import get_ignored_node_names_from_ignored_scope
@@ -433,16 +432,23 @@ def _get_range_estimator_parameters(
433432
user_params = self._range_estimator_params[quantizer_group]
434433
if user_params is None:
435434
return deepcopy(params)
436-
437-
if isinstance(user_params, StatisticsCollectorParameters):
435+
if (
436+
user_params.min.aggregator_type is AggregatorType.HISTOGRAM
437+
or user_params.max.aggregator_type is AggregatorType.HISTOGRAM
438+
):
439+
if user_params != RangeEstimatorParametersSet.HISTOGRAM:
440+
msg = (
441+
f"Given parameters set {user_params} is not supported by NNCF."
442+
" Please use the RangeEstimatorParametersSet.HISTOGRAM to enable the histogram aggregation."
443+
)
444+
raise nncf.ParameterNotSupportedError(msg)
438445
if quantizer_config.per_channel:
439446
msg = (
440-
f"Could not create signle aggregator with parameters {user_params}",
441-
" Per channel statistic collection is not supported for the single aggregator case yet.",
447+
f"Rollback to RangeEstimatorParametersSet.MINMAX for the target point: '{target_point}' as"
448+
" HistogramAggregator does not support per-channel activations."
442449
)
443-
raise nncf.InternalError(msg)
444-
return deepcopy(user_params)
445-
450+
nncf_logger.warning(msg)
451+
user_params = RangeEstimatorParametersSet.MINMAX
446452
min_changes = changes_asdict(user_params.min)
447453
min_statistic_collector = dataclasses.replace(params.min, **min_changes)
448454

@@ -503,9 +509,22 @@ def _get_stat_collector(
503509
num_samples=num_samples,
504510
)
505511

512+
def _get_histogram_statistic_collector(self, num_samples: int) -> TensorCollector:
513+
"""
514+
Return the histogram statistic collector.
515+
516+
:param num_samples: Maximum number of samples to collect.
517+
:return: An histogram TensorCollector for the statistics calculation.
518+
"""
519+
reducer = self._backend_entity.reducer_map[StatisticsType.RAW]()
520+
aggregator = AGGREGATORS_MAP[AggregatorType.HISTOGRAM](num_samples=num_samples)
521+
collector = TensorCollector(MinMaxTensorStatistic)
522+
collector.register_statistic_branch(MinMaxTensorStatistic.MIN_MAX_STAT, reducer, aggregator)
523+
return collector
524+
506525
def _get_statistic_collector(
507526
self,
508-
range_estimator_params: Union[RangeEstimatorParameters, StatisticsCollectorParameters],
527+
range_estimator_params: RangeEstimatorParameters,
509528
use_abs_max: bool,
510529
reduction_axes: Optional[tuple[int, ...]],
511530
aggregation_axes: Optional[tuple[int, ...]],
@@ -523,24 +542,13 @@ def _get_statistic_collector(
523542
:param num_samples: Maximum number of samples to collect.
524543
:return: TensorCollector for the statistics calculation.
525544
"""
526-
collector = TensorCollector(MinMaxTensorStatistic)
527-
if isinstance(range_estimator_params, StatisticsCollectorParameters):
528-
if range_estimator_params.statistics_type is not StatisticsType.RAW:
529-
msg = "Only RAW statistic type is suppored for single aggregator case."
530-
raise nncf.InternalError(msg)
531-
532-
if range_estimator_params.aggregator_type is not AggregatorType.HISTOGRAM:
533-
msg = "Only HISTOGRAM aggregator type is suppored for single aggregator case."
534-
raise nncf.InternalError(msg)
535-
536-
reducer = self._backend_entity.reducer_map[StatisticsType.RAW]()
537-
aggregator = AGGREGATORS_MAP[AggregatorType.HISTOGRAM]()
538-
collector.register_statistic_branch(MinMaxTensorStatistic.MIN_MAX_STAT, reducer, aggregator)
539-
return collector
545+
if range_estimator_params == RangeEstimatorParametersSet.HISTOGRAM:
546+
return self._get_histogram_statistic_collector(num_samples)
540547

541548
if not self._backend_entity.supports_inplace_statistics:
542549
inplace = False
543550

551+
collector = TensorCollector(MinMaxTensorStatistic)
544552
for params, container_key in zip(
545553
[range_estimator_params.min, range_estimator_params.max],
546554
[MinMaxTensorStatistic.MIN_STAT, MinMaxTensorStatistic.MAX_STAT],

src/nncf/quantization/range_estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,8 @@ class RangeEstimatorParametersSet:
153153
min=StatisticsCollectorParameters(statistics_type=StatisticsType.QUANTILE, aggregator_type=AggregatorType.MEAN),
154154
max=StatisticsCollectorParameters(statistics_type=StatisticsType.QUANTILE, aggregator_type=AggregatorType.MEAN),
155155
)
156+
157+
HISTOGRAM = RangeEstimatorParameters(
158+
min=StatisticsCollectorParameters(statistics_type=StatisticsType.RAW, aggregator_type=AggregatorType.HISTOGRAM),
159+
max=StatisticsCollectorParameters(statistics_type=StatisticsType.RAW, aggregator_type=AggregatorType.HISTOGRAM),
160+
)

tests/common/experimental/test_statistic_collector.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def _aggregate_impl(self):
6767
return self._container[0]
6868

6969

70+
class DummyMinMaxAggregatedAggregator(DummyTensorAggregator):
71+
def _aggregate_impl(self):
72+
return {MinMaxTensorStatistic.MIN_STAT: self._container[0], MinMaxTensorStatistic.MAX_STAT: self._container[0]}
73+
74+
7075
def get_output_info(reducers: list[DummyTensorReducer]) -> list[tuple[int, list[str]]]:
7176
retval = []
7277
for reducer in reducers:
@@ -362,6 +367,16 @@ def test_min_max_stat_building(self):
362367
assert isinstance(statistic, MinMaxTensorStatistic)
363368
assert statistic.min_values == statistic.max_values == Tensor(np.array(1))
364369

370+
def test_min_max_stat_composed(self):
371+
tensor_collector = TensorCollector(MinMaxTensorStatistic)
372+
tensor_collector.register_statistic_branch(
373+
MinMaxTensorStatistic.MIN_MAX_STAT, DummyTensorReducer("A"), DummyMinMaxAggregatedAggregator()
374+
)
375+
tensor_collector.register_input_for_all_reducers(Tensor(np.array(1)))
376+
statistic = tensor_collector.get_statistics()
377+
assert isinstance(statistic, MinMaxTensorStatistic)
378+
assert statistic.min_values == statistic.max_values == Tensor(np.array(1))
379+
365380
def test_mean_max_stat_building(self):
366381
tensor_collector = TensorCollector(MeanTensorStatistic)
367382
tensor_collector.register_statistic_branch(

tests/cross_fw/test_templates/test_ptq_params.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@
2626
from nncf.common.quantization.structs import QuantizerGroup
2727
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
2828
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
29+
from nncf.experimental.common.tensor_statistics.collectors import HistogramAggregator
2930
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
3031
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
3132
from nncf.experimental.common.tensor_statistics.collectors import MinAggregator
33+
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
3234
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
3335
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
3436
from nncf.parameters import ModelType
3537
from nncf.quantization.advanced_parameters import OverflowFix
3638
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
3739
from nncf.quantization.passes import transform_to_inference_graph
40+
from nncf.quantization.range_estimator import AggregatorType
3841
from nncf.quantization.range_estimator import RangeEstimatorParametersSet
3942
from nncf.scopes import IgnoredScope
4043
from nncf.tensor import Tensor
@@ -154,6 +157,14 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl
154157
assert MeanAggregator in aggrs
155158
assert aggrs[0].__class__ == aggrs[1].__class__
156159

160+
def check_is_histogram_statistic_collector(self, tensor_collector: TensorCollector):
161+
reducers = [r.__class__ for r in tensor_collector.reducers]
162+
assert len(reducers) == 1
163+
assert RawReducer in reducers
164+
aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()]
165+
assert len(aggrs) == 1
166+
assert HistogramAggregator in aggrs
167+
157168
@abstractmethod
158169
def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_q):
159170
pass
@@ -193,32 +204,71 @@ def nncf_graph_cls(self):
193204
def get_backend_tensor(self, value):
194205
pass
195206

196-
@pytest.mark.parametrize(
197-
"range_estimator_params", [RangeEstimatorParametersSet.MINMAX, RangeEstimatorParametersSet.MEAN_MINMAX, None]
198-
)
207+
RANGE_ESTIMATOR_TEST_PARAMS = [
208+
RangeEstimatorParametersSet.MINMAX,
209+
RangeEstimatorParametersSet.MEAN_MINMAX,
210+
None,
211+
RangeEstimatorParametersSet.HISTOGRAM,
212+
]
213+
214+
@pytest.mark.parametrize("range_estimator_params", RANGE_ESTIMATOR_TEST_PARAMS)
199215
def test_range_estimator_per_tensor(self, test_params, range_estimator_params):
216+
for stat_point_, tensor_collector in self._get_stat_points(
217+
test_params, range_estimator_params, "test_range_estimator_per_tensor"
218+
):
219+
if stat_point_.target_point.is_weight_target_point():
220+
# default tensor_collector for weights
221+
self.check_is_min_max_statistic_collector(tensor_collector)
222+
continue
223+
if range_estimator_params in [None, RangeEstimatorParametersSet.MEAN_MINMAX]:
224+
# default tensor_collector for per-tensor
225+
self.check_is_mean_min_max_statistic_collector(tensor_collector)
226+
elif range_estimator_params == RangeEstimatorParametersSet.MINMAX:
227+
self.check_is_min_max_statistic_collector(tensor_collector)
228+
elif range_estimator_params == RangeEstimatorParametersSet.HISTOGRAM:
229+
self.check_is_histogram_statistic_collector(tensor_collector)
230+
231+
@pytest.mark.parametrize("range_estimator_params", RANGE_ESTIMATOR_TEST_PARAMS)
232+
def test_range_estimator_per_channel(self, test_params, range_estimator_params):
233+
for stat_point_, tensor_collector in self._get_stat_points(
234+
test_params, range_estimator_params, "test_range_estimator_per_channel"
235+
):
236+
if stat_point_.target_point.is_weight_target_point():
237+
# default tensor_collector for weights
238+
self.check_is_min_max_statistic_collector(tensor_collector)
239+
continue
240+
if range_estimator_params in [
241+
None,
242+
RangeEstimatorParametersSet.MINMAX,
243+
RangeEstimatorParametersSet.HISTOGRAM,
244+
]:
245+
# default tensor_collector for per-tensor
246+
self.check_is_min_max_statistic_collector(tensor_collector)
247+
elif range_estimator_params == RangeEstimatorParametersSet.MEAN_MINMAX:
248+
self.check_is_mean_min_max_statistic_collector(tensor_collector)
249+
250+
def _get_stat_points(self, test_params, range_estimator_params, test_params_key):
200251
min_max_algo = MinMaxQuantization(activations_range_estimator_params=range_estimator_params)
201252
min_max_algo._backend_entity = self.get_algo_backend()
202253
assert min_max_algo._range_estimator_params[QuantizerGroup.ACTIVATIONS] == range_estimator_params
203254

204-
params = test_params["test_range_estimator_per_tensor"]
255+
params = test_params[test_params_key]
205256
stat_points = min_max_algo.get_statistic_points(params["model"], params["nncf_graph"])
206257
assert len(stat_points) == params["stat_points_num"]
207-
208258
for stat_point in stat_points.values():
209259
for stat_point_ in stat_point:
210260
for tensor_collector in stat_point_.algorithm_to_tensor_collectors[min_max_algo._algorithm_key]:
211-
if stat_point_.target_point.is_weight_target_point():
212-
# default tensor_collector for weights
213-
self.check_is_min_max_statistic_collector(tensor_collector)
214-
continue
215-
if range_estimator_params is None:
216-
# default tensor_collector for per-tensor
217-
self.check_is_mean_min_max_statistic_collector(tensor_collector)
218-
if range_estimator_params == RangeEstimatorParametersSet.MINMAX:
219-
self.check_is_min_max_statistic_collector(tensor_collector)
220-
elif range_estimator_params == RangeEstimatorParametersSet.MEAN_MINMAX:
221-
self.check_is_mean_min_max_statistic_collector(tensor_collector)
261+
yield stat_point_, tensor_collector
262+
263+
def test_unsupported_params(self, test_params):
264+
range_estimator_params = deepcopy(RangeEstimatorParametersSet.HISTOGRAM)
265+
range_estimator_params.min.aggregator_type = AggregatorType.MAX
266+
267+
params = test_params["test_range_estimator_per_tensor"]
268+
min_max_algo = MinMaxQuantization(activations_range_estimator_params=range_estimator_params)
269+
min_max_algo._backend_entity = self.get_algo_backend()
270+
with pytest.raises(nncf.ParameterNotSupportedError):
271+
min_max_algo.get_statistic_points(params["model"], params["nncf_graph"])
222272

223273
@pytest.mark.parametrize("quantize_outputs", [False, True])
224274
def test_quantize_outputs(self, test_params, quantize_outputs):

0 commit comments

Comments
 (0)