Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 0fa9f72

Browse files
anmarquesbfineran
andauthored
Cherry-pick ONNX export update into release 1.0 (#936)
* Bump up version id * Fix for ONNX export for quantized BERT models (#935) * Remove quantization of identity branch on BERT models * Style and quality fixes. * Update src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py Co-authored-by: Benjamin Fineran <[email protected]> * Removed unused function Co-authored-by: Benjamin Fineran <[email protected]> Co-authored-by: Benjamin Fineran <[email protected]>
1 parent fc4c771 commit 0fa9f72

File tree

3 files changed

+22
-93
lines changed

3 files changed

+22
-93
lines changed

src/sparseml/onnx/utils/graph_optimizer.py

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
__all__ = [
4141
"fold_conv_bns",
4242
"quantize_resnet_identity_add_inputs",
43-
"quantized_residual_add_optim",
4443
]
4544

4645

@@ -202,91 +201,6 @@ def quantize_resnet_identity_add_inputs(quantized_model: onnx.ModelProto) -> boo
202201
return optimization_made
203202

204203

205-
def quantized_residual_add_optim(quantized_model: onnx.ModelProto) -> bool:
206-
"""
207-
This optimization adds a quant/dequant block to the identity branch of a
208-
residual whose non-identity branch is quantized. This enables the add at the
209-
end of the residual to be fused at runtime.
210-
211-
Function will match to any node who has two children nodes - one add node
212-
and one quantize node whose branch eventually leads to the other add node.
213-
214-
:param quantized_model: A loaded quantized model to perform this optimization on
215-
:return: True if an in-place optimization was made
216-
"""
217-
graph = ONNXGraph(quantized_model)
218-
optimization_made = False
219-
for node in quantized_model.graph.node:
220-
children_nodes = graph.get_node_children(node)
221-
if len(children_nodes) != 2:
222-
continue
223-
224-
add_node = [node for node in children_nodes if node.op_type == "Add"]
225-
quant_node = [
226-
node for node in children_nodes if node.op_type == "QuantizeLinear"
227-
]
228-
if not add_node or not quant_node:
229-
continue
230-
add_node = add_node[0]
231-
quant_node = quant_node[0]
232-
233-
# verify that quant_node eventually leads to add_node
234-
curr_node = [quant_node]
235-
iter = 0
236-
max_iter = 20 # avoid cycles
237-
while curr_node and curr_node[0] != add_node and iter < max_iter:
238-
curr_node = graph.get_node_children(curr_node[0])
239-
iter += 1
240-
if curr_node[0] != add_node:
241-
continue
242-
243-
# create de-quantize node for identity
244-
dequant_node = _make_dequant_node_for_quant(quant_node)
245-
246-
# update graph
247-
identity_edge_idx = 0 if add_node.input[0] == node.output[0] else 1
248-
graph.add_node(dequant_node)
249-
graph.update_node_input(add_node, dequant_node.output[0], identity_edge_idx)
250-
optimization_made = True
251-
252-
# if any of the add children have are a quantize op while others aren't
253-
# add a quant/dequant block to the non quantized paths to allow for fusion
254-
# of the add
255-
add_node_children = graph.get_node_children(add_node)
256-
add_node_quant_child_idx = [
257-
idx
258-
for idx, node in enumerate(add_node_children)
259-
if node.op_type == "QuantizeLinear"
260-
]
261-
if not add_node_quant_child_idx or all(
262-
n.op_type == "Add" or n.op_type == "QuantizeLinear"
263-
for n in add_node_children
264-
):
265-
# no quant child node, or all child nodes are quant/add nodes
266-
continue
267-
268-
# make dequant pair node for quant child and add to graph
269-
add_node_dequant_child = _make_dequant_node_for_quant(
270-
add_node_children[add_node_quant_child_idx[0]]
271-
)
272-
graph.add_node(add_node_dequant_child)
273-
274-
# update all non quant node children to take the quant/dequant block as input
275-
for add_child_node in add_node_children:
276-
if add_child_node.op_type == "QuantizeLinear":
277-
continue
278-
add_node_id_idx = [
279-
idx
280-
for idx, output_id in enumerate(add_child_node.input)
281-
if output_id == add_node.output[0]
282-
][0]
283-
graph.update_node_input(
284-
add_child_node, add_node_dequant_child.output[0], add_node_id_idx
285-
)
286-
287-
return optimization_made
288-
289-
290204
def _make_dequant_node_for_quant(quant_node: onnx.NodeProto) -> onnx.NodeProto:
291205
return onnx.helper.make_node(
292206
"DequantizeLinear",

src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import numpy
2727
import onnx
28+
import torch
2829
from onnx import ModelProto, NodeProto, numpy_helper
2930

3031
from sparseml.onnx.utils import (
@@ -34,7 +35,6 @@
3435
get_node_attributes,
3536
get_node_output_nodes,
3637
quantize_resnet_identity_add_inputs,
37-
quantized_residual_add_optim,
3838
remove_node_and_params_from_graph,
3939
swap_node_output,
4040
update_model_param,
@@ -323,9 +323,21 @@ def _attribute_to_kwarg(attribute: onnx.AttributeProto):
323323
def _quantize_array(
324324
array: numpy.ndarray, scale: float, zero_point: int, dtype: Any = numpy.uint8
325325
) -> numpy.ndarray:
326-
dmin = numpy.iinfo(dtype).min
327-
dmax = numpy.iinfo(dtype).max
328-
return ((array / scale).round() + zero_point).clip(dmin, dmax).astype(dtype)
326+
if dtype == numpy.uint8:
327+
tensor_dtype = torch.quint8
328+
elif dtype == numpy.int8:
329+
tensor_dtype = torch.qint8
330+
elif dtype == numpy.int32:
331+
tensor_dtype = torch.qint32
332+
333+
tensor = torch.Tensor(array).to(torch.float32)
334+
if isinstance(scale, numpy.ndarray):
335+
scale = scale.item()
336+
if isinstance(zero_point, numpy.ndarray):
337+
zero_point = zero_point.item()
338+
339+
quant_tensor = torch.quantize_per_tensor(tensor, scale, zero_point, tensor_dtype)
340+
return quant_tensor.int_repr().numpy()
329341

330342

331343
def _convert_quantizable_conv(
@@ -450,6 +462,7 @@ def _convert_quantizable_gemm(
450462
weight_quantize_params.target,
451463
weight_quantize_params.scale,
452464
weight_quantize_params.zero_point,
465+
weight_quantize_params.zero_point.dtype,
453466
)
454467
quantized_weight = quantized_weight.transpose() # Gemm has implicit transpose
455468
quantized_weight_name = "{}.weight_quantized".format(gemm_node.name)
@@ -732,6 +745,7 @@ def _add_quantized_conv_matmul_add_ops(
732745
weight_quantize_params.target,
733746
weight_quantize_params.scale,
734747
weight_quantize_params.zero_point,
748+
weight_quantize_params.zero_point.dtype,
735749
)
736750
if transpose_weight:
737751
quantized_weight = quantized_weight.transpose()
@@ -1404,7 +1418,9 @@ def _quantize_qat_embedding(model: ModelProto):
14041418
embedding = numpy_helper.to_array(embedding_initializer)
14051419
scale = numpy_helper.to_array(scale_initializer)
14061420
zero_point = numpy_helper.to_array(zp_initializer)
1407-
embedding_quant = _quantize_array(embedding, scale, zero_point)
1421+
embedding_quant = _quantize_array(
1422+
embedding, scale, zero_point, zero_point.dtype
1423+
)
14081424
embedding_quant_initializer = numpy_helper.from_array(
14091425
embedding_quant, name=f"{embedding_initializer.name}_quant"
14101426
)
@@ -1569,7 +1585,6 @@ def quantize_torch_qat_export(
15691585
_convert_quantizable_gemm_no_activations(model)
15701586
_quantize_qat_embedding(model)
15711587
quantize_resnet_identity_add_inputs(model)
1572-
quantized_residual_add_optim(model)
15731588
_remove_duplicate_quantize_ops(model)
15741589
_cleanup_unused_quants(model)
15751590

src/sparseml/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datetime import date
2020

2121

22-
version_base = "1.0.0"
22+
version_base = "1.0.1"
2323
is_release = False # change to True to set the generated version as a release version
2424

2525

0 commit comments

Comments
 (0)