Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
# ==============================================================================

from abc import ABC, abstractmethod
import numpy as np
from model_compression_toolkit.logger import Logger


class BaseCollector(object):
class BaseCollector(ABC):
"""
Base class for statistics collection object.
"""
Expand All @@ -26,6 +27,7 @@ def __init__(self):
# When manipulation statistics in a granularity they were not collected by, the data is invalid.
self.is_legal = True

@abstractmethod
def scale(self, scale_factor: np.ndarray):
"""
Scale all statistics in collector by some factor.
Expand All @@ -37,6 +39,7 @@ def scale(self, scale_factor: np.ndarray):
raise NotImplemented(
f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover

@abstractmethod
def shift(self, shift_value: np.ndarray):
"""
Shift all statistics in collector by some value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def update(self,
x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
"""
self.i += 1 # Update the iteration index
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS or self.axis is None else self.axis
n = x.shape[axis]
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def update(self,
x: Tensor that goes through the collector and needs to be considered in the min/max computation.
"""

axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS or self.axis is None else self.axis
n = x.shape[axis]
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
Expand Down
41 changes: 0 additions & 41 deletions model_compression_toolkit/core/common/model_validation.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_out_channel_axis(cls, node_type: Any):
Node's output channel axis.

"""
return cls.out_channel_axis_mapping.get(node_type, -1)
return cls.out_channel_axis_mapping.get(node_type)


def set_keras_info(func):
Expand Down
37 changes: 0 additions & 37 deletions model_compression_toolkit/core/keras/keras_model_validation.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_out_channel_axis(cls, node_type: Any):
Node's output channel axis.

"""
return cls.out_channel_axis_mapping.get(node_type, 1)
return cls.out_channel_axis_mapping.get(node_type)


def set_pytorch_info(func):
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import tensorflow as tf
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
from tensorflow.keras.models import Model
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
Expand Down Expand Up @@ -235,8 +234,6 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
if core_config.debug_config.bypass:
return in_model, None

KerasModelValidation(model=in_model).validate()

if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/ptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
AttachTpcToKeras
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
from tensorflow.keras.models import Model
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
Expand Down Expand Up @@ -129,8 +128,6 @@ def keras_post_training_quantization(in_model: Model,
if core_config.debug_config.bypass:
return in_model, None

KerasModelValidation(model=in_model).validate()

if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config to mixed-precision facade is not of type "
Expand Down
3 changes: 0 additions & 3 deletions model_compression_toolkit/qat/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info

Expand Down Expand Up @@ -175,8 +174,6 @@ def keras_quantization_aware_training_init_experimental(in_model: Model,
f"If you encounter an issue, please open an issue in our GitHub "
f"project https://github.com/sony/model_optimization")

KerasModelValidation(model=in_model).validate()

if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config to mixed-precision facade is not of type "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from model_compression_toolkit.core.keras.constants import EPSILON_VAL, GAMMA, BETA, MOVING_MEAN, MOVING_VARIANCE
from model_compression_toolkit.core.keras.default_framework_info import KerasInfo
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
from model_compression_toolkit.core.keras.statistics_correction.apply_second_moment_correction import \
keras_apply_second_moment_correction
from model_compression_toolkit.core.runner import core_runner
Expand Down Expand Up @@ -269,8 +268,6 @@ def prepare_graph(self,
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> \
Tuple[Graph, Graph]:

KerasModelValidation(model=in_model).validate()

core_config = CoreConfig(quantization_config=quant_config,
debug_config=DebugConfig(analyze_similarity=analyze_similarity,
network_editor=network_editor)
Expand Down
Loading