@@ -178,23 +178,36 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
178178 } else {
179179 auto & e = m_stats[wname];
180180 if (e.values .empty ()) {
181- e.values .resize (src1->ne [0 ], 0 );
182- e.counts .resize (src1->ne [0 ], 0 );
181+ if (src0->ne [3 ] > 1 ) {
182+ LOG_ERR (" Unsupported 4D tensor %s\n " , wname.c_str ());
183+ exit (1 );
184+ }
185+ // If we have a 3D tensor as it is the case for the attn_k_b and attn_v_b for DeepSeek MLA models,
186+ // than we need to compute the imatrix for each head, and not just one imatrx for all heads.
187+ // Hence, the storage we need is src0->ne[0]*src0->ne[2].
188+ e.values .resize (src0->ne [0 ]*src0->ne [2 ], 0 );
189+ e.counts .resize (src0->ne [0 ]*src0->ne [2 ], 0 );
183190 }
184- else if (e.values .size () != (size_t )src1 ->ne [0 ]) {
191+ else if (e.values .size () != (size_t )(src0 ->ne [0 ]*src0-> ne [ 2 ]) ) {
185192 LOG_ERR (" %s: inconsistent size for %s (%d vs %d)\n " , __func__, wname.c_str (), (int )e.values .size (), (int )src1->ne [0 ]);
186193 exit (1 ); // GGML_ABORT("fatal error");
187194 }
188195 ++e.ncall ;
189196 LOG_DBGV (2 , " %s[%d]: %32s, %s, %5d x %5d, %d\n " , __func__, m_last_call, wname.c_str (), ggml_op_name (t->op ), (int )src1->ne [0 ], (int )src1->ne [1 ], (int )src1->type );
190- for (int row = 0 ; row < (int )src1->ne [1 ]; ++row) {
191- const float * x = (const float *) (data + row * src1->nb [1 ]);
192- for (int j = 0 ; j < (int )src1->ne [0 ]; ++j) {
193- e.values [j] += x[j]*x[j];
194- e.counts [j]++;
195- if (!std::isfinite (e.values [j])) {
196- LOG_ERR (" %f detected in %s\n " , e.values [j], wname.c_str ());
197- exit (1 );
197+ int rk2 = src1->ne [2 ]/src0->ne [2 ];
198+ for (int i12 = 0 ; i12 < (int )src1->ne [2 ]; ++i12) { // i.e., loop over attention heads for MLA models
199+ int i02 = i12/rk2;
200+ auto values = e.values .data () + i02*src0->ne [0 ];
201+ auto counts = e.counts .data () + i02*src0->ne [0 ];
202+ for (int i11 = 0 ; i11 < (int )src1->ne [1 ]; ++i11) {
203+ const float * x = (const float *)((const char *)data + i11*src1->nb [1 ] + i12*src1->nb [2 ]);
204+ for (int j = 0 ; j < (int )src1->ne [0 ]; ++j) {
205+ values[j] += x[j]*x[j];
206+ counts[j]++;
207+ if (!std::isfinite (values[j])) {
208+ LOG_ERR (" %f detected in %s\n " , e.values [j], wname.c_str ());
209+ exit (1 );
210+ }
198211 }
199212 }
200213 }
0 commit comments