Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions model_compression_toolkit/core/common/model_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,21 @@ def create_stats_collector_for_node(node: common.BaseNode,


def create_tensor2node(graph: common.Graph,
node: common.BaseNode):
node: common.BaseNode,
next_node_output_channel_axis: int):
"""
Force statistic collector creation and assignment for a node.
Args:
graph: Graph of the node (for retrieving the current tensor).
node: Node to create a tensor for.
next_node_output_channel_axis: channel output axis of next node.

"""
current_sc = graph.get_out_stats_collector(node)
is_list_nostat_collectors = isinstance(current_sc, list) and len(
[sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
stats_collector = common.StatsCollector(node.out_channel_axis)
stats_collector = common.StatsCollector(next_node_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis)
graph.set_out_stats_collector_to_node(node, stats_collector)


Expand Down Expand Up @@ -175,7 +177,8 @@ def __init__(self, graph: Graph,
for ie in graph.incoming_edges(n):
input_node = ie.source_node
create_tensor2node(graph,
input_node)
input_node,
n.out_channel_axis)
if sc is not None:
graph.set_out_stats_collector_to_node(n, sc)

Expand Down
20 changes: 17 additions & 3 deletions tests_pytest/common_tests/unit_tests/test_model_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,32 @@ def test_create_tensor2node_assigns_stats_collector(self):
graph = Mock()
node = Mock()
node.type = DummyLayer
node.out_channel_axis = 6

# Simulate absence of an output stats collector.
graph.get_out_stats_collector = Mock(return_value=None)

create_tensor2node(graph, node)
create_tensor2node(graph, node, 5)

# Verify that set_out_stats_collector_to_node was called with the node and a StatsCollector.
graph.set_out_stats_collector_to_node.assert_called_once()
args, _ = graph.set_out_stats_collector_to_node.call_args
assigned_node, assigned_collector = args
assert assigned_node is node
assert isinstance(assigned_collector, StatsCollector)
assert assigned_collector.mc.axis == 6
assert assigned_collector.mpcc.axis == 6

node.out_channel_axis = None
create_tensor2node(graph, node, 5)

# Verify that set_out_stats_collector_to_node was called with the node and a StatsCollector.
args, _ = graph.set_out_stats_collector_to_node.call_args
assigned_node, assigned_collector = args
assert assigned_node is node
assert isinstance(assigned_collector, StatsCollector)
assert assigned_collector.mc.axis == 5
assert assigned_collector.mpcc.axis == 5


class TestModelCollectorInit:
Expand Down Expand Up @@ -226,8 +240,8 @@ def test_bias_correction_creates_tensor2node(self, monkeypatch, fw_impl_mock, pa

calls = []
# Define a fake function to record call arguments for create_tensor2node.
def fake_create_tensor2node(graph, node):
calls.append((graph, node))
def fake_create_tensor2node(graph, node, next_node_output_channel_axis):
calls.append((graph, node, next_node_output_channel_axis))

# Patch create_tensor2node in the model_collector module.
monkeypatch.setattr(model_collector, "create_tensor2node", fake_create_tensor2node)
Expand Down