@@ -402,6 +402,7 @@ struct ggml_backend_opencl_context {
402402 cl_program program_conv_2d_f32;
403403 cl_program program_conv_2d_f16_f32;
404404 cl_program program_tsembd;
405+ cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
405406 cl_program program_mul_mv_id_q4_0_f32_8x_flat;
406407 cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
407408 cl_program program_mul_mv_id_mxfp4_f32;
@@ -452,7 +453,7 @@ struct ggml_backend_opencl_context {
452453 cl_kernel kernel_mul_mat_f16_f32_tiled;
453454 cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
454455 cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
455- cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
456+ cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans ;
456457 cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
457458 cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
458459 cl_kernel kernel_convert_block_q4_0_noshuffle;
@@ -475,6 +476,7 @@ struct ggml_backend_opencl_context {
475476 cl_kernel kernel_conv_2d_f32;
476477 cl_kernel kernel_conv_2d_f16_f32;
477478 cl_kernel kernel_timestep_embedding;
479+ cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
478480 cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
479481 cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
480482 cl_kernel kernel_mul_mv_id_mxfp4_f32;
@@ -559,14 +561,14 @@ struct ggml_backend_opencl_context {
559561
560562 fprintf (ftrace, " [\n " );
561563 for (const ProfilingInfo & info : profiling_info) {
562- fprintf (ftrace, " {\" name\" : \" %s\" , \" cat\" : \" OpenCL\" , \" ph\" : \" B\" , \" ts\" : %lu , \" pid\" : \"\" , \" tid\" : \" Host\" },\n " ,
564+ fprintf (ftrace, " {\" name\" : \" %s\" , \" cat\" : \" OpenCL\" , \" ph\" : \" B\" , \" ts\" : %llu , \" pid\" : \"\" , \" tid\" : \" Host\" },\n " ,
563565 info.kernel_name .c_str (), info.cmd_queued /1000 );
564- fprintf (ftrace, " {\" name\" : \" %s\" , \" cat\" : \" OpenCL\" , \" ph\" : \" E\" , \" ts\" : %lu , \" pid\" : \"\" , \" tid\" : \" Host\" },\n " ,
566+ fprintf (ftrace, " {\" name\" : \" %s\" , \" cat\" : \" OpenCL\" , \" ph\" : \" E\" , \" ts\" : %llu , \" pid\" : \"\" , \" tid\" : \" Host\" },\n " ,
565567 info.kernel_name .c_str (), info.cmd_submit /1000 );
566568
567- fprintf (ftrace, " {\" name\" : \" %s\" , \" cat\" : \" OpenCL\" , \" ph\" : \" B\" , \" ts\" : %lu , \" pid\" : \"\" , \" tid\" : \" Device\" },\n " ,
569+ fprintf (ftrace, " {\" name\" : \" %s\" , \" cat\" : \" OpenCL\" , \" ph\" : \" B\" , \" ts\" : %llu , \" pid\" : \"\" , \" tid\" : \" Device\" },\n " ,
568570 info.kernel_name .c_str (), info.cmd_start /1000 );
569- fprintf (ftrace, " {\" name\" : \" %s\" , \" cat\" : \" OpenCL\" , \" ph\" : \" E\" , \" ts\" : %lu , \" pid\" : \"\" , \" tid\" : \" Device\" },\n " ,
571+ fprintf (ftrace, " {\" name\" : \" %s\" , \" cat\" : \" OpenCL\" , \" ph\" : \" E\" , \" ts\" : %llu , \" pid\" : \"\" , \" tid\" : \" Device\" },\n " ,
570572 info.kernel_name .c_str (), info.cmd_end /1000 );
571573 }
572574 fclose (ftrace);
@@ -777,6 +779,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
777779 CL_CHECK ((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel (backend_ctx->program_cvt , " kernel_convert_block_q4_0" , &err), err));
778780 CL_CHECK ((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel (backend_ctx->program_cvt , " kernel_restore_block_q4_0" , &err), err));
779781 CL_CHECK ((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel (backend_ctx->program_cvt , " kernel_convert_block_mxfp4" , &err), err));
782+ CL_CHECK ((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel (backend_ctx->program_cvt , " kernel_convert_block_mxfp4_trans" , &err), err));
783+ CL_CHECK ((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel (backend_ctx->program_cvt , " kernel_restore_block_mxfp4_trans" , &err), err));
780784 CL_CHECK ((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel (backend_ctx->program_cvt , " kernel_restore_block_mxfp4" , &err), err));
781785 CL_CHECK ((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel (backend_ctx->program_cvt , " kernel_convert_block_q8_0" , &err), err));
782786 CL_CHECK ((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel (backend_ctx->program_cvt , " kernel_restore_block_q8_0" , &err), err));
@@ -1991,6 +1995,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
19911995 CL_CHECK ((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel (backend_ctx->program_CL_gemm , " kernel_mul_mat_Ab_Bi_8x4" , &err), err));
19921996 GGML_LOG_CONT (" ." );
19931997 }
1998+
1999+ std::string CL_moe_compile_opts = std::string (" -cl-std=" ) + opencl_c_std +
2000+ " -cl-mad-enable "
2001+ " -cl-fast-relaxed-math" ;
2002+
2003+ // gemv_moe_mxfp4_f32
2004+ {
2005+ #ifdef GGML_OPENCL_EMBED_KERNELS
2006+ const std::string kernel_src {
2007+ #include " gemv_moe_mxfp4_f32.cl.h"
2008+ };
2009+ #else
2010+ const std::string kernel_src = read_file (" gemv_moe_mxfp4_f32.cl" );
2011+ #endif
2012+ backend_ctx->program_gemv_moe_mxfp4_f32 =
2013+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), CL_moe_compile_opts);
2014+
2015+ CL_CHECK ((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel (backend_ctx->program_gemv_moe_mxfp4_f32 , " kernel_gemv_moe_mxfp4_f32" , &err), err));
2016+ GGML_LOG_CONT (" ." );
2017+ }
2018+
2019+ // gemm_moe_mxfp4_f32
2020+ {
2021+ #ifdef GGML_OPENCL_EMBED_KERNELS
2022+ const std::string kernel_src {
2023+ #include " gemm_moe_mxfp4_f32.cl.h"
2024+ };
2025+ #else
2026+ const std::string kernel_src = read_file (" gemm_moe_mxfp4_f32.cl" );
2027+ #endif
2028+ backend_ctx->program_gemm_moe_mxfp4_f32 =
2029+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), CL_moe_compile_opts);
2030+
2031+ CL_CHECK ((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel (backend_ctx->program_gemm_moe_mxfp4_f32 , " kernel_gemm_moe_mxfp4_f32" , &err), err));
2032+ GGML_LOG_CONT (" ." );
2033+ }
19942034#endif // GGML_OPENCL_USE_ADRENO_KERNELS
19952035 GGML_LOG_CONT (" \n " );
19962036}
@@ -3299,6 +3339,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
32993339 tensor->ne [2 ] == 1 && tensor->ne [3 ] == 1 ;
33003340}
33013341
3342+ inline bool use_adreno_moe_kernels (const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
3343+ GGML_UNUSED (backend_ctx);
3344+ int ne01 = tensor->ne [1 ];
3345+ return ((strstr (tensor->name , " ffn" ) != NULL ) || (strstr (tensor->name , " as" ) != NULL )) && (ne01 % 64 == 0 );
3346+ }
3347+
33023348static void ggml_backend_opencl_buffer_set_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
33033349 ggml_backend_opencl_context *backend_ctx = ggml_cl2_init (buffer->buft ->device );
33043350
@@ -3601,14 +3647,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
36013647 CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
36023648 CL_CHECK (err);
36033649
3650+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
3651+ if (use_adreno_moe_kernels (backend_ctx, tensor)) {
3652+ cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans ;
3653+
3654+ int ne00 = tensor->ne [0 ];
3655+ int ne01 = tensor->ne [1 ];
3656+ int ne02 = tensor->ne [2 ];
3657+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &data_device));
3658+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra->q ));
3659+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra->e ));
3660+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (int ), &ne00));
3661+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne01));
3662+
3663+ size_t global_work_size[3 ] = {static_cast <size_t >(((ne01 + 63 ) / 64 ) * 64 ), static_cast <size_t >(ne00 / 32 ), static_cast <size_t >(ne02)};
3664+ size_t local_work_size[3 ] = {64 , 2 , 1 };
3665+
3666+ cl_event evt;
3667+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
3668+ CL_CHECK (clWaitForEvents (1 , &evt));
3669+ CL_CHECK (clReleaseMemObject (data_device));
3670+ tensor->extra = extra;
3671+
3672+ return ;
3673+ }
3674+ #endif
36043675 cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4 ;
36053676
36063677 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &data_device));
36073678 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra->q ));
36083679 CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra->e ));
36093680
3610- size_t global_work_size[] = {(size_t )ggml_nelements (tensor)/ggml_blck_size (tensor->type ), 1 , 1 };
3611- size_t local_work_size[] = {64 , 1 , 1 };
3681+ size_t global_work_size[3 ] = {(size_t )ggml_nelements (tensor)/ggml_blck_size (tensor->type ), 1 , 1 };
3682+ size_t local_work_size[3 ] = {64 , 1 , 1 };
36123683
36133684 cl_event evt;
36143685 CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
@@ -3624,7 +3695,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
36243695 { extra->q }
36253696 };
36263697 extra->q_img = clCreateImage (context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL , &err);
3627-
36283698 tensor->extra = extra;
36293699
36303700 return ;
@@ -3751,6 +3821,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
37513821 ggml_nbytes (tensor), NULL , &err);
37523822 CL_CHECK (err);
37533823
3824+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
3825+ if (use_adreno_moe_kernels (backend_ctx, tensor)) {
3826+ cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans ;
3827+
3828+ int ne00 = tensor->ne [0 ];
3829+ int ne01 = tensor->ne [1 ];
3830+ int ne02 = tensor->ne [2 ];
3831+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra->q ));
3832+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra->e ));
3833+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &data_device));
3834+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_int), &ne00));
3835+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_int), &ne01));
3836+
3837+ size_t global_work_size[3 ] = {static_cast <size_t >(((ne01 + 63 ) / 64 ) * 64 ), static_cast <size_t >(ne00 / 32 ), static_cast <size_t >(ne02)};
3838+ size_t local_work_size[3 ] = {64 , 2 , 1 };
3839+
3840+ cl_event evt;
3841+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL ,
3842+ global_work_size, local_work_size, 0 , NULL , &evt));
3843+ CL_CHECK (clWaitForEvents (1 , &evt));
3844+ CL_CHECK (clEnqueueReadBuffer (
3845+ queue, data_device, CL_TRUE, offset,
3846+ size, data, 0 , NULL , NULL ));
3847+ CL_CHECK (clReleaseMemObject (data_device));
3848+ return ;
3849+ }
3850+ #endif
37543851 cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4 ;
37553852 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra->q ));
37563853 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra->e ));
@@ -7553,6 +7650,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
75537650 const int ne21 = src2->ne [1 ];
75547651
75557652 const cl_ulong nb21 = src2->nb [1 ];
7653+ const cl_ulong nb20 = src2->nb [0 ];
75567654
75577655 const int ne0 = dst->ne [0 ];
75587656 const int ne1 = dst->ne [1 ];
@@ -7692,6 +7790,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
76927790 break ;
76937791 }
76947792 case GGML_TYPE_MXFP4: {
7793+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
7794+ if (use_adreno_moe_kernels (backend_ctx, src0)) {
7795+ cl_int status;
7796+
7797+ size_t local_size[3 ] = {64 , 2 , 1 };
7798+ size_t global_size[3 ] = {64 , 2 , 1 };
7799+
7800+ cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
7801+
7802+ int tile_size = 320 ;
7803+ if (ne12 == 1 ) { // for gemv
7804+ kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32 ;
7805+
7806+ // create a sub_buffer for src2
7807+ cl_buffer_region region;
7808+ region.origin = offset2;
7809+ region.size = ne20 * ne21 * sizeof (int );
7810+ buf_src2 = clCreateSubBuffer (extra2->data_device , 0 , CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
7811+ CL_CHECK (status);
7812+
7813+ // set thread grid
7814+ global_size[0 ] = static_cast <size_t >(ne01);
7815+ global_size[1 ] = 4 ;
7816+ global_size[2 ] = static_cast <size_t >(ne20);
7817+ local_size[1 ] = 4 ;
7818+ } else { // for gemm
7819+ kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32 ;
7820+
7821+ // preprocess router table
7822+ int num_tiles_per_expert = (ne01 + tile_size - 1 ) / tile_size;
7823+ void * host_src2_reorder = malloc (ne20 * ne21 * 4 * num_tiles_per_expert * sizeof (short ));
7824+ void * host_src2 = malloc (ne21 * nb21);
7825+ CL_CHECK (clEnqueueReadBuffer (backend_ctx->queue , extra2->data_device , CL_TRUE, offset2, ne21 * nb21, host_src2, 0 , NULL , NULL ));
7826+ int total_experts = nb21 / nb20;
7827+ int out_idx = 0 ;
7828+ for (int i_expert = 0 ; i_expert < ne02; i_expert++) {
7829+ for (int i_tile = 0 ; i_tile < num_tiles_per_expert; i_tile++) {
7830+ for (int j = 0 ; j < ne21; j++) {
7831+ for (int i = 0 ; i < ne20; i++) {
7832+ int expert = ((int *)host_src2)[j * total_experts + i];
7833+ if (i_expert == expert) {
7834+ ((short *)host_src2_reorder)[out_idx] = static_cast <short >(expert);
7835+ ((short *)host_src2_reorder)[out_idx + 1 ] = static_cast <short >(j * ne11 + (i % ne11));
7836+ ((short *)host_src2_reorder)[out_idx + 2 ] = static_cast <short >(j * ne20 + i);
7837+ ((short *)host_src2_reorder)[out_idx + 3 ] = static_cast <short >(i_tile);
7838+ out_idx += 4 ;
7839+ }
7840+ }
7841+ }
7842+ }
7843+ }
7844+ buf_src2 = clCreateBuffer (backend_ctx->context , CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof (short ), host_src2_reorder, &status);
7845+ CL_CHECK (status);
7846+
7847+ // set thread grid
7848+ global_size[0 ] = static_cast <size_t >(tile_size);
7849+ global_size[2 ] = static_cast <size_t >(ne20 * ne21 * num_tiles_per_expert);
7850+ }
7851+
7852+ // create a sub_buffer for src1
7853+ cl_buffer_region region;
7854+ region.origin = offset1;
7855+ region.size = ne10 * ne11 * ne12 * sizeof (float );
7856+ src1_sub_buffer = clCreateSubBuffer (extra1->data_device , 0 , CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
7857+ CL_CHECK (status);
7858+
7859+ // create image for src1
7860+ cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
7861+ cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast <size_t >(ne10 * ne11 * ne12 / 4 ), 0 ,0 ,0 ,0 ,0 ,0 ,0 , {src1_sub_buffer}};
7862+ buf_src1_image = clCreateImage (backend_ctx->context , CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL , &status);
7863+ CL_CHECK (status);
7864+
7865+ // Set kernel args
7866+ int arg_idx = 0 ;
7867+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &extra0_mxfp4->q ));
7868+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &extra0_mxfp4->e ));
7869+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &buf_src1_image));
7870+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &buf_src2));
7871+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &extrad->data_device ));
7872+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_ulong), &offsetd));
7873+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (int ), &ne00));
7874+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (int ), &ne01));
7875+ if (ne12 == 1 ) {
7876+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (int ), &ne11));
7877+ } else {
7878+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (int ), &tile_size));
7879+ }
7880+
7881+ // launch kernel
7882+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_size, local_size, dst);
7883+
7884+ // deallocate sub buffers and images
7885+ CL_CHECK (clReleaseMemObject (src1_sub_buffer));
7886+ CL_CHECK (clReleaseMemObject (buf_src1_image));
7887+ CL_CHECK (clReleaseMemObject (buf_src2));
7888+ return ;
7889+ } // else fallback to generic kernel
7890+ #endif // GGML_OPENCL_USE_ADRENO_KERNELS
7891+
76957892#ifdef GGML_OPENCL_SOA_Q
76967893 kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat ;
76977894
0 commit comments