Skip to content

Commit 2c8d568

Browse files
authored
Fix channel axis init for collectors when a collection if required for the input of a node and previous node doesn't need collector (i.e. Dropout->Linear). (SonySemiconductorSolutions#1492)
1 parent 3fdc043 commit 2c8d568

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

model_compression_toolkit/core/common/model_collector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,21 @@ def create_stats_collector_for_node(node: common.BaseNode,
5757

5858

5959
def create_tensor2node(graph: common.Graph,
60-
node: common.BaseNode):
60+
node: common.BaseNode,
61+
next_node_output_channel_axis: int):
6162
"""
6263
Force statistic collector creation and assignment for a node.
6364
Args:
6465
graph: Graph of the node (for retrieving the current tensor).
6566
node: Node to create a tensor for.
67+
next_node_output_channel_axis: channel output axis of next node.
6668
6769
"""
6870
current_sc = graph.get_out_stats_collector(node)
6971
is_list_nostat_collectors = isinstance(current_sc, list) and len(
7072
[sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
7173
if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
72-
stats_collector = common.StatsCollector(node.out_channel_axis)
74+
stats_collector = common.StatsCollector(next_node_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis)
7375
graph.set_out_stats_collector_to_node(node, stats_collector)
7476

7577

@@ -175,7 +177,8 @@ def __init__(self, graph: Graph,
175177
for ie in graph.incoming_edges(n):
176178
input_node = ie.source_node
177179
create_tensor2node(graph,
178-
input_node)
180+
input_node,
181+
n.out_channel_axis)
179182
if sc is not None:
180183
graph.set_out_stats_collector_to_node(n, sc)
181184

tests_pytest/common_tests/unit_tests/test_model_collector.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,32 @@ def test_create_tensor2node_assigns_stats_collector(self):
132132
graph = Mock()
133133
node = Mock()
134134
node.type = DummyLayer
135+
node.out_channel_axis = 6
135136

136137
# Simulate absence of an output stats collector.
137138
graph.get_out_stats_collector = Mock(return_value=None)
138139

139-
create_tensor2node(graph, node)
140+
create_tensor2node(graph, node, 5)
140141

141142
# Verify that set_out_stats_collector_to_node was called with the node and a StatsCollector.
142143
graph.set_out_stats_collector_to_node.assert_called_once()
143144
args, _ = graph.set_out_stats_collector_to_node.call_args
144145
assigned_node, assigned_collector = args
145146
assert assigned_node is node
146147
assert isinstance(assigned_collector, StatsCollector)
148+
assert assigned_collector.mc.axis == 6
149+
assert assigned_collector.mpcc.axis == 6
150+
151+
node.out_channel_axis = None
152+
create_tensor2node(graph, node, 5)
153+
154+
# Verify that set_out_stats_collector_to_node was called with the node and a StatsCollector.
155+
args, _ = graph.set_out_stats_collector_to_node.call_args
156+
assigned_node, assigned_collector = args
157+
assert assigned_node is node
158+
assert isinstance(assigned_collector, StatsCollector)
159+
assert assigned_collector.mc.axis == 5
160+
assert assigned_collector.mpcc.axis == 5
147161

148162

149163
class TestModelCollectorInit:
@@ -226,8 +240,8 @@ def test_bias_correction_creates_tensor2node(self, monkeypatch, fw_impl_mock, pa
226240

227241
calls = []
228242
# Define a fake function to record call arguments for create_tensor2node.
229-
def fake_create_tensor2node(graph, node):
230-
calls.append((graph, node))
243+
def fake_create_tensor2node(graph, node, next_node_output_channel_axis):
244+
calls.append((graph, node, next_node_output_channel_axis))
231245

232246
# Patch create_tensor2node in the model_collector module.
233247
monkeypatch.setattr(model_collector, "create_tensor2node", fake_create_tensor2node)

0 commit comments

Comments
 (0)