Skip to content

Commit 81ed69c

Browse files
committed
Extract node's channel axis info for statistics collectors from adjacent nodes.
1 parent afa5889 commit 81ed69c

3 files changed

Lines changed: 15 additions & 4 deletions

File tree

model_compression_toolkit/core/common/collectors/mean_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ def update(self,
8787
x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
8888
"""
8989
self.i += 1 # Update the iteration index
90-
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
91-
if axis is None:
90+
if self.axis is None:
9291
mu = np.mean(np.reshape(x, [1, -1]), axis=-1) # mean per channel for a batch
9392
else:
93+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
9494
n = x.shape[axis]
9595
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
9696
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch

model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ def update(self,
130130
x: Tensor that goes through the collector and needs to be considered in the min/max computation.
131131
"""
132132

133-
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
134-
if axis is None:
133+
if self.axis is None:
135134
x_reshape = np.reshape(x, [1, -1])
136135
else:
136+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
137137
n = x.shape[axis]
138138
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
139139
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])

model_compression_toolkit/core/common/model_collector.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,17 @@ def __init__(self, graph: Graph,
157157
for n in graph.get_topo_sorted_nodes():
158158
quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
159159
sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
160+
if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None):
161+
# Missing output channel axis info, so try to extract it from previous and next nodes output channel axis.
162+
possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)}
163+
# Filter out None values.
164+
possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set))
165+
if len(possible_output_channel_axis_list) > 0:
166+
if len(possible_output_channel_axis_list) > 1:
167+
Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.')
168+
sc.mc.axis = possible_output_channel_axis_list[0]
169+
sc.mpcc.axis = possible_output_channel_axis_list[0]
170+
160171
# If we use bias correction, and the node has kernel weights to quantize, we need to make sure
161172
# its previous nodes' tensors are consistent with this node.
162173
if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(

0 commit comments

Comments
 (0)