Skip to content

Commit de43d0b

Browse files
committed
feat(ggml-metal): Add support for F16 and BF16 ssm_conv weights
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 86788a2 commit de43d0b

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
218218
};
219219

220220
const char * suffix = "";
221-
if (n % 4 == 0) {
221+
if (n % 4 == 0 && op->type == GGML_TYPE_F32) {
222222
suffix = "_4";
223223
}
224224

@@ -394,7 +394,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar
394394

395395
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
396396
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
397-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
398397

399398
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
400399
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
@@ -404,7 +403,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
404403

405404
const char * suffix = "";
406405

407-
if (op->src[1]->ne[0] % 4 == 0) {
406+
if (op->src[1]->ne[0] % 4 == 0 && op->src[1]->type == GGML_TYPE_F32) {
408407
suffix = "_4";
409408
}
410409

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,14 @@ kernel void kernel_scale_f32(
12351235
dst[tpig] = src0[tpig] * args.scale + args.bias;
12361236
}
12371237

1238+
kernel void kernel_scale_f16(
1239+
constant ggml_metal_kargs_scale & args,
1240+
device const half * src0,
1241+
device half * dst,
1242+
uint tpig[[thread_position_in_grid]]) {
1243+
dst[tpig] = src0[tpig] * args.scale + args.bias;
1244+
}
1245+
12381246
kernel void kernel_scale_f32_4(
12391247
constant ggml_metal_kargs_scale & args,
12401248
device const float4 * src0,
@@ -2207,8 +2215,9 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne
22072215
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
22082216
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
22092217

2210-
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
2211-
kernel void kernel_ssm_conv_f32_f32(
2218+
// ref: ggml.c:ggml_compute_forward_ssm_conv_impl
2219+
template<typename src_t, typename conv_t>
2220+
kernel void kernel_ssm_conv_impl(
22122221
constant ggml_metal_kargs_ssm_conv & args,
22132222
device const void * src0,
22142223
device const void * src1,
@@ -2226,14 +2235,14 @@ kernel void kernel_ssm_conv_f32_f32(
22262235
//const int64_t n_t = args.ne1;
22272236
//const int64_t n_s = args.ne2;
22282237

2229-
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2230-
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
2231-
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2238+
device const src_t * s = (device const src_t *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2239+
device const conv_t * c = (device const conv_t *) ((device const char *) src1 + ir*args.nb11);
2240+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
22322241

22332242
float sumf = 0.0f;
22342243

22352244
for (int64_t i0 = 0; i0 < nc; ++i0) {
2236-
sumf += s[i0] * c[i0];
2245+
sumf += static_cast<float>(s[i0]) * static_cast<float>(c[i0]);
22372246
}
22382247

22392248
x[0] = sumf;
@@ -2270,6 +2279,13 @@ kernel void kernel_ssm_conv_f32_f32_4(
22702279
x[0] = sumf;
22712280
}
22722281

2282+
typedef decltype(kernel_ssm_conv_impl<float, float>) kernel_ssm_conv_t;
2283+
template [[host_name("kernel_ssm_conv_f32_f32")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl<float, float>;
2284+
template [[host_name("kernel_ssm_conv_f32_f16")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl<float, half>;
2285+
#if defined(GGML_METAL_HAS_BF16)
2286+
template [[host_name("kernel_ssm_conv_f32_bf16")]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl<float, bfloat>;
2287+
#endif
2288+
22732289
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
22742290
kernel void kernel_ssm_scan_f32(
22752291
constant ggml_metal_kargs_ssm_scan & args,

0 commit comments

Comments
 (0)