@@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
408
408
cl_program program_mul_mv_id_mxfp4_f32_flat;
409
409
cl_program program_mul_mm_f32_f32_l4_lm;
410
410
cl_program program_mul_mm_f16_f32_l4_lm;
411
+ cl_program program_mul_mm_q4_0_f32_l4_lm;
411
412
cl_program program_mul_mm_q8_0_f32_l4_lm;
412
413
413
414
cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
@@ -481,6 +482,7 @@ struct ggml_backend_opencl_context {
481
482
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
482
483
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
483
484
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
485
+ cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;
484
486
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
485
487
486
488
std::vector<ProfilingInfo> profiling_info;
@@ -1193,6 +1195,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1193
1195
GGML_LOG_CONT (" ." );
1194
1196
}
1195
1197
1198
+ // mul_mm_q4_0_f32_l4_lm
1199
+ {
1200
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1201
+ const std::string kernel_src {
1202
+ #include " mul_mm_q4_0_f32_l4_lm.cl.h"
1203
+ };
1204
+ #else
1205
+ const std::string kernel_src = read_file (" mul_mm_q4_0_f32_l4_lm.cl" );
1206
+ #endif
1207
+ backend_ctx->program_mul_mm_q4_0_f32_l4_lm =
1208
+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1209
+
1210
+ CL_CHECK ((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel (backend_ctx->program_mul_mm_q4_0_f32_l4_lm , " kernel_mul_mm_q4_0_f32_l4_lm" , &err), err));
1211
+ GGML_LOG_CONT (" ." );
1212
+ }
1213
+
1196
1214
// mul_mm_q8_0_f32_l4_lm
1197
1215
{
1198
1216
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -6974,6 +6992,41 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
6974
6992
backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
6975
6993
return ;
6976
6994
}
6995
+ case GGML_TYPE_Q4_0: {
6996
+ kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm ;
6997
+ nth0 = 128 ; // calculated as (BM*BN)/(TM*TN)
6998
+
6999
+ int batch_stride_a = ne00*ne01;
7000
+ int batch_stride_b = ne10*ne11;
7001
+ int batch_stride_d = ne0*ne1;
7002
+
7003
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0_q4_0->q ));
7004
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra0_q4_0->d ));
7005
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
7006
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
7007
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
7008
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
7009
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
7010
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
7011
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
7012
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne11));
7013
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12));
7014
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne10)); // stride_a
7015
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne10)); // stride_b
7016
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne01)); // stride_d
7017
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &batch_stride_a));
7018
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &batch_stride_b));
7019
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &batch_stride_d));
7020
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &r2));
7021
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &r3));
7022
+
7023
+ // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
7024
+ size_t global_work_size[] = {(size_t )(CEIL_DIV (ne01, 64 )*nth0), (size_t )(CEIL_DIV (ne11, 64 )), (size_t )ne12*ne13};
7025
+ size_t local_work_size[] = {(size_t )nth0, 1 , 1 };
7026
+
7027
+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
7028
+ return ;
7029
+ }
6977
7030
case GGML_TYPE_Q8_0: {
6978
7031
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm ;
6979
7032
nth0 = 128 ; // calculated as (BM*BN)/(TM*TN)
0 commit comments