Skip to content

Commit 0441ccb

Browse files
committed
fix: Subset input states to match ids
The code now runs cleanly for parallel requests Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 188ae84 commit 0441ccb

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/llama-model.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11876,6 +11876,17 @@ struct llm_graph_context_mamba : public llm_graph_context {
1187611876
// TODO: make this configurable
1187711877
const uint32_t chunk_size = 256;
1187811878

11879+
// extract the state(s) for the sequences identified by ids
11880+
if (ssm->ne[3] != ids->ne[0]) {
11881+
ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1
11882+
// ggml_tensor * ids_perm = ggml_permute(ctx, ids, 1, 2, 3, 0); // put the taget dim in dim 0
11883+
ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids,
11884+
ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape
11885+
ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows
11886+
ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape
11887+
GGML_ASSERT(ssm->ne[3] == ids->ne[0]);
11888+
}
11889+
1187911890
// step 1: compute dt softplus
1188011891
// NOTE: In other implementations, the bias is added after
1188111892
// the softplus. This shouldn't be a problem, but it's a
@@ -11944,7 +11955,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1194411955
cb(surrogate_attention_matrix, "surrogate_attention_matrix", il);
1194511956

1194611957
// step 6: compute y
11947-
ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3)
11958+
ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3));
1194811959
ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
1194911960
y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3));
1195011961
cb(y_chunk, "y_chunk", il);

0 commit comments

Comments
 (0)