Skip to content

Commit ef70ad0

Browse files
committed
fix #2
1 parent 1fde10f commit ef70ad0

File tree

1 file changed

+3
-3
lines changed
  • src/nncf/quantization/algorithms/weight_compression

1 file changed

+3
-3
lines changed

src/nncf/quantization/algorithms/weight_compression/awq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
217217
s = s.astype(TensorDataType.float32)
218218
X = X.astype(TensorDataType.float32)
219219

220-
is_3d_weight = weight.ndim == 3
221220
is_2d_weight = weight.ndim == 2
222221

223222
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
@@ -234,6 +233,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
234233
s = fns.unsqueeze(s, 0) # [hidden_dim] -> [1, hidden_dim]
235234
X = fns.unsqueeze(X, 0) # [hidden_dim, samples] -> [1, hidden_dim, samples]
236235
weight = fns.unsqueeze(weight, 0) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
236+
reduction_axis += 1
237237

238238
top_k = max(int(s.shape[-1] * self._percent_to_apply), 1)
239239
topk_idxs = fns.argsort(-s)[:, :top_k]
@@ -251,7 +251,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
251251

252252
groups_to_correct = list(groups_to_correct)
253253

254-
if reduction_axis == 0 or (reduction_axis == 1 and is_3d_weight):
254+
if reduction_axis == 1:
255255
# Weights
256256
# 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
257257
# 2D: [1, hidden_dimension, out_features] -> [1, out_features, hidden_dimension]
@@ -298,7 +298,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
298298
magnitudes,
299299
threshold,
300300
cur_scale,
301-
prev_w[offset : offset + group_size]
301+
prev_w[expert_idx, offset : offset + group_size]
302302
* prev_s
303303
* prev_weight.shape[reduction_axis]
304304
/ threshold,

0 commit comments

Comments
 (0)