Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def apply(
prev_weight = self._backend_entity.get_weight(merge_node, prev_weight_port_id, model, graph)

prev_statistics = statistics[merge_node.node_name]
scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics)
scale = self._data_aware_step(
wp, weight, statistics[k], prev_weight, prev_statistics, weight_port_id, graph
)

w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0])
a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0])
Expand Down Expand Up @@ -210,7 +212,9 @@ def apply(

return transformed_model

def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statistics=None):
def _data_aware_step(
self, wp, weight, statistics, prev_weight=None, prev_statistics=None, weight_port_id=None, graph=None
):
alpha_step = (self._alpha_max - self._alpha_min) / self._steps
config = wp.compression_config
s, X = process_stats(statistics, self._subset_size)
Expand All @@ -220,6 +224,9 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
reduction_axis = wp.reduction_axes[0]

# Get transpose_b value from backend to handle weight shape correctly in a backend-agnostic way
transpose_b = self._backend_entity.get_weight_transpose_b(wp.node_with_weight, weight_port_id, graph)

prev_s, prev_w = None, None
if prev_statistics is not None and prev_weight is not None:
prev_s, _ = process_stats(prev_statistics, self._subset_size)
Expand All @@ -239,9 +246,10 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis

groups_to_correct = list(groups_to_correct)

if reduction_axis == 0:
weight = fns.transpose(weight)
reduction_axis = 1
# Remove the old transpose logic - we now rely on explicit transpose_b handling
# if reduction_axis == 0:
# weight = fns.transpose(weight)
# reduction_axis = 1

shape_vector = fns.mean(X, axis=1)
scale = fns.ones_like(shape_vector)
Expand All @@ -257,7 +265,11 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
a_max = 1e2
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)

gweight = weight[:, offset : offset + group_size]
# Slice weight block along the input-channel dimension taking transpose_b into account
if transpose_b:
gweight = weight[:, offset : offset + group_size]
else:
gweight = weight[offset : offset + group_size, :]
gacts = X[offset : offset + group_size, :]

fp32_out = fns.matmul(gweight, gacts)
Expand All @@ -274,18 +286,14 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
) # take the threshold from the fp16 type with some margin
# per channel magnitudes for the previous MatMul
# mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis]
magnitudes = (
(prev_w[offset : offset + group_size] / cur_scale) * prev_s * prev_weight.shape[reduction_axis]
)
prev_w_slice = prev_w[offset : offset + group_size]
magnitudes = (prev_w_slice / cur_scale) * prev_s * prev_weight.shape[reduction_axis]
if magnitudes.max() >= threshold:
cur_scale = AWQ._clamp_scale(
magnitudes,
threshold,
cur_scale,
prev_w[offset : offset + group_size]
* prev_s
* prev_weight.shape[reduction_axis]
/ threshold,
prev_w_slice * prev_s * prev_weight.shape[reduction_axis] / threshold,
)

weights_to_fake_quantize = gweight * cur_scale
Expand All @@ -307,7 +315,8 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
alpha += alpha_step

if best_scale is not None:
scale.data[offset : offset + group_size] = best_scale.data
# Assign best_scale to the corresponding slice of the scale vector
scale[offset : offset + group_size] = best_scale

return scale

Expand Down
14 changes: 14 additions & 0 deletions src/nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ def get_reduction_axes(node_with_weight: NNCFNode, weight_port_id: int, graph: N
:return: Reduction shape in tuple format or None if not applicable.
"""

@staticmethod
@abstractmethod
def get_weight_transpose_b(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> bool:
"""
Returns whether the weight input of the given node is treated as transposed (transpose_b=True).

This is backend-specific and abstracts away how the underlying framework stores MatMul/Gemm attributes.

:param node_with_weight: The node with weight.
:param weight_port_id: The input port ID that corresponds to the weight.
:param graph: The model graph.
:return: True if the backend treats the weight as transposed (e.g., [Out, In]), False otherwise.
"""

@staticmethod
@abstractmethod
def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> list[tuple[str, int]]:
Expand Down
113 changes: 83 additions & 30 deletions src/nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,25 @@ def _quantize_weights(
if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes:
msg = "Convolution metatypes are not supported"
raise RuntimeError(msg)
if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]:
msg = "Transpose is not supported"
raise RuntimeError(msg)

weight_tensor = self._backend_entity.get_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph
)
weight_tensor = fns.astype(weight_tensor, TensorDataType.float32)

# Get transpose_b value via backend to handle weight shape correctly in a backend-agnostic way
transpose_b = self._backend_entity.get_weight_transpose_b(
wc_params.node_with_weight, wc_params.weight_port_id, graph
)

dead_indices = fns.diag(hessian) == 0
hessian[dead_indices, dead_indices] = 1
weight_tensor[:, dead_indices] = 0

# Zero out dead indices along the input-channel dimension
if transpose_b:
weight_tensor[:, dead_indices] = 0
else:
weight_tensor[dead_indices, :] = 0

scales = []
zero_points = []
Expand All @@ -235,7 +242,7 @@ def _quantize_weights(
group_size = (
wc_params.compression_config.group_size
if wc_params.compression_config.group_size != -1
else weight_tensor.shape[1]
else (weight_tensor.shape[1] if transpose_b else weight_tensor.shape[0])
)
reduction_axes = wc_params.reduction_axes
block_compression_config = WeightCompressionConfig(
Expand All @@ -254,35 +261,50 @@ def _quantize_weights(
i2 = min(i1 + self._block_size, columns)
count = i2 - i1

weight_block = weight_tensor[:, i1:i2].clone()
# Extract weight block along the input-channel dimension
if transpose_b:
weight_block = weight_tensor[:, i1:i2].clone()
else:
weight_block = weight_tensor[i1:i2, :].clone()
quantized_block = fns.zeros_like(weight_block)
error_block = fns.zeros_like(weight_block)
loss_block = fns.zeros_like(weight_block)
hessian_inv_block = hessian_inv[i1:i2, i1:i2]

for i in range(count):
weight_col = weight_block[:, i]
if transpose_b:
weight_col = weight_block[:, i]
else:
weight_col = weight_block[i, :]
hessian_diag_val = hessian_inv_block[i, i]

if (i1 + i) % group_size == 0:
if not block_compression_config.is_integer:
if transpose_b:
weight_slice = weight_tensor[:, i1 + i : i1 + i + group_size]
else:
weight_slice = weight_tensor[i1 + i : i1 + i + group_size, :]
scale = calculate_float_quantization_params(
weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes, block_compression_config
weight_slice, reduction_axes, block_compression_config
)
scales.append(scale)
else:
if transpose_b:
weight_slice = weight_tensor[:, i1 + i : i1 + i + group_size]
else:
weight_slice = weight_tensor[i1 + i : i1 + i + group_size, :]
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
wc_statistics,
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
weight_slice,
reduction_axes,
block_compression_config,
)
else:
scale, zero_point = calculate_integer_quantization_params(
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
weight_slice,
reduction_axes,
block_compression_config,
)
Expand All @@ -303,19 +325,44 @@ def _quantize_weights(
precomputed_zero_point=zero_points[-1],
)
quantized_col = fns.flatten(quantized_col)
quantized_block[:, i] = quantized_col
loss_block[:, i] = (weight_col - quantized_col) ** 2 / hessian_diag_val**2
if transpose_b:
quantized_block[:, i] = quantized_col
else:
quantized_block[i, :] = quantized_col
loss_col = (weight_col - quantized_col) ** 2 / hessian_diag_val**2
if transpose_b:
loss_block[:, i] = loss_col
else:
loss_block[i, :] = loss_col

error_col = (weight_col - quantized_col) / hessian_diag_val
weight_block[:, i:] -= fns.matmul(
fns.unsqueeze(error_col, 1), fns.unsqueeze(hessian_inv_block[i, i:], 0)
)
error_block[:, i] = error_col

quantized_tensor[:, i1:i2] = quantized_block
losses[:, i1:i2] = loss_block / 2

weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:])
if transpose_b:
weight_block[:, i:] -= fns.matmul(
fns.unsqueeze(error_col, 1), fns.unsqueeze(hessian_inv_block[i, i:], 0)
)
error_block[:, i] = error_col
else:
weight_block[i:, :] -= fns.matmul(
fns.unsqueeze(error_col, 0), fns.unsqueeze(hessian_inv_block[i:, i], 1)
)
error_block[i, :] = error_col

if transpose_b:
quantized_tensor[:, i1:i2] = quantized_block
losses[:, i1:i2] = loss_block / 2
else:
quantized_tensor[i1:i2, :] = quantized_block
losses[i1:i2, :] = loss_block / 2

# Update remaining weights with error propagation
if transpose_b:
weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:])
else:
# For transpose_b=False: error_block shape is [i2-i1, out_features]
# hessian_inv[i2:, i1:i2] shape is [columns-i2, i2-i1]
# We need to transpose error_block to get [out_features, i2-i1]
# Then: hessian_inv[i2:, i1:i2] @ error_block^T gives [columns-i2, out_features]
weight_tensor[i2:, :] -= fns.matmul(hessian_inv[i2:, i1:i2], fns.transpose(error_block))

quantized_tensor = quantized_tensor.reshape(weight_tensor.shape).astype(weight_tensor.dtype)
self._backend_entity.set_weight(
Expand All @@ -325,13 +372,19 @@ def _quantize_weights(
scales = fns.stack(scales, axis=1)
if wc_params.compression_config.group_size == -1:
scales = fns.squeeze(scales, axis=-1)
if wc_params.compression_config.mode in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT4_ASYM,
]:
zero_points = fns.stack(zero_points, axis=1)

zero_points_tensor = None
if (
zero_points
and zero_points[0] is not None
and wc_params.compression_config.mode
in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT4_ASYM,
]
):
zero_points_tensor = fns.stack(zero_points, axis=1)
if wc_params.compression_config.group_size == -1:
zero_points = fns.squeeze(zero_points, axis=-1)
else:
zero_points = None
return scales, zero_points
zero_points_tensor = fns.squeeze(zero_points_tensor, axis=-1)

return scales, zero_points_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pandas as pd

import nncf
from nncf.common.graph import NNCFGraph
from nncf.common.logging import nncf_logger
from nncf.common.utils.debug import DEBUG_LOG_DIR
from nncf.common.utils.debug import is_debug
Expand Down Expand Up @@ -108,19 +109,30 @@ def is_applicable(self, wc_params: WeightCompressionParameters):
return wc_params.compression_config.num_bits == 4

def calculate_adapters(
self, weight: Tensor, compressed_weight: CompressedWeight, wc_params: WeightCompressionParameters
self,
weight: Tensor,
compressed_weight: CompressedWeight,
wc_params: WeightCompressionParameters,
graph: NNCFGraph,
) -> tuple[Tensor, Tensor, list[float]]:
"""
Calculates low rank matrices for a given original and compressed weights.

:param weight: original floating-point weight matrix.
:param compressed_weight: compressed weight matrix.
:param wc_params: parameters of weight compression.
:param graph: The model graph.
:return: two low rank matrices in the order of execution of corresponding linear layers.
"""
layer_name = wc_params.node_with_weight.node_name
layer_statistics = self._statistics[layer_name]
is_debug = self._debug_interface is not None

# Get transpose_b value via backend to handle weight shape correctly in a backend-agnostic way
transpose_b = self._backend_entity.get_weight_transpose_b(
wc_params.node_with_weight, wc_params.weight_port_id, graph
)

lora_A, lora_B, mean_noises = self.calculate_low_rank_matrices(
weight,
compressed_weight,
Expand All @@ -129,6 +141,7 @@ def calculate_adapters(
self._lora_correction_params,
layer_statistics,
is_debug,
transpose_b=transpose_b,
)
if is_debug:
self._debug_interface.add_noises(layer_name, mean_noises)
Expand All @@ -143,6 +156,7 @@ def calculate_low_rank_matrices(
lora_correction_params: AdvancedLoraCorrectionParameters,
layer_statistics: WCTensorStatistic,
is_debug: Optional[bool] = False,
transpose_b: bool = True, # Add this parameter with default True for backward compatibility
):
"""
Calculates low rank matrices for a given original and compressed weights.
Expand Down Expand Up @@ -190,7 +204,8 @@ def calculate_low_rank_matrices(

# O stands for output dimension, H - input dimension or hidden size, SS - samples size, R - rank.
# reduction axes is all axes except output dimension in linear/conv layers.
if reduction_axes[0] == 1:
# Use transpose_b directly instead of inferring from reduction_axes
if not transpose_b:
svd_residual = fns.transpose(svd_residual)
residual = svd_residual.clone() # [H, O]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,26 @@ def get_reduction_axes(node_with_weight: NNCFNode, weight_port_id: int, graph: N
channel_axes = (0,) + channel_axes
return get_reduction_axes(channel_axes, const_shape)

@staticmethod
def get_weight_transpose_b(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> bool:
"""
Returns the equivalent of transpose_b for ONNX MatMul/Gemm nodes.

For MatMul the initializer layout already follows the expected [K, N] contract of the op,
so we treat weights as transposed (transpose_b=True) by default. For Gemm we respect the
transB attribute when the corresponding input port matches the B input.
"""
# Gemm-specific handling: attribute name and semantics follow ONNX Gemm spec.
if node_with_weight.metatype is onnx_metatypes.ONNXGemmMetatype and weight_port_id == 1:
node_attrs = node_with_weight.layer_attributes.node_attrs
# In Gemm, transB=1 means the B input is transposed.
trans_b_attr = node_attrs.get("transB", 0)
return bool(trans_b_attr)

# For MatMul and other ops, rely on the fact that the stored initializer already matches
# the expected layout for the backend kernels and treat it as transpose_b=True.
return True

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> ONNXTargetPoint:
return ONNXTargetPoint(target_type, target_node_name, port_id)
Expand Down
Loading