Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3122842
enable awq
anzr299 Dec 3, 2025
63e221a
Merge branch 'openvinotoolkit:develop' into an/3d_weights_awq
anzr299 Dec 5, 2025
59033c1
update scale unsqueeze logic
anzr299 Dec 5, 2025
608b2e2
mergeable fix
anzr299 Dec 5, 2025
1948c95
fix the fix
anzr299 Dec 5, 2025
b5bbe15
add ignored nodes awq test for torch and torch fx
anzr299 Dec 5, 2025
10bdd6f
add remaining awq tests
anzr299 Dec 5, 2025
cad65db
Merge branch 'openvinotoolkit:develop' into an/3d_weights_awq
anzr299 Dec 7, 2025
9f511a4
Update awq.py
anzr299 Dec 7, 2025
b62074f
add 3d matmul model to onnx
anzr299 Dec 8, 2025
0bb9c9e
Merge branch 'an/3d_weights_awq' of https://github.com/anzr299/nncf i…
anzr299 Dec 8, 2025
b364b4d
fix onnx model test
anzr299 Dec 8, 2025
4249626
fix some tests
anzr299 Dec 8, 2025
cc44179
add ov model
anzr299 Dec 8, 2025
a934313
add model
anzr299 Dec 8, 2025
48c499c
xfail openvino test
anzr299 Dec 8, 2025
ba8b725
fix condition for is_mergeable
anzr299 Dec 8, 2025
667716e
fix mergeable issue
anzr299 Dec 8, 2025
2122b11
add act model for openvino; include data free test and call max varia…
anzr299 Dec 8, 2025
3e92f47
add torch and torch fx act linear model tests
anzr299 Dec 8, 2025
d44f6d5
fix data shape for OV model
anzr299 Dec 8, 2025
090f0f5
fix awq data free
anzr299 Dec 8, 2025
7222304
add check for opset version when weights are 3D
anzr299 Dec 8, 2025
211c806
xfail openvino case
anzr299 Dec 8, 2025
0d96516
fix test
anzr299 Dec 8, 2025
0b1019c
remove extra comments
anzr299 Dec 8, 2025
80c1ec2
fix
anzr299 Dec 8, 2025
1c76e5f
fix dynamic shapes
anzr299 Dec 8, 2025
b51d878
add xfail for last test
anzr299 Dec 8, 2025
db8417b
check dynamic dimensions correctly
anzr299 Dec 8, 2025
9173dfb
fix onnx backend formatting of weights
anzr299 Dec 8, 2025
7a9a128
fix reduction axes
anzr299 Dec 9, 2025
e605827
fix onnx test
anzr299 Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 62 additions & 31 deletions src/nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,35 @@ def apply(
weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph)
if len(weight_data) != 1: # not supported by the algorithm
continue
is_mergeable = self._backend_entity.is_node_with_weights(merge_node, graph)

nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}")

_, weight_port_id = weight_data[0]

weight = self._backend_entity.get_weight(
wp.node_with_weight, weight_port_id, model, graph
) # get_const_value(wp.weight_node)
weight_dtype = weight.dtype
weight = weight.astype(TensorDataType.float32)

is_mergeable = False
if self._backend_entity.is_node_with_weights(merge_node, graph):
mergeable_node_weight_data = self._backend_entity.get_weight_names_and_port_ids(merge_node, graph)
merge_node_weight_dims = [
len(self._backend_entity.get_weight_shape(merge_node, port_id, graph))
for _, port_id in mergeable_node_weight_data
]
is_mergeable = len(weight.shape) in merge_node_weight_dims

nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}")

weight_dim = len(weight.shape)
if is_data_free:
scale = self._data_free_step(weight, 1 - wp.reduction_axes[0])
# Reached this formula using a simple generalization of possible values.
# It comes out to be a beautiful constant - reduction axes where
# constant is (n-1)th odd number. Where n is the dimension
# 2(n-1)-1 -> 2n-3
# Example: 2D -> 1 - reduction_axes (reduction_axes=1 -> 1-1=0; reduction_axes=0; 1-0=1)
# 3D -> 3 - reduction_axes (reduction_axes=1 -> 3-1=2; reduction_axes=2; 3-2=1)
# 4D -> 5 - reduction_axes (reduction_axes=1 -> 3-1=2; reduction_axes=2; 3-2=1)
scale = self._data_free_step(weight, (weight_dim * 2) - 3 - wp.reduction_axes[0])
else:
prev_weight, prev_statistics = None, None
if is_mergeable:
Expand All @@ -183,7 +199,7 @@ def apply(
prev_statistics = statistics[merge_node.node_name]
scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics)

w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0])
w_scale = fns.unsqueeze(scale, (weight_dim * 2) - 3 - wp.reduction_axes[0])
a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0])

scaled_weight = (weight * w_scale).astype(weight_dtype)
Expand All @@ -194,9 +210,9 @@ def apply(
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
merge_weight = (merge_weight * a_scale).astype(weight_dtype)
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
a_scale = fns.transpose(a_scale)
a_scale = fns.moveaxis(a_scale, -1, -2)
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
a_scale = fns.transpose(a_scale).astype(weight_dtype)
a_scale = fns.moveaxis(a_scale, -1, -2).astype(weight_dtype)
next_nodes = graph.get_next_nodes(merge_node)
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
scale_insertion_command = self._backend_entity.scale_insertion_command(
Expand All @@ -217,49 +233,63 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
s = s.astype(TensorDataType.float32)
X = X.astype(TensorDataType.float32)

is_2d_weight = weight.ndim == 2

assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
reduction_axis = wp.reduction_axes[0]

if is_2d_weight:
s = fns.unsqueeze(s, 0)
X = fns.unsqueeze(X, 0)
weight = fns.unsqueeze(weight, 0)
prev_weight = fns.unsqueeze(prev_weight, 0) if prev_weight is not None else None
reduction_axis += 1

prev_s, prev_w = None, None
if prev_statistics is not None and prev_weight is not None:
prev_s, _ = process_stats(prev_statistics, self._subset_size)
prev_s = prev_s.astype(TensorDataType.float32).max().item()
prev_w = fns.mean(fns.abs(prev_weight), axis=reduction_axis)

top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:top_k]
top_k = max(int(s.shape[-1] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:, :top_k]

group_size = config.group_size
if group_size == -1:
group_size = s.shape[0]
group_size = s.shape[-1]

groups_to_correct = set()
for idx in topk_idxs:
groups_to_correct.add(idx.data // group_size)
for batch_idx in range(topk_idxs.shape[0]):
for k_idx in range(topk_idxs.shape[1]):
idx = topk_idxs[batch_idx, k_idx].item()
group_idx = idx // group_size
groups_to_correct.add((batch_idx, group_idx))

groups_to_correct = list(groups_to_correct)

if reduction_axis == 0:
weight = fns.transpose(weight)
reduction_axis = 1
if reduction_axis == 1:
# Weights
# 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
# 2D: [1, hidden_dimension, out_features] -> [1, out_features, hidden_dimension]
weight = fns.moveaxis(weight, -1, -2)
reduction_axis = weight.ndim - 1

shape_vector = fns.mean(X, axis=1)
shape_vector = fns.mean(X, axis=-1)
scale = fns.ones_like(shape_vector)

awq_config = deepcopy(config)
awq_config.group_size = -1

for gi in groups_to_correct:
for batch_idx, gi in groups_to_correct:
offset = gi * group_size
gscale = s[offset : offset + group_size]
gscale = s[batch_idx, offset : offset + group_size]
gweight = weight[batch_idx, :, offset : offset + group_size]
gacts = X[batch_idx, offset : offset + group_size, :]

a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32)
a_max = 1e2
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)

gweight = weight[:, offset : offset + group_size]
gacts = X[offset : offset + group_size, :]

fp32_out = fns.matmul(gweight, gacts)
min_diff = fns.max(fns.abs(fp32_out))
best_scale = None
Expand All @@ -275,28 +305,26 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
# per channel magnitudes for the previous MatMul
# mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis]
magnitudes = (
(prev_w[offset : offset + group_size] / cur_scale) * prev_s * prev_weight.shape[reduction_axis]
(prev_w[batch_idx, offset : offset + group_size] / cur_scale)
* prev_s
* prev_weight.shape[reduction_axis]
)
if magnitudes.max() >= threshold:
cur_scale = AWQ._clamp_scale(
magnitudes,
threshold,
cur_scale,
prev_w[offset : offset + group_size]
prev_w[batch_idx, offset : offset + group_size]
* prev_s
* prev_weight.shape[reduction_axis]
/ threshold,
)

weights_to_fake_quantize = gweight * cur_scale
if not config.is_integer:
g_decompressed_weighs = float_quantize_dequantize_weight(
weights_to_fake_quantize, awq_config, reduction_axis
)
g_decompressed_weighs = float_quantize_dequantize_weight(weights_to_fake_quantize, awq_config, -1)
else:
g_decompressed_weighs = integer_quantize_dequantize_weight(
weights_to_fake_quantize, awq_config, reduction_axis
)
g_decompressed_weighs = integer_quantize_dequantize_weight(weights_to_fake_quantize, awq_config, -1)
sacts = gacts / fns.unsqueeze(cur_scale, 1)

cur_out = fns.matmul(g_decompressed_weighs, sacts)
Expand All @@ -307,7 +335,10 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
alpha += alpha_step

if best_scale is not None:
scale.data[offset : offset + group_size] = best_scale.data
scale.data[batch_idx, offset : offset + group_size] = best_scale.data

if is_2d_weight:
scale = fns.squeeze(scale, 0) # [1, hidden_dim] -> [hidden_dim]

return scale

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,16 @@ def _preprocess_compressed_weight(
scale = compressed_weight.scale
zero_point = compressed_weight.zero_point

axis = 1 if dequantize_block_size else None
# For 3D weights, we need to squeeze at the next dimension compared to 2D because of batch dim
axis = 1 + len(scale.shape) % 3 if dequantize_block_size else None
scale = scale.squeeze(axis=axis)
if zero_point is not None:
zero_point = zero_point.squeeze(axis=axis)

if apply_transpose:
scale = fns.transpose(scale)
scale = fns.moveaxis(scale, -1, -2)
if zero_point is not None:
zero_point = fns.transpose(zero_point)
zero_point = fns.moveaxis(zero_point, -1, -2)

if zero_point is not None:
zero_point = zero_point.astype(tensor.dtype)
Expand Down Expand Up @@ -259,6 +260,10 @@ def transform_model(
# For opsets earlier than 21, we use the `MatMulNBits` operation from ONNX Runtime contrib operators.
# See https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md
if opset_version < 21 and dequantize_block_size > 0:
if len(weight.shape) == 3:
msg = """ONNX does not support 3D weights for opset version < 21.
Please use a higher opset version or per-channel quantization"""
raise nncf.ParameterNotSupportedError(msg)
compressed_weight, scale, zero_point = self._preprocess_compressed_weight(
compressed_weight, weight.shape, dequantize_block_size=None, apply_transpose=True
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def test_scale_estimation_outlier_channel_has_lowest_error(self, mocker):
# AWQ Tests
@staticmethod
@abstractmethod
def get_awq_act_model(with_multiply, n_layers):
def get_awq_act_model(is_3d_weights, with_multiply, n_layers):
"Returns a backend model for test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul."

@staticmethod
Expand All @@ -366,13 +366,16 @@ def get_num_multiply_from_awq(model: TModel) -> int:
def int4_mode(self, request):
return None

@pytest.mark.parametrize("is_3d_weights", [True, False])
@pytest.mark.parametrize("with_multiply", (True, False))
def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(self, int4_mode, with_multiply, mocker):
def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(
self, int4_mode, with_multiply, is_3d_weights, mocker
):
n_layers = 8
n_awq_target = n_layers - 1 # first MatMul is always int8
model = self.get_awq_act_model(with_multiply, n_layers)
model = self.get_awq_act_model(is_3d_weights, with_multiply, n_layers)

dataset = Dataset([self.to_tensor(np.ones([1, 8, 8], dtype=np.float32))], self.get_transform_func())
dataset = Dataset([self.to_tensor(np.ones([2, 8, 8], dtype=np.float32))], self.get_transform_func())

with SpyWeightCompressionStatisticsContext(mocker):
model = compress_weights(model, mode=int4_mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)
Expand All @@ -382,8 +385,11 @@ def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(self, int

@staticmethod
@abstractmethod
def get_awq_model() -> TModel:
"Returns a backend model for test_awq_with_ignored_scope."
def get_awq_model(is_3d_weights) -> TModel:
"""
Returns a backend model for test_awq_with_ignored_scope."
:param is_3d_weights: The model has 3d weights
"""

@staticmethod
@abstractmethod
Expand All @@ -402,16 +408,19 @@ def get_num_int4_group_sizes(model: TModel) -> dict[int, int]:

@staticmethod
@abstractmethod
def get_ignored_scope_name() -> str:
def get_ignored_scope_name(is_3d_weights) -> str:
"Returns ignored scope name for test_awq_with_ignored_scope."

def test_awq_with_ignored_scope(self, mocker):
model = self.get_awq_model()
@pytest.mark.parametrize("is_3d_weights", [True, False])
def test_awq_with_ignored_scope(self, mocker, is_3d_weights):
model = self.get_awq_model(is_3d_weights)
sz = 8
n_samples = 10

input_shape = [2, 8, sz]

dataset = Dataset(
[self.to_tensor(np.ones([1, 8, sz], dtype=np.float32)) for i in range(n_samples)],
[self.to_tensor(np.ones(input_shape, dtype=np.float32)) for i in range(n_samples)],
self.get_transform_func(),
)

Expand All @@ -423,12 +432,12 @@ def test_awq_with_ignored_scope(self, mocker):
group_size=-1,
dataset=dataset,
awq=True,
ignored_scope=IgnoredScope(names=[self.get_ignored_scope_name()]),
ignored_scope=IgnoredScope(names=[self.get_ignored_scope_name(is_3d_weights)]),
)

int4_ref_num_compressed = 4 # last MatMul is always int8; one - is ignored; total 6 matmuls
int4_num_nodes = self.get_num_int4_nodes(compressed_model)
assert int4_num_nodes == int4_ref_num_compressed
assert int4_num_nodes == int4_ref_num_compressed, int4_num_nodes

def test_rope_weight_compression(self):
model = self.get_RoPE_model()
Expand Down Expand Up @@ -473,14 +482,15 @@ def test_sam_pe_weight_compression(self):

@staticmethod
@abstractmethod
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
def get_reference_for_test_awq_scale_reference(is_3d_weights) -> dict[str, Tensor]:
"Returns reference for test_awq_scale_reference."

def test_awq_scale_reference(self, monkeypatch, mocker):
@pytest.mark.parametrize("is_3d_weights", [True, False])
def test_awq_scale_reference(self, monkeypatch, mocker, is_3d_weights):
monkeypatch.setattr("nncf.quantization.algorithms.weight_compression.algorithm.AWQ", SpyAWQ)
model = self.get_awq_model()
model = self.get_awq_model(is_3d_weights)

input = 0.01 * np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8) + 0.02
input = 0.01 * np.arange(0, 2 * 4 * 8, dtype=np.float32).reshape(2, 4, 8) + 0.02
input = self.to_tensor(input)
dataset = Dataset([input], self.get_transform_func())

Expand All @@ -495,7 +505,7 @@ def test_awq_scale_reference(self, monkeypatch, mocker):
)
assert spy_instance is not None
for node_name, scales in spy_instance._scale_per_target_node.items():
assert fns.allclose(scales, self.get_reference_for_test_awq_scale_reference()[node_name])
assert fns.allclose(scales, self.get_reference_for_test_awq_scale_reference(is_3d_weights)[node_name])

@pytest.mark.parametrize(
["group_size", "fallback_mode", "min_adjusted_group_size", "expected_outcome"],
Expand Down Expand Up @@ -619,14 +629,15 @@ def test_group_size_fallback_modes(
f"Expected {ref_num_group_sizes} group size values, but got {num_group_sizes}."
)

@pytest.mark.parametrize("dataset", [None, np.ones([1, 8, 8], dtype=np.float32)])
@pytest.mark.parametrize("is_3d_weights", [True, False])
@pytest.mark.parametrize("dataset", [None, np.ones([2, 8, 8], dtype=np.float32)])
@pytest.mark.parametrize("prefer_data_aware_scaling", [True, False])
def test_data_free_awq(self, dataset, prefer_data_aware_scaling, mocker):
input_data = np.ones([1, 8, 8], dtype=np.float32)
def test_data_free_awq(self, dataset, prefer_data_aware_scaling, is_3d_weights, mocker):
input_data = np.ones([2, 8, 8], dtype=np.float32)

n_layers = 8
n_awq_target = n_layers - 1 # first MatMul is always int8
model = self.get_awq_act_model(True, n_layers)
model = self.get_awq_act_model(is_3d_weights, True, n_layers)
model = self.wrap_model(model, input_data)

if dataset is not None:
Expand Down
Loading
Loading