Skip to content

Commit 808f692

Browse files
committed
set reuse in exporter
1 parent 2a2310d commit 808f692

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def export(self) -> None:
7373
"""
7474
for layer in self.model.children():
7575
self.is_layer_exportable_fn(layer)
76+
# Set reuse for weight quantizers if quantizer is reused
77+
if isinstance(layer, PytorchQuantizationWrapper):
78+
for _, quantizer in layer.weights_quantizers.items():
79+
if quantizer.reuse:
80+
quantizer.enable_reuse_quantizer()
7681

7782
# Set forward that is used during onnx export.
7883
# If _use_onnx_custom_quantizer_ops is set to True, the quantizer forward function will use
@@ -116,6 +121,13 @@ def export(self) -> None:
116121
dynamic_axes={'input': {0: 'batch_size'},
117122
'output': {0: 'batch_size'}})
118123

124+
for layer in self.model.children():
125+
# Set disable for reuse for weight quantizers if quantizer was reused
126+
if isinstance(layer, PytorchQuantizationWrapper):
127+
for _, quantizer in layer.weights_quantizers.items():
128+
if quantizer.reuse:
129+
quantizer.disable_reuse_quantizer()
130+
119131
def _enable_onnx_custom_ops_export(self):
120132
"""
121133
Enable the custom implementation forward in quantizers, so it is exported

model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def fully_quantized_wrapper(node: common.BaseNode,
4848
# Set reuse for weight quantizers if node is reused
4949
for _, quantizer in weight_quantizers.items():
5050
if node.reuse_group:
51-
quantizer.enable_reuse_quantizer()
51+
quantizer.reuse = True
5252
# for positional weights we need to extract the weight's value.
5353
weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
5454
for attr in weight_quantizers if isinstance(attr, int)}

0 commit comments

Comments
 (0)