Skip to content

Commit 086a63e

Browse files
metal: SSM kernel improvements (#17876)
* feat: Add a batched version of ssm_conv This was done using Claude Code. It found a number of optimizations around how the threads were organized, resulting in a huge performance boost! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]> * feat: Optimized SSM_SCAN kernel for metal This used Claude Code and resulted in a modest performance improvement while maintaining correctness. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]> * test: Add test-backend-ops perf tests for SSM_CONV Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <[email protected]> * test: Real representitive tests for SSM_CONV Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <[email protected]> * refactor: Use function constant for ssm_conv batch size Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <[email protected]> * test: backend op tests for ssm_scan from granite4 1b-h Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <[email protected]> * style: remove commented out templates Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <[email protected]> * feat: float4 version of ssm_conv_batched Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <[email protected]> * fix: Add missing ggml_metal_cv_free Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent b635092 commit 086a63e

File tree

6 files changed

+209
-17
lines changed

6 files changed

+209
-17
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,38 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me
411411
return res;
412412
}
413413

414+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
415+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
416+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
417+
418+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
419+
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
420+
421+
char base[256];
422+
char name[256];
423+
424+
const char * suffix = "";
425+
if (op->src[1]->ne[0] % 4 == 0) {
426+
suffix = "_4";
427+
}
428+
429+
snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
430+
snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
431+
432+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
433+
if (!res.pipeline) {
434+
ggml_metal_cv_t cv = ggml_metal_cv_init();
435+
436+
ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
437+
438+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
439+
440+
ggml_metal_cv_free(cv);
441+
}
442+
443+
return res;
444+
}
445+
414446
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
415447
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
416448

@@ -427,7 +459,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me
427459
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
428460
}
429461

430-
res.smem = 32*sizeof(float)*nsg;
462+
// Shared memory layout:
463+
// - sgptg * NW floats for partial sums (nsg * 32)
464+
// - sgptg floats for shared_x_dt (nsg)
465+
// - sgptg floats for shared_dA (nsg)
466+
// Total: nsg * (32 + 2) floats
467+
res.smem = (32 + 2)*sizeof(float)*nsg;
431468

432469
return res;
433470
}

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad
117117
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
118118
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
119119
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
120121
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
121122
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
122123
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
#define FC_MUL_MV 600
7878
#define FC_MUL_MM 700
7979
#define FC_ROPE 800
80+
#define FC_SSM_CONV 900
8081

8182
// op-specific constants
8283
#define OP_FLASH_ATTN_EXT_NQPTG 8

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,15 +1365,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
13651365
/*.nb2 =*/ nb2,
13661366
};
13671367

1368-
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1368+
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
1369+
const bool use_batched = (ne1 > 1);
1370+
1371+
if (use_batched) {
1372+
// Determine the smallest power of 2 that's >= ne1, but <= 256
1373+
int BATCH_SIZE;
1374+
if (ne1 > 128) BATCH_SIZE = 256;
1375+
else if (ne1 > 64 ) BATCH_SIZE = 128;
1376+
else if (ne1 > 32 ) BATCH_SIZE = 64;
1377+
else if (ne1 > 16 ) BATCH_SIZE = 32;
1378+
else if (ne1 > 8 ) BATCH_SIZE = 16;
1379+
else if (ne1 > 4 ) BATCH_SIZE = 8;
1380+
else BATCH_SIZE = 2;
1381+
1382+
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
13691383

1370-
ggml_metal_encoder_set_pipeline(enc, pipeline);
1371-
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1372-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1373-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1374-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1384+
ggml_metal_encoder_set_pipeline(enc, pipeline);
1385+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1386+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1387+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1388+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
13751389

1376-
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1390+
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
1391+
// Each threadgroup has BATCH_SIZE threads, each handling one token
1392+
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
1393+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
1394+
} else {
1395+
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1396+
1397+
ggml_metal_encoder_set_pipeline(enc, pipeline);
1398+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1399+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1400+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1401+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1402+
1403+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1404+
}
13771405

13781406
return 1;
13791407
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 127 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,7 +2343,102 @@ kernel void kernel_ssm_conv_f32_f32_4(
23432343
x[0] = sumf;
23442344
}
23452345

2346+
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
2347+
2348+
// Batched version: each threadgroup processes multiple tokens for better efficiency
2349+
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
2350+
kernel void kernel_ssm_conv_f32_f32_batched(
2351+
constant ggml_metal_kargs_ssm_conv & args,
2352+
device const void * src0,
2353+
device const void * src1,
2354+
device float * dst,
2355+
uint3 tgpig[[threadgroup_position_in_grid]],
2356+
uint3 tpitg[[thread_position_in_threadgroup]],
2357+
uint3 ntg[[threads_per_threadgroup]]) {
2358+
// tgpig.x = row index (ir)
2359+
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
2360+
// tgpig.z = sequence index (i3)
2361+
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
2362+
const short BATCH_SIZE = FC_ssm_conv_bs;
2363+
2364+
const int64_t ir = tgpig.x;
2365+
const int64_t i2_base = tgpig.y * BATCH_SIZE;
2366+
const int64_t i3 = tgpig.z;
2367+
const int64_t i2_off = tpitg.x;
2368+
const int64_t i2 = i2_base + i2_off;
2369+
2370+
const int64_t nc = args.ne10; // conv kernel size (typically 4)
2371+
const int64_t n_t = args.ne1; // number of tokens
2372+
2373+
// Bounds check for partial batches at the end
2374+
if (i2 >= n_t) {
2375+
return;
2376+
}
2377+
2378+
// Load conv weights (shared across all tokens for this row)
2379+
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
2380+
2381+
// Load source for this specific token
2382+
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2383+
2384+
// Output location for this token
2385+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2386+
2387+
float sumf = 0.0f;
2388+
for (int64_t i0 = 0; i0 < nc; ++i0) {
2389+
sumf += s[i0] * c[i0];
2390+
}
2391+
2392+
x[0] = sumf;
2393+
}
2394+
2395+
kernel void kernel_ssm_conv_f32_f32_batched_4(
2396+
constant ggml_metal_kargs_ssm_conv & args,
2397+
device const void * src0,
2398+
device const void * src1,
2399+
device float * dst,
2400+
uint3 tgpig[[threadgroup_position_in_grid]],
2401+
uint3 tpitg[[thread_position_in_threadgroup]],
2402+
uint3 ntg[[threads_per_threadgroup]]) {
2403+
// tgpig.x = row index (ir)
2404+
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
2405+
// tgpig.z = sequence index (i3)
2406+
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
2407+
const short BATCH_SIZE = FC_ssm_conv_bs;
2408+
2409+
const int64_t ir = tgpig.x;
2410+
const int64_t i2_base = tgpig.y * BATCH_SIZE;
2411+
const int64_t i3 = tgpig.z;
2412+
const int64_t i2_off = tpitg.x;
2413+
const int64_t i2 = i2_base + i2_off;
2414+
2415+
const int64_t nc = args.ne10; // conv kernel size (typically 4)
2416+
const int64_t n_t = args.ne1; // number of tokens
2417+
2418+
// Bounds check for partial batches at the end
2419+
if (i2 >= n_t) {
2420+
return;
2421+
}
2422+
2423+
// Load conv weights (shared across all tokens for this row)
2424+
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
2425+
2426+
// Load source for this specific token
2427+
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2428+
2429+
// Output location for this token
2430+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2431+
2432+
float sumf = 0.0f;
2433+
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
2434+
sumf += dot(s[i0], c[i0]);
2435+
}
2436+
2437+
x[0] = sumf;
2438+
}
2439+
23462440
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
2441+
// Optimized version: reduces redundant memory loads by having one thread load shared values
23472442
kernel void kernel_ssm_scan_f32(
23482443
constant ggml_metal_kargs_ssm_scan & args,
23492444
device const void * src0,
@@ -2363,7 +2458,15 @@ kernel void kernel_ssm_scan_f32(
23632458
uint3 tgpg[[threadgroups_per_grid]]) {
23642459
constexpr short NW = N_SIMDWIDTH;
23652460

2366-
shared[tpitg.x] = 0.0f;
2461+
// Shared memory layout:
2462+
// [0..sgptg*NW-1]: partial sums for reduction (existing)
2463+
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
2464+
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
2465+
threadgroup float * shared_sums = shared;
2466+
threadgroup float * shared_x_dt = shared + sgptg * NW;
2467+
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
2468+
2469+
shared_sums[tpitg.x] = 0.0f;
23672470

23682471
const int32_t i0 = tpitg.x;
23692472
const int32_t i1 = tgpig.x;
@@ -2403,32 +2506,47 @@ kernel void kernel_ssm_scan_f32(
24032506
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
24042507
threadgroup_barrier(mem_flags::mem_threadgroup);
24052508

2406-
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2407-
const float dt0 = dt[0];
2509+
// Pre-compute x_dt and dA for this batch of tokens
2510+
// Only first sgptg threads do the loads and expensive math
2511+
if (i0 < sgptg && i2 + i0 < n_t) {
2512+
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
2513+
device const float * x_t = x + i0 * args.ns12;
2514+
device const float * dt_t = dt + i0 * args.ns21;
2515+
2516+
const float dt0 = dt_t[0];
24082517
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
2409-
const float x_dt = x[0] * dtsp;
2410-
const float dA = exp(dtsp * A0);
2518+
shared_x_dt[i0] = x_t[0] * dtsp;
2519+
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
2520+
}
2521+
2522+
threadgroup_barrier(mem_flags::mem_threadgroup);
2523+
2524+
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2525+
const float x_dt = shared_x_dt[t];
2526+
const float dA = exp(shared_dA[t] * A0);
24112527

24122528
s = (s0 * dA) + (B[i0] * x_dt);
24132529

24142530
const float sumf = simd_sum(s * C[i0]);
24152531

24162532
if (tiisg == 0) {
2417-
shared[t*NW + sgitg] = sumf;
2533+
shared_sums[t*NW + sgitg] = sumf;
24182534
}
24192535

24202536
// recurse
24212537
s0 = s;
24222538

2423-
x += args.ns12;
2424-
dt += args.ns21;
24252539
B += args.ns42;
24262540
C += args.ns52;
24272541
}
24282542

2543+
// Advance pointers for next batch
2544+
x += sgptg * args.ns12;
2545+
dt += sgptg * args.ns21;
2546+
24292547
threadgroup_barrier(mem_flags::mem_threadgroup);
24302548

2431-
const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
2549+
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
24322550

24332551
if (tiisg == 0 && i2 + sgitg < n_t) {
24342552
y[sgitg*nh*nr] = sumf;

tests/test-backend-ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8193,6 +8193,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
81938193
}
81948194
}
81958195

8196+
// Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
8197+
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
8198+
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
8199+
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
8200+
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
8201+
8202+
81968203
return test_cases;
81978204
}
81988205

0 commit comments

Comments
 (0)