@@ -1235,6 +1235,14 @@ kernel void kernel_scale_f32(
12351235 dst[tpig] = src0[tpig] * args.scale + args.bias ;
12361236}
12371237
1238+ kernel void kernel_scale_f16 (
1239+ constant ggml_metal_kargs_scale & args,
1240+ device const half * src0,
1241+ device half * dst,
1242+ uint tpig[[thread_position_in_grid]]) {
1243+ dst[tpig] = src0[tpig] * args.scale + args.bias ;
1244+ }
1245+
12381246kernel void kernel_scale_f32_4 (
12391247 constant ggml_metal_kargs_scale & args,
12401248 device const float4 * src0,
@@ -2207,8 +2215,9 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne
22072215template [[host_name(" kernel_soft_max_f16_4" )]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
22082216template [[host_name(" kernel_soft_max_f32_4" )]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
22092217
2210- // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
2211- kernel void kernel_ssm_conv_f32_f32 (
2218+ // ref: ggml.c:ggml_compute_forward_ssm_conv_impl
2219+ template <typename src_t , typename conv_t >
2220+ kernel void kernel_ssm_conv_impl (
22122221 constant ggml_metal_kargs_ssm_conv & args,
22132222 device const void * src0,
22142223 device const void * src1,
@@ -2226,14 +2235,14 @@ kernel void kernel_ssm_conv_f32_f32(
22262235 // const int64_t n_t = args.ne1;
22272236 // const int64_t n_s = args.ne2;
22282237
2229- device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02 );
2230- device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11 );
2231- device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2 );
2238+ device const src_t * s = (device const src_t *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02 );
2239+ device const conv_t * c = (device const conv_t *) ((device const char *) src1 + ir*args.nb11 );
2240+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2 );
22322241
22332242 float sumf = 0 .0f ;
22342243
22352244 for (int64_t i0 = 0 ; i0 < nc; ++i0) {
2236- sumf += s[i0] * c[i0];
2245+ sumf += static_cast < float >( s[i0]) * static_cast < float >( c[i0]) ;
22372246 }
22382247
22392248 x[0 ] = sumf;
@@ -2270,6 +2279,13 @@ kernel void kernel_ssm_conv_f32_f32_4(
22702279 x[0 ] = sumf;
22712280}
22722281
2282+ typedef decltype (kernel_ssm_conv_impl<float , float >) kernel_ssm_conv_t;
2283+ template [[host_name(" kernel_ssm_conv_f32_f32" )]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl<float , float >;
2284+ template [[host_name(" kernel_ssm_conv_f32_f16" )]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl<float , half>;
2285+ #if defined(GGML_METAL_HAS_BF16)
2286+ template [[host_name(" kernel_ssm_conv_f32_bf16" )]] kernel kernel_ssm_conv_t kernel_ssm_conv_impl<float , bfloat>;
2287+ #endif
2288+
22732289// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
22742290kernel void kernel_ssm_scan_f32 (
22752291 constant ggml_metal_kargs_ssm_scan & args,
0 commit comments