Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder, PytorchPreservingActivationQuantizationHolder
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder, PytorchPreservingActivationQuantizationHolder, PytorchFLNActivationQuantizationHolder


def _build_input_tensors_list(node: BaseNode,
Expand Down Expand Up @@ -347,6 +347,12 @@ def _add_modules(self, reused_nodes_only=False):
holder_type=PytorchPreservingActivationQuantizationHolder,
**holder_kwargs)

elif node.is_fln_quantization():
holder_kwargs = {'quantization_bypass': True}
activation_quantizer_holder = self.get_activation_quantizer_holder(node,
holder_type=PytorchFLNActivationQuantizationHolder,
**holder_kwargs)

if activation_quantizer_holder is not None:
activation_quantizer_holder_name = node.name + '_' + ACTIVATION_HOLDER_QUANTIZER
self.add_module(activation_quantizer_holder_name, activation_quantizer_holder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_inferable_quantizers(node: BaseNode,
weight_quantizer = get_weights_quantizer_for_node(node, attr)
weight_quantizers[attr] = weight_quantizer

if node.is_activation_quantization_enabled():
if node.is_activation_quantization_enabled() or node.is_fln_quantization():
num_of_outputs = len(node.output_shape) if isinstance(node.output_shape, list) else 1
activation_quantizers = [get_activations_quantizer_for_node(node)] * num_of_outputs

Expand Down
142 changes: 142 additions & 0 deletions tests_pytest/pytorch_tests/e2e_tests/test_fln_quantization_holder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import model_compression_toolkit as mct
import torch
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchFLNActivationQuantizationHolder

from tests_pytest._test_util.tpc_util import configure_mp_activation_opsets
from model_compression_toolkit.target_platform_capabilities.schema.v2 import QuantizationMethod, AttributeQuantizationConfig, \
OpQuantizationConfig, QuantizationConfigOptions, Signedness, OperatorSetNames, TargetPlatformCapabilities, Fusing, OperatorsSet
from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc


def build_tpc():
default_op_cfg = OpQuantizationConfig(
default_weight_attr_config=AttributeQuantizationConfig(),
attr_weights_configs_mapping={},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=8,
supported_input_activation_n_bits=[8],
enable_activation_quantization=True,
enable_weights_quantization=True,
quantization_preserving=False,
fixed_scale=None,
fixed_zero_point=None,
simd_size=32,
signedness=Signedness.AUTO
)

opsets, _ = configure_mp_activation_opsets(
opset_names=[OperatorSetNames.CONV,
OperatorSetNames.RELU,
OperatorSetNames.SIGMOID,
OperatorSetNames.FULLY_CONNECTED,
OperatorSetNames.HARDSWISH],
base_op_config=default_op_cfg,
a_nbits=[8]
)
default_cfg = QuantizationConfigOptions(quantization_configurations=[default_op_cfg])

test_qc = generate_test_op_qc(**generate_test_attr_configs(), activation_n_bits=16)

tpc = TargetPlatformCapabilities(
default_qco=default_cfg,
operator_set=opsets,
fusing_patterns=[
Fusing(operator_groups=(
OperatorsSet(name=OperatorSetNames.CONV),
OperatorsSet(name=OperatorSetNames.RELU)), fuse_op_quantization_config=test_qc),
Fusing(operator_groups=(
OperatorsSet(name=OperatorSetNames.CONV),
OperatorsSet(name=OperatorSetNames.SIGMOID))),
Fusing(operator_groups=(
OperatorsSet(name=OperatorSetNames.FULLY_CONNECTED),
OperatorsSet(name=OperatorSetNames.HARDSWISH)), fuse_op_quantization_config=test_qc),
]
)
return tpc

def representative_data_gen(shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1):
for _ in range(num_iter):
yield [torch.randn(batch_size, *shape)] * num_inputs

def get_float_model():
class BaseModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.relu = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.sigmoid = torch.nn.Sigmoid()
self.flatten = torch.nn.Flatten()
self.fc = torch.nn.Linear(in_features=48, out_features=10)
self.hswish = torch.nn.Hardswish()

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmoid(x)
x = self.flatten(x)
x = self.fc(x)
x = self.hswish(x)
return x
return BaseModel()

def test_fln_quantization_holder():

float_model = get_float_model()
tpc = build_tpc()

quantized_model, _ = mct.ptq.pytorch_post_training_quantization(
in_module=float_model,
representative_data_gen=representative_data_gen,
target_platform_capabilities=tpc
)

# check conv1
assert hasattr(quantized_model, 'conv1_activation_holder_quantizer')
conv1_activation_holder_quantizer = quantized_model.conv1_activation_holder_quantizer
assert isinstance(conv1_activation_holder_quantizer, PytorchFLNActivationQuantizationHolder)
assert conv1_activation_holder_quantizer.quantization_bypass == True
assert conv1_activation_holder_quantizer.activation_holder_quantizer.num_bits == 16

# check relu
assert hasattr(quantized_model, 'relu_activation_holder_quantizer')
relu_activation_holder_quantizer = quantized_model.relu_activation_holder_quantizer
assert isinstance(relu_activation_holder_quantizer, PytorchActivationQuantizationHolder)
assert relu_activation_holder_quantizer.activation_holder_quantizer.num_bits == 8

# check conv2
assert not hasattr(quantized_model, 'conv2_activation_holder_quantizer')

# check sigmoid
assert hasattr(quantized_model, 'sigmoid_activation_holder_quantizer')
sigmoid_activation_holder_quantizer = quantized_model.sigmoid_activation_holder_quantizer
assert isinstance(sigmoid_activation_holder_quantizer, PytorchActivationQuantizationHolder)
assert sigmoid_activation_holder_quantizer.activation_holder_quantizer.num_bits == 8

# check fc
assert hasattr(quantized_model, 'fc_activation_holder_quantizer')
fc_activation_holder_quantizer = quantized_model.fc_activation_holder_quantizer
assert isinstance(fc_activation_holder_quantizer, PytorchFLNActivationQuantizationHolder)
assert fc_activation_holder_quantizer.quantization_bypass == True
assert fc_activation_holder_quantizer.activation_holder_quantizer.num_bits == 16

# check hswish
assert hasattr(quantized_model, 'hswish_activation_holder_quantizer')
hswish_activation_holder_quantizer = quantized_model.hswish_activation_holder_quantizer
assert isinstance(hswish_activation_holder_quantizer, PytorchActivationQuantizationHolder)
assert hswish_activation_holder_quantizer.activation_holder_quantizer.num_bits == 8
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from unittest.mock import Mock
from typing import List
import torch
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_activation_quantizer_holder, fully_quantized_wrapper
from mct_quantizers import QuantizationMethod, PytorchActivationQuantizationHolder, PytorchFLNActivationQuantizationHolder

from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.edge import Edge
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig, NodeQuantizationConfig
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig, ActivationQuantizationMode
from model_compression_toolkit.target_platform_capabilities import AttributeQuantizationConfig, OpQuantizationConfig, Signedness
from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import FrameworkQuantizationCapabilities
from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo
from model_compression_toolkit.core.common.framework_info import set_fw_info
from tests_pytest._test_util.graph_builder_utils import DummyLayer
from tests_pytest._test_util.tpc_util import minimal_tpc


def build_node(name='node', framework_attr={}, layer_class=DummyLayer,
qcs: List[CandidateNodeQuantizationConfig] = None):
node = BaseNode(name=name,
framework_attr=framework_attr,
input_shape=(4, 5, 6),
output_shape=(4, 5, 6),
weights={},
layer_class=layer_class,
reuse=False)
if qcs:
assert isinstance(qcs, list)
node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=qcs[0], candidates_quantization_cfg=qcs)
return node

def build_qc(q_mode=ActivationQuantizationMode.QUANT):
op_cfg = OpQuantizationConfig(
default_weight_attr_config=AttributeQuantizationConfig(),
attr_weights_configs_mapping={},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=8,
enable_activation_quantization=True,
quantization_preserving=False,
supported_input_activation_n_bits=8,
signedness=Signedness.AUTO
)
a_qcfg = NodeActivationQuantizationConfig(op_cfg=op_cfg)
a_qcfg.set_qc(QuantizationConfig())
a_qcfg.quant_mode = q_mode
w_qcfg = NodeWeightsQuantizationConfig(op_cfg=op_cfg,
weights_channels_axis=ChannelAxisMapping(0, 1),
node_attrs_list=['weight', 'bias'])
qc = CandidateNodeQuantizationConfig(activation_quantization_cfg=a_qcfg,
weights_quantization_cfg=w_qcfg)
return qc

def get_test_graph():

set_fw_info(PyTorchInfo)

conv1 = build_node('conv1', framework_attr={'in_channels':3, 'out_channels':3, 'kernel_size':3},
layer_class=torch.nn.Conv2d, qcs=[build_qc(q_mode=ActivationQuantizationMode.FLN_QUANT)])
relu = build_node('relu', layer_class=torch.nn.ReLU, qcs=[build_qc()])
conv2 = build_node('conv2', framework_attr={'in_channels':3, 'out_channels':3, 'kernel_size':3},
layer_class=torch.nn.Conv2d, qcs=[build_qc(q_mode=ActivationQuantizationMode.FLN_NO_QUANT)])
sigmoid = build_node('sigmoid', layer_class=torch.nn.Sigmoid, qcs=[build_qc()])
flatten = build_node('flatten', layer_class=torch.nn.Flatten,
qcs=[build_qc(q_mode=ActivationQuantizationMode.PRESERVE_QUANT)])
fc = build_node('fc', framework_attr={'in_features':48, 'out_features':10},
layer_class=torch.nn.Linear, qcs=[build_qc(q_mode=ActivationQuantizationMode.FLN_QUANT)])
hswish = build_node('hswish', layer_class=torch.nn.Hardswish, qcs=[build_qc()])

graph = Graph('g', input_nodes=[conv1],
nodes=[relu, conv2, sigmoid, flatten, fc],
output_nodes=[hswish],
edge_list=[Edge(conv1, relu, 0, 0),
Edge(relu, conv2, 0, 0),
Edge(conv2, sigmoid, 0, 0),
Edge(sigmoid, flatten, 0, 0),
Edge(flatten, fc, 0, 0),
Edge(fc, hswish, 0, 0),
]
)
fqc = FrameworkQuantizationCapabilities(tpc=minimal_tpc(), name="test")
graph.set_fqc(fqc)

return graph

def get_inferable_quantizers_mock(node):

if node.name == 'conv2' or node.name == 'relu':
activation_quantizers = Mock()
activation_quantizers.num_bits = 8
activation_quantizers.signed = False
activation_quantizers.threshold_np = 8.0

elif node.name == 'conv1' or node.name == 'fc':
activation_quantizers = Mock()
activation_quantizers.num_bits = 16
activation_quantizers.signed = True
activation_quantizers.threshold_np = 16.0

elif node.name == 'sigmoid' or node.name == 'hswish':
activation_quantizers = Mock()
activation_quantizers.num_bits = 4
activation_quantizers.signed = False
activation_quantizers.threshold_np = 4.0
else:
return {}, []

return {}, [activation_quantizers]


class TestPyTorchModelBuilder():

def test_pytorch_model(self, fw_impl_mock):
graph = get_test_graph()
fw_impl_mock.get_inferable_quantizers.side_effect = lambda node: get_inferable_quantizers_mock(node)
exportable_model, _ = PyTorchModelBuilder(graph=graph,
wrapper=lambda n, m:
fully_quantized_wrapper(n, m,
fw_impl=fw_impl_mock),
get_activation_quantizer_holder_fn=lambda n, holder_type, **kwargs:
get_activation_quantizer_holder(n, holder_type,
fw_impl=fw_impl_mock, **kwargs)).build_model()

# check conv1
assert hasattr(exportable_model, 'conv1_activation_holder_quantizer')
conv1_activation_holder_quantizer = exportable_model.conv1_activation_holder_quantizer
assert isinstance(conv1_activation_holder_quantizer, PytorchFLNActivationQuantizationHolder)
assert conv1_activation_holder_quantizer.quantization_bypass == True
assert conv1_activation_holder_quantizer.activation_holder_quantizer.num_bits == 16
assert conv1_activation_holder_quantizer.activation_holder_quantizer.signed == True
assert conv1_activation_holder_quantizer.activation_holder_quantizer.threshold_np == 16.0

# check relu
assert hasattr(exportable_model, 'relu_activation_holder_quantizer')
relu_activation_holder_quantizer = exportable_model.relu_activation_holder_quantizer
assert isinstance(relu_activation_holder_quantizer, PytorchActivationQuantizationHolder)
assert relu_activation_holder_quantizer.activation_holder_quantizer.num_bits == 8
assert relu_activation_holder_quantizer.activation_holder_quantizer.signed == False
assert relu_activation_holder_quantizer.activation_holder_quantizer.threshold_np == 8.0

# check conv2 (FLN_NO_QUANT)
assert not hasattr(exportable_model, 'conv2_activation_holder_quantizer')

# check sigmoid
assert hasattr(exportable_model, 'sigmoid_activation_holder_quantizer')
sigmoid_activation_holder_quantizer = exportable_model.sigmoid_activation_holder_quantizer
assert isinstance(sigmoid_activation_holder_quantizer, PytorchActivationQuantizationHolder)
assert sigmoid_activation_holder_quantizer.activation_holder_quantizer.num_bits == 4
assert sigmoid_activation_holder_quantizer.activation_holder_quantizer.signed == False
assert sigmoid_activation_holder_quantizer.activation_holder_quantizer.threshold_np == 4.0

# check fc
assert hasattr(exportable_model, 'fc_activation_holder_quantizer')
fc_activation_holder_quantizer = exportable_model.fc_activation_holder_quantizer
assert isinstance(fc_activation_holder_quantizer, PytorchFLNActivationQuantizationHolder)
assert fc_activation_holder_quantizer.quantization_bypass == True
assert fc_activation_holder_quantizer.activation_holder_quantizer.num_bits == 16
assert fc_activation_holder_quantizer.activation_holder_quantizer.signed == True
assert fc_activation_holder_quantizer.activation_holder_quantizer.threshold_np == 16.0

# check hswish
assert hasattr(exportable_model, 'hswish_activation_holder_quantizer')
hswish_activation_holder_quantizer = exportable_model.hswish_activation_holder_quantizer
assert isinstance(hswish_activation_holder_quantizer, PytorchActivationQuantizationHolder)
assert hswish_activation_holder_quantizer.activation_holder_quantizer.num_bits == 4
assert hswish_activation_holder_quantizer.activation_holder_quantizer.signed == False
assert hswish_activation_holder_quantizer.activation_holder_quantizer.threshold_np == 4.0
Loading
Loading