Skip to content

Commit bea7346

Browse files
committed
fix prev weights case
1 parent dc5cf33 commit bea7346

File tree

1 file changed

+11
-9
lines changed
  • src/nncf/quantization/algorithms/weight_compression

1 file changed

+11
-9
lines changed

src/nncf/quantization/algorithms/weight_compression/awq.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
218218
X = X.astype(TensorDataType.float32)
219219

220220
is_3d_weight = weight.ndim == 3
221-
if s.ndim == 1:
222-
s = fns.unsqueeze(s, 0) # [hidden_dim] -> [1, hidden_dim]
223-
X = fns.unsqueeze(X, 0) # [hidden_dim, samples] -> [1, hidden_dim, samples]
224-
weight = fns.unsqueeze(weight, 0) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
225-
is_2d_weight = True
226-
else:
227-
is_2d_weight = False
221+
is_2d_weight = weight.ndim == 2
228222

229223
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
230224
reduction_axis = wp.reduction_axes[0]
@@ -233,7 +227,13 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
233227
if prev_statistics is not None and prev_weight is not None:
234228
prev_s, _ = process_stats(prev_statistics, self._subset_size)
235229
prev_s = prev_s.astype(TensorDataType.float32).max().item()
236-
prev_w = fns.mean(fns.abs(prev_weight), axis=reduction_axis)
230+
prev_weight = fns.unsqueeze(prev_weight, 0) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
231+
prev_w = fns.mean(fns.abs(prev_weight), axis=reduction_axis + 1)
232+
233+
if is_2d_weight:
234+
s = fns.unsqueeze(s, 0) # [hidden_dim] -> [1, hidden_dim]
235+
X = fns.unsqueeze(X, 0) # [hidden_dim, samples] -> [1, hidden_dim, samples]
236+
weight = fns.unsqueeze(weight, 0) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
237237

238238
top_k = max(int(s.shape[-1] * self._percent_to_apply), 1)
239239
topk_idxs = fns.argsort(-s)[:, :top_k]
@@ -289,7 +289,9 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
289289
# per channel magnitudes for the previous MatMul
290290
# mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis]
291291
magnitudes = (
292-
(prev_w[offset : offset + group_size] / cur_scale) * prev_s * prev_weight.shape[reduction_axis]
292+
(prev_w[expert_idx, offset : offset + group_size] / cur_scale)
293+
* prev_s
294+
* prev_weight.shape[reduction_axis]
293295
)
294296
if magnitudes.max() >= threshold:
295297
cur_scale = AWQ._clamp_scale(

0 commit comments

Comments
 (0)