Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
# ==============================================================================

from abc import ABC, abstractmethod
import numpy as np
from model_compression_toolkit.logger import Logger


class BaseCollector(object):
class BaseCollector(ABC):
"""
Base class for statistics collection object.
"""
Expand All @@ -26,6 +27,7 @@ def __init__(self):
# When manipulation statistics in a granularity they were not collected by, the data is invalid.
self.is_legal = True

@abstractmethod
def scale(self, scale_factor: np.ndarray):
"""
Scale all statistics in collector by some factor.
Expand All @@ -37,6 +39,7 @@ def scale(self, scale_factor: np.ndarray):
raise NotImplemented(
f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover

@abstractmethod
def shift(self, shift_value: np.ndarray):
"""
Shift all statistics in collector by some value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,13 @@ def update(self,
x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
"""
self.i += 1 # Update the iteration index
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
n = x.shape[axis]
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
if self.axis is None:
mu = np.mean(np.reshape(x, [1, -1]), axis=-1) # mean per channel for a batch
else:
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
n = x.shape[axis]
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
self.current_sum += mu # sum of all batches
self.current_mean = self.current_sum / self.i # mean of all batches

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,13 @@ def update(self,
x: Tensor that goes through the collector and needs to be considered in the min/max computation.
"""

axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
n = x.shape[axis]
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
if self.axis is None:
x_reshape = np.reshape(x, [1, -1])
else:
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
n = x.shape[axis]
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
if self.state is None:
x_max = np.max(x_reshape, axis=-1)
x_min = np.min(x_reshape, axis=-1)
Expand Down
11 changes: 11 additions & 0 deletions model_compression_toolkit/core/common/model_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,17 @@ def __init__(self, graph: Graph,
for n in graph.get_topo_sorted_nodes():
quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None):
# Missing output channel axis info, so try to extract it from previous and next nodes output channel axis.
possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)}
# Filter out None values.
possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set))
if len(possible_output_channel_axis_list) > 0:
if len(possible_output_channel_axis_list) > 1:
Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.')
sc.mc.axis = possible_output_channel_axis_list[0]
sc.mpcc.axis = possible_output_channel_axis_list[0]

# If we use bias correction, and the node has kernel weights to quantize, we need to make sure
# its previous nodes' tensors are consistent with this node.
if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(
Expand Down
41 changes: 0 additions & 41 deletions model_compression_toolkit/core/common/model_validation.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def get_pruned_node_num_params(self,
num_oc = np.sum(output_mask)
else:
# Get the node channel axis from framework info
channel_axis = node.out_channel_axis
channel_axis = self.fw_impl.default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
if channel_axis is None:
Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_out_channel_axis(cls, node_type: Any):
Node's output channel axis.

"""
return cls.out_channel_axis_mapping.get(node_type, -1)
return cls.out_channel_axis_mapping.get(node_type)


def set_keras_info(func):
Expand Down
37 changes: 0 additions & 37 deletions model_compression_toolkit/core/keras/keras_model_validation.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from model_compression_toolkit.logger import Logger


# default output channel axis to use when it's not defined in node's fw_info.
_default_output_channel_axis = -1


class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation):
"""
Implementation of the PruningFramework for the Keras framework. This class provides
Expand Down Expand Up @@ -172,6 +176,10 @@ def attrs_oi_channels_info_for_pruning(self,

return attributes_with_axis

@property
def default_output_channel_axis(self):
return _default_output_channel_axis


def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_out_channel_axis(cls, node_type: Any):
Node's output channel axis.

"""
return cls.out_channel_axis_mapping.get(node_type, 1)
return cls.out_channel_axis_mapping.get(node_type)


def set_pytorch_info(func):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from model_compression_toolkit.logger import Logger


# default output channel axis to use when it's not defined in node's fw_info.
_default_output_channel_axis = 1


class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplementation):
"""
Implementation of the PruningFramework for the Pytorch framework. This class provides
Expand Down Expand Up @@ -190,6 +194,10 @@ def attrs_oi_channels_info_for_pruning(self,

return attributes_with_axis

@property
def default_output_channel_axis(self):
return _default_output_channel_axis


def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
"""
Expand Down Expand Up @@ -283,7 +291,7 @@ def _edit_node_input_shape(node: BaseNode,

# Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
# This is done by summing the mask, as each '1' in the mask represents a retained channel.
channel_axis = node.out_channel_axis
channel_axis = _default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
new_input_shape[0][channel_axis] = int(np.sum(input_mask))

# Update the node's input shape with the new dimensions.
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import tensorflow as tf
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
from tensorflow.keras.models import Model
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
Expand Down Expand Up @@ -235,8 +234,6 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
if core_config.debug_config.bypass:
return in_model, None

KerasModelValidation(model=in_model).validate()

if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/ptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
AttachTpcToKeras
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
from tensorflow.keras.models import Model
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
Expand Down Expand Up @@ -129,8 +128,6 @@ def keras_post_training_quantization(in_model: Model,
if core_config.debug_config.bypass:
return in_model, None

KerasModelValidation(model=in_model).validate()

if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config to mixed-precision facade is not of type "
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/qat/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info

Expand Down Expand Up @@ -175,8 +174,6 @@ def keras_quantization_aware_training_init_experimental(in_model: Model,
f"If you encounter an issue, please open an issue in our GitHub "
f"project https://github.com/sony/model_optimization")

KerasModelValidation(model=in_model).validate()

if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config to mixed-precision facade is not of type "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from model_compression_toolkit.core.keras.constants import EPSILON_VAL, GAMMA, BETA, MOVING_MEAN, MOVING_VARIANCE
from model_compression_toolkit.core.keras.default_framework_info import KerasInfo
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
from model_compression_toolkit.core.keras.statistics_correction.apply_second_moment_correction import \
keras_apply_second_moment_correction
from model_compression_toolkit.core.runner import core_runner
Expand Down Expand Up @@ -269,8 +268,6 @@ def prepare_graph(self,
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> \
Tuple[Graph, Graph]:

KerasModelValidation(model=in_model).validate()

core_config = CoreConfig(quantization_config=quant_config,
debug_config=DebugConfig(analyze_similarity=analyze_similarity,
network_editor=network_editor)
Expand Down