@@ -365,6 +365,7 @@ struct ggml_backend_opencl_context {
365365 cl_program program_mul_mv_q4_0_f32_1d_8x_flat;
366366 cl_program program_mul_mv_q4_0_f32_1d_16x_flat;
367367 cl_program program_mul_mv_q6_K;
368+ cl_program program_mul_mv_mxfp4_f32;
368369 cl_program program_mul_mv_f16_f16;
369370 cl_program program_mul_mv_f16_f32_1row;
370371 cl_program program_mul_mv_f16_f32_l4;
@@ -439,6 +440,7 @@ struct ggml_backend_opencl_context {
439440 cl_kernel kernel_convert_block_q4_0_noshuffle;
440441 cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
441442 cl_kernel kernel_mul_mv_q6_K_f32;
443+ cl_kernel kernel_mul_mv_mxfp4_f32;
442444 cl_kernel kernel_im2col_f32, kernel_im2col_f16;
443445 cl_kernel kernel_argsort_f32_i32;
444446 cl_kernel kernel_sum_rows_f32;
@@ -971,6 +973,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
971973 GGML_LOG_CONT (" ." );
972974 }
973975
976+ // mul_mv_mxfp4_f32
977+ {
978+ #ifdef GGML_OPENCL_EMBED_KERNELS
979+ const std::string kernel_src {
980+ #include " mul_mv_mxfp4_f32.cl.h"
981+ };
982+ #else
983+ const std::string kernel_src = read_file (" mul_mv_mxfp4_f32.cl" );
984+ #endif
985+ backend_ctx->program_mul_mv_mxfp4_f32 =
986+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
987+
988+ CL_CHECK ((backend_ctx->kernel_mul_mv_mxfp4_f32 = clCreateKernel (backend_ctx->program_mul_mv_mxfp4_f32 , " kernel_mul_mv_mxfp4_f32" , &err), err));
989+ GGML_LOG_CONT (" ." );
990+ }
991+
974992 // mul_mv_f16_f16
975993 {
976994#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2552,7 +2570,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
25522570 return true ;
25532571 } else if (op->src [0 ]->type == GGML_TYPE_F32) {
25542572 return op->src [1 ]->type == GGML_TYPE_F32;
2555- } else if (op->src [0 ]->type == GGML_TYPE_Q4_0 ||
2573+ } else if (op->src [0 ]->type == GGML_TYPE_Q4_0 || op-> src [ 0 ]-> type == GGML_TYPE_MXFP4 ||
25562574 op->src [0 ]->type == GGML_TYPE_Q6_K) {
25572575 return op->src [1 ]->type == GGML_TYPE_F32 && ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]);
25582576 }
@@ -6254,11 +6272,47 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
62546272 CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &r2));
62556273 CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &r3));
62566274 break ;
6275+ case GGML_TYPE_MXFP4: {
6276+ kernel = backend_ctx->kernel_mul_mv_mxfp4_f32 ;
6277+
6278+ if (backend_ctx->gpu_family == INTEL) {
6279+ nth0 = 16 ;
6280+ nth1 = 2 ;
6281+ ndst = nth1*2 ;
6282+ } else if (backend_ctx->gpu_family == ADRENO) {
6283+ nth0 = 64 ;
6284+ nth1 = 2 ;
6285+ ndst = nth1*2 ;
6286+ } else {
6287+ GGML_ASSERT (false && " TODO: Unknown GPU" );
6288+ }
6289+
6290+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
6291+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
6292+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
6293+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
6294+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
6295+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
6296+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
6297+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
6298+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
6299+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
6300+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12));
6301+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb11));
6302+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb12));
6303+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb13));
6304+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne0));
6305+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne1));
6306+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &r2));
6307+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &r3));
6308+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (float )*nth0,nullptr ));
6309+ break ;
6310+ }
62576311 default :
62586312 GGML_ASSERT (false && " not implemented" );
62596313 }
62606314
6261- if (src0t == GGML_TYPE_Q4_0 ||
6315+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||
62626316 src0t == GGML_TYPE_Q4_1 ||
62636317 src0t == GGML_TYPE_Q8_0 ||
62646318 src0t == GGML_TYPE_Q2_K) {
0 commit comments