@@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
408408 cl_program program_mul_mv_id_mxfp4_f32_flat;
409409 cl_program program_mul_mm_f32_f32_l4_lm;
410410 cl_program program_mul_mm_f16_f32_l4_lm;
411+ cl_program program_mul_mm_q8_0_f32_l4_lm;
411412
412413 cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
413414 cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
@@ -480,6 +481,7 @@ struct ggml_backend_opencl_context {
480481 cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
481482 cl_kernel kernel_mul_mm_f32_f32_l4_lm;
482483 cl_kernel kernel_mul_mm_f16_f32_l4_lm;
484+ cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
483485
484486 std::vector<ProfilingInfo> profiling_info;
485487
@@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11911193 GGML_LOG_CONT (" ." );
11921194 }
11931195
1196+ // mul_mm_q8_0_f32_l4_lm
1197+ {
1198+ #ifdef GGML_OPENCL_EMBED_KERNELS
1199+ const std::string kernel_src {
1200+ #include " mul_mm_q8_0_f32_l4_lm.cl.h"
1201+ };
1202+ #else
1203+ const std::string kernel_src = read_file (" mul_mm_q8_0_f32_l4_lm.cl" );
1204+ #endif
1205+ backend_ctx->program_mul_mm_q8_0_f32_l4_lm =
1206+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1207+
1208+ CL_CHECK ((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel (backend_ctx->program_mul_mm_q8_0_f32_l4_lm , " kernel_mul_mm_q8_0_f32_l4_lm" , &err), err));
1209+ GGML_LOG_CONT (" ." );
1210+ }
1211+
11941212 // mul
11951213 {
11961214#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -6961,6 +6979,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
69616979 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
69626980 return ;
69636981 }
6982+ case GGML_TYPE_Q8_0: {
6983+ if (ne11 < 32 ) {
6984+ break ;
6985+ }
6986+ kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm ;
6987+ nth0 = 128 ; // calculated as (BM*BN)/(TM*TN)
6988+
6989+ int batch_stride_a = ne00*ne01;
6990+ int batch_stride_b = ne10*ne11;
6991+ int batch_stride_d = ne0*ne1;
6992+
6993+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0_q8_0->q ));
6994+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra0_q8_0->d ));
6995+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
6996+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
6997+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
6998+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
6999+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
7000+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
7001+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
7002+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne11));
7003+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12));
7004+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne10)); // stride_a
7005+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne10)); // stride_b
7006+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne01)); // stride_d
7007+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &batch_stride_a));
7008+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &batch_stride_b));
7009+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &batch_stride_d));
7010+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &r2));
7011+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &r3));
7012+
7013+ // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
7014+ size_t global_work_size[] = {(size_t )(CEIL_DIV (ne01, 64 )*nth0), (size_t )(CEIL_DIV (ne11, 64 )), (size_t )ne12*ne13};
7015+ size_t local_work_size[] = {(size_t )nth0, 1 , 1 };
7016+
7017+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
7018+ return ;
7019+ }
69647020 default :
69657021 break ;
69667022 }
0 commit comments