Skip to content

Commit 629f837

Browse files
committed
Merge branch 'main' into apply_quant_info_to_fusinginfo
2 parents 36d4bd2 + 5bfd07d commit 629f837

File tree

14 files changed

+452
-31
lines changed

14 files changed

+452
-31
lines changed

model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,15 @@ def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) ->
143143
'max_candidate_idx': max_candidate_idx
144144
}
145145

146-
def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQuantizationHolder:
146+
def mixed_precision_activation_holder(self, n: BaseNode, holder_type: PytorchActivationQuantizationHolder = PytorchActivationQuantizationHolder) -> PytorchActivationQuantizationHolder:
147147
"""
148148
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization for a node.
149149
The layer should hold either a configurable activation quantizer, if it is quantized with mixed precision,
150150
or an inferable quantizer for fixed single bit-width quantization.
151151
152152
Args:
153153
n: Node to get PytorchActivationQuantizationHolder to attach in its output.
154+
holder_type: The type of the activation quantization holder to use.
154155
155156
Returns:
156157
A PytorchActivationQuantizationHolder layer for the node activation quantization.
@@ -192,7 +193,7 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQua
192193
# thus we make sure this is the only possible case (unless it's a node with no activation
193194
# quantization, which in this case has an empty list).
194195
if len(activation_quantizers) == 1:
195-
return PytorchActivationQuantizationHolder(activation_quantizers[0])
196+
return holder_type(activation_quantizers[0])
196197

197198
Logger.critical(f"PytorchActivationQuantizationHolder expects a single quantizer, but ({len(activation_quantizers)}) quantizers were found for node {n}.")# pragma: no cover
198199

model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@
3434
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
3535
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
3636
from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
37-
from mct_quantizers import PytorchQuantizationWrapper
37+
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder, PytorchPreservingActivationQuantizationHolder
3838

3939

4040
def _build_input_tensors_list(node: BaseNode,
@@ -332,13 +332,21 @@ def _add_modules(self, reused_nodes_only=False):
332332
else:
333333
self.add_module(node.name, node_op)
334334

335-
# Add activation quantization modules if an activation holder is configured for this node
336-
if node.is_activation_quantization_enabled() and self.get_activation_quantizer_holder is not None:
337-
activation_quantizer_holder = self.get_activation_quantizer_holder(node)
338-
if activation_quantizer_holder is not None:
339-
self.add_module(node.name + '_' + ACTIVATION_HOLDER_QUANTIZER, activation_quantizer_holder)
340-
self.node_to_activation_quantization_holder.update(
341-
{node.name: node.name + '_' + ACTIVATION_HOLDER_QUANTIZER})
335+
activation_quantizer_holder = None
336+
if self.use_activation_holder_during_model_building:
337+
if node.is_activation_quantization_enabled():
338+
activation_quantizer_holder = self.get_activation_quantizer_holder(node, holder_type=PytorchActivationQuantizationHolder)
339+
340+
elif node.is_quantization_preserving():
341+
prev_node = self.graph.retrieve_preserved_quantization_node(node)
342+
if prev_node.is_activation_quantization_enabled():
343+
activation_quantizer_holder = self.get_activation_quantizer_holder(prev_node, holder_type=PytorchPreservingActivationQuantizationHolder)
344+
345+
if activation_quantizer_holder is not None:
346+
activation_quantizer_holder_name = node.name + '_' + ACTIVATION_HOLDER_QUANTIZER
347+
self.add_module(activation_quantizer_holder_name, activation_quantizer_holder)
348+
self.node_to_activation_quantization_holder.update(
349+
{node.name: activation_quantizer_holder_name})
342350

343351
def forward(self,
344352
*args: Any) -> Any:

model_compression_toolkit/core/pytorch/reader/reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from model_compression_toolkit.core.common import Graph
2424
from model_compression_toolkit.core.pytorch.reader.graph_builders import edges_builder, nodes_builder
2525
from model_compression_toolkit.core.pytorch.utils import set_model
26-
from sony_custom_layers.pytorch import CustomLayer
26+
from edgemdt_cl.pytorch import CustomLayer
2727

2828

2929
def _trace_model(root: Union[torch.nn.Module, Callable[..., Any]]) -> GraphModule:

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ def export(self, output_names=None) -> None:
100100
model_output = self.model(*model_input) if isinstance(model_input, (list, tuple)) else self.model(
101101
model_input)
102102

103-
input_nodes = [n for n in self.model.node_sort if n.type == DummyPlaceHolder]
104-
input_names = [f"input_{i}" for i in range(len(input_nodes))] if len(input_nodes) > 1 else ["input"]
103+
input_names = [f"input_{i}" for i in range(len(model_input))] if len(model_input) > 1 else ["input"]
105104
dynamic_axes = {name: {0: 'batch_size'} for name in input_names}
106105
if output_names is None:
107106
# Determine number of outputs and prepare output_names and dynamic_axes

model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@
2323

2424
if FOUND_TORCH:
2525
import torch
26-
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
26+
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder, PytorchPreservingActivationQuantizationHolder
2727
from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS
2828
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
2929
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
@@ -65,22 +65,26 @@ def fully_quantized_wrapper(node: common.BaseNode,
6565
return module
6666

6767

68-
def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable:
68+
def get_activation_quantizer_holder(node: BaseNode, holder_type: PytorchActivationQuantizationHolder, fw_impl) -> Callable:
6969
"""
7070
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
7171
If the layer is not supposed to be wrapped with an activation quantizer - return None.
7272
Args:
7373
node: Node to attach a PytorchActivationQuantizationHolder to its output.
74+
holder_type: The type of the activation quantization holder to use.
7475
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
7576
Returns:
7677
A PytorchActivationQuantizationHolder module for the node's activation quantization.
7778
"""
78-
_, activation_quantizers = fw_impl.get_inferable_quantizers(node)
7979
# Holder by definition uses a single quantizer for the activation quantization
8080
# thus we make sure this is the only possible case (unless it's a node we no activation
8181
# quantization, which in this case has an empty list).
82+
_, activation_quantizers = fw_impl.get_inferable_quantizers(node)
8283
if len(activation_quantizers) == 1:
83-
return PytorchActivationQuantizationHolder(activation_quantizers[0])
84+
if holder_type == PytorchActivationQuantizationHolder:
85+
return holder_type(activation_quantizers[0])
86+
elif holder_type == PytorchPreservingActivationQuantizationHolder:
87+
return holder_type(activation_quantizers[0], quantization_bypass=True)
8488
Logger.critical(
8589
f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
8690
f'were found for node {node}')
@@ -96,13 +100,14 @@ def get_exportable_pytorch_model(graph: Graph):
96100
Returns:
97101
Fully quantized PyTorch model.
98102
"""
103+
fw_impl = C.pytorch.pytorch_implementation.PytorchImplementation()
99104
exportable_model, user_info = PyTorchModelBuilder(graph=graph,
100105
wrapper=lambda n, m:
101106
fully_quantized_wrapper(n, m,
102-
fw_impl=C.pytorch.pytorch_implementation.PytorchImplementation()),
103-
get_activation_quantizer_holder_fn=lambda n:
104-
get_activation_quantizer_holder(n,
105-
fw_impl=C.pytorch.pytorch_implementation.PytorchImplementation())).build_model()
107+
fw_impl=fw_impl),
108+
get_activation_quantizer_holder_fn=lambda n, holder_type:
109+
get_activation_quantizer_holder(n, holder_type,
110+
fw_impl=fw_impl)).build_model()
106111

107112
Logger.info("\nPlease run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
108113
"Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"

model_compression_toolkit/gptq/pytorch/gptq_training.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,12 @@ def gptq_wrapper(self,
197197
# quantized, do we need to wrap them as well?
198198
return layer
199199

200-
def get_activation_quantizer_holder(self, n: BaseNode) -> Callable:
200+
def get_activation_quantizer_holder(self, n: BaseNode, holder_type: PytorchActivationQuantizationHolder = PytorchActivationQuantizationHolder) -> Callable:
201201
"""
202202
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
203203
Args:
204204
n: Node to attach a PytorchActivationQuantizationHolder to its output.
205+
holder_type: The type of the activation quantization holder to use.
205206
Returns:
206207
A PytorchActivationQuantizationHolder module for the node's activation quantization.
207208
"""
@@ -213,7 +214,7 @@ def get_activation_quantizer_holder(self, n: BaseNode) -> Callable:
213214
f"but {len(activation_quantizers)} were found for node {n.name}. "
214215
f"Ensure the node is configured with a single activation quantizer.")
215216
quantizer = self.gradual_act_quantizer_wrapper_factory(activation_quantizers[0])
216-
return PytorchActivationQuantizationHolder(quantizer)
217+
return holder_type(quantizer)
217218

218219
def build_gptq_model(self):
219220
"""

model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030

3131

3232
def get_activation_quantizer_holder(n: common.BaseNode,
33-
qat_config: QATConfig) -> Callable:
33+
qat_config: QATConfig, holder_type: PytorchActivationQuantizationHolder = PytorchActivationQuantizationHolder) -> Callable:
3434
"""
3535
Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
3636
If the layer is not supposed to be wrapped with activation quantizers - return None.
3737
3838
Args:
3939
n: Node for which to retrieve anActivationQuantizationHolder to attach to its output.
4040
qat_config: QAT configuration (for example, training methods).
41+
holder_type: The type of the activation quantization holder to use.
4142
4243
Returns:
4344
A ActivationQuantizationHolder layer for the node's activation quantization.
@@ -49,7 +50,7 @@ def get_activation_quantizer_holder(n: common.BaseNode,
4950
# thus we make sure this is the only possible case (unless it's a node with no activation
5051
# quantization, which in this case has an empty list).
5152
if len(activation_quantizers) == 1:
52-
return PytorchActivationQuantizationHolder(activation_quantizers[0])
53+
return holder_type(activation_quantizers[0])
5354
Logger.critical(f'ActivationQuantizationHolder supports only a single quantizer, but ({len(activation_quantizers)}) quantizers were found for node {n}.')
5455

5556

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
2121
AttachTpcToFramework
2222

23-
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
23+
from edgemdt_cl.keras.object_detection.ssd_post_process import SSDPostProcess
2424

2525
if version.parse(tf.__version__) >= version.parse("2.13"):
2626
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
3333
AttachTpcToFramework
3434
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
35-
from sony_custom_layers.pytorch import MulticlassNMS, MulticlassNMSWithIndices
35+
from edgemdt_cl.pytorch import MulticlassNMS, MulticlassNMSWithIndices
3636

3737

3838
class AttachTpcToPytorch(AttachTpcToFramework):

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ scipy
1111
protobuf
1212
mct-quantizers-nightly
1313
pydantic>=2.0
14-
sony-custom-layers-dev==0.4.0.dev6
14+
edge-mdt-cl-dev
1515

0 commit comments

Comments
 (0)