diff --git a/model_compression_toolkit/core/common/model_collector.py b/model_compression_toolkit/core/common/model_collector.py index b734cc57a..95f841abe 100644 --- a/model_compression_toolkit/core/common/model_collector.py +++ b/model_compression_toolkit/core/common/model_collector.py @@ -57,19 +57,21 @@ def create_stats_collector_for_node(node: common.BaseNode, def create_tensor2node(graph: common.Graph, - node: common.BaseNode): + node: common.BaseNode, + next_node_output_channel_axis: int): """ Force statistic collector creation and assignment for a node. Args: graph: Graph of the node (for retrieving the current tensor). node: Node to create a tensor for. + next_node_output_channel_axis: channel output axis of next node. """ current_sc = graph.get_out_stats_collector(node) is_list_nostat_collectors = isinstance(current_sc, list) and len( [sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0 if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors: - stats_collector = common.StatsCollector(node.out_channel_axis) + stats_collector = common.StatsCollector(next_node_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis) graph.set_out_stats_collector_to_node(node, stats_collector) @@ -175,7 +177,8 @@ def __init__(self, graph: Graph, for ie in graph.incoming_edges(n): input_node = ie.source_node create_tensor2node(graph, - input_node) + input_node, + n.out_channel_axis) if sc is not None: graph.set_out_stats_collector_to_node(n, sc) diff --git a/tests_pytest/common_tests/unit_tests/test_model_collector.py b/tests_pytest/common_tests/unit_tests/test_model_collector.py index e6a393e95..20ad64076 100644 --- a/tests_pytest/common_tests/unit_tests/test_model_collector.py +++ b/tests_pytest/common_tests/unit_tests/test_model_collector.py @@ -132,11 +132,12 @@ def test_create_tensor2node_assigns_stats_collector(self): graph = Mock() node = Mock() node.type = DummyLayer + node.out_channel_axis = 6 # Simulate absence of an output stats collector. graph.get_out_stats_collector = Mock(return_value=None) - create_tensor2node(graph, node) + create_tensor2node(graph, node, 5) # Verify that set_out_stats_collector_to_node was called with the node and a StatsCollector. graph.set_out_stats_collector_to_node.assert_called_once() @@ -144,6 +145,19 @@ def test_create_tensor2node_assigns_stats_collector(self): assigned_node, assigned_collector = args assert assigned_node is node assert isinstance(assigned_collector, StatsCollector) + assert assigned_collector.mc.axis == 6 + assert assigned_collector.mpcc.axis == 6 + + node.out_channel_axis = None + create_tensor2node(graph, node, 5) + + # Verify that set_out_stats_collector_to_node was called with the node and a StatsCollector. + args, _ = graph.set_out_stats_collector_to_node.call_args + assigned_node, assigned_collector = args + assert assigned_node is node + assert isinstance(assigned_collector, StatsCollector) + assert assigned_collector.mc.axis == 5 + assert assigned_collector.mpcc.axis == 5 class TestModelCollectorInit: @@ -226,8 +240,8 @@ def test_bias_correction_creates_tensor2node(self, monkeypatch, fw_impl_mock, pa calls = [] # Define a fake function to record call arguments for create_tensor2node. - def fake_create_tensor2node(graph, node): - calls.append((graph, node)) + def fake_create_tensor2node(graph, node, next_node_output_channel_axis): + calls.append((graph, node, next_node_output_channel_axis)) # Patch create_tensor2node in the model_collector module. monkeypatch.setattr(model_collector, "create_tensor2node", fake_create_tensor2node)