Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit ce8a677

Browse files
bfineranmarkurtzjeanniefinksnatuanmgoin
authored
[hotfix 0.10.1] transformers export and QAT flow fixes (#549)
* Update README.md for transformers to note the quantization conversion issue (#539) * Update README.md * Update integrations/huggingface-transformers/README.md Co-authored-by: Jeannie Finks <[email protected]> Co-authored-by: Jeannie Finks <[email protected]> * Enforce order on input keys to export (#545) * Enforce order on input keys to export * Warn if input dropped from onnx export * Restrict mistune version to fix docs build (#547) * quantization fixes for transformers flows (#548) * quantization fixes for transformers flows * match on class name instead * quality * set release branch version to 0.10.1 * Revert "Update README.md for transformers to note the quantization conversion issue (#539)" This reverts commit 9304997. Co-authored-by: Mark Kurtz <[email protected]> Co-authored-by: Jeannie Finks <[email protected]> Co-authored-by: Tuan Nguyen <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent b09c6d0 commit ce8a677

File tree

6 files changed

+91
-20
lines changed

6 files changed

+91
-20
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"flake8==3.9.2",
7171
"isort==5.8.0",
7272
"m2r2~=0.2.7",
73+
"mistune==0.8.4",
7374
"myst-parser~=0.14.0",
7475
"rinohtype~=0.4.2",
7576
"sphinx~=3.5.0",

src/sparseml/pytorch/optim/modifier_quantization.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class QuantizationModifier(ScheduledModifier):
100100
transformer based models such as BERT where the quantized MatMul outputs
101101
are kept at 32 bits of precision and fake quantizing the outputs harm training
102102
recovery. Default is True
103+
:param exclude_module_types: optional list of module class names
104+
to not propagate quantization configs to. Default is None
103105
"""
104106

105107
def __init__(
@@ -114,6 +116,7 @@ def __init__(
114116
quantize_embeddings: bool = True,
115117
reduce_range: bool = False,
116118
quantize_linear_activations: bool = True,
119+
exclude_module_types: Union[List[str], None] = None,
117120
):
118121
if torch_quantization is None or torch_intrinsic is None:
119122
raise RuntimeError(
@@ -138,6 +141,7 @@ def __init__(
138141
self._quantize_embeddings = quantize_embeddings
139142
self._reduce_range = reduce_range
140143
self._quantize_linear_activations = quantize_linear_activations
144+
self._exclude_module_types = exclude_module_types
141145

142146
self._modules_to_quantize = None
143147
self._qat_enabled = False
@@ -278,6 +282,14 @@ def quantize_linear_activations(self) -> bool:
278282
"""
279283
return self._quantize_linear_activations
280284

285+
@ModifierProp()
286+
def exclude_module_types(self) -> Union[List[str], None]:
287+
"""
288+
:return: optional list of module class names to not propagate
289+
quantization configs to. Default is None
290+
"""
291+
return self._exclude_module_types
292+
281293
def initialize(
282294
self,
283295
module: Module,
@@ -423,10 +435,15 @@ def _enable_module_qat(self, module: Module):
423435
if not self._quantize_linear_activations:
424436
remove_activation_qat_by_layer_name(quant_module, ["Linear"])
425437

438+
# remove qconfigs for module types in exclude_module_types
439+
if self._exclude_module_types:
440+
self._strip_excluded_module_qconfigs(module)
441+
426442
# set modules with proper qconfigs to QAT mode
427443
torch_quantization.prepare_qat(module, inplace=True)
428444
if self._quantize_embeddings:
429445
prepare_embeddings_qat(module, reduce_range=self._reduce_range)
446+
430447
self._qat_enabled = True
431448

432449
def _disable_quantization_observer_update_ready(self, epoch: float) -> bool:
@@ -443,6 +460,16 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:
443460
and not self._bn_stats_frozen
444461
)
445462

463+
def _strip_excluded_module_qconfigs(self, module: Module):
464+
if not self._exclude_module_types:
465+
return
466+
excluded_classes = set(self._exclude_module_types)
467+
for submodule in module.modules():
468+
if submodule.__class__.__name__ in excluded_classes and hasattr(
469+
submodule, "qconfig"
470+
):
471+
submodule.qconfig = None
472+
446473
def _validate_params(self):
447474
if (
448475
self._disable_quantization_observer_epoch is not None

src/sparseml/pytorch/utils/quantization/quantize_qat_export.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,11 @@ def _delete_repeated_qat_blocks(model: ModelProto):
264264
nodes_to_delete.append(dequant_node_1)
265265

266266
for n in nodes_to_delete:
267-
delete_quant_node(model, n)
267+
delete_quant_node(model, n, keep_params=True)
268+
269+
# cleanup graph
270+
graph.update()
271+
graph.delete_unused_initializers()
268272

269273

270274
def _attribute_to_kwarg(attribute: onnx.AttributeProto):
@@ -1214,12 +1218,14 @@ def _quantize_qat_embedding(model: ModelProto):
12141218
qdq_output = False
12151219

12161220
if qdq_output:
1221+
# forward gather output to dequant input
1222+
output_dequant_node.input[0] = gather_node.output[0]
1223+
output_dequant_node.input[1] = input_quant_node.input[1]
1224+
output_dequant_node.input[2] = input_quant_node.input[2]
12171225
# delete unnecessary quantize and dequantize ops
1218-
delete_quant_node(model, input_quant_node, keep_params=False)
1226+
delete_quant_node(model, input_quant_node, keep_params=True)
12191227
delete_quant_node(model, input_dequant_node, keep_params=False)
12201228
delete_quant_node(model, output_quant_node, keep_params=False)
1221-
# forward gather output to dequant input
1222-
output_dequant_node.input[0] = gather_node.output[0]
12231229

12241230
else:
12251231
# use input dequant to dequantize output
@@ -1265,7 +1271,10 @@ def _remove_duplicate_quantize_ops(model: ModelProto):
12651271
_replace_input_id_model(
12661272
model, remove_node.output[0], keep_node.output[0]
12671273
)
1268-
remove_node_and_params_from_graph(model, remove_node)
1274+
delete_quant_node(model, remove_node, keep_params=True)
1275+
# cleanup graph
1276+
graph.update()
1277+
graph.delete_unused_initializers()
12691278

12701279

12711280
def _cleanup_unused_quants(model: ModelProto):
@@ -1296,15 +1305,18 @@ def _cleanup_unused_quants(model: ModelProto):
12961305
continue
12971306

12981307
# Forward QuantizeLinear input to DequantizeLinear output
1299-
for child in dequant_children:
1300-
_replace_input_id_model(model, dequant_node.output[0], quant_node.input[0])
1308+
_replace_input_id_model(model, dequant_node.output[0], quant_node.input[0])
13011309

13021310
# Remove QuantizeLinear->DequantizeLinear block
13031311
nodes_to_delete.append(quant_node)
13041312
nodes_to_delete.append(dequant_node)
13051313

13061314
for n in nodes_to_delete:
1307-
delete_quant_node(model, n)
1315+
delete_quant_node(model, n, keep_params=True)
1316+
1317+
# update graph
1318+
graph.update()
1319+
graph.delete_unused_initializers()
13081320

13091321

13101322
def quantize_torch_qat_export(

src/sparseml/transformers/export.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
"""
5656

5757
import argparse
58+
import collections
59+
import inspect
5860
import logging
5961
import math
6062
import os
@@ -180,13 +182,32 @@ def export_transformer_to_onnx(
180182
inputs = tokenizer(
181183
"", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value
182184
).data # Dict[Tensor]
185+
186+
# Rearrange inputs' keys to match those defined by model foward func, which
187+
# seem to define how the order of inputs is determined in the exported model
188+
forward_args_spec = inspect.getfullargspec(model.__class__.forward)
189+
dropped = [f for f in inputs.keys() if f not in forward_args_spec.args]
190+
inputs = collections.OrderedDict(
191+
[
192+
(f, inputs[f][0].reshape(1, -1))
193+
for f in forward_args_spec.args
194+
if f in inputs
195+
]
196+
)
197+
if dropped:
198+
_LOGGER.warning(
199+
"The following inputs were not present in the model forward function "
200+
f"and therefore dropped from ONNX export: {dropped}"
201+
)
202+
183203
inputs_shapes = {
184204
key: (
185205
f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: "
186206
f"{list(val.shape) if hasattr(val, 'shape') else 'unknown'}"
187207
)
188208
for key, val in inputs.items()
189209
}
210+
190211
_LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}")
191212

192213
# run export

src/sparseml/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datetime import date
2020

2121

22-
version_base = "0.10.0"
22+
version_base = "0.10.1"
2323
is_release = False # change to True to set the generated version as a release version
2424

2525

tests/sparseml/pytorch/optim/test_modifier_quantization.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
start_epoch=0.0,
5555
quantize_linear_activations=False,
5656
),
57+
lambda: QuantizationModifier(
58+
start_epoch=0.0,
59+
exclude_module_types=["Linear"],
60+
),
5761
]
5862

5963

@@ -67,9 +71,13 @@ def _is_quantiable_module(module):
6771
return isinstance(module, (Conv2d, Linear))
6872

6973

70-
def _test_quantizable_module(
71-
module, qat_expected, reduce_range, quantize_linear_activations
72-
):
74+
def _test_quantizable_module(module, qat_expected, modifier):
75+
reduce_range = modifier.reduce_range
76+
quantize_linear_activations = modifier.quantize_linear_activations
77+
78+
excluded_types = modifier.exclude_module_types or []
79+
qat_expected = qat_expected and module.__class__.__name__ not in excluded_types
80+
7381
if qat_expected:
7482
assert hasattr(module, "qconfig") and module.qconfig is not None
7583
assert hasattr(module, "weight_fake_quant") and (
@@ -97,12 +105,7 @@ def _test_qat_applied(modifier, model):
97105
submodules = [""]
98106
for module in model.modules():
99107
if _is_quantiable_module(module):
100-
_test_quantizable_module(
101-
module,
102-
True,
103-
modifier.reduce_range,
104-
modifier.quantize_linear_activations,
105-
)
108+
_test_quantizable_module(module, True, modifier)
106109
else:
107110
assert not hasattr(model, "qconfig") or model.qconfig is None
108111
submodules = modifier.submodules
@@ -112,8 +115,7 @@ def _test_qat_applied(modifier, model):
112115
_test_quantizable_module(
113116
module,
114117
_is_valid_submodule(name, submodules),
115-
modifier.reduce_range,
116-
modifier.quantize_linear_activations,
118+
modifier,
117119
)
118120

119121

@@ -207,6 +209,7 @@ def test_quantization_modifier_yaml():
207209
quantize_embeddings = False
208210
reduce_range = True
209211
quantize_linear_activations = False
212+
exclude_module_types = ["LayerNorm", "Tanh"]
210213
yaml_str = f"""
211214
!QuantizationModifier
212215
start_epoch: {start_epoch}
@@ -217,6 +220,7 @@ def test_quantization_modifier_yaml():
217220
quantize_embeddings: {quantize_embeddings}
218221
reduce_range: {reduce_range}
219222
quantize_linear_activations: {quantize_linear_activations}
223+
exclude_module_types: {exclude_module_types}
220224
"""
221225
yaml_modifier = QuantizationModifier.load_obj(
222226
yaml_str
@@ -233,6 +237,7 @@ def test_quantization_modifier_yaml():
233237
quantize_embeddings=quantize_embeddings,
234238
reduce_range=reduce_range,
235239
quantize_linear_activations=quantize_linear_activations,
240+
exclude_module_types=exclude_module_types,
236241
)
237242

238243
assert isinstance(yaml_modifier, QuantizationModifier)
@@ -276,3 +281,8 @@ def test_quantization_modifier_yaml():
276281
== serialized_modifier.quantize_linear_activations
277282
== obj_modifier.quantize_linear_activations
278283
)
284+
assert (
285+
yaml_modifier.exclude_module_types
286+
== serialized_modifier.exclude_module_types
287+
== obj_modifier.exclude_module_types
288+
)

0 commit comments

Comments
 (0)