Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions src/nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.tensor_slicing import get_weight_slice
from nncf.quantization.algorithms.weight_compression.tensor_slicing import set_weight_slice
from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight
from nncf.quantization.passes import transform_to_inference_graph
Expand Down Expand Up @@ -181,7 +183,7 @@ def apply(
prev_weight = self._backend_entity.get_weight(merge_node, prev_weight_port_id, model, graph)

prev_statistics = statistics[merge_node.node_name]
scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics)
scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics, weight_port_id)

w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0])
a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0])
Expand Down Expand Up @@ -210,7 +212,7 @@ def apply(

return transformed_model

def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statistics=None):
def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statistics=None, weight_port_id=None):
alpha_step = (self._alpha_max - self._alpha_min) / self._steps
config = wp.compression_config
s, X = process_stats(statistics, self._subset_size)
Expand All @@ -220,6 +222,9 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
reduction_axis = wp.reduction_axes[0]

# Get transpose_b value to handle weight shape correctly
transpose_b = wp.node_with_weight.layer_attributes.constant_attributes[weight_port_id]["transpose"]

prev_s, prev_w = None, None
if prev_statistics is not None and prev_weight is not None:
prev_s, _ = process_stats(prev_statistics, self._subset_size)
Expand All @@ -239,9 +244,10 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis

groups_to_correct = list(groups_to_correct)

if reduction_axis == 0:
weight = fns.transpose(weight)
reduction_axis = 1
# Remove the old transpose logic - we'll use get_weight_slice instead
# if reduction_axis == 0:
# weight = fns.transpose(weight)
# reduction_axis = 1

shape_vector = fns.mean(X, axis=1)
scale = fns.ones_like(shape_vector)
Expand All @@ -257,7 +263,8 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
a_max = 1e2
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)

gweight = weight[:, offset : offset + group_size]
# Use get_weight_slice instead of hardcoded slicing
gweight = get_weight_slice(weight, slice(offset, offset + group_size), transpose_b)
gacts = X[offset : offset + group_size, :]

fp32_out = fns.matmul(gweight, gacts)
Expand All @@ -274,18 +281,14 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
) # take the threshold from the fp16 type with some margin
# per channel magnitudes for the previous MatMul
# mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis]
magnitudes = (
(prev_w[offset : offset + group_size] / cur_scale) * prev_s * prev_weight.shape[reduction_axis]
)
prev_w_slice = prev_w[offset : offset + group_size]
magnitudes = (prev_w_slice / cur_scale) * prev_s * prev_weight.shape[reduction_axis]
if magnitudes.max() >= threshold:
cur_scale = AWQ._clamp_scale(
magnitudes,
threshold,
cur_scale,
prev_w[offset : offset + group_size]
* prev_s
* prev_weight.shape[reduction_axis]
/ threshold,
prev_w_slice * prev_s * prev_weight.shape[reduction_axis] / threshold,
)

weights_to_fake_quantize = gweight * cur_scale
Expand All @@ -307,7 +310,8 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
alpha += alpha_step

if best_scale is not None:
scale.data[offset : offset + group_size] = best_scale.data
# Use set_weight_slice for assignment
set_weight_slice(scale, slice(offset, offset + group_size), best_scale, transpose_b)

return scale

Expand Down
86 changes: 58 additions & 28 deletions src/nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight
from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation
from nncf.quantization.algorithms.weight_compression.tensor_slicing import get_weight_slice
from nncf.quantization.algorithms.weight_compression.tensor_slicing import set_weight_slice
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_float_quantization_params
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params
from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight
Expand Down Expand Up @@ -212,18 +214,22 @@ def _quantize_weights(
if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes:
msg = "Convolution metatypes are not supported"
raise RuntimeError(msg)
if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]:
msg = "Transpose is not supported"
raise RuntimeError(msg)

weight_tensor = self._backend_entity.get_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph
)
weight_tensor = fns.astype(weight_tensor, TensorDataType.float32)

# Get transpose_b value to handle weight shape correctly
transpose_b = wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id][
"transpose"
]

dead_indices = fns.diag(hessian) == 0
hessian[dead_indices, dead_indices] = 1
weight_tensor[:, dead_indices] = 0

# Zero out dead indices using utility helper
set_weight_slice(weight_tensor, dead_indices, 0, transpose_b)

scales = []
zero_points = []
Expand All @@ -235,7 +241,7 @@ def _quantize_weights(
group_size = (
wc_params.compression_config.group_size
if wc_params.compression_config.group_size != -1
else weight_tensor.shape[1]
else (weight_tensor.shape[1] if transpose_b else weight_tensor.shape[0])
)
reduction_axes = wc_params.reduction_axes
block_compression_config = WeightCompressionConfig(
Expand All @@ -254,35 +260,38 @@ def _quantize_weights(
i2 = min(i1 + self._block_size, columns)
count = i2 - i1

weight_block = weight_tensor[:, i1:i2].clone()
# Extract weight block using utility helper
weight_block = get_weight_slice(weight_tensor, slice(i1, i2), transpose_b).clone()
quantized_block = fns.zeros_like(weight_block)
error_block = fns.zeros_like(weight_block)
loss_block = fns.zeros_like(weight_block)
hessian_inv_block = hessian_inv[i1:i2, i1:i2]

for i in range(count):
weight_col = weight_block[:, i]
weight_col = get_weight_slice(weight_block, i, transpose_b)
hessian_diag_val = hessian_inv_block[i, i]

if (i1 + i) % group_size == 0:
if not block_compression_config.is_integer:
weight_slice = get_weight_slice(weight_tensor, slice(i1 + i, i1 + i + group_size), transpose_b)
scale = calculate_float_quantization_params(
weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes, block_compression_config
weight_slice, reduction_axes, block_compression_config
)
scales.append(scale)
else:
weight_slice = get_weight_slice(weight_tensor, slice(i1 + i, i1 + i + group_size), transpose_b)
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
wc_statistics,
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
weight_slice,
reduction_axes,
block_compression_config,
)
else:
scale, zero_point = calculate_integer_quantization_params(
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
weight_slice,
reduction_axes,
block_compression_config,
)
Expand All @@ -303,19 +312,34 @@ def _quantize_weights(
precomputed_zero_point=zero_points[-1],
)
quantized_col = fns.flatten(quantized_col)
quantized_block[:, i] = quantized_col
loss_block[:, i] = (weight_col - quantized_col) ** 2 / hessian_diag_val**2
set_weight_slice(quantized_block, i, quantized_col, transpose_b)
loss_col = (weight_col - quantized_col) ** 2 / hessian_diag_val**2
set_weight_slice(loss_block, i, loss_col, transpose_b)

error_col = (weight_col - quantized_col) / hessian_diag_val
weight_block[:, i:] -= fns.matmul(
fns.unsqueeze(error_col, 1), fns.unsqueeze(hessian_inv_block[i, i:], 0)
)
error_block[:, i] = error_col
if transpose_b:
weight_block[:, i:] -= fns.matmul(
fns.unsqueeze(error_col, 1), fns.unsqueeze(hessian_inv_block[i, i:], 0)
)
set_weight_slice(error_block, i, error_col, transpose_b)
else:
weight_block[i:, :] -= fns.matmul(
fns.unsqueeze(error_col, 0), fns.unsqueeze(hessian_inv_block[i:, i], 1)
)
set_weight_slice(error_block, i, error_col, transpose_b)

quantized_tensor[:, i1:i2] = quantized_block
losses[:, i1:i2] = loss_block / 2
set_weight_slice(quantized_tensor, slice(i1, i2), quantized_block, transpose_b)
set_weight_slice(losses, slice(i1, i2), loss_block / 2, transpose_b)

weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:])
# Update remaining weights with error propagation
if transpose_b:
weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:])
else:
# For transpose_b=False: error_block shape is [i2-i1, out_features]
# hessian_inv[i2:, i1:i2] shape is [columns-i2, i2-i1]
# We need to transpose error_block to get [out_features, i2-i1]
# Then: hessian_inv[i2:, i1:i2] @ error_block^T gives [columns-i2, out_features]
weight_tensor[i2:, :] -= fns.matmul(hessian_inv[i2:, i1:i2], fns.transpose(error_block))

quantized_tensor = quantized_tensor.reshape(weight_tensor.shape).astype(weight_tensor.dtype)
self._backend_entity.set_weight(
Expand All @@ -325,13 +349,19 @@ def _quantize_weights(
scales = fns.stack(scales, axis=1)
if wc_params.compression_config.group_size == -1:
scales = fns.squeeze(scales, axis=-1)
if wc_params.compression_config.mode in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT4_ASYM,
]:
zero_points = fns.stack(zero_points, axis=1)

zero_points_tensor = None
if (
zero_points
and zero_points[0] is not None
and wc_params.compression_config.mode
in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT4_ASYM,
]
):
zero_points_tensor = fns.stack(zero_points, axis=1)
if wc_params.compression_config.group_size == -1:
zero_points = fns.squeeze(zero_points, axis=-1)
else:
zero_points = None
return scales, zero_points
zero_points_tensor = fns.squeeze(zero_points_tensor, axis=-1)

return scales, zero_points_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def calculate_adapters(
layer_name = wc_params.node_with_weight.node_name
layer_statistics = self._statistics[layer_name]
is_debug = self._debug_interface is not None

# Get transpose_b value to handle weight shape correctly
transpose_b = wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id][
"transpose"
]

lora_A, lora_B, mean_noises = self.calculate_low_rank_matrices(
weight,
compressed_weight,
Expand All @@ -129,6 +135,7 @@ def calculate_adapters(
self._lora_correction_params,
layer_statistics,
is_debug,
transpose_b=transpose_b,
)
if is_debug:
self._debug_interface.add_noises(layer_name, mean_noises)
Expand All @@ -143,6 +150,7 @@ def calculate_low_rank_matrices(
lora_correction_params: AdvancedLoraCorrectionParameters,
layer_statistics: WCTensorStatistic,
is_debug: Optional[bool] = False,
transpose_b: bool = True, # Add this parameter with default True for backward compatibility
):
"""
Calculates low rank matrices for a given original and compressed weights.
Expand Down Expand Up @@ -190,7 +198,8 @@ def calculate_low_rank_matrices(

# O stands for output dimension, H - input dimension or hidden size, SS - samples size, R - rank.
# reduction axes is all axes except output dimension in linear/conv layers.
if reduction_axes[0] == 1:
# Use transpose_b directly instead of inferring from reduction_axes
if not transpose_b:
svd_residual = fns.transpose(svd_residual)
residual = svd_residual.clone() # [H, O]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight
from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight
from nncf.tensor import Tensor
from nncf.tensor import functions as fns
from nncf.tensor.definitions import TensorDataType
from nncf.tensor.functions.openvino_numeric import DTYPE_MAP_REV

Expand Down Expand Up @@ -178,6 +179,16 @@ def insert_adapters(
should_add_convert_node = activation_dtype != ov.Type.f16
mm_node = self.name_to_node_mapping[wc_params.node_with_weight.node_name]

# Get the original MatMul's transpose attributes
node_attributes = mm_node.get_attributes()
transpose_a = node_attributes.get("transpose_a", False)
transpose_b = node_attributes.get("transpose_b", True) # Default to True for backward compatibility

# Transpose lora_B if the original MatMul had transpose_b=False
# This ensures the matrix multiplication A_MM @ B_W has compatible dimensions
if not transpose_b:
lora_B = fns.transpose(lora_B)

if int8_lora:
const_node_name = wc_params.node_with_weight.node_name
int8_compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT8_ASYM, group_size=-1)
Expand All @@ -203,7 +214,9 @@ def insert_adapters(
A_W = opset.constant(lora_A.data)
B_W = opset.constant(lora_B.data)

A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True)
# LoRA adapters: input @ A^T @ B^T
# Always keep transpose_b=True to ensure the adapter aligns with the MatMul output shape
A_MM = opset.matmul(input_node, A_W, transpose_a=transpose_a, transpose_b=True)
B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True)

node_output_port = mm_node.output(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def apply(

weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph)

# Get transpose_b value to handle weight shape correctly
transpose_b = wp.node_with_weight.layer_attributes.constant_attributes[weight_port_id]["transpose"]

scale, zero_point = self.calculate_quantization_params(
stats,
weight,
Expand All @@ -150,6 +153,7 @@ def apply(
self._initial_steps,
self._scale_steps,
self._weight_penalty,
transpose_b=transpose_b,
)
res[weight_name] = CompressedWeight(None, scale, zero_point, None)

Expand All @@ -165,6 +169,7 @@ def calculate_quantization_params(
initial_steps: int = 5,
scale_steps: int = 10,
weight_penalty: float = -1.0,
transpose_b: bool = True, # Add this parameter with default True for backward compatibility
) -> Tensor:
"""
Calculates the quantization parameters for a given set of weights and activations.
Expand Down Expand Up @@ -199,7 +204,8 @@ def calculate_quantization_params(
is_3d_weight = len(weight.shape) == 3

was_transposed = False
if reduction_axis == 0 or (reduction_axis == 1 and is_3d_weight):
# Use transpose_b directly instead of inferring from reduction_axis
if not transpose_b or (reduction_axis == 1 and is_3d_weight):
# Weights
# 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
# 2D: [hidden_dimension, out_features] -> [out_features, hidden_dimension]
Expand Down
Loading
Loading