Skip to content

Commit 3122842

Browse files
committed
enable awq
1 parent c7afa57 commit 3122842

File tree

1 file changed

+43
-28
lines changed
  • src/nncf/quantization/algorithms/weight_compression

1 file changed

+43
-28
lines changed

src/nncf/quantization/algorithms/weight_compression/awq.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def apply(
183183
prev_statistics = statistics[merge_node.node_name]
184184
scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics)
185185

186-
w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0])
187-
a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0])
186+
w_scale = fns.unsqueeze(scale, -1 - wp.reduction_axes[0])
187+
a_scale = fns.unsqueeze(1.0 / scale, -wp.reduction_axes[0])
188188

189189
scaled_weight = (weight * w_scale).astype(weight_dtype)
190190
self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight)
@@ -194,9 +194,9 @@ def apply(
194194
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
195195
merge_weight = (merge_weight * a_scale).astype(weight_dtype)
196196
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
197-
a_scale = fns.transpose(a_scale)
197+
a_scale = fns.moveaxis(a_scale, -1, -2)
198198
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
199-
a_scale = fns.transpose(a_scale).astype(weight_dtype)
199+
a_scale = fns.moveaxis(a_scale, -1, -2).astype(weight_dtype)
200200
next_nodes = graph.get_next_nodes(merge_node)
201201
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
202202
scale_insertion_command = self._backend_entity.scale_insertion_command(
@@ -217,49 +217,63 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
217217
s = s.astype(TensorDataType.float32)
218218
X = X.astype(TensorDataType.float32)
219219

220+
is_2d_weight = weight.ndim == 2
221+
220222
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
221223
reduction_axis = wp.reduction_axes[0]
222224

223225
prev_s, prev_w = None, None
224226
if prev_statistics is not None and prev_weight is not None:
225227
prev_s, _ = process_stats(prev_statistics, self._subset_size)
226228
prev_s = prev_s.astype(TensorDataType.float32).max().item()
227-
prev_w = fns.mean(fns.abs(prev_weight), axis=reduction_axis)
229+
prev_weight = fns.unsqueeze(prev_weight, 0) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
230+
prev_w = fns.mean(fns.abs(prev_weight), axis=reduction_axis + 1)
231+
232+
if is_2d_weight:
233+
s = fns.unsqueeze(s, 0) # [hidden_dim] -> [1, hidden_dim]
234+
X = fns.unsqueeze(X, 0) # [hidden_dim, samples] -> [1, hidden_dim, samples]
235+
weight = fns.unsqueeze(weight, 0) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
236+
reduction_axis += 1
228237

229-
top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
230-
topk_idxs = fns.argsort(-s)[:top_k]
238+
top_k = max(int(s.shape[-1] * self._percent_to_apply), 1)
239+
topk_idxs = fns.argsort(-s)[:, :top_k]
231240

232241
group_size = config.group_size
233242
if group_size == -1:
234-
group_size = s.shape[0]
243+
group_size = s.shape[-1]
235244

236245
groups_to_correct = set()
237-
for idx in topk_idxs:
238-
groups_to_correct.add(idx.data // group_size)
246+
for expert_idx in range(topk_idxs.shape[0]):
247+
for k_idx in range(topk_idxs.shape[1]):
248+
idx = topk_idxs[expert_idx, k_idx].item()
249+
group_idx = idx // group_size
250+
groups_to_correct.add((expert_idx, group_idx))
239251

240252
groups_to_correct = list(groups_to_correct)
241253

242-
if reduction_axis == 0:
243-
weight = fns.transpose(weight)
244-
reduction_axis = 1
254+
if reduction_axis == 1:
255+
# Weights
256+
# 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
257+
# 2D: [1, hidden_dimension, out_features] -> [1, out_features, hidden_dimension]
258+
weight = fns.moveaxis(weight, -1, -2)
259+
reduction_axis = weight.ndim - 1
245260

246-
shape_vector = fns.mean(X, axis=1)
261+
shape_vector = fns.mean(X, axis=-1)
247262
scale = fns.ones_like(shape_vector)
248263

249264
awq_config = deepcopy(config)
250265
awq_config.group_size = -1
251266

252-
for gi in groups_to_correct:
267+
for expert_idx, gi in groups_to_correct:
253268
offset = gi * group_size
254-
gscale = s[offset : offset + group_size]
269+
gscale = s[expert_idx, offset : offset + group_size]
270+
gweight = weight[expert_idx, :, offset : offset + group_size]
271+
gacts = X[expert_idx, offset : offset + group_size, :]
255272

256273
a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32)
257274
a_max = 1e2
258275
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)
259276

260-
gweight = weight[:, offset : offset + group_size]
261-
gacts = X[offset : offset + group_size, :]
262-
263277
fp32_out = fns.matmul(gweight, gacts)
264278
min_diff = fns.max(fns.abs(fp32_out))
265279
best_scale = None
@@ -275,28 +289,26 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
275289
# per channel magnitudes for the previous MatMul
276290
# mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis]
277291
magnitudes = (
278-
(prev_w[offset : offset + group_size] / cur_scale) * prev_s * prev_weight.shape[reduction_axis]
292+
(prev_w[expert_idx, offset : offset + group_size] / cur_scale)
293+
* prev_s
294+
* prev_weight.shape[reduction_axis]
279295
)
280296
if magnitudes.max() >= threshold:
281297
cur_scale = AWQ._clamp_scale(
282298
magnitudes,
283299
threshold,
284300
cur_scale,
285-
prev_w[offset : offset + group_size]
301+
prev_w[expert_idx, offset : offset + group_size]
286302
* prev_s
287303
* prev_weight.shape[reduction_axis]
288304
/ threshold,
289305
)
290306

291307
weights_to_fake_quantize = gweight * cur_scale
292308
if not config.is_integer:
293-
g_decompressed_weighs = float_quantize_dequantize_weight(
294-
weights_to_fake_quantize, awq_config, reduction_axis
295-
)
309+
g_decompressed_weighs = float_quantize_dequantize_weight(weights_to_fake_quantize, awq_config, -1)
296310
else:
297-
g_decompressed_weighs = integer_quantize_dequantize_weight(
298-
weights_to_fake_quantize, awq_config, reduction_axis
299-
)
311+
g_decompressed_weighs = integer_quantize_dequantize_weight(weights_to_fake_quantize, awq_config, -1)
300312
sacts = gacts / fns.unsqueeze(cur_scale, 1)
301313

302314
cur_out = fns.matmul(g_decompressed_weighs, sacts)
@@ -307,7 +319,10 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
307319
alpha += alpha_step
308320

309321
if best_scale is not None:
310-
scale.data[offset : offset + group_size] = best_scale.data
322+
scale.data[expert_idx, offset : offset + group_size] = best_scale.data
323+
324+
if is_2d_weight:
325+
scale = fns.squeeze(scale, 0) # [1, hidden_dim] -> [hidden_dim]
311326

312327
return scale
313328

0 commit comments

Comments
 (0)