@@ -243,9 +243,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
243243 auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
244244 ggml_tensor * ssm = ggml_reshape_4d (ctx, states, d_state, head_dim, n_head, mctx_cur->get_size ());
245245
246- // Empty y that will be extended with each chunk of tokens
247- ggml_tensor * y = ggml_new_tensor_4d (ctx, x->type , x->ne [0 ], x->ne [1 ], 0 , x->ne [3 ]);
248-
249246 if (n_seq_tokens == 1 ) {
250247 // if (true) {
251248 // DEBUG
@@ -259,9 +256,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
259256
260257 // otherwise, use the SSD formulation
261258
262- // TODO: make this configurable
263- const uint32_t chunk_size = 256 ;
264-
265259 // extract the state(s) for the sequences identified by ids
266260 if (ssm->ne [3 ] != ids->ne [0 ]) {
267261 ggml_tensor * ssm_perm = ggml_permute (ctx, ssm, 0 , 2 , 3 , 1 ); // put the target dim in dim 1
@@ -287,115 +281,80 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
287281 ggml_tensor * dtX = ggml_mul (ctx, x, ggml_reshape_4d (ctx, dt_softplus, 1 , dt_softplus->ne [0 ], dt_softplus->ne [1 ], dt_softplus->ne [2 ])); // {head_dim, n_head, n_seq_tokens, n_seqs}
288282 cb (dtX, " dtX" , il);
289283
290- // loop over all chunks
284+
285+ // step 3: compute CB
291286 uint32_t repeats = n_head / n_group;
292- for (auto chunk_i = 0 ; chunk_i < n_seq_tokens; chunk_i += chunk_size) {
293-
294- // chunk views
295- const auto chunk_size_i = std::min (chunk_size, uint32_t (n_seq_tokens - chunk_i));
296- // slice dtA on dim 1
297- ggml_tensor * dtA_chunk = ggml_view_3d (ctx, dtA,
298- dtA->ne [0 ], chunk_size_i, dtA->ne [2 ],
299- dtA->nb [1 ], dtA->nb [2 ],
300- chunk_i * dtA->nb [1 ]);
301- cb (dtA_chunk, " dtA_chunk" , il);
302- // slice dtX on dim 2
303- ggml_tensor * dtX_chunk = ggml_view_4d (ctx, dtX,
304- dtX->ne [0 ], dtX->ne [1 ], chunk_size_i, dtX->ne [3 ],
305- dtX->nb [1 ], dtX->nb [2 ], dtX->nb [3 ],
306- chunk_i * dtX->nb [2 ]);
307- cb (dtX_chunk, " dtX_chunk" , il);
308- // slice B on dim 2
309- ggml_tensor * B_chunk = ggml_view_4d (ctx, B,
310- B->ne [0 ], B->ne [1 ], chunk_size_i, B->ne [3 ],
311- B->nb [1 ], B->nb [2 ], B->nb [3 ],
312- chunk_i * B->nb [2 ]);
313- cb (B_chunk, " B_chunk" , il);
314- // slice C on dim 2
315- ggml_tensor * C_chunk = ggml_view_4d (ctx, C,
316- C->ne [0 ], C->ne [1 ], chunk_size_i, C->ne [3 ],
317- C->nb [1 ], C->nb [2 ], C->nb [3 ],
318- chunk_i * C->nb [2 ]);
319- cb (C_chunk, " C_chunk" , il);
320-
321- // step 3: compute CB
322- ggml_tensor * C_perm = ggml_permute (ctx, C_chunk, 0 , 2 , 1 , 3 ); // {d_state, n_seq_tokens, n_group, n_seqs}
323- ggml_tensor * B_perm = ggml_permute (ctx, B_chunk, 0 , 2 , 1 , 3 ); // {d_state, n_seq_tokens, n_group, n_seqs}
324- ggml_tensor * CB = ggml_mul_mat (ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs}
325- CB = ggml_repeat_4d (ctx, CB, CB->ne [0 ], CB->ne [1 ], CB->ne [2 ] * repeats, CB->ne [3 ]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs}
326- cb (CB, " CB" , il);
327-
328- // step 4: compute decay
329- ggml_tensor * dtA_tmp0 = ggml_repeat_4d (ctx, dtA_chunk,
330- dtA_chunk->ne [0 ], dtA_chunk->ne [1 ], dtA_chunk->ne [2 ], dtA_chunk->ne [3 ] * chunk_size_i); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
331- ggml_tensor * dtA_tmp1 = ggml_tri_dims (ctx, dtA_tmp0, nan (" " ), GGML_TRI_TYPE_LOWER, 3 , 1 ); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
332- ggml_tensor * segsum = ggml_cumsum (ctx, dtA_tmp1, 1 ); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
333- cb (segsum, " segsum" , il);
334- ggml_tensor * decay = ggml_exp (ctx, segsum); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
335- decay = ggml_permute (ctx, decay, 2 , 1 , 3 , 0 ); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs}
336- cb (decay, " decay" , il);
337-
338- // step 5: compute surrogate_attention_matrix
339- ggml_tensor * CBdecay = ggml_mul (ctx, CB, ggml_cont (ctx, decay));
340- ggml_tensor * surrogate_attention_matrix = ggml_tri_keep (ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG);
341- cb (surrogate_attention_matrix, " surrogate_attention_matrix" , il);
342-
343- // step 6: compute y
344- ggml_tensor * dtX_chunk_perm = ggml_cont (ctx, ggml_permute (ctx, dtX_chunk, 1 , 2 , 0 , 3 ));
345- ggml_tensor * y_chunk = ggml_mul_mat (ctx, dtX_chunk_perm, surrogate_attention_matrix);
346- y_chunk = ggml_cont (ctx, ggml_permute (ctx, y_chunk, 0 , 2 , 1 , 3 ));
347- cb (y_chunk, " y_chunk" , il);
348-
349- // step 7: compute dtxdecay
350- ggml_tensor * decay_last = ggml_view_4d (ctx, decay,
351- decay->ne [0 ], 1 , decay->ne [2 ], decay->ne [3 ],
352- decay->nb [1 ], decay->nb [2 ], decay->nb [3 ],
353- (decay->ne [1 ] - 1 ) * decay->nb [1 ]);
354- decay_last = ggml_cont (ctx, ggml_permute (ctx, decay_last, 2 , 0 , 1 , 3 ));
355- cb (decay_last, " decay_last" , il);
356- B_perm = ggml_cont (ctx, B_perm);
357- B_perm = ggml_repeat_4d (ctx, B_perm,
358- B_perm->ne [0 ], B_perm->ne [1 ], B_perm->ne [2 ] * repeats, B_perm->ne [3 ]);
359- ggml_tensor * dtxdecay = ggml_mul (ctx, dtX_chunk, decay_last);
360- dtxdecay = ggml_cont (ctx, ggml_permute (ctx, dtxdecay, 1 , 2 , 0 , 3 ));
361- cb (dtxdecay, " dtxdecay" , il);
362-
363- // step 8: compute next_state
364- ggml_tensor * next_state = ggml_mul_mat (ctx, ggml_cont (ctx, ggml_permute (ctx, B_perm, 1 , 0 , 2 , 3 )), dtxdecay);
365- if (next_state->type != ssm->type ) {
366- next_state = ggml_cast (ctx, next_state, ssm->type );
367- }
368- cb (next_state, " next_state" , il);
369-
370- // TODO: Skip y and state updates if no previous state
371-
372- // step 9: update from previous state
373- ggml_tensor * exp_dtA_cumsum = ggml_exp (ctx, ggml_cumsum (ctx, dtA_chunk, 1 )); // {n_head, chunk_size_i, n_seqs}
374- cb (exp_dtA_cumsum, " exp_dtA_cumsum" , il);
375- ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d (ctx, exp_dtA_cumsum,
376- exp_dtA_cumsum->ne [0 ], 1 , exp_dtA_cumsum->ne [2 ], exp_dtA_cumsum->ne [3 ],
377- exp_dtA_cumsum->nb [1 ], exp_dtA_cumsum->nb [2 ], exp_dtA_cumsum->nb [3 ],
378- (exp_dtA_cumsum->ne [1 ] - 1 ) * exp_dtA_cumsum->nb [1 ]); // {n_head, 1, n_seqs}
379- cb (exp_dtA_cumsum_last, " exp_dtA_cumsum_last" , il);
380- ggml_tensor * exp_dtA_cumsum_perm = ggml_permute (ctx, exp_dtA_cumsum_last, 2 , 1 , 3 , 0 ); // {1, 1, n_head, n_seqs}
381- next_state = ggml_add (ctx, next_state, ggml_mul (ctx, ssm, ggml_cont (ctx, exp_dtA_cumsum_perm)));
382- cb (next_state, " next_state_updated" , il);
383-
384- // step 10: update from previous y
385- ggml_tensor * y_prev = ggml_mul_mat (ctx, ggml_permute (ctx, C_chunk, 0 , 2 , 1 , 3 ), ssm);
386- cb (y_prev, " y_prev" , il);
387- y_prev = ggml_mul (ctx,
388- ggml_cont (ctx, ggml_permute (ctx, y_prev, 2 , 0 , 1 , 3 )),
389- ggml_cont (ctx, ggml_permute (ctx, exp_dtA_cumsum, 1 , 2 , 3 , 0 )));
390- cb (y_prev, " y_prev_mul" , il);
391- y_chunk = ggml_add (ctx, y_chunk, y_prev);
392- cb (y_chunk, " y_chunk_updated" , il);
393-
394- // step 11: recurse
395- y = ggml_concat (ctx, y, y_chunk, 2 );
396- cb (y, " y" , il);
397- ssm = next_state;
287+ ggml_tensor * C_perm = ggml_permute (ctx, C, 0 , 2 , 1 , 3 ); // {d_state, n_seq_tokens, n_group, n_seqs}
288+ ggml_tensor * B_perm = ggml_permute (ctx, B, 0 , 2 , 1 , 3 ); // {d_state, n_seq_tokens, n_group, n_seqs}
289+ ggml_tensor * CB = ggml_mul_mat (ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs}
290+ CB = ggml_repeat_4d (ctx, CB, CB->ne [0 ], CB->ne [1 ], CB->ne [2 ] * repeats, CB->ne [3 ]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs}
291+ cb (CB, " CB" , il);
292+
293+ // step 4: compute decay
294+ ggml_tensor * dtA_tmp0 = ggml_repeat_4d (ctx, dtA,
295+ dtA->ne [0 ], dtA->ne [1 ], dtA->ne [2 ], dtA->ne [3 ] * n_seq_tokens); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1}
296+ ggml_tensor * dtA_tmp1 = ggml_tri_dims (ctx, dtA_tmp0, nan (" " ), GGML_TRI_TYPE_LOWER, 3 , 1 ); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1}
297+ ggml_tensor * segsum = ggml_cumsum (ctx, dtA_tmp1, 1 ); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1}
298+ cb (segsum, " segsum" , il);
299+ ggml_tensor * decay = ggml_exp (ctx, segsum); // {n_head, n_seq_tokens_0, n_seqs, n_seq_tokens_1}
300+ decay = ggml_permute (ctx, decay, 2 , 1 , 3 , 0 ); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
301+ cb (decay, " decay" , il);
302+
303+ // step 5: compute surrogate_attention_matrix
304+ ggml_tensor * CBdecay = ggml_mul (ctx, CB, ggml_cont (ctx, decay));
305+ ggml_tensor * surrogate_attention_matrix = ggml_tri_keep (ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG);
306+ cb (surrogate_attention_matrix, " surrogate_attention_matrix" , il);
307+
308+ // step 6: compute y
309+ ggml_tensor * dtX_chunk_perm = ggml_cont (ctx, ggml_permute (ctx, dtX, 1 , 2 , 0 , 3 ));
310+ ggml_tensor * y = ggml_mul_mat (ctx, dtX_chunk_perm, surrogate_attention_matrix);
311+ y = ggml_cont (ctx, ggml_permute (ctx, y, 0 , 2 , 1 , 3 ));
312+ cb (y, " y" , il);
313+
314+ // step 7: compute dtxdecay
315+ ggml_tensor * decay_last = ggml_view_4d (ctx, decay,
316+ decay->ne [0 ], 1 , decay->ne [2 ], decay->ne [3 ],
317+ decay->nb [1 ], decay->nb [2 ], decay->nb [3 ],
318+ (decay->ne [1 ] - 1 ) * decay->nb [1 ]);
319+ decay_last = ggml_cont (ctx, ggml_permute (ctx, decay_last, 2 , 0 , 1 , 3 ));
320+ cb (decay_last, " decay_last" , il);
321+ B_perm = ggml_cont (ctx, B_perm);
322+ B_perm = ggml_repeat_4d (ctx, B_perm,
323+ B_perm->ne [0 ], B_perm->ne [1 ], B_perm->ne [2 ] * repeats, B_perm->ne [3 ]);
324+ ggml_tensor * dtxdecay = ggml_mul (ctx, dtX, decay_last);
325+ dtxdecay = ggml_cont (ctx, ggml_permute (ctx, dtxdecay, 1 , 2 , 0 , 3 ));
326+ cb (dtxdecay, " dtxdecay" , il);
327+
328+ // step 8: compute next_state
329+ ggml_tensor * next_state = ggml_mul_mat (ctx, ggml_cont (ctx, ggml_permute (ctx, B_perm, 1 , 0 , 2 , 3 )), dtxdecay);
330+ if (next_state->type != ssm->type ) {
331+ next_state = ggml_cast (ctx, next_state, ssm->type );
398332 }
333+ cb (next_state, " next_state" , il);
334+
335+ // TODO: Skip y and state updates if no previous state
336+
337+ // step 9: update from previous state
338+ ggml_tensor * exp_dtA_cumsum = ggml_exp (ctx, ggml_cumsum (ctx, dtA, 1 )); // {n_head, chunk_size_i, n_seqs}
339+ cb (exp_dtA_cumsum, " exp_dtA_cumsum" , il);
340+ ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d (ctx, exp_dtA_cumsum,
341+ exp_dtA_cumsum->ne [0 ], 1 , exp_dtA_cumsum->ne [2 ], exp_dtA_cumsum->ne [3 ],
342+ exp_dtA_cumsum->nb [1 ], exp_dtA_cumsum->nb [2 ], exp_dtA_cumsum->nb [3 ],
343+ (exp_dtA_cumsum->ne [1 ] - 1 ) * exp_dtA_cumsum->nb [1 ]); // {n_head, 1, n_seqs}
344+ cb (exp_dtA_cumsum_last, " exp_dtA_cumsum_last" , il);
345+ ggml_tensor * exp_dtA_cumsum_perm = ggml_permute (ctx, exp_dtA_cumsum_last, 2 , 1 , 3 , 0 ); // {1, 1, n_head, n_seqs}
346+ next_state = ggml_add (ctx, next_state, ggml_mul (ctx, ssm, ggml_cont (ctx, exp_dtA_cumsum_perm)));
347+ cb (next_state, " next_state_updated" , il);
348+
349+ // step 10: update from previous y
350+ ggml_tensor * y_prev = ggml_mul_mat (ctx, ggml_permute (ctx, C, 0 , 2 , 1 , 3 ), ssm);
351+ cb (y_prev, " y_prev" , il);
352+ y_prev = ggml_mul (ctx,
353+ ggml_cont (ctx, ggml_permute (ctx, y_prev, 2 , 0 , 1 , 3 )),
354+ ggml_cont (ctx, ggml_permute (ctx, exp_dtA_cumsum, 1 , 2 , 3 , 0 )));
355+ cb (y_prev, " y_prev_mul" , il);
356+ y = ggml_add (ctx, y, y_prev);
357+ cb (y, " y_updated" , il);
399358
400359 // Concat the output y and state
401360 if (ssm->type != y->type ) {
0 commit comments