@@ -218,20 +218,64 @@ void IMatrixCollector::save_imatrix(int ncall) const {
218218 fname += std::to_string (ncall);
219219 }
220220
221+ // avoid writing imatrix entries that do not have full data
222+ // this can happen with MoE models where some of the experts end up not being exercised by the provided training data
223+
224+ int n_entries = 0 ;
225+ std::vector<std::string> to_store;
226+
227+ bool is_first = true ; // for printing
228+ for (const auto & kv : m_stats) {
229+ const int n_all = kv.second .counts .size ();
230+
231+ if (n_all == 0 ) {
232+ continue ;
233+ }
234+
235+ int n_zeros = 0 ;
236+ for (const int c : kv.second .counts ) {
237+ if (c == 0 ) {
238+ n_zeros++;
239+ }
240+ }
241+
242+ if (n_zeros != 0 && is_first) {
243+ fprintf (stderr, " \n " );
244+ is_first = false ;
245+ }
246+
247+ if (n_zeros == n_all) {
248+ fprintf (stderr, " %s: entry '%40s' has no data - skipping\n " , __func__, kv.first .c_str ());
249+ continue ;
250+ }
251+
252+ if (n_zeros > 0 ) {
253+ fprintf (stderr, " %s: entry '%40s' has partial data (%.2f%%) - skipping\n " , __func__, kv.first .c_str (), 100 .0f * (n_all - n_zeros) / n_all);
254+ continue ;
255+ }
256+
257+ n_entries++;
258+ to_store.push_back (kv.first );
259+ }
260+
261+ if (to_store.size () < m_stats.size ()) {
262+ fprintf (stderr, " %s: warning: storing only %zu out of %zu entries\n " , __func__, to_store.size (), m_stats.size ());
263+ }
264+
221265 std::ofstream out (fname, std::ios::binary);
222- int n_entries = m_stats.size ();
223266 out.write ((const char *) &n_entries, sizeof (n_entries));
224- for (const auto & p : m_stats) {
225- int len = p.first .size ();
267+ for (const auto & name : to_store) {
268+ const auto & stat = m_stats.at (name);
269+ int len = name.size ();
226270 out.write ((const char *) &len, sizeof (len));
227- out.write (p. first .c_str (), len);
228- out.write ((const char *) &p. second . ncall , sizeof (p. second .ncall ));
229- int nval = p. second .values .size ();
271+ out.write (name .c_str (), len);
272+ out.write ((const char *) &stat. ncall , sizeof (stat .ncall ));
273+ int nval = stat .values .size ();
230274 out.write ((const char *) &nval, sizeof (nval));
231275 if (nval > 0 ) {
232276 std::vector<float > tmp (nval);
233277 for (int i = 0 ; i < nval; i++) {
234- tmp[i] = (p. second . values [i] / static_cast <float >(p. second . counts [i])) * static_cast <float >(p. second .ncall );
278+ tmp[i] = (stat. values [i] / static_cast <float >(stat. counts [i])) * static_cast <float >(stat .ncall );
235279 }
236280 out.write ((const char *)tmp.data (), nval*sizeof (float ));
237281 }
0 commit comments