Skip to content

Commit 47bf2dd

Browse files
authored
Remove default output axis values from fw_info. (#1489)
* Remove default output axis values from fw_info. * Handle missing defaults: 1. Collectors for nodes that don't have a known output channel axis: Extract node's channel axis info for statistics collectors from adjacent nodes, or collect for the whole tensor. 2. Pruning retain defaults that are now defined in the PruningImplementation.
1 parent 64e3adb commit 47bf2dd

15 files changed

Lines changed: 49 additions & 103 deletions

File tree

model_compression_toolkit/core/common/collectors/base_collector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
from abc import ABC, abstractmethod
1617
import numpy as np
1718
from model_compression_toolkit.logger import Logger
1819

1920

20-
class BaseCollector(object):
21+
class BaseCollector(ABC):
2122
"""
2223
Base class for statistics collection object.
2324
"""
@@ -26,6 +27,7 @@ def __init__(self):
2627
# When manipulation statistics in a granularity they were not collected by, the data is invalid.
2728
self.is_legal = True
2829

30+
@abstractmethod
2931
def scale(self, scale_factor: np.ndarray):
3032
"""
3133
Scale all statistics in collector by some factor.
@@ -37,6 +39,7 @@ def scale(self, scale_factor: np.ndarray):
3739
raise NotImplemented(
3840
f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover
3941

42+
@abstractmethod
4043
def shift(self, shift_value: np.ndarray):
4144
"""
4245
Shift all statistics in collector by some value.

model_compression_toolkit/core/common/collectors/mean_collector.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,13 @@ def update(self,
8787
x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
8888
"""
8989
self.i += 1 # Update the iteration index
90-
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
91-
n = x.shape[axis]
92-
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
93-
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
90+
if self.axis is None:
91+
mu = np.mean(np.reshape(x, [1, -1]), axis=-1) # mean per channel for a batch
92+
else:
93+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
94+
n = x.shape[axis]
95+
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
96+
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
9497
self.current_sum += mu # sum of all batches
9598
self.current_mean = self.current_sum / self.i # mean of all batches
9699

model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,13 @@ def update(self,
130130
x: Tensor that goes through the collector and needs to be considered in the min/max computation.
131131
"""
132132

133-
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
134-
n = x.shape[axis]
135-
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
136-
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
133+
if self.axis is None:
134+
x_reshape = np.reshape(x, [1, -1])
135+
else:
136+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
137+
n = x.shape[axis]
138+
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
139+
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
137140
if self.state is None:
138141
x_max = np.max(x_reshape, axis=-1)
139142
x_min = np.min(x_reshape, axis=-1)

model_compression_toolkit/core/common/model_collector.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,17 @@ def __init__(self, graph: Graph,
157157
for n in graph.get_topo_sorted_nodes():
158158
quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
159159
sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
160+
if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None):
161+
# Missing output channel axis info, so try to extract it from previous and next nodes output channel axis.
162+
possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)}
163+
# Filter out None values.
164+
possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set))
165+
if len(possible_output_channel_axis_list) > 0:
166+
if len(possible_output_channel_axis_list) > 1:
167+
Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.')
168+
sc.mc.axis = possible_output_channel_axis_list[0]
169+
sc.mpcc.axis = possible_output_channel_axis_list[0]
170+
160171
# If we use bias correction, and the node has kernel weights to quantize, we need to make sure
161172
# its previous nodes' tensors are consistent with this node.
162173
if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(

model_compression_toolkit/core/common/model_validation.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

model_compression_toolkit/core/common/pruning/memory_calculator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def get_pruned_node_num_params(self,
303303
num_oc = np.sum(output_mask)
304304
else:
305305
# Get the node channel axis from framework info
306-
channel_axis = node.out_channel_axis
306+
channel_axis = self.fw_impl.default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
307307
if channel_axis is None:
308308
Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
309309

model_compression_toolkit/core/keras/default_framework_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def get_out_channel_axis(cls, node_type: Any):
143143
Node's output channel axis.
144144
145145
"""
146-
return cls.out_channel_axis_mapping.get(node_type, -1)
146+
return cls.out_channel_axis_mapping.get(node_type)
147147

148148

149149
def set_keras_info(func):

model_compression_toolkit/core/keras/keras_model_validation.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from model_compression_toolkit.logger import Logger
2929

3030

31+
# default output channel axis to use when it's not defined in node's fw_info.
32+
_default_output_channel_axis = -1
33+
34+
3135
class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation):
3236
"""
3337
Implementation of the PruningFramework for the Keras framework. This class provides
@@ -172,6 +176,10 @@ def attrs_oi_channels_info_for_pruning(self,
172176

173177
return attributes_with_axis
174178

179+
@property
180+
def default_output_channel_axis(self):
181+
return _default_output_channel_axis
182+
175183

176184
def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
177185
"""

model_compression_toolkit/core/pytorch/default_framework_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def get_out_channel_axis(cls, node_type: Any):
101101
Node's output channel axis.
102102
103103
"""
104-
return cls.out_channel_axis_mapping.get(node_type, 1)
104+
return cls.out_channel_axis_mapping.get(node_type)
105105

106106

107107
def set_pytorch_info(func):

0 commit comments

Comments
 (0)