Skip to content

Commit 88445b3

Browse files
authored
Support 3D Weights in AWQ Algorithm (openvinotoolkit#3728)
### Changes The core idea of this change is to first unsqueeze the weights so that it becomes 3D. Even the 2D weights. Then the rest of the algorithm implementation is changed such that it expects the weight shape to be 3D. Earlier we traversed each group in a weight individually. But now Since we want to find the scales for per-channel as well as per-expert, we traverse by group index as well as batch/expert index (this is just 1 for 2D weights so the behavior is same as before). ### Reason for changes Support AWQ for models with 3D weights such as MoE models. ### Related tickets 175789 & 175212 ### Tests Current AWQ tests were extended to include the AWQ test models with 3D weights. **Accuracy Evaluation Results:** Model: Qwen/Qwen3-30B-A3B NNCF Backend: OpenVINO Higher is better. Task: gsm8k Limit: 100 Max New Tokens: 10000 OpenVINO version: 2026.0.0.dev20260102 n-shots: 5(default) Model | Filter | Score (exact_match) | Stderr -- | -- | -- | -- FP16 | flexible-extract | 0.92 | 0.0273   | strict-match | 0.82 | 0.0386 INT4 SYM Per-Channel (no AWQ) | flexible-extract | 0.83 | 0.0378   | strict-match | 0.27 | 0.0446 INT4 SYM Per-Channel (AWQ data-free) | flexible-extract | 0.83 | 0.0378   | strict-match | 0.22 | 0.0416 INT4 Per-Channel (AWQ data-aware) | flexible-extract | 0.83 | 0.0378   | strict-match | 0.35 | 0.0479 Comparison of accuracy with `meta-llama/Llama-3.2-1B-Instruct` on Develop and this branch Variant | bits_per_byte | byte_perplexity | word_perplexity -- | -- | -- | -- This Branch (Data Aware) | 0.7774 | 1.7141 | 17.8427 This Branch (Data Free) | 0.7774 | 1.7141 | 17.8427 develop (Data Aware) | 0.7774 | 1.7141 | 17.8427 develop (Data Free) | 0.7774 | 1.7141 | 17.8427 WC Conformance test: https://github.com/openvinotoolkit/nncf/actions/runs/20883502496 - Pass WC Example Test: https://github.com/openvinotoolkit/nncf/actions/runs/20883506117 - Pass
1 parent 2c068c5 commit 88445b3

File tree

8 files changed

+732
-245
lines changed

8 files changed

+732
-245
lines changed

src/nncf/quantization/algorithms/weight_compression/awq.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,8 @@ def apply(
159159
weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph)
160160
if len(weight_data) != 1: # not supported by the algorithm
161161
continue
162-
is_mergeable = self._backend_entity.is_node_with_weights(merge_node, graph)
163-
164-
nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}")
165-
166162
_, weight_port_id = weight_data[0]
163+
167164
weight = self._backend_entity.get_weight(
168165
wp.node_with_weight, weight_port_id, model, graph
169166
) # get_const_value(wp.weight_node)
@@ -172,8 +169,26 @@ def apply(
172169

173170
act_ch_axis, act_shape = self._get_activation_channel_axis_and_shape(graph, wp)
174171

172+
is_mergeable = False
173+
if self._backend_entity.is_node_with_weights(merge_node, graph):
174+
mergeable_node_weight_data = self._backend_entity.get_weight_names_and_port_ids(merge_node, graph)
175+
merge_node_weight_ndims = [
176+
len(self._backend_entity.get_weight_shape(merge_node, port_id, graph))
177+
for _, port_id in mergeable_node_weight_data
178+
]
179+
is_mergeable = len(weight.shape) in merge_node_weight_ndims
180+
181+
nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}")
182+
183+
weight_ndim = len(weight.shape)
184+
# Weights scale reduction formula:
185+
# 2(n-1)-1 -> 2n-3
186+
# Example: 2D -> 1 - reduction_axes (reduction_axes=1) = 0
187+
# 3D -> 3 - reduction_axes (reduction_axes=1) = 2
188+
# 4D -> 5 - reduction_axes (reduction_axes=1) = 4
189+
weight_scale_reduction_axes = (weight_ndim * 2) - 3 - wp.reduction_axes[0]
175190
if is_data_free:
176-
scale = self._data_free_step(weight, 1 - wp.reduction_axes[0])
191+
scale = self._data_free_step(weight, axis=weight_scale_reduction_axes)
177192
else:
178193
prev_weight, prev_statistics = None, None
179194
if is_mergeable:
@@ -185,7 +200,7 @@ def apply(
185200
prev_statistics = statistics[merge_node.node_name]
186201
scale = self._data_aware_step(wp, weight, statistics[k], act_ch_axis, prev_weight, prev_statistics)
187202

188-
w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0])
203+
w_scale = fns.unsqueeze(scale, weight_scale_reduction_axes)
189204
a_scale = 1.0 / scale
190205

191206
scaled_weight = (weight * w_scale).astype(weight_dtype)
@@ -198,9 +213,18 @@ def apply(
198213
merge_weight = (merge_weight * a_scale).astype(weight_dtype)
199214
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
200215
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
201-
# Calculate the activation scale shape
202-
a_scale_shape = [scale.shape[0] if axis == act_ch_axis else 1 for axis in range(len(act_shape))]
203-
a_scale = fns.reshape(a_scale, tuple(a_scale_shape))
216+
act_ndim = len(act_shape)
217+
scale_shape = a_scale.shape
218+
# Only the last dim in the activation scale is for channel. The others are for batch
219+
batch_dims = iter(scale_shape[:-1])
220+
# For the last dim of the scale which is assumed channel, we place it as it is
221+
# For the rest of the elements we iterate the batch dims and place accordingly
222+
# And once we finish, we start placing ones if the current dimension is not
223+
# channel axis And it is not a batch dim, we place 1.
224+
act_scale_shape = tuple(
225+
scale_shape[-1] if dim == act_ch_axis else next(batch_dims, 1) for dim in range(act_ndim)
226+
)
227+
a_scale = fns.reshape(a_scale, act_scale_shape)
204228

205229
next_nodes = graph.get_next_nodes(merge_node)
206230
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
@@ -223,49 +247,63 @@ def _data_aware_step(self, wp, weight, statistics, act_ch_axis, prev_weight=None
223247
s = s.astype(TensorDataType.float32)
224248
X = X.astype(TensorDataType.float32)
225249

250+
is_2d_weight = weight.ndim == 2
251+
226252
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
227253
reduction_axis = wp.reduction_axes[0]
228254

255+
if is_2d_weight:
256+
s = fns.unsqueeze(s, 0)
257+
X = fns.unsqueeze(X, 0)
258+
weight = fns.unsqueeze(weight, 0)
259+
prev_weight = fns.unsqueeze(prev_weight, 0) if prev_weight is not None else None
260+
reduction_axis += 1
261+
229262
prev_s, prev_w = None, None
230263
if prev_statistics is not None and prev_weight is not None:
231264
prev_s, _ = process_stats(prev_statistics, self._subset_size, act_ch_axis)
232265
prev_s = prev_s.astype(TensorDataType.float32).max().item()
233266
prev_w = fns.mean(fns.abs(prev_weight), axis=reduction_axis)
234267

235-
top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
236-
topk_idxs = fns.argsort(-s)[:top_k]
268+
top_k = max(int(s.shape[-1] * self._percent_to_apply), 1)
269+
topk_idxs = fns.argsort(-s)[:, :top_k]
237270

238271
group_size = config.group_size
239272
if group_size == -1:
240-
group_size = s.shape[0]
273+
group_size = s.shape[-1]
241274

242275
groups_to_correct = set()
243-
for idx in topk_idxs:
244-
groups_to_correct.add(idx.data // group_size)
276+
for batch_idx in range(topk_idxs.shape[0]):
277+
for k_idx in range(topk_idxs.shape[1]):
278+
idx = topk_idxs[batch_idx, k_idx].item()
279+
group_idx = idx // group_size
280+
groups_to_correct.add((batch_idx, group_idx))
245281

246282
groups_to_correct = list(groups_to_correct)
247283

248-
if reduction_axis == 0:
249-
weight = fns.transpose(weight)
250-
reduction_axis = 1
284+
if reduction_axis == 1:
285+
# Weights
286+
# 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
287+
# 2D: [1, hidden_dimension, out_features] -> [1, out_features, hidden_dimension]
288+
weight = fns.moveaxis(weight, -1, -2)
289+
reduction_axis = weight.ndim - 1
251290

252-
shape_vector = fns.mean(X, axis=1)
291+
shape_vector = fns.mean(X, axis=-1)
253292
scale = fns.ones_like(shape_vector)
254293

255294
awq_config = deepcopy(config)
256295
awq_config.group_size = -1
257296

258-
for gi in groups_to_correct:
297+
for batch_idx, gi in groups_to_correct:
259298
offset = gi * group_size
260-
gscale = s[offset : offset + group_size]
299+
gscale = s[batch_idx, offset : offset + group_size]
300+
gweight = weight[batch_idx, :, offset : offset + group_size]
301+
gacts = X[batch_idx, offset : offset + group_size, :]
261302

262303
a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32)
263304
a_max = 1e2
264305
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)
265306

266-
gweight = weight[:, offset : offset + group_size]
267-
gacts = X[offset : offset + group_size, :]
268-
269307
fp32_out = fns.matmul(gweight, gacts)
270308
min_diff = fns.max(fns.abs(fp32_out))
271309
best_scale = None
@@ -281,28 +319,26 @@ def _data_aware_step(self, wp, weight, statistics, act_ch_axis, prev_weight=None
281319
# per channel magnitudes for the previous MatMul
282320
# mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis]
283321
magnitudes = (
284-
(prev_w[offset : offset + group_size] / cur_scale) * prev_s * prev_weight.shape[reduction_axis]
322+
(prev_w[batch_idx, offset : offset + group_size] / cur_scale)
323+
* prev_s
324+
* prev_weight.shape[reduction_axis]
285325
)
286326
if magnitudes.max() >= threshold:
287327
cur_scale = AWQ._clamp_scale(
288328
magnitudes,
289329
threshold,
290330
cur_scale,
291-
prev_w[offset : offset + group_size]
331+
prev_w[batch_idx, offset : offset + group_size]
292332
* prev_s
293333
* prev_weight.shape[reduction_axis]
294334
/ threshold,
295335
)
296336

297337
weights_to_fake_quantize = gweight * cur_scale
298338
if not config.is_integer:
299-
g_decompressed_weighs = float_quantize_dequantize_weight(
300-
weights_to_fake_quantize, awq_config, reduction_axis
301-
)
339+
g_decompressed_weighs = float_quantize_dequantize_weight(weights_to_fake_quantize, awq_config, -1)
302340
else:
303-
g_decompressed_weighs = integer_quantize_dequantize_weight(
304-
weights_to_fake_quantize, awq_config, reduction_axis
305-
)
341+
g_decompressed_weighs = integer_quantize_dequantize_weight(weights_to_fake_quantize, awq_config, -1)
306342
sacts = gacts / fns.unsqueeze(cur_scale, 1)
307343

308344
cur_out = fns.matmul(g_decompressed_weighs, sacts)
@@ -313,7 +349,10 @@ def _data_aware_step(self, wp, weight, statistics, act_ch_axis, prev_weight=None
313349
alpha += alpha_step
314350

315351
if best_scale is not None:
316-
scale.data[offset : offset + group_size] = best_scale.data
352+
scale.data[batch_idx, offset : offset + group_size] = best_scale.data
353+
354+
if is_2d_weight:
355+
scale = fns.squeeze(scale, 0) # [1, hidden_dim] -> [hidden_dim]
317356

318357
return scale
319358

src/nncf/quantization/algorithms/weight_compression/onnx_backend.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,16 @@ def _preprocess_compressed_weight(
110110
scale = compressed_weight.scale
111111
zero_point = compressed_weight.zero_point
112112

113-
axis = 1 if dequantize_block_size else None
113+
# For 3D weights, we need to squeeze at the next dimension compared to 2D because of batch dim
114+
axis = 1 + len(scale.shape) % 3 if dequantize_block_size else None
114115
scale = scale.squeeze(axis=axis)
115116
if zero_point is not None:
116117
zero_point = zero_point.squeeze(axis=axis)
117118

118119
if apply_transpose:
119-
scale = fns.transpose(scale)
120+
scale = fns.moveaxis(scale, -1, -2)
120121
if zero_point is not None:
121-
zero_point = fns.transpose(zero_point)
122+
zero_point = fns.moveaxis(zero_point, -1, -2)
122123

123124
if zero_point is not None:
124125
zero_point = zero_point.astype(tensor.dtype)
@@ -267,6 +268,10 @@ def transform_model(
267268
# For opsets earlier than 21, we use the `MatMulNBits` operation from ONNX Runtime contrib operators.
268269
# See https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md
269270
if opset_version < 21 and dequantize_block_size > 0:
271+
if len(weight.shape) == 3:
272+
msg = """ONNX does not support 3D weights for opset version < 21.
273+
Please use a higher opset version or per-channel quantization"""
274+
raise nncf.ParameterNotSupportedError(msg)
270275
compressed_weight, scale, zero_point = self._preprocess_compressed_weight(
271276
compressed_weight, weight.shape, dequantize_block_size=None, apply_transpose=True
272277
)

tests/cross_fw/test_templates/template_test_weights_compression.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def test_scale_estimation_outlier_channel_has_lowest_error(self, mocker):
360360
# AWQ Tests
361361
@staticmethod
362362
@abstractmethod
363-
def get_awq_act_model(with_multiply, n_layers):
363+
def get_awq_act_model(is_3d_weights, with_multiply, n_layers):
364364
"Returns a backend model for test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul."
365365

366366
@staticmethod
@@ -372,13 +372,16 @@ def get_num_multiply_from_awq(model: TModel) -> int:
372372
def int4_mode(self, request):
373373
return None
374374

375+
@pytest.mark.parametrize("is_3d_weights", [True, False])
375376
@pytest.mark.parametrize("with_multiply", (True, False))
376-
def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(self, int4_mode, with_multiply, mocker):
377+
def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(
378+
self, int4_mode, with_multiply, is_3d_weights, mocker
379+
):
377380
n_layers = 8
378381
n_awq_target = n_layers - 1 # first MatMul is always int8
379-
model = self.get_awq_act_model(with_multiply, n_layers)
382+
model = self.get_awq_act_model(is_3d_weights, with_multiply, n_layers)
380383

381-
dataset = Dataset([self.to_tensor(np.ones([1, 8, 8], dtype=np.float32))], self.get_transform_func())
384+
dataset = Dataset([self.to_tensor(np.ones([2, 8, 8], dtype=np.float32))], self.get_transform_func())
382385

383386
with SpyWeightCompressionStatisticsContext(mocker):
384387
model = compress_weights(model, mode=int4_mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)
@@ -388,8 +391,11 @@ def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(self, int
388391

389392
@staticmethod
390393
@abstractmethod
391-
def get_awq_model(non_mergable_pattern: bool) -> TModel:
392-
"Returns a backend model for test_awq_with_ignored_scope."
394+
def get_awq_model(non_mergable_pattern: bool, is_3d_weights: bool) -> TModel:
395+
"""
396+
Returns a backend model for test_awq_with_ignored_scope."
397+
:param is_3d_weights: The model has 3d weights
398+
"""
393399

394400
@staticmethod
395401
@abstractmethod
@@ -408,16 +414,19 @@ def get_num_int4_group_sizes(model: TModel) -> dict[int, int]:
408414

409415
@staticmethod
410416
@abstractmethod
411-
def get_ignored_scope_name() -> str:
417+
def get_ignored_scope_name(is_3d_weights) -> str:
412418
"Returns ignored scope name for test_awq_with_ignored_scope."
413419

414-
def test_awq_with_ignored_scope(self, mocker):
415-
model = self.get_awq_model(non_mergable_pattern=False)
420+
@pytest.mark.parametrize("is_3d_weights", [True, False])
421+
def test_awq_with_ignored_scope(self, mocker, is_3d_weights):
422+
model = self.get_awq_model(non_mergable_pattern=False, is_3d_weights=is_3d_weights)
416423
sz = 8
417424
n_samples = 10
418425

426+
input_shape = [2, 8, sz]
427+
419428
dataset = Dataset(
420-
[self.to_tensor(np.ones([1, 8, sz], dtype=np.float32)) for i in range(n_samples)],
429+
[self.to_tensor(np.ones(input_shape, dtype=np.float32)) for i in range(n_samples)],
421430
self.get_transform_func(),
422431
)
423432

@@ -429,12 +438,12 @@ def test_awq_with_ignored_scope(self, mocker):
429438
group_size=-1,
430439
dataset=dataset,
431440
awq=True,
432-
ignored_scope=IgnoredScope(names=[self.get_ignored_scope_name()]),
441+
ignored_scope=IgnoredScope(names=[self.get_ignored_scope_name(is_3d_weights)]),
433442
)
434443

435444
int4_ref_num_compressed = 4 # last MatMul is always int8; one - is ignored; total 6 matmuls
436445
int4_num_nodes = self.get_num_int4_nodes(compressed_model)
437-
assert int4_num_nodes == int4_ref_num_compressed
446+
assert int4_num_nodes == int4_ref_num_compressed, int4_num_nodes
438447

439448
def test_rope_weight_compression(self):
440449
model = self.get_RoPE_model()
@@ -490,12 +499,14 @@ def transpose_a_supported(self) -> bool:
490499

491500
# Transpose inputs does not affect mergable pattern code, skippting (True, False)
492501
@pytest.mark.parametrize("transpose_a,non_mergable_pattern", [(True, True), (False, True), (False, False)])
502+
@pytest.mark.parametrize("is_3d_weights", [True, False])
493503
def test_awq_scale_reference(
494504
self,
495505
non_mergable_pattern,
496506
transpose_a,
497507
test_awq_scale_ref,
498508
transpose_a_supported,
509+
is_3d_weights,
499510
monkeypatch,
500511
mocker,
501512
):
@@ -505,11 +516,14 @@ def test_awq_scale_reference(
505516
msg = "Transpose a is not supported for the current backend"
506517
pytest.skip(msg)
507518

508-
INPUT_SHAPE = (2, 4)
509-
model = self.get_transposable_awq_model(transpose_a=True, transpose_b=True, input_shape=INPUT_SHAPE)
519+
INPUT_SHAPE = (2, 2, 4) if is_3d_weights else (2, 4)
520+
model = self.get_transposable_awq_model(
521+
transpose_a=True, transpose_b=True, input_shape=INPUT_SHAPE, is_3d_weights=is_3d_weights
522+
)
510523
else:
511-
INPUT_SHAPE = (1, 4, 8)
512-
model = self.get_awq_model(non_mergable_pattern)
524+
batch_size = 1 if not is_3d_weights else 2
525+
INPUT_SHAPE = (batch_size, 4, 8)
526+
model = self.get_awq_model(non_mergable_pattern, is_3d_weights)
513527
input = 0.01 * np.arange(0, np.multiply.reduce(INPUT_SHAPE), dtype=np.float32).reshape(INPUT_SHAPE) + 0.02
514528
input = self.to_tensor(input)
515529
dataset = Dataset([input] * 2, self.get_transform_func())
@@ -526,7 +540,7 @@ def test_awq_scale_reference(
526540
)
527541
assert spy_instance is not None
528542
for node_name, scales in spy_instance._scale_per_target_node.items():
529-
ref = test_awq_scale_ref[node_name]
543+
ref = test_awq_scale_ref[is_3d_weights][node_name]
530544
assert fns.allclose(scales, ref)
531545
assert scales.shape == ref.shape
532546

@@ -652,14 +666,15 @@ def test_group_size_fallback_modes(
652666
f"Expected {ref_num_group_sizes} group size values, but got {num_group_sizes}."
653667
)
654668

655-
@pytest.mark.parametrize("dataset", [None, np.ones([1, 8, 8], dtype=np.float32)])
669+
@pytest.mark.parametrize("is_3d_weights", [True, False])
670+
@pytest.mark.parametrize("dataset", [None, np.ones([2, 8, 8], dtype=np.float32)])
656671
@pytest.mark.parametrize("prefer_data_aware_scaling", [True, False])
657-
def test_data_free_awq(self, dataset, prefer_data_aware_scaling, mocker):
658-
input_data = np.ones([1, 8, 8], dtype=np.float32)
672+
def test_data_free_awq(self, dataset, prefer_data_aware_scaling, is_3d_weights, mocker):
673+
input_data = np.ones([2, 8, 8], dtype=np.float32)
659674

660675
n_layers = 8
661676
n_awq_target = n_layers - 1 # first MatMul is always int8
662-
model = self.get_awq_act_model(True, n_layers)
677+
model = self.get_awq_act_model(is_3d_weights, True, n_layers)
663678
model = self.wrap_model(model, input_data)
664679

665680
if dataset is not None:
@@ -778,7 +793,9 @@ def test_process_stats(self, case: ProcessStatsTestCase):
778793

779794
@staticmethod
780795
@abstractmethod
781-
def get_transposable_awq_model(transpose_a: bool, transpose_b: bool, input_shape=None) -> TModel:
796+
def get_transposable_awq_model(
797+
transpose_a: bool, transpose_b: bool, input_shape=None, is_3d_weights: bool = False
798+
) -> TModel:
782799
"Returns a backend model for test_compression_with_transpose."
783800

784801
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)