@@ -399,6 +399,7 @@ struct ggml_backend_opencl_context {
399399 cl_program program_conv_2d_f16_f32;
400400 cl_program program_tsembd;
401401 cl_program program_mul_mv_id_q4_0_f32_8x_flat;
402+ cl_program program_mul_mv_id_mxfp4_f32;
402403 cl_program program_mul_mm_f32_f32_l4_lm;
403404 cl_program program_mul_mm_f16_f32_l4_lm;
404405
@@ -457,6 +458,7 @@ struct ggml_backend_opencl_context {
457458 cl_kernel kernel_conv_2d_f16_f32;
458459 cl_kernel kernel_timestep_embedding;
459460 cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
461+ cl_kernel kernel_mul_mv_id_mxfp4_f32;
460462 cl_kernel kernel_mul_mm_f32_f32_l4_lm;
461463 cl_kernel kernel_mul_mm_f16_f32_l4_lm;
462464
@@ -1629,6 +1631,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
16291631 GGML_LOG_CONT (" ." );
16301632 }
16311633
1634+ // mul_mv_id_mxfp4_f32
1635+ {
1636+ #ifdef GGML_OPENCL_EMBED_KERNELS
1637+ const std::string kernel_src {
1638+ #include " mul_mv_id_mxfp4_f32.cl.h"
1639+ };
1640+ #else
1641+ const std::string kernel_src = read_file (" mul_mv_id_mxfp4_f32.cl" );
1642+ #endif
1643+ backend_ctx->program_mul_mv_id_mxfp4_f32 =
1644+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1645+
1646+ CL_CHECK ((backend_ctx->kernel_mul_mv_id_mxfp4_f32 = clCreateKernel (backend_ctx->program_mul_mv_id_mxfp4_f32 , " kernel_mul_mv_id_mxfp4_f32" , &err), err));
1647+ GGML_LOG_CONT (" ." );
1648+ }
1649+
16321650 // Adreno kernels
16331651#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
16341652 // transpose
@@ -2576,7 +2594,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
25762594 }
25772595 return false ;
25782596 case GGML_OP_MUL_MAT_ID:
2579- if (op->src [0 ]->type == GGML_TYPE_Q4_0) {
2597+ if (op->src [0 ]->type == GGML_TYPE_Q4_0 ||
2598+ op->src [0 ]->type == GGML_TYPE_MXFP4) {
25802599 if (op->src [1 ]->type == GGML_TYPE_F32) {
25812600 return ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]);
25822601 }
@@ -6361,10 +6380,12 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
63616380
63626381 ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
63636382
6383+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
63646384 ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
63656385 ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra ;
63666386 ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
63676387
6388+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
63686389 cl_ulong offset1 = extra1->offset + src1->view_offs ;
63696390 cl_ulong offset2 = extra2->offset + src2->view_offs ;
63706391 cl_ulong offsetd = extrad->offset + dst->view_offs ;
@@ -6379,7 +6400,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
63796400 const int ne03 = src0->ne [3 ];
63806401
63816402 const cl_ulong nb00 = src0->nb [0 ];
6403+ const cl_ulong nb01 = src0->nb [1 ];
63826404 const cl_ulong nb02 = src0->nb [2 ];
6405+ const cl_ulong nb03 = src0->nb [3 ];
63836406
63846407 const int ne10 = src1->ne [0 ];
63856408 const int ne11 = src1->ne [1 ];
@@ -6388,6 +6411,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
63886411
63896412 const cl_ulong nb11 = src1->nb [1 ];
63906413 const cl_ulong nb12 = src1->nb [2 ];
6414+ const cl_ulong nb13 = src1->nb [3 ];
63916415
63926416 const int ne20 = src2->ne [0 ];
63936417 const int ne21 = src2->ne [1 ];
@@ -6455,6 +6479,49 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
64556479
64566480 break ;
64576481 }
6482+ case GGML_TYPE_MXFP4: {
6483+ kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32 ;
6484+
6485+ if (backend_ctx->gpu_family == INTEL) {
6486+ sgs = 16 ;
6487+ nsg = 2 ;
6488+ ndst = 2 ;
6489+ } else if (backend_ctx->gpu_family == ADRENO) {
6490+ sgs = 64 ;
6491+ nsg = 2 ;
6492+ ndst = 2 ;
6493+ } else {
6494+ GGML_ASSERT (false && " TODO: Unknown GPU" );
6495+ }
6496+
6497+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
6498+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
6499+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
6500+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
6501+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extra2->data_device ));
6502+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
6503+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
6504+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
6505+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00));
6506+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb01));
6507+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb02));
6508+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb03));
6509+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne11));
6510+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne12));
6511+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb11));
6512+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb12));
6513+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb13));
6514+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne20));
6515+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &ne21));
6516+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb21));
6517+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (int ), &ne0));
6518+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (int ), &ne1));
6519+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &r2));
6520+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &r3));
6521+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (float )*sgs,nullptr ));
6522+
6523+ break ;
6524+ }
64586525 default :
64596526 GGML_ASSERT (false && " not implemented" );;
64606527 }
0 commit comments