Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 47d9472

Browse files
committed
QAT and quant postprocessing for torch.nn.Embedding (#374)
* QAT and quant postprocessing for torch.nn.Embedding * cleanup * residual optim and logging fixes * response to comments
1 parent a1fda05 commit 47d9472

File tree

6 files changed

+222
-24
lines changed

6 files changed

+222
-24
lines changed

src/sparseml/onnx/utils/graph_optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,12 @@ def quantize_resnet_identity_add_inputs(quantized_model: onnx.ModelProto) -> boo
165165
quantize_node = get_quantize_parent_for_dequantize_node(
166166
quantized_model, dequantize_node
167167
)
168+
168169
# check that the quantize block takes input from the same relu
169-
if quantize_node.input[0] != other_input_node.output[0]:
170+
if (
171+
quantize_node is None
172+
or quantize_node.input[0] != other_input_node.output[0]
173+
):
170174
continue
171175

172176
# create de-quantize node for identity

src/sparseml/pytorch/optim/modifier_quantization.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
configure_module_qat_wrappers,
4242
fuse_module_conv_bn_relus,
4343
get_qat_qconfig,
44+
prepare_embeddings_qat,
4445
)
4546

4647

@@ -80,6 +81,10 @@ class QuantizationModifier(ScheduledModifier):
8081
exception. For compatibility with YAML serialization only.
8182
:param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed
8283
to the model fusing function
84+
:param quantize_embeddings: if True, will perform QAT on torch.nn.Embedding layers
85+
using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
86+
quantize embedding weights. Default is True. Models without embedding layers
87+
will be unaffected
8388
"""
8489

8590
def __init__(
@@ -91,6 +96,7 @@ def __init__(
9196
freeze_bn_stats_epoch: Union[float, None] = None,
9297
end_epoch: float = -1,
9398
model_fuse_fn_kwargs: Dict[str, Any] = None,
99+
quantize_embeddings: bool = True,
94100
):
95101
if torch_quantization is None or torch_intrinsic is None:
96102
raise RuntimeError(
@@ -112,6 +118,7 @@ def __init__(
112118
self._model_fuse_fn_kwargs = model_fuse_fn_kwargs or {}
113119
self._disable_quantization_observer_epoch = disable_quantization_observer_epoch
114120
self._freeze_bn_stats_epoch = freeze_bn_stats_epoch
121+
self._quantize_embeddings = quantize_embeddings
115122

116123
self._modules_to_quantize = None
117124
self._qat_enabled = False
@@ -140,7 +147,7 @@ def submodules(self) -> Union[List[str], None]:
140147
def submodules(self, value: Union[List[str], None]):
141148
"""
142149
:params value: List of submodule names to perform QAT on. Set None to quantize
143-
entire model
150+
entire model
144151
"""
145152
self._submodules = value
146153
if isinstance(self._submodules, list):
@@ -151,18 +158,18 @@ def submodules(self, value: Union[List[str], None]):
151158
def model_fuse_fn_name(self) -> Union[str, None]:
152159
"""
153160
:return: Name of model function to fuse the model in place prior
154-
to performing QAT. None to uses the default function
155-
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
161+
to performing QAT. None to uses the default function
162+
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
156163
"""
157164
return self._model_fuse_fn_name
158165

159166
@model_fuse_fn_name.setter
160167
def model_fuse_fn_name(self, value: Union[str, None]):
161168
"""
162169
:params value: Name of model function to fuse the model in place prior
163-
to performing QAT. Set None to use the default function
164-
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
165-
to skip module fusing.
170+
to performing QAT. Set None to use the default function
171+
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
172+
to skip module fusing.
166173
"""
167174
self._model_fuse_fn_name = value
168175
if (
@@ -176,17 +183,17 @@ def model_fuse_fn_name(self, value: Union[str, None]):
176183
def disable_quantization_observer_epoch(self) -> Union[float, None]:
177184
"""
178185
:return: Epoch to disable updates to the module's
179-
quantization observers. After this point, quantized weights and zero points will
180-
not be updated. When None, observers never disabled during QAT
186+
quantization observers. After this point, quantized weights and zero points
187+
will not be updated. When None, observers never disabled during QAT
181188
"""
182189
return self._disable_quantization_observer_epoch
183190

184191
@disable_quantization_observer_epoch.setter
185192
def disable_quantization_observer_epoch(self, value: Union[float, None]):
186193
"""
187194
:params value: Epoch to disable updates to the module's
188-
quantization observers. After this point, quantized weights and zero points will
189-
not be updated. Set None to not disable observers during QAT
195+
quantization observers. After this point, quantized weights and zero points
196+
will not be updated. Set None to not disable observers during QAT
190197
"""
191198
self._disable_quantization_observer_epoch = value
192199
self._validate_params()
@@ -195,19 +202,37 @@ def disable_quantization_observer_epoch(self, value: Union[float, None]):
195202
def freeze_bn_stats_epoch(self) -> Union[float, None]:
196203
"""
197204
:return: Epoch to stop the tracking of batch norm stats. When
198-
None, batch norm stats are track for all of training
205+
None, batch norm stats are track for all of training
199206
"""
200207
return self._freeze_bn_stats_epoch
201208

202209
@freeze_bn_stats_epoch.setter
203210
def freeze_bn_stats_epoch(self, value: Union[float, None]):
204211
"""
205212
:params value: Epoch to stop the tracking of batch norm stats. Set
206-
None to not stop tracking batch norm stats during QAT
213+
None to not stop tracking batch norm stats during QAT
207214
"""
208215
self._freeze_bn_stats_epoch = value
209216
self._validate_params()
210217

218+
@ModifierProp()
219+
def quantize_embeddings(self) -> bool:
220+
"""
221+
:return: if True, will perform QAT on torch.nn.Embedding layers
222+
using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
223+
quantize embedding weights
224+
"""
225+
return self._freeze_bn_stats_epoch
226+
227+
@quantize_embeddings.setter
228+
def quantize_embeddings(self, value: bool):
229+
"""
230+
:params value: if True, will perform QAT on torch.nn.Embedding layers
231+
using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
232+
quantize embedding weights
233+
"""
234+
self._quantize_embeddings = value
235+
211236
def initialize(
212237
self,
213238
module: Module,
@@ -350,6 +375,8 @@ def _enable_module_qat(self, module: Module):
350375
add_quant_dequant(quant_module)
351376
# set model to QAT mode
352377
torch_quantization.prepare_qat(quant_module, inplace=True)
378+
if self._quantize_embeddings:
379+
prepare_embeddings_qat(quant_module)
353380
self._qat_enabled = True
354381

355382
def _disable_quantization_observer_update_ready(self, epoch: float) -> bool:

src/sparseml/pytorch/utils/quantization/helpers.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Callable, List, Union
2121

2222
import torch
23-
from torch.nn import BatchNorm2d, Conv2d, Module, ReLU
23+
from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU
2424

2525

2626
try:
@@ -40,6 +40,7 @@
4040
"add_quant_dequant",
4141
"get_qat_qconfig",
4242
"fuse_module_conv_bn_relus",
43+
"prepare_embeddings_qat",
4344
]
4445

4546

@@ -318,11 +319,15 @@ def add_quant_dequant(module):
318319

319320
def get_qat_qconfig(
320321
symmetric_activations: bool = False,
322+
symmetric_weights: bool = True,
321323
) -> "torch.quantization.QConfig":
322324
"""
323325
:param symmetric_activations: if True, activations will have a symmetric
324-
quantization range with zero point set to 128. Otherwise activations
326+
UINT8 quantization range with zero point set to 128. Otherwise activations
325327
will use asymmetric quantization with any zero point. Default is False
328+
:param symmetric_weights: if True, weights will have a symmetric
329+
INT8 quantization range with zero point set to 0. Otherwise activations
330+
will use asymmetric quantization with any zero point. Default is True
326331
:return: A QAT fake quantization config for symmetric weight quantization and
327332
asymmetric activation quantization. The difference between this and
328333
torch.quantization.default_qat_qconfig is that the activation observer
@@ -339,7 +344,17 @@ def get_qat_qconfig(
339344
qscheme=activation_qscheme,
340345
reduce_range=False,
341346
)
342-
weight_observer = torch_quantization.default_weight_fake_quant
347+
weight_qscheme = (
348+
torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine
349+
)
350+
weight_observer = torch_quantization.FakeQuantize.with_args(
351+
observer=torch_quantization.MovingAverageMinMaxObserver,
352+
quant_min=-128,
353+
quant_max=127,
354+
dtype=torch.qint8,
355+
qscheme=weight_qscheme,
356+
reduce_range=False,
357+
)
343358
return torch_quantization.QConfig(
344359
activation=activation_observer,
345360
weight=weight_observer,
@@ -423,6 +438,44 @@ def fuse_module_conv_bn_relus(
423438
return module
424439

425440

441+
def prepare_embeddings_qat(
442+
module: Module,
443+
qconfig: "torch.quantization.QConfig" = None,
444+
):
445+
"""
446+
adds a fake quantize call to the weights of any Embedding modules in the given
447+
module
448+
449+
:param module: module to run QAT for the embeddings of
450+
:param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8
451+
asymmetric range
452+
"""
453+
if qconfig is None:
454+
qconfig = get_qat_qconfig(symmetric_weights=False)
455+
for submodule in module.modules():
456+
if type(submodule) is Embedding:
457+
_prepare_qat_embedding(submodule, qconfig)
458+
459+
460+
def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"):
461+
embedding.weight_fake_quant = qconfig.weight()
462+
463+
def _qat_forward(self, input: torch.Tensor) -> torch.Tensor:
464+
return torch.nn.functional.embedding(
465+
input,
466+
self.weight_fake_quant(self.weight),
467+
self.padding_idx,
468+
self.max_norm,
469+
self.norm_type,
470+
self.scale_grad_by_freq,
471+
self.sparse,
472+
)
473+
474+
# bind qat forward to embedding
475+
qat_forward_bound = _qat_forward.__get__(embedding, embedding.__class__)
476+
setattr(embedding, "forward", qat_forward_bound)
477+
478+
426479
def _set_submodule(root_module, sub_module_path, sub_module):
427480
current_module = root_module
428481
sub_module_path = sub_module_path.split(".")

src/sparseml/pytorch/utils/quantization/quantize_qat_export.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,100 @@ def _convert_quantizable_ops(model: ModelProto):
909909
)
910910

911911

912+
def _quantize_qat_embedding(model: ModelProto):
913+
"""
914+
A pass for quantizing qat embeddings
915+
916+
Starting with:
917+
| INPUT QuantizeLinear (with constant embedding)
918+
| | |
919+
| | DequantizeLinear
920+
| | |
921+
| Gather
922+
| |
923+
| QuantizeLinear
924+
| |
925+
| DequantizeLinear
926+
| |
927+
| OUTPUT
928+
929+
Converts to:
930+
| INPUT
931+
| |
932+
| Gather(UINT8 data initializer)
933+
| |
934+
| DequantizeLinear
935+
| |
936+
| OUTPUT
937+
"""
938+
graph = ONNXGraph(model)
939+
gather_nodes = [node for node in model.graph.node if node.op_type == "Gather"]
940+
941+
converted_nodes = 0
942+
for gather_node in gather_nodes:
943+
# find input quant and dequant nodes
944+
input_dequant_node = graph.get_node_single_parent(gather_node, 0)
945+
if not input_dequant_node or input_dequant_node.op_type != "DequantizeLinear":
946+
continue
947+
input_quant_node = graph.get_node_single_parent(input_dequant_node, 0)
948+
if not input_quant_node or input_quant_node.op_type != "QuantizeLinear":
949+
continue
950+
# find embedding weights, sclae, and zero point
951+
embedding_initializer = graph.get_init_by_name(input_quant_node.input[0])
952+
scale_initializer = graph.get_init_by_name(input_quant_node.input[1])
953+
zp_initializer = graph.get_init_by_name(input_quant_node.input[2])
954+
if not embedding_initializer or not scale_initializer or not zp_initializer:
955+
continue
956+
957+
# quantize embedding
958+
embedding = numpy_helper.to_array(embedding_initializer)
959+
scale = numpy_helper.to_array(scale_initializer)
960+
zero_point = numpy_helper.to_array(zp_initializer)
961+
embedding_quant = _quantize_array(embedding, scale, zero_point)
962+
embedding_quant_initializer = numpy_helper.from_array(
963+
embedding_quant, name=f"{embedding_initializer.name}_quant"
964+
)
965+
966+
# update graph
967+
model.graph.initializer.append(embedding_quant_initializer)
968+
gather_node.input[0] = embedding_quant_initializer.name
969+
970+
# detect QDQ block on output
971+
output_quant_node = graph.get_node_single_child(gather_node)
972+
if output_quant_node and output_quant_node.op_type == "QuantizeLinear":
973+
output_dequant_node = graph.get_node_single_child(output_quant_node)
974+
qdq_output = (
975+
output_dequant_node
976+
and output_dequant_node.op_type == "DequantizeLinear"
977+
)
978+
else:
979+
qdq_output = False
980+
981+
if qdq_output:
982+
# delete unnecessary quantize and dequantize ops
983+
delete_quant_node(model, input_quant_node, keep_params=False)
984+
delete_quant_node(model, input_dequant_node, keep_params=False)
985+
delete_quant_node(model, output_quant_node, keep_params=False)
986+
# forward gather output to dequant input
987+
output_dequant_node.input[0] = gather_node.output[0]
988+
989+
else:
990+
# use input dequant to dequantize output
991+
embedding_quant_output_id = f"{gather_node.output[0]}_quant"
992+
input_dequant_node.input[0] = embedding_quant_output_id
993+
input_dequant_node.output[0] = gather_node.output[0]
994+
gather_node.output[0] = embedding_quant_output_id
995+
996+
delete_quant_node(model, input_quant_node, keep_params=False)
997+
graph.update()
998+
converted_nodes += 1
999+
1000+
graph.delete_unused_initializers()
1001+
1002+
if converted_nodes > 0:
1003+
_LOGGER.info(f"Converted {converted_nodes} QAT embedding ops to UINT8")
1004+
1005+
9121006
def _replace_input_id_model(model: ModelProto, old_id: str, new_id: str):
9131007
for node in model.graph.node:
9141008
for idx, inp in enumerate(node.input):
@@ -996,6 +1090,7 @@ def quantize_torch_qat_export(
9961090
_convert_quantizable_matmul(model)
9971091
_convert_quantizable_matmul_and_add(model)
9981092
_convert_quantizable_ops(model)
1093+
_quantize_qat_embedding(model)
9991094
quantize_resnet_identity_add_inputs(model)
10001095
quantized_residual_add_optim(model)
10011096
_remove_duplicate_quantize_ops(model)

0 commit comments

Comments
 (0)