Skip to content

Commit 9180223

Browse files
Legacy unused torch aggregator is removed
1 parent 7438f86 commit 9180223

File tree

5 files changed

+2
-235
lines changed

5 files changed

+2
-235
lines changed

src/nncf/common/tensor_statistics/reduction.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

src/nncf/torch/statistics/aggregator.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

src/nncf/torch/tensor_statistics/algo.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

tests/torch/function_hook/quantization/test_reducers_and_aggregators.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@
2828
from nncf.common.tensor_statistics.collectors import MeanVarianceReducer
2929
from nncf.common.tensor_statistics.collectors import MinReducer
3030
from nncf.common.tensor_statistics.collectors import QuantileReducer
31-
from nncf.common.tensor_statistics.collectors import TensorCollector
3231
from nncf.tensor import Tensor
3332
from nncf.tensor import functions as fns
34-
from nncf.torch.tensor_statistics.algo import create_register_input_hook
3533
from tests.common.test_reducers_and_aggregators import TemplateTestReducersAggregators
3634

3735

@@ -122,17 +120,3 @@ def test_median_function(use_cuda, size, ref):
122120
res = fns.median(tensor, axis=0)
123121
assert res.data == ref
124122
assert res.data.is_cuda == (device == "cuda")
125-
126-
127-
def test_create_register_input_hook_with_return_type(mocker):
128-
collector = TensorCollector()
129-
collector.register_input_for_all_reducers = mocker.MagicMock()
130-
hook = create_register_input_hook(collector)
131-
input_ = torch.return_types.max([torch.tensor((1,))] * 2)
132-
output_ = hook(input_)
133-
assert input_ is output_
134-
mocker = collector.register_input_for_all_reducers
135-
mocker.assert_called_once()
136-
attr = mocker.call_args_list[0][0][0]
137-
assert isinstance(attr, Tensor)
138-
assert attr.data == torch.tensor(1)

tests/torch/function_hook/quantization/test_statistics_caching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
from nncf.common.graph.transformations.commands import TargetType
1414
from nncf.tensor import Tensor
15+
from nncf.torch.function_hook.statistics.aggregator import PT2StatisticsAggregator
1516
from nncf.torch.graph.transformations.commands import PTTargetPoint
16-
from nncf.torch.statistics.aggregator import PTStatisticsAggregator
1717
from tests.cross_fw.test_templates.test_statistics_caching import TemplateTestStatisticsCaching
1818

1919

@@ -22,7 +22,7 @@ def create_target_point(self, target_point_type: TargetType, name: str, port_id:
2222
return PTTargetPoint(target_type=target_point_type, target_node_name=name, input_port_id=port_id)
2323

2424
def get_statistics_aggregator(self):
25-
return PTStatisticsAggregator(None)
25+
return PT2StatisticsAggregator(None)
2626

2727
def _create_dummy_min_max_tensor(self) -> Tensor:
2828
return Tensor(torch.zeros(3)), Tensor(torch.ones(3))

0 commit comments

Comments
 (0)