@@ -39,6 +39,7 @@ struct Stats {
3939 std::vector<float > values;
4040 std::vector<int > counts;
4141 int ncall = 0 ;
42+ int n_as = 1 ;
4243};
4344
4445class IMatrixCollector {
@@ -132,11 +133,15 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
132133 if (e.values .empty ()) {
133134 e.values .resize (src1->ne [0 ]*n_as, 0 );
134135 e.counts .resize (src1->ne [0 ]*n_as, 0 );
136+ e.n_as = n_as;
135137 }
136138 else if (e.values .size () != (size_t )src1->ne [0 ]*n_as) {
137139 fprintf (stderr, " Oops: inconsistent size for %s (%d vs %d)\n " , wname.c_str (), (int )e.values .size (), (int )src1->ne [0 ]*n_as);
138140 exit (1 ); // GGML_ABORT("fatal error");
139141 }
142+ else if (e.n_as != n_as) {
143+ fprintf (stderr, " Oops: inconsistent n_as for %s (%d vs %d)\n " , wname.c_str (), e.n_as , n_as);
144+ }
140145 if (m_params.verbosity > 1 ) {
141146 printf (" %s[%d]: %32s, %s, %5d x %5d, %d\n " , __func__, m_last_call, wname.c_str (), ggml_op_name (t->op ), (int )src1->ne [0 ], (int )src1->ne [2 ], (int )src1->type );
142147 }
@@ -258,8 +263,38 @@ void IMatrixCollector::save_imatrix(int ncall) const {
258263 }
259264
260265 if (n_zeros > 0 ) {
261- fprintf (stderr, " %s: entry '%40s' has partial data (%.2f%%) - skipping\n " , __func__, kv.first .c_str (), 100 .0f * (n_all - n_zeros) / n_all);
262- continue ;
266+ fprintf (stderr, " %s: entry '%40s' has partial data (%.2f%%)" , __func__, kv.first .c_str (), 100 .0f * (n_all - n_zeros) / n_all);
267+ bool store_it = false ;
268+ if (kv.second .n_as > 1 ) {
269+ int n_per_expert = n_all / kv.second .n_as ;
270+ std::vector<int > bad_experts;
271+ bad_experts.reserve (kv.second .n_as );
272+ for (int i = 0 ; i < kv.second .n_as ; ++i) {
273+ auto counts = kv.second .counts .data () + i*n_per_expert;
274+ int nz_i = 0 ;
275+ for (int j = 0 ; j < n_per_expert; ++j) {
276+ if (counts[j] == 0 ) ++nz_i;
277+ }
278+ if (nz_i > 0 ) bad_experts.push_back (i);
279+ }
280+ fprintf (stderr, " %d out of %d experts are missing data" , int (bad_experts.size ()), kv.second .n_as );
281+ if (bad_experts.size () < round (kv.second .n_as * 0.05 )) {
282+ fprintf (stderr, " Storing **but be aware**\n " );
283+ store_it = true ;
284+ for (auto i : bad_experts) {
285+ auto counts = (int *)kv.second .counts .data () + i*n_per_expert;
286+ auto values = (float *)kv.second .values .data () + i*n_per_expert;
287+ for (int j = 0 ; j < n_per_expert; ++j) {
288+ counts[j] = 1 ;
289+ values[j] = 1 ;
290+ }
291+ }
292+ }
293+ }
294+ if (!store_it) {
295+ fprintf (stderr, " - skipping\n " );
296+ continue ;
297+ }
263298 }
264299
265300 n_entries++;
0 commit comments