@@ -183,8 +183,8 @@ def apply(
183183 prev_statistics = statistics [merge_node .node_name ]
184184 scale = self ._data_aware_step (wp , weight , statistics [k ], prev_weight , prev_statistics )
185185
186- w_scale = fns .unsqueeze (scale , 1 - wp .reduction_axes [0 ])
187- a_scale = fns .unsqueeze (1.0 / scale , wp .reduction_axes [0 ])
186+ w_scale = fns .unsqueeze (scale , - 1 - wp .reduction_axes [0 ])
187+ a_scale = fns .unsqueeze (1.0 / scale , - wp .reduction_axes [0 ])
188188
189189 scaled_weight = (weight * w_scale ).astype (weight_dtype )
190190 self ._backend_entity .set_weight (wp .node_with_weight , weight_port_id , model , graph , scaled_weight )
@@ -194,9 +194,9 @@ def apply(
194194 merge_weight = self ._backend_entity .get_weight (merge_node , port_id , model , graph )
195195 merge_weight = (merge_weight * a_scale ).astype (weight_dtype )
196196 self ._backend_entity .set_weight (merge_node , port_id , model , graph , merge_weight )
197- a_scale = fns .transpose (a_scale )
197+ a_scale = fns .moveaxis (a_scale , - 1 , - 2 )
198198 else : # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
199- a_scale = fns .transpose (a_scale ).astype (weight_dtype )
199+ a_scale = fns .moveaxis (a_scale , - 1 , - 2 ).astype (weight_dtype )
200200 next_nodes = graph .get_next_nodes (merge_node )
201201 source_node_output_port = graph .get_output_edges (merge_node )[0 ].output_port_id
202202 scale_insertion_command = self ._backend_entity .scale_insertion_command (
@@ -217,49 +217,63 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
217217 s = s .astype (TensorDataType .float32 )
218218 X = X .astype (TensorDataType .float32 )
219219
220+ is_2d_weight = weight .ndim == 2
221+
220222 assert isinstance (wp .reduction_axes , tuple ) and len (wp .reduction_axes ) == 1
221223 reduction_axis = wp .reduction_axes [0 ]
222224
223225 prev_s , prev_w = None , None
224226 if prev_statistics is not None and prev_weight is not None :
225227 prev_s , _ = process_stats (prev_statistics , self ._subset_size )
226228 prev_s = prev_s .astype (TensorDataType .float32 ).max ().item ()
227- prev_w = fns .mean (fns .abs (prev_weight ), axis = reduction_axis )
229+ prev_weight = fns .unsqueeze (prev_weight , 0 ) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
230+ prev_w = fns .mean (fns .abs (prev_weight ), axis = reduction_axis + 1 )
231+
232+ if is_2d_weight :
233+ s = fns .unsqueeze (s , 0 ) # [hidden_dim] -> [1, hidden_dim]
234+ X = fns .unsqueeze (X , 0 ) # [hidden_dim, samples] -> [1, hidden_dim, samples]
235+ weight = fns .unsqueeze (weight , 0 ) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
236+ reduction_axis += 1
228237
229- top_k = max (int (s .shape [0 ] * self ._percent_to_apply ), 1 )
230- topk_idxs = fns .argsort (- s )[:top_k ]
238+ top_k = max (int (s .shape [- 1 ] * self ._percent_to_apply ), 1 )
239+ topk_idxs = fns .argsort (- s )[:, : top_k ]
231240
232241 group_size = config .group_size
233242 if group_size == - 1 :
234- group_size = s .shape [0 ]
243+ group_size = s .shape [- 1 ]
235244
236245 groups_to_correct = set ()
237- for idx in topk_idxs :
238- groups_to_correct .add (idx .data // group_size )
246+ for expert_idx in range (topk_idxs .shape [0 ]):
247+ for k_idx in range (topk_idxs .shape [1 ]):
248+ idx = topk_idxs [expert_idx , k_idx ].item ()
249+ group_idx = idx // group_size
250+ groups_to_correct .add ((expert_idx , group_idx ))
239251
240252 groups_to_correct = list (groups_to_correct )
241253
242- if reduction_axis == 0 :
243- weight = fns .transpose (weight )
244- reduction_axis = 1
254+ if reduction_axis == 1 :
255+ # Weights
256+ # 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
257+ # 2D: [1, hidden_dimension, out_features] -> [1, out_features, hidden_dimension]
258+ weight = fns .moveaxis (weight , - 1 , - 2 )
259+ reduction_axis = weight .ndim - 1
245260
246- shape_vector = fns .mean (X , axis = 1 )
261+ shape_vector = fns .mean (X , axis = - 1 )
247262 scale = fns .ones_like (shape_vector )
248263
249264 awq_config = deepcopy (config )
250265 awq_config .group_size = - 1
251266
252- for gi in groups_to_correct :
267+ for expert_idx , gi in groups_to_correct :
253268 offset = gi * group_size
254- gscale = s [offset : offset + group_size ]
269+ gscale = s [expert_idx , offset : offset + group_size ]
270+ gweight = weight [expert_idx , :, offset : offset + group_size ]
271+ gacts = X [expert_idx , offset : offset + group_size , :]
255272
256273 a_min = fns .astype (fns .quantile (gscale , 0.1 ), TensorDataType .float32 )
257274 a_max = 1e2
258275 gscale = fns .clip (gscale , a_min = a_min , a_max = a_max )
259276
260- gweight = weight [:, offset : offset + group_size ]
261- gacts = X [offset : offset + group_size , :]
262-
263277 fp32_out = fns .matmul (gweight , gacts )
264278 min_diff = fns .max (fns .abs (fp32_out ))
265279 best_scale = None
@@ -275,28 +289,26 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
275289 # per channel magnitudes for the previous MatMul
276290 # mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis]
277291 magnitudes = (
278- (prev_w [offset : offset + group_size ] / cur_scale ) * prev_s * prev_weight .shape [reduction_axis ]
292+ (prev_w [expert_idx , offset : offset + group_size ] / cur_scale )
293+ * prev_s
294+ * prev_weight .shape [reduction_axis ]
279295 )
280296 if magnitudes .max () >= threshold :
281297 cur_scale = AWQ ._clamp_scale (
282298 magnitudes ,
283299 threshold ,
284300 cur_scale ,
285- prev_w [offset : offset + group_size ]
301+ prev_w [expert_idx , offset : offset + group_size ]
286302 * prev_s
287303 * prev_weight .shape [reduction_axis ]
288304 / threshold ,
289305 )
290306
291307 weights_to_fake_quantize = gweight * cur_scale
292308 if not config .is_integer :
293- g_decompressed_weighs = float_quantize_dequantize_weight (
294- weights_to_fake_quantize , awq_config , reduction_axis
295- )
309+ g_decompressed_weighs = float_quantize_dequantize_weight (weights_to_fake_quantize , awq_config , - 1 )
296310 else :
297- g_decompressed_weighs = integer_quantize_dequantize_weight (
298- weights_to_fake_quantize , awq_config , reduction_axis
299- )
311+ g_decompressed_weighs = integer_quantize_dequantize_weight (weights_to_fake_quantize , awq_config , - 1 )
300312 sacts = gacts / fns .unsqueeze (cur_scale , 1 )
301313
302314 cur_out = fns .matmul (g_decompressed_weighs , sacts )
@@ -307,7 +319,10 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
307319 alpha += alpha_step
308320
309321 if best_scale is not None :
310- scale .data [offset : offset + group_size ] = best_scale .data
322+ scale .data [expert_idx , offset : offset + group_size ] = best_scale .data
323+
324+ if is_2d_weight :
325+ scale = fns .squeeze (scale , 0 ) # [1, hidden_dim] -> [hidden_dim]
311326
312327 return scale
313328
0 commit comments