diff --git a/src/nncf/quantization/algorithms/weight_compression/awq.py b/src/nncf/quantization/algorithms/weight_compression/awq.py index 508ad57060d..b2950f9787a 100644 --- a/src/nncf/quantization/algorithms/weight_compression/awq.py +++ b/src/nncf/quantization/algorithms/weight_compression/awq.py @@ -29,6 +29,8 @@ from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.tensor_slicing import get_weight_slice +from nncf.quantization.algorithms.weight_compression.tensor_slicing import set_weight_slice from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.quantization.passes import transform_to_inference_graph @@ -181,7 +183,7 @@ def apply( prev_weight = self._backend_entity.get_weight(merge_node, prev_weight_port_id, model, graph) prev_statistics = statistics[merge_node.node_name] - scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics) + scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics, weight_port_id) w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0]) a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0]) @@ -210,7 +212,7 @@ def apply( return transformed_model - def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statistics=None): + def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statistics=None, weight_port_id=None): alpha_step = (self._alpha_max - self._alpha_min) / self._steps config = wp.compression_config s, X = process_stats(statistics, self._subset_size) @@ -220,6 +222,9 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 reduction_axis = wp.reduction_axes[0] + # Get transpose_b value to handle weight shape correctly + transpose_b = wp.node_with_weight.layer_attributes.constant_attributes[weight_port_id]["transpose"] + prev_s, prev_w = None, None if prev_statistics is not None and prev_weight is not None: prev_s, _ = process_stats(prev_statistics, self._subset_size) @@ -239,9 +244,10 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis groups_to_correct = list(groups_to_correct) - if reduction_axis == 0: - weight = fns.transpose(weight) - reduction_axis = 1 + # Remove the old transpose logic - we'll use get_weight_slice instead + # if reduction_axis == 0: + # weight = fns.transpose(weight) + # reduction_axis = 1 shape_vector = fns.mean(X, axis=1) scale = fns.ones_like(shape_vector) @@ -257,7 +263,8 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis a_max = 1e2 gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) - gweight = weight[:, offset : offset + group_size] + # Use get_weight_slice instead of hardcoded slicing + gweight = get_weight_slice(weight, slice(offset, offset + group_size), transpose_b) gacts = X[offset : offset + group_size, :] fp32_out = fns.matmul(gweight, gacts) @@ -274,18 +281,14 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis ) # take the threshold from the fp16 type with some margin # per channel magnitudes for the previous MatMul # mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis] - magnitudes = ( - (prev_w[offset : offset + group_size] / cur_scale) * prev_s * prev_weight.shape[reduction_axis] - ) + prev_w_slice = prev_w[offset : offset + group_size] + magnitudes = (prev_w_slice / cur_scale) * prev_s * prev_weight.shape[reduction_axis] if magnitudes.max() >= threshold: cur_scale = AWQ._clamp_scale( magnitudes, threshold, cur_scale, - prev_w[offset : offset + group_size] - * prev_s - * prev_weight.shape[reduction_axis] - / threshold, + prev_w_slice * prev_s * prev_weight.shape[reduction_axis] / threshold, ) weights_to_fake_quantize = gweight * cur_scale @@ -307,7 +310,8 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis alpha += alpha_step if best_scale is not None: - scale.data[offset : offset + group_size] = best_scale.data + # Use set_weight_slice for assignment + set_weight_slice(scale, slice(offset, offset + group_size), best_scale, transpose_b) return scale diff --git a/src/nncf/quantization/algorithms/weight_compression/gptq.py b/src/nncf/quantization/algorithms/weight_compression/gptq.py index b90f2e0574b..002cb356bcd 100644 --- a/src/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/src/nncf/quantization/algorithms/weight_compression/gptq.py @@ -27,6 +27,8 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation +from nncf.quantization.algorithms.weight_compression.tensor_slicing import get_weight_slice +from nncf.quantization.algorithms.weight_compression.tensor_slicing import set_weight_slice from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_float_quantization_params from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight @@ -212,18 +214,22 @@ def _quantize_weights( if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes: msg = "Convolution metatypes are not supported" raise RuntimeError(msg) - if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]: - msg = "Transpose is not supported" - raise RuntimeError(msg) weight_tensor = self._backend_entity.get_weight( wc_params.node_with_weight, wc_params.weight_port_id, model, graph ) weight_tensor = fns.astype(weight_tensor, TensorDataType.float32) + # Get transpose_b value to handle weight shape correctly + transpose_b = wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id][ + "transpose" + ] + dead_indices = fns.diag(hessian) == 0 hessian[dead_indices, dead_indices] = 1 - weight_tensor[:, dead_indices] = 0 + + # Zero out dead indices using utility helper + set_weight_slice(weight_tensor, dead_indices, 0, transpose_b) scales = [] zero_points = [] @@ -235,7 +241,7 @@ def _quantize_weights( group_size = ( wc_params.compression_config.group_size if wc_params.compression_config.group_size != -1 - else weight_tensor.shape[1] + else (weight_tensor.shape[1] if transpose_b else weight_tensor.shape[0]) ) reduction_axes = wc_params.reduction_axes block_compression_config = WeightCompressionConfig( @@ -254,35 +260,38 @@ def _quantize_weights( i2 = min(i1 + self._block_size, columns) count = i2 - i1 - weight_block = weight_tensor[:, i1:i2].clone() + # Extract weight block using utility helper + weight_block = get_weight_slice(weight_tensor, slice(i1, i2), transpose_b).clone() quantized_block = fns.zeros_like(weight_block) error_block = fns.zeros_like(weight_block) loss_block = fns.zeros_like(weight_block) hessian_inv_block = hessian_inv[i1:i2, i1:i2] for i in range(count): - weight_col = weight_block[:, i] + weight_col = get_weight_slice(weight_block, i, transpose_b) hessian_diag_val = hessian_inv_block[i, i] if (i1 + i) % group_size == 0: if not block_compression_config.is_integer: + weight_slice = get_weight_slice(weight_tensor, slice(i1 + i, i1 + i + group_size), transpose_b) scale = calculate_float_quantization_params( - weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes, block_compression_config + weight_slice, reduction_axes, block_compression_config ) scales.append(scale) else: + weight_slice = get_weight_slice(weight_tensor, slice(i1 + i, i1 + i + group_size), transpose_b) if self._scale_estimation and block_compression_config.num_bits == 4: activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) scale, zero_point = ScaleEstimation.calculate_quantization_params( wc_statistics, - weight_tensor[:, (i1 + i) : (i1 + i + group_size)], + weight_slice, reduction_axes, block_compression_config, ) else: scale, zero_point = calculate_integer_quantization_params( - weight_tensor[:, (i1 + i) : (i1 + i + group_size)], + weight_slice, reduction_axes, block_compression_config, ) @@ -303,19 +312,34 @@ def _quantize_weights( precomputed_zero_point=zero_points[-1], ) quantized_col = fns.flatten(quantized_col) - quantized_block[:, i] = quantized_col - loss_block[:, i] = (weight_col - quantized_col) ** 2 / hessian_diag_val**2 + set_weight_slice(quantized_block, i, quantized_col, transpose_b) + loss_col = (weight_col - quantized_col) ** 2 / hessian_diag_val**2 + set_weight_slice(loss_block, i, loss_col, transpose_b) error_col = (weight_col - quantized_col) / hessian_diag_val - weight_block[:, i:] -= fns.matmul( - fns.unsqueeze(error_col, 1), fns.unsqueeze(hessian_inv_block[i, i:], 0) - ) - error_block[:, i] = error_col + if transpose_b: + weight_block[:, i:] -= fns.matmul( + fns.unsqueeze(error_col, 1), fns.unsqueeze(hessian_inv_block[i, i:], 0) + ) + set_weight_slice(error_block, i, error_col, transpose_b) + else: + weight_block[i:, :] -= fns.matmul( + fns.unsqueeze(error_col, 0), fns.unsqueeze(hessian_inv_block[i:, i], 1) + ) + set_weight_slice(error_block, i, error_col, transpose_b) - quantized_tensor[:, i1:i2] = quantized_block - losses[:, i1:i2] = loss_block / 2 + set_weight_slice(quantized_tensor, slice(i1, i2), quantized_block, transpose_b) + set_weight_slice(losses, slice(i1, i2), loss_block / 2, transpose_b) - weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:]) + # Update remaining weights with error propagation + if transpose_b: + weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:]) + else: + # For transpose_b=False: error_block shape is [i2-i1, out_features] + # hessian_inv[i2:, i1:i2] shape is [columns-i2, i2-i1] + # We need to transpose error_block to get [out_features, i2-i1] + # Then: hessian_inv[i2:, i1:i2] @ error_block^T gives [columns-i2, out_features] + weight_tensor[i2:, :] -= fns.matmul(hessian_inv[i2:, i1:i2], fns.transpose(error_block)) quantized_tensor = quantized_tensor.reshape(weight_tensor.shape).astype(weight_tensor.dtype) self._backend_entity.set_weight( @@ -325,13 +349,19 @@ def _quantize_weights( scales = fns.stack(scales, axis=1) if wc_params.compression_config.group_size == -1: scales = fns.squeeze(scales, axis=-1) - if wc_params.compression_config.mode in [ - CompressWeightsMode.INT8_ASYM, - CompressWeightsMode.INT4_ASYM, - ]: - zero_points = fns.stack(zero_points, axis=1) + + zero_points_tensor = None + if ( + zero_points + and zero_points[0] is not None + and wc_params.compression_config.mode + in [ + CompressWeightsMode.INT8_ASYM, + CompressWeightsMode.INT4_ASYM, + ] + ): + zero_points_tensor = fns.stack(zero_points, axis=1) if wc_params.compression_config.group_size == -1: - zero_points = fns.squeeze(zero_points, axis=-1) - else: - zero_points = None - return scales, zero_points + zero_points_tensor = fns.squeeze(zero_points_tensor, axis=-1) + + return scales, zero_points_tensor diff --git a/src/nncf/quantization/algorithms/weight_compression/lora_correction.py b/src/nncf/quantization/algorithms/weight_compression/lora_correction.py index 0fe478dfab5..2035b9068d5 100644 --- a/src/nncf/quantization/algorithms/weight_compression/lora_correction.py +++ b/src/nncf/quantization/algorithms/weight_compression/lora_correction.py @@ -121,6 +121,12 @@ def calculate_adapters( layer_name = wc_params.node_with_weight.node_name layer_statistics = self._statistics[layer_name] is_debug = self._debug_interface is not None + + # Get transpose_b value to handle weight shape correctly + transpose_b = wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id][ + "transpose" + ] + lora_A, lora_B, mean_noises = self.calculate_low_rank_matrices( weight, compressed_weight, @@ -129,6 +135,7 @@ def calculate_adapters( self._lora_correction_params, layer_statistics, is_debug, + transpose_b=transpose_b, ) if is_debug: self._debug_interface.add_noises(layer_name, mean_noises) @@ -143,6 +150,7 @@ def calculate_low_rank_matrices( lora_correction_params: AdvancedLoraCorrectionParameters, layer_statistics: WCTensorStatistic, is_debug: Optional[bool] = False, + transpose_b: bool = True, # Add this parameter with default True for backward compatibility ): """ Calculates low rank matrices for a given original and compressed weights. @@ -190,7 +198,8 @@ def calculate_low_rank_matrices( # O stands for output dimension, H - input dimension or hidden size, SS - samples size, R - rank. # reduction axes is all axes except output dimension in linear/conv layers. - if reduction_axes[0] == 1: + # Use transpose_b directly instead of inferring from reduction_axes + if not transpose_b: svd_residual = fns.transpose(svd_residual) residual = svd_residual.clone() # [H, O] diff --git a/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 3ec241b36c6..2fe1dd5dd0a 100644 --- a/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -63,6 +63,7 @@ from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight from nncf.tensor import Tensor +from nncf.tensor import functions as fns from nncf.tensor.definitions import TensorDataType from nncf.tensor.functions.openvino_numeric import DTYPE_MAP_REV @@ -178,6 +179,16 @@ def insert_adapters( should_add_convert_node = activation_dtype != ov.Type.f16 mm_node = self.name_to_node_mapping[wc_params.node_with_weight.node_name] + # Get the original MatMul's transpose attributes + node_attributes = mm_node.get_attributes() + transpose_a = node_attributes.get("transpose_a", False) + transpose_b = node_attributes.get("transpose_b", True) # Default to True for backward compatibility + + # Transpose lora_B if the original MatMul had transpose_b=False + # This ensures the matrix multiplication A_MM @ B_W has compatible dimensions + if not transpose_b: + lora_B = fns.transpose(lora_B) + if int8_lora: const_node_name = wc_params.node_with_weight.node_name int8_compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT8_ASYM, group_size=-1) @@ -203,7 +214,9 @@ def insert_adapters( A_W = opset.constant(lora_A.data) B_W = opset.constant(lora_B.data) - A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True) + # LoRA adapters: input @ A^T @ B^T + # Always keep transpose_b=True to ensure the adapter aligns with the MatMul output shape + A_MM = opset.matmul(input_node, A_W, transpose_a=transpose_a, transpose_b=True) B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True) node_output_port = mm_node.output(0) diff --git a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 4ad557b9868..565790c95a6 100644 --- a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -141,6 +141,9 @@ def apply( weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) + # Get transpose_b value to handle weight shape correctly + transpose_b = wp.node_with_weight.layer_attributes.constant_attributes[weight_port_id]["transpose"] + scale, zero_point = self.calculate_quantization_params( stats, weight, @@ -150,6 +153,7 @@ def apply( self._initial_steps, self._scale_steps, self._weight_penalty, + transpose_b=transpose_b, ) res[weight_name] = CompressedWeight(None, scale, zero_point, None) @@ -165,6 +169,7 @@ def calculate_quantization_params( initial_steps: int = 5, scale_steps: int = 10, weight_penalty: float = -1.0, + transpose_b: bool = True, # Add this parameter with default True for backward compatibility ) -> Tensor: """ Calculates the quantization parameters for a given set of weights and activations. @@ -199,7 +204,8 @@ def calculate_quantization_params( is_3d_weight = len(weight.shape) == 3 was_transposed = False - if reduction_axis == 0 or (reduction_axis == 1 and is_3d_weight): + # Use transpose_b directly instead of inferring from reduction_axis + if not transpose_b or (reduction_axis == 1 and is_3d_weight): # Weights # 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension] # 2D: [hidden_dimension, out_features] -> [out_features, hidden_dimension] diff --git a/src/nncf/quantization/algorithms/weight_compression/tensor_slicing.py b/src/nncf/quantization/algorithms/weight_compression/tensor_slicing.py new file mode 100644 index 00000000000..ce2360046fb --- /dev/null +++ b/src/nncf/quantization/algorithms/weight_compression/tensor_slicing.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from nncf.tensor import Tensor + +# slice is a built-in type, so we don't need to import it. +# slice_obj can be: an int (index), a slice (start:end), or a Tensor/Array (mask/indices) + + +def get_weight_slice( + weight: Tensor, + slice_obj: Union[int, slice, Tensor], + is_transposed: bool, +) -> Tensor: + """ + Generic helper to get a subset of weights along the input channel dimension. + + :param weight: The weight tensor. + :param slice_obj: An integer index, a slice(start, end), or a boolean mask/index tensor. + :param is_transposed: True if weight is [Out, In], False if [In, Out]. + :return: A slice of the weight tensor. + """ + if is_transposed: + return weight[:, slice_obj] + return weight[slice_obj, :] + + +def set_weight_slice( + weight: Tensor, + slice_obj: Union[int, slice, Tensor], + value: Union[Tensor, float, int], + is_transposed: bool, +) -> None: + """ + Generic helper to set a subset of weights along the input channel dimension. + + :param weight: The target tensor to modify in-place. + :param slice_obj: An integer index, a slice(start, end), or a boolean mask/index tensor. + :param value: The value(s) to assign. + :param is_transposed: True if weight is [Out, In], False if [In, Out]. + """ + if is_transposed: + weight[:, slice_obj] = value + else: + weight[slice_obj, :] = value diff --git a/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py index d0c96e952fb..be9b2c710e3 100644 --- a/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -643,7 +643,11 @@ def _calculate_integer_quantized_weight( compressed_weights = weight / scale if zero_point is not None: - compressed_weights += zero_point.astype(weight.dtype) + zp = zero_point.astype(weight.dtype) + if zp.ndim < compressed_weights.ndim: + new_shape = list(zp.shape) + [1] * (compressed_weights.ndim - zp.ndim) + zp = fns.reshape(zp, new_shape) + compressed_weights += zp compressed_weights = fns.round(compressed_weights) compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype) diff --git a/tests/openvino/native/quantization/test_utils_slice_weight.py b/tests/openvino/native/quantization/test_utils_slice_weight.py new file mode 100644 index 00000000000..ecd2389538c --- /dev/null +++ b/tests/openvino/native/quantization/test_utils_slice_weight.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import openvino as ov +import pytest +from openvino import opset13 as opset + +import nncf +from nncf import CompressWeightsMode + + +def get_transpose_b_false_model(): + """Creates model with [In, Out] weight layout (transpose_b=False)""" + input_shape = [1, 32] + input_node = opset.parameter(input_shape, name="Input") + # Weight shape [32, 16] -> Input=32, Output=16 + weight_data = np.random.rand(32, 16).astype(np.float32) + matmul_node = opset.matmul(input_node, weight_data, transpose_a=False, transpose_b=False, name="MatMul") + result_node = opset.result(matmul_node, name="Result") + return ov.Model([result_node], [input_node], "transpose_b_false_model") + + +@pytest.mark.parametrize( + "params", [{"awq": True}, {"gptq": True}, {"scale_estimation": True}, {"lora_correction": True}] +) +def test_compress_weights_algorithms_transpose_b_false(params): + """ + Checks that ALL data-aware algorithms support transpose_b=False + without crashing. + """ + model = get_transpose_b_false_model() + + # Dummy dataset for calibration + dataset = nncf.Dataset([np.random.rand(1, 32).astype(np.float32) for _ in range(3)]) + + # We use INT4_ASYM as it supports all these advanced algorithms + try: + nncf.compress_weights( + model, + mode=CompressWeightsMode.INT4_ASYM, + dataset=dataset, + subset_size=2, + **params, # Unpacks to awq=True, gptq=True, etc. + ) + except Exception as e: + pytest.fail(f"Algorithm {list(params.keys())[0]} failed for transpose_b=False. Error: {e}")