Skip to content

Commit 6dee0b0

Browse files
[WC] WC/Mixed Precision/AWQ transpose_a support (#3794)
### Changes * Weight compression / mixed precision `transpose_a` support * AWQ `transpose_a` support * `process_statistics` `transpose_a` param support * [ONNX] AWQ gemm support ### Reason for changes * To apply WC/mixed precision/AWQ to the mamba model family ### Related tickets 173277 ### Tests * tests/cross_fw/test_templates/template_test_weights_compression.py::test_mixed_precision expanded with `transpose_a` param to check base WC/ mixed precision algo * tests/cross_fw/test_templates/template_test_weights_compression.py::test_awq_scale_reference expanded with `transpose_a` and `non_mergable_pattern` to check non mergeable AWQ branch & activation transpose support * tests/cross_fw/test_templates/template_test_weights_compression.py::test_process_stats refactored to test `act_ch_axis` support in `process_statistics` fn * tests/cross_fw/test_templates/template_test_weights_compression.py::test_compression_skipped_with_transposed_activations moved to common to test ONNX/OV failing with appropriate error when unsupported `trahspose_a` model supplied to an algorithm --------- Co-authored-by: andreyanufr <andrey.anufriev@intel.com>
1 parent 7438f86 commit 6dee0b0

File tree

20 files changed

+659
-170
lines changed

20 files changed

+659
-170
lines changed

src/nncf/quantization/algorithms/weight_compression/activation_stats.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
from nncf.tensor import functions as fns
1818

1919

20-
def process_stats(stats: WCTensorStatistic, subset_size: int) -> tuple[Tensor, Tensor]:
20+
def process_stats(stats: WCTensorStatistic, subset_size: int, act_ch_axis: int = -1) -> tuple[Tensor, Tensor]:
2121
"""
2222
A function for processing activations. Shared between AWQ, Scale Estimation and LoRA Correction algorithms.
2323
2424
:param stats: An object containing statistics for the layer.
2525
:param subset_size: The number of samples for AWQ.
26+
:param act_ch_axis: The activation channel axis.
2627
:return: tuple of the following tensors:
2728
s - maximum channel magnitude across samples [HiddenDim]
2829
X - average channel magnitude across tokens in the sequence [HiddenDim, min(SampleSize, ~subset_size)]
@@ -41,7 +42,9 @@ def process_stats(stats: WCTensorStatistic, subset_size: int) -> tuple[Tensor, T
4142

4243
# Prevent high memory and time consumption by sampling
4344
if X_full.shape[sample_axis] > subset_size:
44-
lens = [reduce(mul, shape[:-1], 1) for shape in stats.shape_values]
45+
lens = [
46+
reduce(mul, shape[:act_ch_axis] + shape[act_ch_axis % len(shape) + 1 :], 1) for shape in stats.shape_values
47+
]
4548
step = X_full.shape[sample_axis] // subset_size
4649
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
4750
X = X_full[..., idxs]

src/nncf/quantization/algorithms/weight_compression/algorithm.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -959,9 +959,9 @@ def get_weight_compression_parameters(
959959
# MoE operations are usually matmuls, so the check for matmul metatype is done
960960
# This is to avoid raising the error for non-MoE cases with 3D weights.
961961
parsed_ov_version = f"{ov_version[0]}.{ov_version[1]}.{ov_version[2]}-{ov_version[3]}"
962-
msg = f"""NNCF compression algorithms do not support 3D weights with current version of
963-
OpenVINO {parsed_ov_version} due to a known issue in statistics collection
964-
Ticket - 176465. Please update to the latest OpenVINO nightly version.
962+
msg = f"""NNCF compression algorithms do not support 3D weights with current version of
963+
OpenVINO {parsed_ov_version} due to a known issue in statistics collection
964+
Ticket - 176465. Please update to the latest OpenVINO nightly version.
965965
Node with weight: {node.node_name}."""
966966
raise nncf.UnsupportedModelError(msg)
967967

@@ -1087,6 +1087,11 @@ def apply_with_parameters(
10871087
)
10881088

10891089
if self._lora_correction:
1090+
for wc_params in all_weight_params:
1091+
if self._backend_entity.matmul_has_transposed_activations(wc_params.node_with_weight, graph):
1092+
msg = "Transposed activations are not supported yet for the LoRa correction algorithm"
1093+
raise nncf.UnsupportedModelError(msg)
1094+
10901095
lora_correction_params = self._advanced_parameters.lora_correction_params
10911096
lora_correction_algo = LoraCorrectionAlgorithm(statistics, lora_correction_params)
10921097
description += " with correction of low-rank adapters"
@@ -1128,19 +1133,21 @@ def apply_with_parameters(
11281133
)
11291134
return transformed_model
11301135

1131-
def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]:
1136+
def _get_activation_node_port_and_channel(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int, int]:
11321137
"""
1133-
This method returns the activation layer and corresponding port id for the node.
1138+
This method returns the activation layer, corresponding port id and channel axis for the given node.
11341139
11351140
:param node: NNCFGraph node for which the activation is sought.
11361141
:param nncf_graph: NNCFGraph instance with the node.
1137-
:return: Tuple with the activation node and port id.
1142+
:return: Tuple with the activation node, port id and channel axis.
11381143
"""
11391144
activation_port = self._backend_entity.get_activation_port_id(node, nncf_graph)
11401145
activation_edge = nncf_graph.get_input_edge_by_port_id(node, activation_port)
11411146
activation_node = activation_edge.from_node
1142-
port_id = activation_edge.output_port_id
1143-
return activation_node, port_id
1147+
activation_channel_axis = self._backend_entity.get_activation_channel_axis(
1148+
node, activation_edge.input_port_id, activation_edge.tensor_shape
1149+
)
1150+
return activation_node, activation_edge.output_port_id, activation_channel_axis
11441151

11451152
def get_matmul_input_to_output_nodes_map(
11461153
self, matmul_nodes: list[NNCFNode], graph: NNCFGraph
@@ -1161,8 +1168,8 @@ def get_matmul_input_to_output_nodes_map(
11611168
"""
11621169
matmul_input_to_output_nodes_map = defaultdict(list)
11631170
for node in matmul_nodes:
1164-
act_node, output_port_id = self._get_activation_node_and_port(node, graph)
1165-
matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node)
1171+
act_node, output_port_id, act_channel_axis = self._get_activation_node_port_and_channel(node, graph)
1172+
matmul_input_to_output_nodes_map[(act_node, output_port_id, act_channel_axis)].append(node)
11661173
return matmul_input_to_output_nodes_map
11671174

11681175
def get_compression_nodes_info(
@@ -1230,7 +1237,11 @@ def get_statistic_points(
12301237

12311238
# Statistics for data aware algorithms
12321239
if self._data_aware_compression:
1233-
for (node, output_port_id), node_with_weights in matmul_input_to_output_nodes_map.items():
1240+
for (
1241+
node,
1242+
output_port_id,
1243+
input_channel_axis,
1244+
), node_with_weights in matmul_input_to_output_nodes_map.items():
12341245
statistic_point = self._backend_entity.target_point(
12351246
TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id
12361247
)
@@ -1245,13 +1256,16 @@ def get_statistic_points(
12451256
]
12461257
all_weight_dims.extend(weight_dims)
12471258

1248-
# by default, reduce activations across all but the last dimension. The last dimension is
1249-
# assumed to be the hidden size dimension.
1259+
# Reduce activations across all but the hidden dimension.
12501260
n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape)
1251-
reduction_axes = tuple(range(n_dims - 1))
1261+
# negative axis (e.g. -1 for the last axis) is converted into corresponding positive value
1262+
input_channel_axis = input_channel_axis % n_dims
1263+
reduction_axes = tuple(i for i in range(n_dims) if i != input_channel_axis)
12521264

1253-
# For 3D weights, hidden dimension is the second dimension. Reduce by all other dimensions
1254-
reduction_axes = (1,) if any(weight_dim == 3 for weight_dim in all_weight_dims) else reduction_axes
1265+
# For 3D weights, keep the batch dimention
1266+
if any(weight_dim == 3 for weight_dim in all_weight_dims):
1267+
assert len(reduction_axes) == 2
1268+
reduction_axes = reduction_axes[1:]
12551269

12561270
stat_collector = self._backend_entity.mean_statistic_collector(
12571271
reduction_axes=reduction_axes, subset_size=self._subset_size
@@ -1291,7 +1305,7 @@ def _get_statistics_for_weights_compression(
12911305
# Where mean_value is a 1D tensor representing an activation reduced over batch and sequence length dimensions,
12921306
# shape is an original shape of an activation before reduction, n is the size of the dataset (or subset_size).
12931307
statistics = {}
1294-
for (act_node, output_port_id), matmul_nodes in matmul_input_to_output_nodes_map.items():
1308+
for (act_node, output_port_id, _), matmul_nodes in matmul_input_to_output_nodes_map.items():
12951309
tensor_collectors = list(
12961310
statistic_points.get_algo_statistics_for_node(
12971311
act_node.node_name,

src/nncf/quantization/algorithms/weight_compression/awq.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def apply(
170170
weight_dtype = weight.dtype
171171
weight = weight.astype(TensorDataType.float32)
172172

173+
act_ch_axis, act_shape = self._get_activation_channel_axis_and_shape(graph, wp)
174+
173175
if is_data_free:
174176
scale = self._data_free_step(weight, 1 - wp.reduction_axes[0])
175177
else:
@@ -181,24 +183,28 @@ def apply(
181183
prev_weight = self._backend_entity.get_weight(merge_node, prev_weight_port_id, model, graph)
182184

183185
prev_statistics = statistics[merge_node.node_name]
184-
scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics)
186+
scale = self._data_aware_step(wp, weight, statistics[k], act_ch_axis, prev_weight, prev_statistics)
185187

186188
w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0])
187-
a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0])
189+
a_scale = 1.0 / scale
188190

189191
scaled_weight = (weight * w_scale).astype(weight_dtype)
190192
self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight)
191193

192194
if is_mergeable: # for MatMul->Multiply->MatMul pattern the scale is merged to the first MatMul
193195
for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph):
194196
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
197+
a_scale = fns.unsqueeze(a_scale, wp.reduction_axes[0])
195198
merge_weight = (merge_weight * a_scale).astype(weight_dtype)
196199
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
197-
a_scale = fns.transpose(a_scale)
198200
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
199-
a_scale = fns.transpose(a_scale).astype(weight_dtype)
201+
# Calculate the activation scale shape
202+
a_scale_shape = [scale.shape[0] if axis == act_ch_axis else 1 for axis in range(len(act_shape))]
203+
a_scale = fns.reshape(a_scale, tuple(a_scale_shape))
204+
200205
next_nodes = graph.get_next_nodes(merge_node)
201206
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
207+
202208
scale_insertion_command = self._backend_entity.scale_insertion_command(
203209
merge_node, next_nodes, source_node_output_port, a_scale.data
204210
)
@@ -210,10 +216,10 @@ def apply(
210216

211217
return transformed_model
212218

213-
def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statistics=None):
219+
def _data_aware_step(self, wp, weight, statistics, act_ch_axis, prev_weight=None, prev_statistics=None):
214220
alpha_step = (self._alpha_max - self._alpha_min) / self._steps
215221
config = wp.compression_config
216-
s, X = process_stats(statistics, self._subset_size)
222+
s, X = process_stats(statistics, self._subset_size, act_ch_axis)
217223
s = s.astype(TensorDataType.float32)
218224
X = X.astype(TensorDataType.float32)
219225

@@ -222,7 +228,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
222228

223229
prev_s, prev_w = None, None
224230
if prev_statistics is not None and prev_weight is not None:
225-
prev_s, _ = process_stats(prev_statistics, self._subset_size)
231+
prev_s, _ = process_stats(prev_statistics, self._subset_size, act_ch_axis)
226232
prev_s = prev_s.astype(TensorDataType.float32).max().item()
227233
prev_w = fns.mean(fns.abs(prev_weight), axis=reduction_axis)
228234

@@ -311,6 +317,16 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
311317

312318
return scale
313319

320+
def _get_activation_channel_axis_and_shape(
321+
self, graph: NNCFGraph, wp: WeightCompressionParameters
322+
) -> tuple[int, tuple[int, ...]]:
323+
activation_port_id = self._backend_entity.get_activation_port_id(wp.node_with_weight, graph)
324+
act_shape = graph.get_input_edge_by_port_id(wp.node_with_weight, activation_port_id).tensor_shape
325+
act_ch_axis = self._backend_entity.get_activation_channel_axis(
326+
wp.node_with_weight, activation_port_id, act_shape
327+
)
328+
return act_ch_axis % len(act_shape), act_shape
329+
314330
@staticmethod
315331
def _clamp_scale(magnitudes, threshold, scale, clamped_scale):
316332
return fns.where(magnitudes < threshold, scale, clamped_scale)

src/nncf/quantization/algorithms/weight_compression/backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ def get_weight(self, node_with_weight: NNCFNode, weight_port_id: int, model: TMo
110110
:return: The weight tensor.
111111
"""
112112

113+
@abstractmethod
114+
def matmul_has_transposed_activations(self, matmul: NNCFNode, int, graph: NNCFGraph) -> bool:
115+
"""
116+
Checks whether the activation input of a MatMul operation is transposed.
117+
118+
:param matmul: MatMul NNCFGraph node.
119+
:param graph: The model graph associated with the model.
120+
:return: True if the node is a matmul node and activation input is transposed,
121+
False otherwise.
122+
"""
123+
113124
@abstractmethod
114125
def get_weight_dtype(
115126
self, node_with_weight: NNCFNode, weight_port_id: int, model: TModel, graph: NNCFGraph
@@ -273,6 +284,18 @@ def get_ignored_patterns() -> GraphPattern:
273284
:return: backend-specific ignored patterns.
274285
"""
275286

287+
@staticmethod
288+
@abstractmethod
289+
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int:
290+
"""
291+
Returns axis number of the activation tensor which correspond to it channel.
292+
293+
:param node: NNCFNode instance.
294+
:param port_id: Port ID for input.
295+
:param input_shape: Shape of the input.
296+
:return: Channel axis number.
297+
"""
298+
276299

277300
class AWQAlgoBackend(WeightCompressionAlgoBackend):
278301
@staticmethod

src/nncf/quantization/algorithms/weight_compression/gptq.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ def apply(
124124
CompressWeightsMode.INT8_SYM,
125125
]:
126126
continue
127+
128+
if self._backend_entity.matmul_has_transposed_activations(wc_params.node_with_weight, graph):
129+
msg = "Transposed activations are not supported yet for the GPTQ algorithm"
130+
raise nncf.UnsupportedModelError(msg)
131+
127132
_, input_tensors = next(iter(inputs.items()))
128133
hessian = self._calculate_hessian(node, input_tensors)
129134
scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors)

src/nncf/quantization/algorithms/weight_compression/mixed_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def get_statistic_points(
279279
self._set_backend_entity(model)
280280

281281
statistic_container = StatisticPointsContainer()
282-
for act_node, output_port_id in nodes_and_port_ids:
282+
for act_node, output_port_id, _ in nodes_and_port_ids:
283283
n_dims = len(graph.get_output_edges_by_port_id(act_node, output_port_id)[0].tensor_shape)
284284
if n_dims < 2:
285285
msg = (

src/nncf/quantization/algorithms/weight_compression/onnx_backend.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from nncf.onnx.graph.model_transformer import remove_initializer
3939
from nncf.onnx.graph.model_transformer import remove_node
4040
from nncf.onnx.graph.model_transformer import set_initializer
41+
from nncf.onnx.graph.node_utils import get_act_quantization_axis
4142
from nncf.onnx.graph.node_utils import get_weight_quantization_axis
4243
from nncf.onnx.graph.onnx_helper import ONNX_DTYPE_TO_NNCF_DTYPE
4344
from nncf.onnx.graph.onnx_helper import get_name_to_node_map
@@ -186,6 +187,13 @@ def get_weight(
186187
weight_tensor = get_tensor_value(model, weight_name)
187188
return Tensor(weight_tensor)
188189

190+
def matmul_has_transposed_activations(self, matmul: NNCFNode, graph: NNCFGraph) -> bool:
191+
if matmul.metatype != metatypes.ONNXGemmMetatype:
192+
return False
193+
act_port_id = self.get_activation_port_id(matmul, graph)
194+
trans_attr = "transB" if act_port_id else "transA"
195+
return matmul.layer_attributes.node_attrs[trans_attr]
196+
189197
def get_weight_dtype(
190198
self, node_with_weight: NNCFNode, weight_port_id: int, model: onnx.ModelProto, graph: NNCFGraph
191199
) -> TensorDataType:
@@ -301,6 +309,10 @@ def filter_func(point: StatisticPoint) -> bool:
301309

302310
return filter_func
303311

312+
@staticmethod
313+
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int:
314+
return get_act_quantization_axis(node, port_id)
315+
304316
def insert_adapters(
305317
self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool
306318
) -> None:
@@ -503,9 +515,13 @@ def get_ignored_patterns() -> GraphPattern:
503515
class ONNXAWQAlgoAlgoBackend(AWQAlgoBackend, ONNXWeightCompressionAlgoBackend):
504516
@staticmethod
505517
def get_awq_patterns() -> dict[str, Callable]:
506-
return get_awq_patterns(
507-
onnx_metatypes.ONNXMatMulMetatype, onnx_metatypes.ONNXMulLayerMetatype, ATOMIC_ACTIVATIONS_OPERATIONS
508-
)
518+
patterns = {}
519+
for mm_metatype in (onnx_metatypes.ONNXMatMulMetatype, onnx_metatypes.ONNXGemmMetatype):
520+
p = get_awq_patterns(mm_metatype, onnx_metatypes.ONNXMulLayerMetatype, ATOMIC_ACTIVATIONS_OPERATIONS)
521+
p = {f"{mm_metatype.__name__}_{k}": v for k, v in p.items()}
522+
patterns.update(p)
523+
524+
return patterns
509525

510526
@staticmethod
511527
def scale_insertion_command(

src/nncf/quantization/algorithms/weight_compression/openvino_backend.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import openvino as ov
1414
from openvino import opset13 as opset
1515

16-
import nncf
1716
from nncf.common.graph import NNCFGraph
1817
from nncf.common.graph import NNCFNode
1918
from nncf.common.graph.operator_metatypes import OperatorMetatype
@@ -35,6 +34,7 @@
3534
from nncf.openvino.graph.node_utils import convert_op
3635
from nncf.openvino.graph.node_utils import create_ov_codebook_subgraph
3736
from nncf.openvino.graph.node_utils import create_ov_const_from_tensor
37+
from nncf.openvino.graph.node_utils import get_activation_channel_axis
3838
from nncf.openvino.graph.node_utils import get_const_value_as_numpy_tensor
3939
from nncf.openvino.graph.node_utils import get_const_value_as_ov_tensor
4040
from nncf.openvino.graph.node_utils import get_weight_channel_axes
@@ -119,9 +119,6 @@ def mean_statistic_collector(
119119

120120
@staticmethod
121121
def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
122-
if node.layer_attributes.input_attributes["transpose"]:
123-
msg = "Transposed input is not supported"
124-
raise nncf.UnsupportedModelError(msg)
125122
constant_ports = node.layer_attributes.get_const_port_ids()
126123
activation_ports = [
127124
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports
@@ -143,6 +140,11 @@ def get_weight(self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.
143140
weight_tensor = get_const_value_as_numpy_tensor(weight_node)
144141
return Tensor(weight_tensor)
145142

143+
def matmul_has_transposed_activations(self, matmul: NNCFNode, graph: NNCFGraph) -> bool:
144+
if matmul.metatype != om.OVMatMulMetatype:
145+
return False
146+
return matmul.layer_attributes.input_attributes["transpose"]
147+
146148
def get_weight_dtype(
147149
self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.Model, graph: NNCFGraph
148150
) -> TensorDataType:
@@ -378,6 +380,10 @@ def get_ignored_patterns() -> GraphPattern:
378380
pattern.add_pattern_alternative(create_sam_pe())
379381
return pattern
380382

383+
@staticmethod
384+
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int:
385+
return get_activation_channel_axis(node, port_id, input_shape)
386+
381387

382388
class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend):
383389
"""

src/nncf/quantization/algorithms/weight_compression/scale_estimation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def apply(
139139
continue
140140
_, weight_port_id = weight_data[0]
141141

142+
if self._backend_entity.matmul_has_transposed_activations(wp.node_with_weight, graph):
143+
msg = "Transposed activations are not supported yet for the Scale Estimation algorithm"
144+
raise nncf.UnsupportedModelError(msg)
145+
142146
weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph)
143147

144148
scale, zero_point = self.calculate_quantization_params(

0 commit comments

Comments
 (0)