@@ -217,7 +217,6 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
217217 s = s .astype (TensorDataType .float32 )
218218 X = X .astype (TensorDataType .float32 )
219219
220- is_3d_weight = weight .ndim == 3
221220 is_2d_weight = weight .ndim == 2
222221
223222 assert isinstance (wp .reduction_axes , tuple ) and len (wp .reduction_axes ) == 1
@@ -234,6 +233,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
234233 s = fns .unsqueeze (s , 0 ) # [hidden_dim] -> [1, hidden_dim]
235234 X = fns .unsqueeze (X , 0 ) # [hidden_dim, samples] -> [1, hidden_dim, samples]
236235 weight = fns .unsqueeze (weight , 0 ) # [out_features, hidden_dim] -> [1, out_features, hidden_dim]
236+ reduction_axis += 1
237237
238238 top_k = max (int (s .shape [- 1 ] * self ._percent_to_apply ), 1 )
239239 topk_idxs = fns .argsort (- s )[:, :top_k ]
@@ -251,7 +251,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
251251
252252 groups_to_correct = list (groups_to_correct )
253253
254- if reduction_axis == 0 or ( reduction_axis == 1 and is_3d_weight ) :
254+ if reduction_axis == 1 :
255255 # Weights
256256 # 3D: [num_experts, hidden_dimension, out_features] -> [num_experts, out_features, hidden_dimension]
257257 # 2D: [1, hidden_dimension, out_features] -> [1, out_features, hidden_dimension]
@@ -298,7 +298,7 @@ def _data_aware_step(self, wp, weight, statistics, prev_weight=None, prev_statis
298298 magnitudes ,
299299 threshold ,
300300 cur_scale ,
301- prev_w [offset : offset + group_size ]
301+ prev_w [expert_idx , offset : offset + group_size ]
302302 * prev_s
303303 * prev_weight .shape [reduction_axis ]
304304 / threshold ,
0 commit comments