Skip to content

Commit e458c30

Browse files
committed
Fix imatrix calculation for MLA models
1 parent 91ecc29 commit e458c30

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

tools/imatrix/imatrix.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,23 +178,36 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
178178
} else {
179179
auto & e = m_stats[wname];
180180
if (e.values.empty()) {
181-
e.values.resize(src1->ne[0], 0);
182-
e.counts.resize(src1->ne[0], 0);
181+
if (src0->ne[3] > 1) {
182+
LOG_ERR("Unsupported 4D tensor %s\n", wname.c_str());
183+
exit(1);
184+
}
185+
// If we have a 3D tensor as it is the case for the attn_k_b and attn_v_b for DeepSeek MLA models,
186+
// than we need to compute the imatrix for each head, and not just one imatrx for all heads.
187+
// Hence, the storage we need is src0->ne[0]*src0->ne[2].
188+
e.values.resize(src0->ne[0]*src0->ne[2], 0);
189+
e.counts.resize(src0->ne[0]*src0->ne[2], 0);
183190
}
184-
else if (e.values.size() != (size_t)src1->ne[0]) {
191+
else if (e.values.size() != (size_t)(src0->ne[0]*src0->ne[2])) {
185192
LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
186193
exit(1); //GGML_ABORT("fatal error");
187194
}
188195
++e.ncall;
189196
LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
190-
for (int row = 0; row < (int)src1->ne[1]; ++row) {
191-
const float * x = (const float *) (data + row * src1->nb[1]);
192-
for (int j = 0; j < (int)src1->ne[0]; ++j) {
193-
e.values[j] += x[j]*x[j];
194-
e.counts[j]++;
195-
if (!std::isfinite(e.values[j])) {
196-
LOG_ERR("%f detected in %s\n", e.values[j], wname.c_str());
197-
exit(1);
197+
int rk2 = src1->ne[2]/src0->ne[2];
198+
for (int i12 = 0; i12 < (int)src1->ne[2]; ++i12) { // i.e., loop over attention heads for MLA models
199+
int i02 = i12/rk2;
200+
auto values = e.values.data() + i02*src0->ne[0];
201+
auto counts = e.counts.data() + i02*src0->ne[0];
202+
for (int i11 = 0; i11 < (int)src1->ne[1]; ++i11) {
203+
const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
204+
for (int j = 0; j < (int)src1->ne[0]; ++j) {
205+
values[j] += x[j]*x[j];
206+
counts[j]++;
207+
if (!std::isfinite(values[j])) {
208+
LOG_ERR("%f detected in %s\n", e.values[j], wname.c_str());
209+
exit(1);
210+
}
198211
}
199212
}
200213
}

0 commit comments

Comments
 (0)