@@ -11876,6 +11876,17 @@ struct llm_graph_context_mamba : public llm_graph_context {
1187611876 // TODO: make this configurable
1187711877 const uint32_t chunk_size = 256;
1187811878
11879+ // extract the state(s) for the sequences identified by ids
11880+ if (ssm->ne[3] != ids->ne[0]) {
11881+ ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1
11882+ // ggml_tensor * ids_perm = ggml_permute(ctx, ids, 1, 2, 3, 0); // put the taget dim in dim 0
11883+ ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids,
11884+ ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape
11885+ ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows
11886+ ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape
11887+ GGML_ASSERT(ssm->ne[3] == ids->ne[0]);
11888+ }
11889+
1187911890 // step 1: compute dt softplus
1188011891 // NOTE: In other implementations, the bias is added after
1188111892 // the softplus. This shouldn't be a problem, but it's a
@@ -11944,7 +11955,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1194411955 cb(surrogate_attention_matrix, "surrogate_attention_matrix", il);
1194511956
1194611957 // step 6: compute y
11947- ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3)
11958+ ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3));
1194811959 ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
1194911960 y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3));
1195011961 cb(y_chunk, "y_chunk", il);
0 commit comments