Skip to content

Commit 09246b2

Browse files
Support 3D Weights in SE Algorithm (#3706)
### Changes <!--- What was changed (briefly), how to reproduce (if applicable), what the reviewers should focus on --> ### Reason for changes <!--- Why should the change be applied --> ### Related tickets 175212 ### Tests PR Performance Job: post_training_weight_compression_performance - 57 Develop Branch Performance Job: post_training_weight_compression_performance - 58 WC Conformance Test: https://github.com/openvinotoolkit/nncf/actions/runs/19161281090: Pass Model: Qwen/Qwen3-30B-A3B NNCF Backend: OpenVINO Higher is better. Task: gsm8k Limit: 100 Max New Tokens: 10000 OpenVINO version: 2026.0.0.dev20251111 (with WA for 176465) n-shots: 5(default) Precision Type | Filter | Value | Stderr -- | -- | -- | -- INT4 SYM Per-Channel (with Scale estimation) Calibrated on wikitext with 128 samples | flexible-extract | 0.66 | 0.0476   | strict-match | 0.38 | 0.0488 INT4 SYM Per-Channel | flexible-extract | 0.77 | 0.0423   | strict-match | 0.28 | 0.0451 FP16 | flexible-extract | 0.91 | 0.0288   | strict-match | 0.86 | 0.0349   WWB Results with Reasoning Disabled:  INT4 Sym Per-Channel: 0.826173(vs FP16) INT4 Sym Per-Channel with SE: 0.938537(vs FF16) Model: openai/gpt-oss-20b NNCF Backend: Torch Higher is better. ``` time \ accelerate launch -m lm_eval --model hf \ --model_args "{\"pretrained\":\"${MODEL_DIR}\",\"enable_thinking\":false}" \ --tasks gsm8k_cot_llama \ --fewshot_as_multiturn \ --apply_chat_template=True \ --device cuda \ --limit 100 \ --output_path $EXP_DIR \ --gen_kwargs max_new_tokens=1024,temperature=0.6,top_p=0.95,top_k=20 \ ``` gpt-oss-20b | strict-match | flexible-extract | -- | -- | -- | bf16 | 0.96 | 0.96 |   int4_sym_gs32_experts_int8_the_rest | 0.96 | 0.94 |   int4_sym_gs32_int8_the_rest | 0.64 | 0.96 |   int4_sym_gs32_int8_the_rest_SE_32tokens_128samples_64se_samples | 0.79 | 0.96 |   int4_sym_gs32_int8_the_rest_SE_128tokens_128samples_64se_samples | 0.81 | 0.96 |   int4_sym_gs32_int8_the_rest_SE_half_128tokens_128samples_64se_samples | 0.83 | 0.97 |   int4_sym_gs32_int8_the_rest_SE_half_256tokens_256samples _64se_samples | 0.79 | 0.95 | int4_sym_gs32_int8_the_rest_SE_256tokens_256samples_64se_samples | 0.90 | 0.96 |   int4_sym_gs32_int8_the_rest_SE_256tokens_256samples_256se_samples | 0.95 | 0.94 | --------- Co-authored-by: Daniil Lyakhov <[email protected]>
1 parent 6d28a5d commit 09246b2

File tree

12 files changed

+428
-48
lines changed

12 files changed

+428
-48
lines changed

src/nncf/onnx/graph/node_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:
138138
transpose = node.layer_attributes.node_attrs[trans_attr]
139139
# 0 - (M, K), 1 - (K, N)
140140
weight_channel_axis = -1 - port_id if transpose else -2 + port_id
141+
if node.metatype == om.ONNXMatMulMetatype:
142+
weight_channel_axis = -1 - port_id if port_id == 0 else -2 + port_id
141143
return weight_channel_axis
142144

143145

src/nncf/quantization/algorithms/weight_compression/activation_stats.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,34 @@ def process_stats(stats: WCTensorStatistic, subset_size: int) -> tuple[Tensor, T
2727
s - maximum channel magnitude across samples [HiddenDim]
2828
X - average channel magnitude across tokens in the sequence [HiddenDim, min(SampleSize, ~subset_size)]
2929
"""
30-
X = fns.stack(stats.mean_values) # [SampleSize, HiddenDim]
31-
X_full = fns.transpose(X) # [HiddenDim, SampleSize]
30+
X = fns.stack(
31+
stats.mean_values
32+
) # [SampleSize, HiddenDim] for 2-D or [SampleSize, No. of Experts, HiddenDim] for 3-D
3233

33-
# prevent high memory and time consumption
34-
if X_full.shape[1] > subset_size:
35-
# activations were reduced across all but the last dimension
34+
# Move SampleSize to the last axis: [HiddenDim, SampleSize] or [No. of Experts, HiddenDim, SampleSize]
35+
# General approach: move axis 0 to the end
36+
axes = list(range(1, len(X.shape))) + [0]
37+
X_full = fns.transpose(X, axes=axes)
38+
39+
# The sample dimension is always the last axis after transpose
40+
sample_axis = -1
41+
42+
# Prevent high memory and time consumption by sampling
43+
if X_full.shape[sample_axis] > subset_size:
44+
# Activations were reduced across all but the last dimension
3645
lens = [reduce(mul, shape[:-1], 1) for shape in stats.shape_values]
37-
step = X_full.shape[1] // subset_size
38-
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
39-
X = X_full[:, idxs] # [HiddenDim, ~SubsetSize]
46+
step = X_full.shape[sample_axis] // subset_size
47+
sorted_idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
48+
idxs = [idx for idx in sorted_idxs if idx < X_full.shape[sample_axis]][:subset_size]
49+
50+
# Create index slices for all dimensions except the last one
51+
# This works for both 2D and 3D (and theoretically any dimensionality)
52+
index_slices = [slice(None)] * (len(X_full.shape) - 1) + [idxs]
53+
X = X_full[tuple(index_slices)]
4054
else:
4155
X = X_full
42-
s = fns.max(fns.abs(X_full), axis=1) # [HiddenDim]
56+
57+
# Compute max magnitude along the sample axis (last axis)
58+
# Result: [HiddenDim] or [No. of Experts, HiddenDim]
59+
s = fns.max(fns.abs(X_full), axis=sample_axis)
4360
return s, X

src/nncf/quantization/algorithms/weight_compression/algorithm.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from collections import OrderedDict
1515
from collections import defaultdict
1616
from functools import reduce
17-
from typing import Any, Iterable, Optional, TypeVar
17+
from typing import Any, Optional, TypeVar
18+
19+
from packaging import version
1820

1921
import nncf
2022
from nncf import Dataset
@@ -786,6 +788,14 @@ def is_weight_compression_supported(
786788

787789
return is_supported_dtype and not no_bit_reduction
788790

791+
def _maybe_get_ov_major_version(self) -> Optional[str]:
792+
try:
793+
import openvino as ov
794+
795+
return ov.__version__.split(".")[0]
796+
except Exception:
797+
return None
798+
789799
def get_weight_compression_parameters(
790800
self,
791801
model: TModel,
@@ -851,6 +861,22 @@ def get_weight_compression_parameters(
851861
f"node name: {node.node_name}. The node will be in {self._backup_mode} mode."
852862
)
853863

864+
model_backend = get_backend(model)
865+
ov_version = self._maybe_get_ov_major_version()
866+
if (
867+
model_backend == BackendType.OPENVINO
868+
and len(weight_shape) == 3
869+
and ov_version
870+
and version.parse(ov_version) <= version.parse("2026")
871+
and node.metatype in self._backend_entity.matmul_metatypes
872+
):
873+
# MoE operations are usually matmuls, so the check for matmul metatype is done
874+
# This is to avoid raising the error for non-MoE cases with 3D weights.
875+
msg = f"""NNCF does not support 3D weights with current version of Openvino {ov_version}
876+
due to a known issue in statistics collection Ticket - 176465
877+
Node with weight: {node.node_name}"""
878+
raise nncf.UnsupportedModelError(msg)
879+
854880
if self._backup_mode != BackupMode.NONE:
855881
mode = (
856882
CompressWeightsMode.INT8_ASYM
@@ -899,7 +925,7 @@ def get_weight_compression_parameters(
899925
matmul_nodes_to_compress, graph
900926
)
901927
if statistic_points is None:
902-
statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys())
928+
statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map)
903929
statistic_points = self._collect_statistics(dataset, graph, model, statistic_points)
904930
statistics = self._get_statistics_for_weights_compression(
905931
matmul_input_to_output_nodes_map, statistic_points
@@ -1089,28 +1115,46 @@ def get_statistic_points(
10891115
self,
10901116
model: TModel,
10911117
graph: NNCFGraph,
1092-
nodes_and_port_ids: Iterable[tuple[NNCFNode, int]],
1118+
matmul_input_to_output_nodes_map: dict[tuple[NNCFNode, int], list[NNCFNode]],
10931119
) -> StatisticPointsContainer:
10941120
"""
10951121
Returns statistic points, for which StatisticsCollector should collect statistics.
10961122
10971123
:param model: Model for statistics collection.
10981124
:param graph: Model graph.
1099-
:param nodes_and_port_ids: Nodes and port ids for which statistics should be collected.
1125+
:param matmul_input_to_output_nodes_map: A mapping from activation node and a port id to corresponding matmul
1126+
nodes which accept this activation as an input.
11001127
:return: Statistic points, for which StatisticsCollector should collect statistics.
11011128
"""
11021129
statistic_container = StatisticPointsContainer()
1130+
11031131
# Statistics for data aware algorithms
11041132
if self._data_aware_compression:
1105-
for node, output_port_id in nodes_and_port_ids:
1133+
for (node, output_port_id), node_with_weights in matmul_input_to_output_nodes_map.items():
11061134
statistic_point = self._backend_entity.target_point(
11071135
TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id
11081136
)
1109-
# Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden
1110-
# size dimension.
1137+
all_weight_dims = []
1138+
for node_with_weight in node_with_weights:
1139+
_, weight_port_ids = zip(
1140+
*self._backend_entity.get_weight_names_and_port_ids(node_with_weight, graph)
1141+
)
1142+
weight_dims = [
1143+
len(self._backend_entity.get_weight_shape(node_with_weight, weight_port_id, graph))
1144+
for weight_port_id in weight_port_ids
1145+
]
1146+
all_weight_dims.extend(weight_dims)
1147+
1148+
# by default, reduce activations across all but the last dimension. The last dimension is
1149+
# assumed to be the hidden size dimension.
11111150
n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape)
1151+
reduction_axes = tuple(range(n_dims - 1))
1152+
1153+
# For 3D weights, hidden dimension is the second dimension. Reduce by all other dimensions
1154+
reduction_axes = (1,) if any(weight_dim == 3 for weight_dim in all_weight_dims) else reduction_axes
1155+
11121156
stat_collector = self._backend_entity.mean_statistic_collector(
1113-
reduction_axes=tuple(range(n_dims - 1)), subset_size=self._subset_size
1157+
reduction_axes=reduction_axes, subset_size=self._subset_size
11141158
)
11151159
statistic_container.add_statistic_point(
11161160
StatisticPoint(
@@ -1120,7 +1164,7 @@ def get_statistic_points(
11201164
# Statistics for mixed precision algorithm
11211165
if self._data_aware_mixed_precision:
11221166
mixed_precision_statistics = self._mixed_precision_algo.get_statistic_points(
1123-
model, graph, nodes_and_port_ids
1167+
model, graph, matmul_input_to_output_nodes_map.keys()
11241168
)
11251169
for points in mixed_precision_statistics.values():
11261170
for point in points:

src/nncf/quantization/algorithms/weight_compression/onnx_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def is_node_with_weights(node: NNCFNode, graph: NNCFGraph) -> bool:
144144
def get_reduction_axes(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> Optional[tuple[int]]:
145145
channel_axes = (get_weight_quantization_axis(node_with_weight, weight_port_id),)
146146
const_shape = node_with_weight.layer_attributes.weight_attrs[weight_port_id]["shape"]
147+
# Everything remains the same, except when 3D weights, reduce by batch dimension also.
148+
if len(const_shape) == 3:
149+
channel_axes = (0,) + channel_axes
147150
return get_reduction_axes(channel_axes, const_shape)
148151

149152
@staticmethod

src/nncf/quantization/algorithms/weight_compression/scale_estimation.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,15 @@ def calculate_quantization_params(
196196
X = X.astype(TensorDataType.float32)
197197
weight = weight.astype(TensorDataType.float32)
198198
eps = fns.finfo(weight).eps
199+
is_3d_weight = len(weight.shape) == 3
199200

200201
was_transposed = False
201-
if reduction_axis == 0:
202-
weight = fns.transpose(weight)
203-
reduction_axis = 1
202+
if reduction_axis == 0 or (reduction_axis == 1 and is_3d_weight):
203+
# Weights
204+
# 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
205+
# 2D: [hidden_dimension, out_features] -> [out_features, hidden_dimension]
206+
weight = fns.moveaxis(weight, -1, -2)
207+
reduction_axis = weight.ndim - 1
204208
was_transposed = True
205209

206210
group_size = config.group_size if config.group_size != -1 else weight.shape[reduction_axis]
@@ -220,7 +224,7 @@ def calculate_quantization_params(
220224
if zp is not None:
221225
zp = zp.astype(scale.dtype)
222226

223-
s = fns.unsqueeze(s, 0)
227+
s = fns.unsqueeze(s, -2)
224228
s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, group_size)
225229

226230
original_weight, _ = reshape_weight_for_grouped_quantization(original_weight, reduction_axis, group_size)
@@ -233,18 +237,20 @@ def calculate_quantization_params(
233237
importance = fns.where(zero_mask, 0.0, importance)
234238

235239
# normalize importances for every group of weights to make sum of them equal to 1.0
236-
denum = fns.sum(importance, axis=2, keepdims=True)
240+
denum = fns.sum(importance, axis=-1, keepdims=True)
237241
importance = importance / (denum + eps)
238242

239-
X, _ = reshape_weight_for_grouped_quantization(X, 0, group_size)
243+
X, _ = reshape_weight_for_grouped_quantization(X, -2, group_size)
240244
best_diffs = None
241245
result_scale = None
242-
fp_outs = fns.matmul(fns.transpose(original_weight, (1, 0, 2)), X)
243-
q_outs = fns.matmul(fns.transpose(q_weights, (1, 0, 2)), X)
246+
fp_outs = fns.matmul(fns.moveaxis(original_weight, -2, -3), X)
247+
q_outs = fns.matmul(fns.moveaxis(q_weights, -2, -3), X)
244248

245249
# metric for minimization with shape [C_OUT, N_GROUPS], N_GROUPS = C_IN / GROUP_SIZE
250+
# For 3D weights, it is [Batch Size, C_OUT, N_GROUPS]
246251
min_max_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
247-
min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0))
252+
min_max_scale_diffs = fns.moveaxis(min_max_scale_diffs, -1, -2)
253+
248254
if weight_penalty > 0.0:
249255
min_max_scale_diffs += weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1)
250256

@@ -272,10 +278,10 @@ def calculate_quantization_params(
272278
)
273279

274280
q_weights_ = fns.zeros_like(original_weight) + out
275-
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)
281+
q_outs = fns.matmul(fns.moveaxis(q_weights_, -2, -3), X)
276282

277283
ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
278-
ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0))
284+
ideal_scale_diffs = fns.moveaxis(ideal_scale_diffs, -1, -2)
279285
if weight_penalty > 0.0:
280286
ideal_scale_diffs += weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1)
281287

@@ -286,7 +292,7 @@ def calculate_quantization_params(
286292

287293
best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs
288294

289-
mask = fns.unsqueeze(mask, axis=2)
295+
mask = fns.unsqueeze(mask, axis=-1)
290296

291297
if result_scale is None:
292298
near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale
@@ -340,17 +346,17 @@ def calculate_quantization_params(
340346
)
341347
q_weights_ = fns.zeros_like(original_weight) + out
342348

343-
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)
349+
q_outs = fns.matmul(fns.moveaxis(q_weights_, -2, -3), X)
344350
ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
345-
ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0))
351+
ideal_scale_diffs = fns.moveaxis(ideal_scale_diffs, -1, -2)
346352
if weight_penalty > 0.0:
347353
ideal_scale_diffs += weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1)
348354

349355
mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype)
350356

351357
best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs
352358

353-
mask = fns.unsqueeze(mask, axis=2)
359+
mask = fns.unsqueeze(mask, axis=-1)
354360

355361
if result_scale is None:
356362
near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale
@@ -359,19 +365,19 @@ def calculate_quantization_params(
359365
result_scale = near_to_ideal_scale
360366

361367
if config.group_size == -1:
362-
result_scale = fns.squeeze(result_scale, axis=1)
368+
result_scale = fns.squeeze(result_scale, axis=-2)
363369
if zp is not None and config.group_size == -1:
364-
zp = fns.squeeze(zp, axis=1)
370+
zp = fns.squeeze(zp, axis=-2)
365371

366372
if was_transposed:
367373
if config.group_size == -1:
368-
result_scale = fns.transpose(result_scale)
374+
result_scale = fns.moveaxis(result_scale, -1, -2)
369375
if zp is not None:
370-
zp = fns.transpose(zp)
376+
zp = fns.moveaxis(zp, -1, -2)
371377
else:
372-
result_scale = fns.transpose(result_scale, axes=(1, 2, 0))
378+
result_scale = fns.moveaxis(result_scale, (-1, -2, -3), (-2, -3, -1))
373379
if zp is not None:
374-
zp = fns.transpose(zp, axes=(1, 2, 0))
380+
zp = fns.moveaxis(zp, (-1, -2, -3), (-2, -3, -1))
375381

376382
return result_scale, zp
377383

@@ -421,5 +427,5 @@ def estimate_scales(weight: Tensor, target: Tensor, zero_mask: Tensor, importanc
421427
"""
422428
ideal_scale = fns.abs(weight) / (fns.abs(target) + zero_mask)
423429
weighted_scale = ideal_scale * importance
424-
near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)
430+
near_to_ideal_scale = fns.sum(weighted_scale, axis=-1, keepdims=True)
425431
return near_to_ideal_scale

src/nncf/quantization/statistics_caching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def register_statistics_for_algorithm(
3939
:param matmul_input_to_output_nodes_map: A dictionary mapping from a tuple of (activation node, port ID)
4040
to a list of MatMul nodes that accept the activation as input.
4141
"""
42-
statistic_points = compression_algo.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys())
42+
statistic_points = compression_algo.get_statistic_points(model, graph, matmul_input_to_output_nodes_map)
4343
aggregator.register_statistic_points(statistic_points)
4444

4545

tests/cross_fw/test_templates/template_test_weights_compression.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,20 +231,40 @@ def get_model_for_test_scale_estimation() -> TModel:
231231
Returns a backend model for test_scale_estimation.
232232
"""
233233

234+
@staticmethod
235+
@abstractmethod
236+
def get_moe_model_for_test_scale_estimation() -> TModel:
237+
"""
238+
Returns a backend MoE model for test_scale_estimation with 3D weights.
239+
"""
240+
241+
@staticmethod
242+
@abstractmethod
243+
def get_moe_scale_estimation_ref() -> TTensor:
244+
"""
245+
Returns the reference output of calculate_quantization_params for MoE model.
246+
"""
247+
234248
@staticmethod
235249
@abstractmethod
236250
def get_scale_estimation_ref() -> TTensor:
237251
"""
238252
Returns the reference output of calculate_quantization_params of ScaleEstimation.
239253
"""
240254

241-
def test_scale_estimation(self, mocker):
255+
@pytest.mark.parametrize("is_moe", [False, True])
256+
def test_scale_estimation(self, mocker, is_moe):
242257
"""Checks that scales match the reference."""
243258
calc_q_params_spy = mocker.spy(ScaleEstimation, "calculate_quantization_params")
244-
model = self.get_model_for_test_scale_estimation()
259+
260+
if is_moe:
261+
model = self.get_moe_model_for_test_scale_estimation()
262+
input = np.arange(0, 2 * 4 * 8, dtype=np.float32).reshape(2, 4, 8)
263+
else:
264+
model = self.get_model_for_test_scale_estimation()
265+
input = np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8)
245266

246267
# prepare dataset with one input tensor
247-
input = np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8)
248268
input = self.to_tensor(input)
249269
dataset = Dataset([input], self.get_transform_func())
250270

@@ -258,8 +278,15 @@ def test_scale_estimation(self, mocker):
258278
all_layers=True,
259279
dataset=dataset,
260280
)
261-
reference = self.get_scale_estimation_ref()
262-
assert fns.allclose(Tensor(reference), calc_q_params_spy.spy_return[0])
281+
282+
computed_scale = calc_q_params_spy.spy_return[0]
283+
284+
if is_moe:
285+
reference = self.get_moe_scale_estimation_ref()
286+
else:
287+
reference = self.get_scale_estimation_ref()
288+
289+
assert fns.allclose(Tensor(reference), computed_scale)
263290

264291
@staticmethod
265292
@abstractmethod
@@ -328,6 +355,7 @@ def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(self, int
328355
model = self.get_awq_act_model(with_multiply, n_layers)
329356

330357
dataset = Dataset([self.to_tensor(np.ones([1, 8, 8], dtype=np.float32))], self.get_transform_func())
358+
331359
with SpyWeightCompressionStatisticsContext(mocker):
332360
model = compress_weights(model, mode=int4_mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)
333361

0 commit comments

Comments
 (0)