Skip to content

Commit 2c79362

Browse files
AWQ Support for ONNX Backend (#3571)
### Changes AWQ Support for ONNX Backend ### Reason for changes Ref: 168332 ### Related tickets Ref: 168332 ### Tests - tinyllama_data_free_awq_backend_ONNX
1 parent ce864f4 commit 2c79362

File tree

9 files changed

+124
-5
lines changed

9 files changed

+124
-5
lines changed

src/nncf/onnx/graph/model_transformer.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nncf.onnx.graph.onnx_helper import get_tensor
3030
from nncf.onnx.graph.transformations.commands import ONNXInitializerUpdateCommand
3131
from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand
32+
from nncf.onnx.graph.transformations.commands import ONNXMultiplyInsertionCommand
3233
from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand
3334
from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand
3435
from nncf.onnx.graph.transformations.commands import ONNXQuantizerInsertionCommand
@@ -91,6 +92,7 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr
9192
initializer_update_transformations = []
9293
qdq_node_removing_transformations = []
9394
model_extraction_transformation = None
95+
multiply_insert_transformations = []
9496
transformations = transformation_layout.transformations
9597
# No transformation applied
9698
if not transformations:
@@ -106,15 +108,24 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr
106108
qdq_node_removing_transformations.append(transformation)
107109
elif isinstance(transformation, ONNXInitializerUpdateCommand):
108110
initializer_update_transformations.append(transformation)
111+
elif isinstance(transformation, ONNXMultiplyInsertionCommand):
112+
multiply_insert_transformations.append(transformation)
109113
# Inplace transformations, using deepcopy of model
110-
if quantizer_insert_transformations or initializer_update_transformations or qdq_node_removing_transformations:
114+
if (
115+
quantizer_insert_transformations
116+
or initializer_update_transformations
117+
or qdq_node_removing_transformations
118+
or multiply_insert_transformations
119+
):
111120
model = deepcopy(self._model)
112121
if quantizer_insert_transformations:
113122
model = self._apply_quantizer_insertion_transformations(model, quantizer_insert_transformations)
114123
if qdq_node_removing_transformations:
115124
model = self._apply_qdq_node_removing_transformations(model, qdq_node_removing_transformations)
116125
if initializer_update_transformations:
117126
model = self._apply_initializer_update_transformations(model, initializer_update_transformations)
127+
if multiply_insert_transformations:
128+
model = self._apply_multiply_insertion_transformations(model, multiply_insert_transformations)
118129
# Transformations that create new model
119130
if output_insert_transformations:
120131
model = self._apply_output_insertion_transformations(output_insert_transformations)
@@ -459,6 +470,48 @@ def _apply_qdq_node_removing_transformations(
459470

460471
return model
461472

473+
@staticmethod
474+
def _apply_multiply_insertion_transformations(
475+
model: onnx.ModelProto, transformations: list[ONNXMultiplyInsertionCommand]
476+
) -> onnx.ModelProto:
477+
"""
478+
Inserts Multiply with provided value for corresponding layer.
479+
480+
:param transformations: List of the smooth insertion transformations.
481+
:returns: Transformed model with Multiply nodes.
482+
"""
483+
node_name_to_node = get_name_to_node_map(model)
484+
485+
for transformation in transformations:
486+
target_node_name = transformation.target_point.target_node_name
487+
target_output_port = transformation.target_point.port_id
488+
target_node = node_name_to_node[target_node_name]
489+
output_tensor_name = target_node.output[target_output_port]
490+
491+
# Create a new initializer for the scale constant
492+
scale_tensor_name = f"{transformation.multiply_node_name}_scale"
493+
scale_tensor = onnx.numpy_helper.from_array(transformation.scale_value, name=scale_tensor_name)
494+
model.graph.initializer.append(scale_tensor)
495+
496+
# Create a new Multiply node
497+
mul_output_name = f"{transformation.multiply_node_name}_output"
498+
mul_node = onnx.helper.make_node(
499+
"Mul",
500+
inputs=[output_tensor_name, scale_tensor_name],
501+
outputs=[mul_output_name],
502+
name=transformation.multiply_node_name,
503+
)
504+
target_index = get_node_index(model, target_node_name)
505+
model.graph.node.insert(target_index + 1, mul_node)
506+
507+
for name in transformation.destination_node_names:
508+
node = node_name_to_node[name]
509+
for i, input_name in enumerate(node.input):
510+
if input_name == output_tensor_name:
511+
node.input[i] = mul_output_name
512+
513+
return model
514+
462515

463516
def set_initializer(initializer_name: str, model: onnx.ModelProto, new_value: np.ndarray) -> None:
464517
"""

src/nncf/onnx/graph/transformations/command_creation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from nncf.common.graph.transformations.command_creation import CommandCreator
1717
from nncf.common.graph.transformations.commands import TargetType
1818
from nncf.onnx.graph.transformations.commands import ONNXInitializerUpdateCommand
19+
from nncf.onnx.graph.transformations.commands import ONNXMultiplyInsertionCommand
1920
from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand
2021
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
2122

@@ -59,3 +60,15 @@ def create_command_to_update_weight(
5960
@staticmethod
6061
def create_command_to_insert_bias(node_without_bias, bias_value):
6162
raise NotImplementedError
63+
64+
@staticmethod
65+
def multiply_insertion_command(
66+
source_node: NNCFNode,
67+
destination_nodes: list[NNCFNode],
68+
source_out_port: int,
69+
scale_value: np.ndarray,
70+
multiply_node_name: str,
71+
) -> ONNXMultiplyInsertionCommand:
72+
target_point = ONNXTargetPoint(TargetType.POST_LAYER_OPERATION, source_node.node_name, source_out_port)
73+
destination_node_names = [d.node_name for d in destination_nodes]
74+
return ONNXMultiplyInsertionCommand(target_point, scale_value, destination_node_names, multiply_node_name)

src/nncf/onnx/graph/transformations/commands.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,27 @@ def __init__(self, target_point: ONNXTargetPoint):
113113
:param target_point: The TargetPoint instance for the layer that contains information for removing.
114114
"""
115115
super().__init__(TransformationType.REMOVE, target_point)
116+
117+
118+
class ONNXMultiplyInsertionCommand(ONNXInsertionCommand):
119+
"""
120+
Inserts Multiply nodes before the corresponding nodes.
121+
"""
122+
123+
def __init__(
124+
self,
125+
target_point: ONNXTargetPoint,
126+
scale_value: np.ndarray,
127+
destination_node_names: list[str],
128+
multiply_node_name: str,
129+
):
130+
"""
131+
:param target_point: The TargetPoint instance for the insertion that contains layer's information.
132+
:param scale_value: Scale value for Multiply layer.
133+
:param destination_node_names: New layer consumers.
134+
:param multiply_node_name: New layer name.
135+
"""
136+
super().__init__(target_point, None)
137+
self.scale_value = scale_value
138+
self.destination_node_names = destination_node_names
139+
self.multiply_node_name = multiply_node_name

src/nncf/quantization/algorithms/weight_compression/awq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def _set_backend_entity(
114114
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXAWQAlgoAlgoBackend
115115

116116
self._backend_entity = FXAWQAlgoAlgoBackend()
117+
elif model_backend == BackendType.ONNX:
118+
from nncf.quantization.algorithms.weight_compression.onnx_backend import ONNXAWQAlgoAlgoBackend
119+
120+
self._backend_entity = ONNXAWQAlgoAlgoBackend(model)
117121
else:
118122
msg = f"Cannot return backend-specific AWQ entity because {model_backend.value} is not supported!"
119123
raise nncf.UnsupportedBackendError(msg)

src/nncf/quantization/algorithms/weight_compression/onnx_backend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
3131
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
3232
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
33+
from nncf.onnx.graph.metatypes import onnx_metatypes
34+
from nncf.onnx.graph.metatypes.groups import ATOMIC_ACTIVATIONS_OPERATIONS
3335
from nncf.onnx.graph.metatypes.groups import CONVOLUTION_METATYPES
3436
from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES
3537
from nncf.onnx.graph.model_transformer import remove_initializer
@@ -43,11 +45,14 @@
4345
from nncf.onnx.graph.onnx_helper import get_tensor_value
4446
from nncf.onnx.graph.onnx_helper import pack_4_bits
4547
from nncf.onnx.graph.onnx_helper import pack_int4_to_uint8
48+
from nncf.onnx.graph.transformations.command_creation import ONNXCommandCreator
4649
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
4750
from nncf.onnx.quantization.ignored_patterns import create_rope
4851
from nncf.parameters import CompressionFormat
4952
from nncf.parameters import CompressWeightsMode
5053
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
54+
from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns
55+
from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend
5156
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
5257
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
5358
from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm
@@ -181,7 +186,7 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC
181186
def set_weight(
182187
self, node_with_weight: NNCFNode, weight_port_id: int, model: onnx.ModelProto, graph: NNCFGraph, weight: Tensor
183188
):
184-
node = self.name_to_node_map[node_with_weight.target_node_name]
189+
node = self.name_to_node_map[node_with_weight.node_name]
185190
initializer_name = node.input[weight_port_id]
186191
set_initializer(initializer_name, model, weight.data)
187192

@@ -464,3 +469,19 @@ def _replace_matmul_with_matmulnbits(
464469
@staticmethod
465470
def get_ignored_patterns() -> GraphPattern:
466471
return create_rope()
472+
473+
474+
class ONNXAWQAlgoAlgoBackend(AWQAlgoBackend, ONNXWeightCompressionAlgoBackend):
475+
@staticmethod
476+
def get_awq_patterns() -> dict[str, Callable]:
477+
return get_awq_patterns(
478+
onnx_metatypes.ONNXMatMulMetatype, onnx_metatypes.ONNXMulLayerMetatype, ATOMIC_ACTIVATIONS_OPERATIONS
479+
)
480+
481+
@staticmethod
482+
def scale_insertion_command(
483+
source_node: NNCFNode, next_nodes: list[NNCFNode], source_node_output_port: int, scale: np.ndarray
484+
):
485+
return ONNXCommandCreator.multiply_insertion_command(
486+
source_node, next_nodes, source_node_output_port, scale, f"{source_node.node_name}/awq_mul"
487+
)

src/nncf/quantization/quantize_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,6 @@ def compress_weights(
634634
raise nncf.ParameterNotSupportedError(msg)
635635

636636
options = {
637-
"awq": awq,
638637
"scale_estimation": scale_estimation,
639638
"gptq": gptq,
640639
"lora_correction": lora_correction,

tests/post_training/data/wc_reference_data.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,7 @@ tinyllama_data_free_awq_backend_TORCH:
118118
metric_value: 0.85466
119119
num_int4: 94
120120
num_int8: 124
121+
tinyllama_data_free_awq_backend_ONNX:
122+
metric_value: 0.82562
123+
num_int4: 264
124+
num_int8: 84

tests/post_training/data/wc_test_durations.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_OV]": 164,
1616
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_TORCH]": 210,
1717
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_backend_ONNX]": 182,
18-
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_opset19_backend_ONNX]": 512
18+
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_opset19_backend_ONNX]": 512,
19+
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_ONNX]": 154
1920
}

tests/post_training/model_scope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@
579579
),
580580
},
581581
# TODO: (andreyanufr) add torch.fx backend
582-
"backends": [BackendType.OV, BackendType.TORCH],
582+
"backends": [BackendType.OV, BackendType.TORCH, BackendType.ONNX],
583583
},
584584
]
585585

0 commit comments

Comments
 (0)