Skip to content

Commit d2d120c

Browse files
compiladeggerganov
authored andcommitted
llama : initial Mamba-2 support (llama/9126)
* llama : initial Mamba-2 support * ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states * llama : support running Mamba-Codestral-7B-v0.1 * llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted * llama : remove unused variable * llama : add missing break * convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. * llama : avoid redundant state copy for Mamba 1 and 2 * metal : attempt to adapt SSM_SCAN for Mamba-2 * metal : fix SSM_SCAN pipeline scope * metal : use log and exp instead of log1pf and expf in SSM_SCAN * metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. * metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. * metal : fix SSM_SCAN state head offset * metal : fix wrong number of tokens per sequence in SSM_SCAN * ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. * ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks * convert : fix flake8 lint * metal : fix confusion between ; and , * metal : add missing args for nb references in ssm_scan_f32_group * metal : single-user mamba2 inference works * kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. * convert : avoid AutoConfig for Mamba and Mamba2 hparams * kv-cache : allow context shift for recurrent models * graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. * ggml : fix mamba2 ssm scan when compiled with SVE * ggml-cpu : reorder SVE FMA for consistency with other SIMD arches * cuda : implement ssm scan for Mamba2 There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2 * mamba : fix mismatched new and delete size for llm_build_mamba Subclasses of llm_graph_context cannot have extra fields, because the called destructor is not the one from the subclass. This otherwise would cause problems when runnning Mamba-(1|2) inference when compiled -DGGML_SANITIZE_ADDRESS=ON * cuda : graceful fallback for Mamba-1 models with weird embd size
1 parent fb5c409 commit d2d120c

File tree

11 files changed

+593
-246
lines changed

11 files changed

+593
-246
lines changed

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2028,7 +2028,8 @@ extern "C" {
20282028
struct ggml_tensor * dt,
20292029
struct ggml_tensor * A,
20302030
struct ggml_tensor * B,
2031-
struct ggml_tensor * C);
2031+
struct ggml_tensor * C,
2032+
struct ggml_tensor * ids);
20322033

20332034
// partition into non-overlapping windows with padding if needed
20342035
// example:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 184 additions & 94 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cpu/simd-mappings.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
189189
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
190190
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
191191
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
192-
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
192+
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
193193
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
194194
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
195195
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)

ggml/src/ggml-cpu/vec.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,43 +37,43 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
3737
for (int i = 0; i < np; i += ggml_f32_step) {
3838
ax1 = GGML_F32_VEC_LOAD(x + i);
3939
ay1 = GGML_F32_VEC_LOAD(y + i);
40-
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
40+
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
4141

4242
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
4343
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
44-
sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
44+
sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
4545

4646
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
4747
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
48-
sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
48+
sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
4949

5050
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
5151
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
52-
sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
52+
sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
5353

5454
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
5555
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
56-
sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
56+
sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
5757

5858
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
5959
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
60-
sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
60+
sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
6161

6262
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
6363
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
64-
sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
64+
sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
6565

6666
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
6767
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
68-
sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
68+
sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
6969
}
7070
// leftovers
7171
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
7272
const int np2 = (n & ~(ggml_f32_epr - 1));
7373
for (int i = np; i < np2; i += ggml_f32_epr) {
7474
ax1 = GGML_F32_VEC_LOAD(x + i);
7575
ay1 = GGML_F32_VEC_LOAD(y + i);
76-
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
76+
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
7777
}
7878
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
7979
if (np2 < n) {

ggml/src/ggml-cpu/vec.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
163163

164164
ax1 = GGML_F32_VEC_LOAD(x + i);
165165
ay1 = GGML_F32_VEC_LOAD(y + i);
166-
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
166+
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
167167

168168
GGML_F32_VEC_STORE(y + i, ay1);
169169

170170
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
171171
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
172-
ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
172+
ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
173173

174174
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
175175

176176
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
177177
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
178-
ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
178+
ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
179179

180180
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
181181

182182
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
183183
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
184-
ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
184+
ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
185185

186186
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
187187

188188
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
189189
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
190-
ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
190+
ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
191191

192192
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
193193

194194
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
195195
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
196-
ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
196+
ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
197197

198198
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
199199

200200
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
201201
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
202-
ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
202+
ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
203203

204204
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
205205

206206
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
207207
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
208-
ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
208+
ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
209209

210210
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
211211
}
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
215215
for (int i = np; i < np2; i += ggml_f32_epr) {
216216
ax1 = GGML_F32_VEC_LOAD(x + i);
217217
ay1 = GGML_F32_VEC_LOAD(y + i);
218-
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
218+
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
219219

220220
GGML_F32_VEC_STORE(y + i, ay1);
221221
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3321,9 +3321,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33213321
case GGML_OP_COS:
33223322
case GGML_OP_CLAMP:
33233323
case GGML_OP_LOG:
3324-
case GGML_OP_SSM_SCAN:
3325-
case GGML_OP_SSM_CONV:
33263324
return true;
3325+
case GGML_OP_SSM_SCAN: {
3326+
if (op->src[3]->ne[0] == 1) {
3327+
// Mamba2
3328+
// (kernel only supports d_state == 128 && d_head % 16 == 0)
3329+
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
3330+
} else {
3331+
// Mamba
3332+
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
3333+
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
3334+
}
3335+
}
3336+
case GGML_OP_SSM_CONV: {
3337+
// assumes d_inner % threads == 0
3338+
return op->src[0]->ne[1] % 128 == 0;
3339+
}
33273340
case GGML_OP_CONT:
33283341
return op->src[0]->type != GGML_TYPE_BF16;
33293342
case GGML_OP_DIAG_MASK_INF:

0 commit comments

Comments
 (0)