Skip to content

Commit dc5cf33

Browse files
committed
init
1 parent 5a71a80 commit dc5cf33

File tree

1 file changed

+38
-25
lines changed
  • src/nncf/quantization/algorithms/weight_compression

1 file changed

+38
-25
lines changed

src/nncf/quantization/algorithms/weight_compression/awq.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def apply(
183183
prev_statistics = statistics[merge_node.node_name]
184184
scale = self._data_aware_step(wp, weight, statistics[k], prev_weight, prev_statistics)
185185

186-
w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0])
187-
a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0])
186+
w_scale = fns.unsqueeze(scale, -wp.reduction_axes[0])
187+
a_scale = fns.unsqueeze(1.0 / scale, -wp.reduction_axes[0])
188188

189189
scaled_weight = (weight * w_scale).astype(weight_dtype)
190190
self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight)
@@ -194,9 +194,9 @@ def apply(
194194
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
195195
merge_weight = (merge_weight * a_scale).astype(weight_dtype)
196196
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
197-
a_scale = fns.transpose(a_scale)
197+
a_scale = fns.moveaxis(a_scale, -1, -2)
198198
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
199-
a_scale = fns.transpose(a_scale).astype(weight_dtype)
199+
a_scale = fns.moveaxis(a_scale, -1, -2).astype(weight_dtype)
200200
next_nodes = graph.get_next_nodes(merge_node)
201201
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
202202
scale_insertion_command = self._backend_entity.scale_insertion_command(
@@ -217,6 +217,15 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
217217
s = s.astype(TensorDataType.float32)
218218
X = X.astype(TensorDataType.float32)
219219

220+
is_3d_weight = weight.ndim == 3
221+
if s.ndim == 1:
222+
s = fns.unsqueeze(s, 0) # [hidden_dim] -> [1, hidden_dim]
223+
X = fns.unsqueeze(X, 0) # [hidden_dim, samples] -> [1, hidden_dim, samples]
224+
weight = fns.unsqueeze(weight, 0) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
225+
is_2d_weight = True
226+
else:
227+
is_2d_weight = False
228+
220229
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
221230
reduction_axis = wp.reduction_axes[0]
222231

@@ -226,40 +235,45 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
226235
prev_s = prev_s.astype(TensorDataType.float32).max().item()
227236
prev_w = fns.mean(fns.abs(prev_weight), axis=reduction_axis)
228237

229-
top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
230-
topk_idxs = fns.argsort(-s)[:top_k]
238+
top_k = max(int(s.shape[-1] * self._percent_to_apply), 1)
239+
topk_idxs = fns.argsort(-s)[:, :top_k]
231240

232241
group_size = config.group_size
233242
if group_size == -1:
234-
group_size = s.shape[0]
243+
group_size = s.shape[-1]
235244

236245
groups_to_correct = set()
237-
for idx in topk_idxs:
238-
groups_to_correct.add(idx.data // group_size)
246+
for expert_idx in range(topk_idxs.shape[0]):
247+
for k_idx in range(topk_idxs.shape[1]):
248+
idx = topk_idxs[expert_idx, k_idx].item()
249+
group_idx = idx // group_size
250+
groups_to_correct.add((expert_idx, group_idx))
239251

240252
groups_to_correct = list(groups_to_correct)
241253

242-
if reduction_axis == 0:
243-
weight = fns.transpose(weight)
244-
reduction_axis = 1
254+
if reduction_axis == 0 or (reduction_axis == 1 and is_3d_weight):
255+
# Weights
256+
# 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
257+
# 2D: [1, hidden_dimension, out_features] -> [1, out_features, hidden_dimension]
258+
weight = fns.moveaxis(weight, -1, -2)
259+
reduction_axis = weight.ndim - 1
245260

246-
shape_vector = fns.mean(X, axis=1)
261+
shape_vector = fns.mean(X, axis=-1)
247262
scale = fns.ones_like(shape_vector)
248263

249264
awq_config = deepcopy(config)
250265
awq_config.group_size = -1
251266

252-
for gi in groups_to_correct:
267+
for expert_idx, gi in groups_to_correct:
253268
offset = gi * group_size
254-
gscale = s[offset : offset + group_size]
269+
gscale = s[expert_idx, offset : offset + group_size]
270+
gweight = weight[expert_idx, :, offset : offset + group_size]
271+
gacts = X[expert_idx, offset : offset + group_size, :]
255272

256273
a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32)
257274
a_max = 1e2
258275
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)
259276

260-
gweight = weight[:, offset : offset + group_size]
261-
gacts = X[offset : offset + group_size, :]
262-
263277
fp32_out = fns.matmul(gweight, gacts)
264278
min_diff = fns.max(fns.abs(fp32_out))
265279
best_scale = None
@@ -290,13 +304,9 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
290304

291305
weights_to_fake_quantize = gweight * cur_scale
292306
if not config.is_integer:
293-
g_decompressed_weighs = float_quantize_dequantize_weight(
294-
weights_to_fake_quantize, awq_config, reduction_axis
295-
)
307+
g_decompressed_weighs = float_quantize_dequantize_weight(weights_to_fake_quantize, awq_config, -1)
296308
else:
297-
g_decompressed_weighs = integer_quantize_dequantize_weight(
298-
weights_to_fake_quantize, awq_config, reduction_axis
299-
)
309+
g_decompressed_weighs = integer_quantize_dequantize_weight(weights_to_fake_quantize, awq_config, -1)
300310
sacts = gacts / fns.unsqueeze(cur_scale, 1)
301311

302312
cur_out = fns.matmul(g_decompressed_weighs, sacts)
@@ -307,7 +317,10 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
307317
alpha += alpha_step
308318

309319
if best_scale is not None:
310-
scale.data[offset : offset + group_size] = best_scale.data
320+
scale.data[expert_idx, offset : offset + group_size] = best_scale.data
321+
322+
if is_2d_weight:
323+
scale = fns.squeeze(scale, 0) # [1, hidden_dim] -> [hidden_dim]
311324

312325
return scale
313326

0 commit comments

Comments
 (0)