Skip to content

Commit 6733bda

Browse files
committed
feat: Remove sub-ubatch batching
Unlike Qwen3Next, we don't hit big commplexity scaling issues here, so removing all of the batching gives a big reduction in complexity and a big boost to performance! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 426a97c commit 6733bda

File tree

1 file changed

+72
-113
lines changed

1 file changed

+72
-113
lines changed

src/models/graph-context-mamba.cpp

Lines changed: 72 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
243243
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
244244
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
245245

246-
// Empty y that will be extended with each chunk of tokens
247-
ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]);
248-
249246
if (n_seq_tokens == 1) {
250247
// if (true) {
251248
//DEBUG
@@ -259,9 +256,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
259256

260257
// otherwise, use the SSD formulation
261258

262-
// TODO: make this configurable
263-
const uint32_t chunk_size = 256;
264-
265259
// extract the state(s) for the sequences identified by ids
266260
if (ssm->ne[3] != ids->ne[0]) {
267261
ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1
@@ -287,115 +281,80 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
287281
ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs}
288282
cb(dtX, "dtX", il);
289283

290-
// loop over all chunks
284+
285+
// step 3: compute CB
291286
uint32_t repeats = n_head / n_group;
292-
for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) {
293-
294-
// chunk views
295-
const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i));
296-
// slice dtA on dim 1
297-
ggml_tensor * dtA_chunk = ggml_view_3d(ctx, dtA,
298-
dtA->ne[0], chunk_size_i, dtA->ne[2],
299-
dtA->nb[1], dtA->nb[2],
300-
chunk_i * dtA->nb[1]);
301-
cb(dtA_chunk, "dtA_chunk", il);
302-
// slice dtX on dim 2
303-
ggml_tensor * dtX_chunk = ggml_view_4d(ctx, dtX,
304-
dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3],
305-
dtX->nb[1], dtX->nb[2], dtX->nb[3],
306-
chunk_i * dtX->nb[2]);
307-
cb(dtX_chunk, "dtX_chunk", il);
308-
// slice B on dim 2
309-
ggml_tensor * B_chunk = ggml_view_4d(ctx, B,
310-
B->ne[0], B->ne[1], chunk_size_i, B->ne[3],
311-
B->nb[1], B->nb[2], B->nb[3],
312-
chunk_i * B->nb[2]);
313-
cb(B_chunk, "B_chunk", il);
314-
// slice C on dim 2
315-
ggml_tensor * C_chunk = ggml_view_4d(ctx, C,
316-
C->ne[0], C->ne[1], chunk_size_i, C->ne[3],
317-
C->nb[1], C->nb[2], C->nb[3],
318-
chunk_i * C->nb[2]);
319-
cb(C_chunk, "C_chunk", il);
320-
321-
// step 3: compute CB
322-
ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs}
323-
ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs}
324-
ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs}
325-
CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs}
326-
cb(CB, "CB", il);
327-
328-
// step 4: compute decay
329-
ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk,
330-
dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
331-
ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
332-
ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
333-
cb(segsum, "segsum", il);
334-
ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
335-
decay = ggml_permute(ctx, decay, 2, 1, 3, 0); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs}
336-
cb(decay, "decay", il);
337-
338-
// step 5: compute surrogate_attention_matrix
339-
ggml_tensor * CBdecay = ggml_mul(ctx, CB, ggml_cont(ctx, decay));
340-
ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG);
341-
cb(surrogate_attention_matrix, "surrogate_attention_matrix", il);
342-
343-
// step 6: compute y
344-
ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3));
345-
ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
346-
y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3));
347-
cb(y_chunk, "y_chunk", il);
348-
349-
// step 7: compute dtxdecay
350-
ggml_tensor * decay_last = ggml_view_4d(ctx, decay,
351-
decay->ne[0], 1, decay->ne[2], decay->ne[3],
352-
decay->nb[1], decay->nb[2], decay->nb[3],
353-
(decay->ne[1] - 1) * decay->nb[1]);
354-
decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3));
355-
cb(decay_last, "decay_last", il);
356-
B_perm = ggml_cont(ctx, B_perm);
357-
B_perm = ggml_repeat_4d(ctx, B_perm,
358-
B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]);
359-
ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last);
360-
dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3));
361-
cb(dtxdecay, "dtxdecay", il);
362-
363-
// step 8: compute next_state
364-
ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
365-
if (next_state->type != ssm->type) {
366-
next_state = ggml_cast(ctx, next_state, ssm->type);
367-
}
368-
cb(next_state, "next_state", il);
369-
370-
// TODO: Skip y and state updates if no previous state
371-
372-
// step 9: update from previous state
373-
ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1)); // {n_head, chunk_size_i, n_seqs}
374-
cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
375-
ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
376-
exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3],
377-
exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
378-
(exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs}
379-
cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
380-
ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs}
381-
next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm)));
382-
cb(next_state, "next_state_updated", il);
383-
384-
// step 10: update from previous y
385-
ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm);
386-
cb(y_prev, "y_prev", il);
387-
y_prev = ggml_mul(ctx,
388-
ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)),
389-
ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 3, 0)));
390-
cb(y_prev, "y_prev_mul", il);
391-
y_chunk = ggml_add(ctx, y_chunk, y_prev);
392-
cb(y_chunk, "y_chunk_updated", il);
393-
394-
// step 11: recurse
395-
y = ggml_concat(ctx, y, y_chunk, 2);
396-
cb(y, "y", il);
397-
ssm = next_state;
287+
ggml_tensor * C_perm = ggml_permute(ctx, C, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs}
288+
ggml_tensor * B_perm = ggml_permute(ctx, B, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs}
289+
ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs}
290+
CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs}
291+
cb(CB, "CB", il);
292+
293+
// step 4: compute decay
294+
ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA,
295+
dtA->ne[0], dtA->ne[1], dtA->ne[2], dtA->ne[3] * n_seq_tokens); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1}
296+
ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1}
297+
ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1, 1); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1}
298+
cb(segsum, "segsum", il);
299+
ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1}
300+
decay = ggml_permute(ctx, decay, 2, 1, 3, 0); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
301+
cb(decay, "decay", il);
302+
303+
// step 5: compute surrogate_attention_matrix
304+
ggml_tensor * CBdecay = ggml_mul(ctx, CB, ggml_cont(ctx, decay));
305+
ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG);
306+
cb(surrogate_attention_matrix, "surrogate_attention_matrix", il);
307+
308+
// step 6: compute y
309+
ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX, 1, 2, 0, 3));
310+
ggml_tensor * y = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
311+
y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3));
312+
cb(y, "y", il);
313+
314+
// step 7: compute dtxdecay
315+
ggml_tensor * decay_last = ggml_view_4d(ctx, decay,
316+
decay->ne[0], 1, decay->ne[2], decay->ne[3],
317+
decay->nb[1], decay->nb[2], decay->nb[3],
318+
(decay->ne[1] - 1) * decay->nb[1]);
319+
decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3));
320+
cb(decay_last, "decay_last", il);
321+
B_perm = ggml_cont(ctx, B_perm);
322+
B_perm = ggml_repeat_4d(ctx, B_perm,
323+
B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]);
324+
ggml_tensor * dtxdecay = ggml_mul(ctx, dtX, decay_last);
325+
dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3));
326+
cb(dtxdecay, "dtxdecay", il);
327+
328+
// step 8: compute next_state
329+
ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
330+
if (next_state->type != ssm->type) {
331+
next_state = ggml_cast(ctx, next_state, ssm->type);
398332
}
333+
cb(next_state, "next_state", il);
334+
335+
// TODO: Skip y and state updates if no previous state
336+
337+
// step 9: update from previous state
338+
ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA, 1)); // {n_head, chunk_size_i, n_seqs}
339+
cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
340+
ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
341+
exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3],
342+
exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
343+
(exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs}
344+
cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
345+
ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs}
346+
next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm)));
347+
cb(next_state, "next_state_updated", il);
348+
349+
// step 10: update from previous y
350+
ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C, 0, 2, 1, 3), ssm);
351+
cb(y_prev, "y_prev", il);
352+
y_prev = ggml_mul(ctx,
353+
ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)),
354+
ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 3, 0)));
355+
cb(y_prev, "y_prev_mul", il);
356+
y = ggml_add(ctx, y, y_prev);
357+
cb(y, "y_updated", il);
399358

400359
// Concat the output y and state
401360
if (ssm->type != y->type) {

0 commit comments

Comments
 (0)