diff --git a/model_compression_toolkit/core/common/collectors/base_collector.py b/model_compression_toolkit/core/common/collectors/base_collector.py index b4081e0e6..4854134b7 100644 --- a/model_compression_toolkit/core/common/collectors/base_collector.py +++ b/model_compression_toolkit/core/common/collectors/base_collector.py @@ -13,11 +13,12 @@ # limitations under the License. # ============================================================================== +from abc import ABC, abstractmethod import numpy as np from model_compression_toolkit.logger import Logger -class BaseCollector(object): +class BaseCollector(ABC): """ Base class for statistics collection object. """ @@ -26,6 +27,7 @@ def __init__(self): # When manipulation statistics in a granularity they were not collected by, the data is invalid. self.is_legal = True + @abstractmethod def scale(self, scale_factor: np.ndarray): """ Scale all statistics in collector by some factor. @@ -37,6 +39,7 @@ def scale(self, scale_factor: np.ndarray): raise NotImplemented( f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover + @abstractmethod def shift(self, shift_value: np.ndarray): """ Shift all statistics in collector by some value. diff --git a/model_compression_toolkit/core/common/collectors/mean_collector.py b/model_compression_toolkit/core/common/collectors/mean_collector.py index abb84cb72..de2505d88 100644 --- a/model_compression_toolkit/core/common/collectors/mean_collector.py +++ b/model_compression_toolkit/core/common/collectors/mean_collector.py @@ -87,10 +87,13 @@ def update(self, x: Tensor that goes through the mean collector and needs to be considered in the mean computation. """ self.i += 1 # Update the iteration index - axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis - n = x.shape[axis] - transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]] - mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch + if self.axis is None: + mu = np.mean(np.reshape(x, [1, -1]), axis=-1) # mean per channel for a batch + else: + axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis + n = x.shape[axis] + transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]] + mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch self.current_sum += mu # sum of all batches self.current_mean = self.current_sum / self.i # mean of all batches diff --git a/model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py b/model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py index c935499a7..b0f73a4db 100644 --- a/model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +++ b/model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py @@ -130,10 +130,13 @@ def update(self, x: Tensor that goes through the collector and needs to be considered in the min/max computation. """ - axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis - n = x.shape[axis] - transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]] - x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1]) + if self.axis is None: + x_reshape = np.reshape(x, [1, -1]) + else: + axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis + n = x.shape[axis] + transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]] + x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1]) if self.state is None: x_max = np.max(x_reshape, axis=-1) x_min = np.min(x_reshape, axis=-1) diff --git a/model_compression_toolkit/core/common/model_collector.py b/model_compression_toolkit/core/common/model_collector.py index b3a677318..b734cc57a 100644 --- a/model_compression_toolkit/core/common/model_collector.py +++ b/model_compression_toolkit/core/common/model_collector.py @@ -157,6 +157,17 @@ def __init__(self, graph: Graph, for n in graph.get_topo_sorted_nodes(): quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n) sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node + if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None): + # Missing output channel axis info, so try to extract it from previous and next nodes output channel axis. + possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)} + # Filter out None values. + possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set)) + if len(possible_output_channel_axis_list) > 0: + if len(possible_output_channel_axis_list) > 1: + Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.') + sc.mc.axis = possible_output_channel_axis_list[0] + sc.mpcc.axis = possible_output_channel_axis_list[0] + # If we use bias correction, and the node has kernel weights to quantize, we need to make sure # its previous nodes' tensors are consistent with this node. if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled( diff --git a/model_compression_toolkit/core/common/model_validation.py b/model_compression_toolkit/core/common/model_validation.py deleted file mode 100644 index 9343171b9..000000000 --- a/model_compression_toolkit/core/common/model_validation.py +++ /dev/null @@ -1,41 +0,0 @@ -from abc import abstractmethod -from typing import Any - -from model_compression_toolkit.core import FrameworkInfo - - -class ModelValidation: - """ - Class to define validation methods in order to validate the received model to quantize. - """ - - def __init__(self, - model: Any): - """ - Initialize a ModelValidation object. - - Args: - model: Model to check its validity. - """ - self.model = model - - @abstractmethod - def validate_output_channel_consistency(self): - """ - - Validate that output channels index in all layers of the model are the same. - If the model has layers with different output channels index, it should throw an exception. - - """ - raise NotImplemented( - f'Framework validation class did not implement validate_output_channel_consistency') # pragma: no cover - - def validate(self): - """ - - Run all validation methods before the quantization process starts. - - """ - self.validate_output_channel_consistency() - - diff --git a/model_compression_toolkit/core/common/pruning/memory_calculator.py b/model_compression_toolkit/core/common/pruning/memory_calculator.py index ae6d7b217..180283953 100644 --- a/model_compression_toolkit/core/common/pruning/memory_calculator.py +++ b/model_compression_toolkit/core/common/pruning/memory_calculator.py @@ -303,7 +303,7 @@ def get_pruned_node_num_params(self, num_oc = np.sum(output_mask) else: # Get the node channel axis from framework info - channel_axis = node.out_channel_axis + channel_axis = self.fw_impl.default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis if channel_axis is None: Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.") diff --git a/model_compression_toolkit/core/keras/default_framework_info.py b/model_compression_toolkit/core/keras/default_framework_info.py index 22815b382..d1c41aaf7 100644 --- a/model_compression_toolkit/core/keras/default_framework_info.py +++ b/model_compression_toolkit/core/keras/default_framework_info.py @@ -143,7 +143,7 @@ def get_out_channel_axis(cls, node_type: Any): Node's output channel axis. """ - return cls.out_channel_axis_mapping.get(node_type, -1) + return cls.out_channel_axis_mapping.get(node_type) def set_keras_info(func): diff --git a/model_compression_toolkit/core/keras/keras_model_validation.py b/model_compression_toolkit/core/keras/keras_model_validation.py deleted file mode 100644 index 3b541736a..000000000 --- a/model_compression_toolkit/core/keras/keras_model_validation.py +++ /dev/null @@ -1,37 +0,0 @@ -from tensorflow.keras.models import Model - -from model_compression_toolkit.core.common.framework_info import get_fw_info -from model_compression_toolkit.core.common.framework_info import ChannelAxis -from model_compression_toolkit.core.common.model_validation import ModelValidation -from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST - - -class KerasModelValidation(ModelValidation): - """ - Class to define validation methods in order to validate the received Keras model to quantize. - """ - - def __init__(self, model: Model): - """ - Initialize a KerasModelValidation object. - - Args: - model: Keras model to check its validity. - """ - - super(KerasModelValidation, self).__init__(model=model) - - def validate_output_channel_consistency(self): - """ - - Validate that output channels index in all layers of the model are the same. - If the model has layers with different output channels index, an exception is thrown. - - """ - fw_info = get_fw_info() - for layer in self.model.layers: - data_format = layer.get_config().get(CHANNELS_FORMAT) - if data_format is not None: - assert (data_format == CHANNELS_FORMAT_LAST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NHWC.value - or data_format == CHANNELS_FORMAT_FIRST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NCHW.value), \ - f'Model can not have layers with different data formats.' diff --git a/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py b/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py index cacd18971..3d632a744 100644 --- a/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +++ b/model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py @@ -28,6 +28,10 @@ from model_compression_toolkit.logger import Logger +# default output channel axis to use when it's not defined in node's fw_info. +_default_output_channel_axis = -1 + + class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation): """ Implementation of the PruningFramework for the Keras framework. This class provides @@ -172,6 +176,10 @@ def attrs_oi_channels_info_for_pruning(self, return attributes_with_axis + @property + def default_output_channel_axis(self): + return _default_output_channel_axis + def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool: """ diff --git a/model_compression_toolkit/core/pytorch/default_framework_info.py b/model_compression_toolkit/core/pytorch/default_framework_info.py index 3d47e50b6..ae1d76e80 100644 --- a/model_compression_toolkit/core/pytorch/default_framework_info.py +++ b/model_compression_toolkit/core/pytorch/default_framework_info.py @@ -101,7 +101,7 @@ def get_out_channel_axis(cls, node_type: Any): Node's output channel axis. """ - return cls.out_channel_axis_mapping.get(node_type, 1) + return cls.out_channel_axis_mapping.get(node_type) def set_pytorch_info(func): diff --git a/model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py index 78ff90882..928da5137 100644 --- a/model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py @@ -29,6 +29,10 @@ from model_compression_toolkit.logger import Logger +# default output channel axis to use when it's not defined in node's fw_info. +_default_output_channel_axis = 1 + + class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplementation): """ Implementation of the PruningFramework for the Pytorch framework. This class provides @@ -190,6 +194,10 @@ def attrs_oi_channels_info_for_pruning(self, return attributes_with_axis + @property + def default_output_channel_axis(self): + return _default_output_channel_axis + def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool: """ @@ -283,7 +291,7 @@ def _edit_node_input_shape(node: BaseNode, # Adjust the last dimension of the shape to match the number of unpruned (retained) channels. # This is done by summing the mask, as each '1' in the mask represents a retained channel. - channel_axis = node.out_channel_axis + channel_axis = _default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis new_input_shape[0][channel_axis] = int(np.sum(input_mask)) # Update the node's input shape with the new dimensions. diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index 6258c5fd0..574146db2 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -43,7 +43,6 @@ import tensorflow as tf from model_compression_toolkit.core.keras.default_framework_info import set_keras_info from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation - from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation from tensorflow.keras.models import Model from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL @@ -235,8 +234,6 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da if core_config.debug_config.bypass: return in_model, None - KerasModelValidation(model=in_model).validate() - if core_config.is_mixed_precision_enabled: if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig): Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. " diff --git a/model_compression_toolkit/ptq/keras/quantization_facade.py b/model_compression_toolkit/ptq/keras/quantization_facade.py index c4648c01b..db1abd283 100644 --- a/model_compression_toolkit/ptq/keras/quantization_facade.py +++ b/model_compression_toolkit/ptq/keras/quantization_facade.py @@ -38,7 +38,6 @@ AttachTpcToKeras from model_compression_toolkit.core.keras.default_framework_info import set_keras_info from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation - from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation from tensorflow.keras.models import Model from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model @@ -129,8 +128,6 @@ def keras_post_training_quantization(in_model: Model, if core_config.debug_config.bypass: return in_model, None - KerasModelValidation(model=in_model).validate() - if core_config.is_mixed_precision_enabled: if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig): Logger.critical("Given quantization config to mixed-precision facade is not of type " diff --git a/model_compression_toolkit/qat/keras/quantization_facade.py b/model_compression_toolkit/qat/keras/quantization_facade.py index 71c8d1bd3..deb76423e 100644 --- a/model_compression_toolkit/qat/keras/quantization_facade.py +++ b/model_compression_toolkit/qat/keras/quantization_facade.py @@ -38,7 +38,6 @@ from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation - from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL from model_compression_toolkit.core.keras.default_framework_info import set_keras_info @@ -175,8 +174,6 @@ def keras_quantization_aware_training_init_experimental(in_model: Model, f"If you encounter an issue, please open an issue in our GitHub " f"project https://github.com/sony/model_optimization") - KerasModelValidation(model=in_model).validate() - if core_config.is_mixed_precision_enabled: if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig): Logger.critical("Given quantization config to mixed-precision facade is not of type " diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/second_moment_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/second_moment_correction_test.py index b3114fdd7..ff68918e4 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/second_moment_correction_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/second_moment_correction_test.py @@ -33,7 +33,6 @@ from model_compression_toolkit.core.keras.constants import EPSILON_VAL, GAMMA, BETA, MOVING_MEAN, MOVING_VARIANCE from model_compression_toolkit.core.keras.default_framework_info import KerasInfo from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation -from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation from model_compression_toolkit.core.keras.statistics_correction.apply_second_moment_correction import \ keras_apply_second_moment_correction from model_compression_toolkit.core.runner import core_runner @@ -269,8 +268,6 @@ def prepare_graph(self, target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> \ Tuple[Graph, Graph]: - KerasModelValidation(model=in_model).validate() - core_config = CoreConfig(quantization_config=quant_config, debug_config=DebugConfig(analyze_similarity=analyze_similarity, network_editor=network_editor)