@@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
408
408
cl_program program_mul_mv_id_mxfp4_f32_flat;
409
409
cl_program program_mul_mm_f32_f32_l4_lm;
410
410
cl_program program_mul_mm_f16_f32_l4_lm;
411
+ cl_program program_mul_mm_q8_0_f32_l4_lm;
411
412
412
413
cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
413
414
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
@@ -480,6 +481,7 @@ struct ggml_backend_opencl_context {
480
481
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
481
482
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
482
483
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
484
+ cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
483
485
484
486
std::vector<ProfilingInfo> profiling_info;
485
487
@@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1191
1193
GGML_LOG_CONT (" ." );
1192
1194
}
1193
1195
1196
+ // mul_mm_q8_0_f32_l4_lm
1197
+ {
1198
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1199
+ const std::string kernel_src {
1200
+ #include " mul_mm_q8_0_f32_l4_lm.cl.h"
1201
+ };
1202
+ #else
1203
+ const std::string kernel_src = read_file (" mul_mm_q8_0_f32_l4_lm.cl" );
1204
+ #endif
1205
+ backend_ctx->program_mul_mm_q8_0_f32_l4_lm =
1206
+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1207
+
1208
+ CL_CHECK ((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel (backend_ctx->program_mul_mm_q8_0_f32_l4_lm , " kernel_mul_mm_q8_0_f32_l4_lm" , &err), err));
1209
+ GGML_LOG_CONT (" ." );
1210
+ }
1211
+
1194
1212
// mul
1195
1213
{
1196
1214
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -6956,6 +6974,41 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
6956
6974
backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
6957
6975
return ;
6958
6976
}
6977
+ case GGML_TYPE_Q8_0: {
6978
+ kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm ;
6979
+ nth0 = 128 ; // calculated as (BM*BN)/(TM*TN)
6980
+
6981
+ int batch_stride_a = ne00*ne01;
6982
+ int batch_stride_b = ne10*ne11;
6983
+ int batch_stride_d = ne0*ne1;
6984
+
6985
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0_q8_0->q ));
6986
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra0_q8_0->d ));
6987
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
6988
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
6989
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
6990
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
6991
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
6992
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
6993
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
6994
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne11));
6995
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12));
6996
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne10)); // stride_a
6997
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne10)); // stride_b
6998
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne01)); // stride_d
6999
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &batch_stride_a));
7000
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &batch_stride_b));
7001
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &batch_stride_d));
7002
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &r2));
7003
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &r3));
7004
+
7005
+ // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
7006
+ size_t global_work_size[] = {(size_t )(CEIL_DIV (ne01, 64 )*nth0), (size_t )(CEIL_DIV (ne11, 64 )), (size_t )ne12*ne13};
7007
+ size_t local_work_size[] = {(size_t )nth0, 1 , 1 };
7008
+
7009
+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
7010
+ return ;
7011
+ }
6959
7012
default :
6960
7013
break ;
6961
7014
}
0 commit comments