@@ -33,6 +33,7 @@ struct Stats {
3333 std::vector<float > values;
3434 std::vector<int > counts;
3535 int ncall = 0 ;
36+ int n_as = 1 ;
3637};
3738
3839class IMatrixCollector {
@@ -127,11 +128,15 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
127128 if (e.values .empty ()) {
128129 e.values .resize (src1->ne [0 ]*n_as, 0 );
129130 e.counts .resize (src1->ne [0 ]*n_as, 0 );
131+ e.n_as = n_as;
130132 }
131133 else if (e.values .size () != (size_t )src1->ne [0 ]*n_as) {
132134 LOG_ERR (" %s: inconsistent size for %s (%d vs %d)\n " , __func__, wname.c_str (), (int )e.values .size (), (int )src1->ne [0 ]*n_as);
133135 exit (1 ); // GGML_ABORT("fatal error");
134136 }
137+ else if (e.n_as != n_as) {
138+ LOG_ERR (" %s: inconsistent n_as for %s (%d vs %d)\n " , __func__, wname.c_str (), e.n_as , n_as);
139+ }
135140 LOG_DBGV (2 , " %s[%d]: %32s, %s, %5d x %5d, %d\n " , __func__, m_last_call, wname.c_str (), ggml_op_name (t->op ), (int )src1->ne [0 ], (int )src1->ne [2 ], (int )src1->type );
136141 // loop over all possible experts, regardless if they are used or not in the batch
137142 for (int ex = 0 ; ex < n_as; ++ex) {
@@ -173,23 +178,36 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
173178 } else {
174179 auto & e = m_stats[wname];
175180 if (e.values .empty ()) {
176- e.values .resize (src1->ne [0 ], 0 );
177- e.counts .resize (src1->ne [0 ], 0 );
181+ if (src0->ne [3 ] > 1 ) {
182+ LOG_ERR (" Unsupported 4D tensor %s\n " , wname.c_str ());
183+ exit (1 );
184+ }
185+ // If we have a 3D tensor as it is the case for the attn_k_b and attn_v_b for DeepSeek MLA models,
186+ // than we need to compute the imatrix for each head, and not just one imatrx for all heads.
187+ // Hence, the storage we need is src0->ne[0]*src0->ne[2].
188+ e.values .resize (src0->ne [0 ]*src0->ne [2 ], 0 );
189+ e.counts .resize (src0->ne [0 ]*src0->ne [2 ], 0 );
178190 }
179- else if (e.values .size () != (size_t )src1 ->ne [0 ]) {
191+ else if (e.values .size () != (size_t )(src0 ->ne [0 ]*src0-> ne [ 2 ]) ) {
180192 LOG_ERR (" %s: inconsistent size for %s (%d vs %d)\n " , __func__, wname.c_str (), (int )e.values .size (), (int )src1->ne [0 ]);
181193 exit (1 ); // GGML_ABORT("fatal error");
182194 }
183195 ++e.ncall ;
184196 LOG_DBGV (2 , " %s[%d]: %32s, %s, %5d x %5d, %d\n " , __func__, m_last_call, wname.c_str (), ggml_op_name (t->op ), (int )src1->ne [0 ], (int )src1->ne [1 ], (int )src1->type );
185- for (int row = 0 ; row < (int )src1->ne [1 ]; ++row) {
186- const float * x = (const float *) (data + row * src1->nb [1 ]);
187- for (int j = 0 ; j < (int )src1->ne [0 ]; ++j) {
188- e.values [j] += x[j]*x[j];
189- e.counts [j]++;
190- if (!std::isfinite (e.values [j])) {
191- LOG_ERR (" %f detected in %s\n " , e.values [j], wname.c_str ());
192- exit (1 );
197+ int rk2 = src1->ne [2 ]/src0->ne [2 ];
198+ for (int i12 = 0 ; i12 < (int )src1->ne [2 ]; ++i12) { // i.e., loop over attention heads for MLA models
199+ int i02 = i12/rk2;
200+ auto values = e.values .data () + i02*src0->ne [0 ];
201+ auto counts = e.counts .data () + i02*src0->ne [0 ];
202+ for (int i11 = 0 ; i11 < (int )src1->ne [1 ]; ++i11) {
203+ const float * x = (const float *)((const char *)data + i11*src1->nb [1 ] + i12*src1->nb [2 ]);
204+ for (int j = 0 ; j < (int )src1->ne [0 ]; ++j) {
205+ values[j] += x[j]*x[j];
206+ counts[j]++;
207+ if (!std::isfinite (values[j])) {
208+ LOG_ERR (" %f detected in %s\n " , e.values [j], wname.c_str ());
209+ exit (1 );
210+ }
193211 }
194212 }
195213 }
@@ -221,6 +239,10 @@ void IMatrixCollector::save_imatrix(int ncall) const {
221239 int n_entries = 0 ;
222240 std::vector<std::string> to_store;
223241
242+ // Retrieve the REQUIRED_GOOD_EXPERT_PERCENTAGE from the environment
243+ const char * required_good_expert_percentage_env_value = getenv (" REQUIRED_GOOD_EXPERT_PERCENTAGE" );
244+ double required_good_expert_percentage = required_good_expert_percentage_env_value ? std::clamp (std::stod (required_good_expert_percentage_env_value), 0.0 , 100.0 ) : 90.0 ;
245+
224246 bool is_first = true ; // for printing
225247 for (const auto & kv : m_stats) {
226248 const int n_all = kv.second .counts .size ();
@@ -247,8 +269,40 @@ void IMatrixCollector::save_imatrix(int ncall) const {
247269 }
248270
249271 if (n_zeros > 0 ) {
250- LOG_WRN (" %s: entry '%40s' has partial data (%.2f%%) - skipping\n " , __func__, kv.first .c_str (), 100 .0f * (n_all - n_zeros) / n_all);
251- continue ;
272+ LOG_WRN (" %s: entry '%40s' has partial data (%.2f%%)\n " , __func__, kv.first .c_str (), 100 .0f * (n_all - n_zeros) / n_all);
273+ bool store_it = false ;
274+ if (kv.second .n_as > 1 ) {
275+ int n_per_expert = n_all / kv.second .n_as ;
276+ std::vector<int > bad_experts;
277+ bad_experts.reserve (kv.second .n_as );
278+ for (int i = 0 ; i < kv.second .n_as ; ++i) {
279+ auto counts = kv.second .counts .data () + i*n_per_expert;
280+ int nz_i = 0 ;
281+ for (int j = 0 ; j < n_per_expert; ++j) {
282+ if (counts[j] == 0 ) ++nz_i;
283+ }
284+ if (nz_i > 0 ) bad_experts.push_back (i);
285+ }
286+ size_t required_good_experts = round ((kv.second .n_as * required_good_expert_percentage) / 100.0 );
287+ size_t good_experts = kv.second .n_as - bad_experts.size ();
288+ LOG_WRN (" %s: %d out of %d experts are missing data - %ld out of %ld required\n " , __func__, int (bad_experts.size ()), kv.second .n_as , good_experts, required_good_experts);
289+ if (good_experts >= required_good_experts) {
290+ LOG_WRN (" %s: %d out of %d experts are missing data - storing but be aware\n " , __func__, int (bad_experts.size ()), kv.second .n_as );
291+ store_it = true ;
292+ for (auto i : bad_experts) {
293+ auto counts = const_cast <int *>(kv.second .counts .data ()) + i * n_per_expert;
294+ auto values = const_cast <float *>(kv.second .values .data ()) + i * n_per_expert;
295+ for (int j = 0 ; j < n_per_expert; ++j) {
296+ counts[j] = 1 ;
297+ values[j] = 1 ;
298+ }
299+ }
300+ }
301+ }
302+ if (!store_it) {
303+ LOG_WRN (" %s: Skipping expert with missing data!\n " , __func__);
304+ continue ;
305+ }
252306 }
253307
254308 n_entries++;
0 commit comments