@@ -218,13 +218,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
218218 X = X .astype (TensorDataType .float32 )
219219
220220 is_3d_weight = weight .ndim == 3
221- if s .ndim == 1 :
222- s = fns .unsqueeze (s , 0 ) # [hidden_dim] -> [1, hidden_dim]
223- X = fns .unsqueeze (X , 0 ) # [hidden_dim, samples] -> [1, hidden_dim, samples]
224- weight = fns .unsqueeze (weight , 0 ) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
225- is_2d_weight = True
226- else :
227- is_2d_weight = False
221+ is_2d_weight = weight .ndim == 2
228222
229223 assert isinstance (wp .reduction_axes , tuple ) and len (wp .reduction_axes ) == 1
230224 reduction_axis = wp .reduction_axes [0 ]
@@ -233,7 +227,13 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
233227 if prev_statistics is not None and prev_weight is not None :
234228 prev_s , _ = process_stats (prev_statistics , self ._subset_size )
235229 prev_s = prev_s .astype (TensorDataType .float32 ).max ().item ()
236- prev_w = fns .mean (fns .abs (prev_weight ), axis = reduction_axis )
230+ prev_weight = fns .unsqueeze (prev_weight , 0 ) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
231+ prev_w = fns .mean (fns .abs (prev_weight ), axis = reduction_axis + 1 )
232+
233+ if is_2d_weight :
234+ s = fns .unsqueeze (s , 0 ) # [hidden_dim] -> [1, hidden_dim]
235+ X = fns .unsqueeze (X , 0 ) # [hidden_dim, samples] -> [1, hidden_dim, samples]
236+ weight = fns .unsqueeze (weight , 0 ) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
237237
238238 top_k = max (int (s .shape [- 1 ] * self ._percent_to_apply ), 1 )
239239 topk_idxs = fns .argsort (- s )[:, :top_k ]
@@ -289,7 +289,9 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
289289 # per channel magnitudes for the previous MatMul
290290 # mean(abs(prev_weight)) * max(abs((prev_activation))) * prev_weight.shape[reduction_axis]
291291 magnitudes = (
292- (prev_w [offset : offset + group_size ] / cur_scale ) * prev_s * prev_weight .shape [reduction_axis ]
292+ (prev_w [expert_idx , offset : offset + group_size ] / cur_scale )
293+ * prev_s
294+ * prev_weight .shape [reduction_axis ]
293295 )
294296 if magnitudes .max () >= threshold :
295297 cur_scale = AWQ ._clamp_scale (
0 commit comments