@@ -402,6 +402,7 @@ struct ggml_backend_opencl_context {
402
402
cl_program program_conv_2d_f32;
403
403
cl_program program_conv_2d_f16_f32;
404
404
cl_program program_tsembd;
405
+ cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
405
406
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
406
407
cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
407
408
cl_program program_mul_mv_id_mxfp4_f32;
@@ -452,7 +453,7 @@ struct ggml_backend_opencl_context {
452
453
cl_kernel kernel_mul_mat_f16_f32_tiled;
453
454
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
454
455
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
455
- cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
456
+ cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4;
456
457
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
457
458
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
458
459
cl_kernel kernel_convert_block_q4_0_noshuffle;
@@ -475,6 +476,7 @@ struct ggml_backend_opencl_context {
475
476
cl_kernel kernel_conv_2d_f32;
476
477
cl_kernel kernel_conv_2d_f16_f32;
477
478
cl_kernel kernel_timestep_embedding;
479
+ cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
478
480
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
479
481
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
480
482
cl_kernel kernel_mul_mv_id_mxfp4_f32;
@@ -777,6 +779,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
777
779
CL_CHECK ((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel (backend_ctx->program_cvt , " kernel_convert_block_q4_0" , &err), err));
778
780
CL_CHECK ((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel (backend_ctx->program_cvt , " kernel_restore_block_q4_0" , &err), err));
779
781
CL_CHECK ((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel (backend_ctx->program_cvt , " kernel_convert_block_mxfp4" , &err), err));
782
+ CL_CHECK ((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel (backend_ctx->program_cvt , " kernel_convert_block_mxfp4_trans" , &err), err));
780
783
CL_CHECK ((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel (backend_ctx->program_cvt , " kernel_restore_block_mxfp4" , &err), err));
781
784
CL_CHECK ((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel (backend_ctx->program_cvt , " kernel_convert_block_q8_0" , &err), err));
782
785
CL_CHECK ((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel (backend_ctx->program_cvt , " kernel_restore_block_q8_0" , &err), err));
@@ -1991,6 +1994,43 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1991
1994
CL_CHECK ((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel (backend_ctx->program_CL_gemm , " kernel_mul_mat_Ab_Bi_8x4" , &err), err));
1992
1995
GGML_LOG_CONT (" ." );
1993
1996
}
1997
+
1998
+ std::string CL_moe_compile_opts = std::string (" -cl-std=" ) + opencl_c_std +
1999
+ " -cl-mad-enable "
2000
+ " -qcom-disable-promote-pointer-to-texture"
2001
+ " -cl-fast-relaxed-math" ;
2002
+
2003
+ // gemv_moe_mxfp4_f32
2004
+ {
2005
+ #ifdef GGML_OPENCL_EMBED_KERNELS
2006
+ const std::string kernel_src {
2007
+ #include " gemv_moe_mxfp4_f32.cl.h"
2008
+ };
2009
+ #else
2010
+ const std::string kernel_src = read_file (" gemv_moe_mxfp4_f32.cl" );
2011
+ #endif
2012
+ backend_ctx->program_gemv_moe_mxfp4_f32 =
2013
+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), CL_moe_compile_opts);
2014
+
2015
+ CL_CHECK ((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel (backend_ctx->program_gemv_moe_mxfp4_f32 , " kernel_gemv_moe_mxfp4_f32" , &err), err));
2016
+ GGML_LOG_CONT (" ." );
2017
+ }
2018
+
2019
+ // gemm_moe_mxfp4_f32
2020
+ {
2021
+ #ifdef GGML_OPENCL_EMBED_KERNELS
2022
+ const std::string kernel_src {
2023
+ #include " gemm_moe_mxfp4_f32.cl.h"
2024
+ };
2025
+ #else
2026
+ const std::string kernel_src = read_file (" gemm_moe_mxfp4_f32.cl" );
2027
+ #endif
2028
+ backend_ctx->program_gemm_moe_mxfp4_f32 =
2029
+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), CL_moe_compile_opts);
2030
+
2031
+ CL_CHECK ((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel (backend_ctx->program_gemm_moe_mxfp4_f32 , " kernel_gemm_moe_mxfp4_f32" , &err), err));
2032
+ GGML_LOG_CONT (" ." );
2033
+ }
1994
2034
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
1995
2035
GGML_LOG_CONT (" \n " );
1996
2036
}
@@ -3596,14 +3636,40 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
3596
3636
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
3597
3637
CL_CHECK (err);
3598
3638
3639
+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
3640
+ if (strstr (tensor->name , " ffn" ) != NULL ) {
3641
+ int ne00 = tensor->ne [0 ];
3642
+ int ne01 = tensor->ne [1 ];
3643
+ int ne02 = tensor->ne [2 ];
3644
+
3645
+ cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans ;
3646
+
3647
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &data_device));
3648
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra->q ));
3649
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra->e ));
3650
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (int ), &ne00));
3651
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne01));
3652
+
3653
+ size_t global_work_size[3 ] = {static_cast <size_t >(((ne01 + 63 ) / 64 ) * 64 ), static_cast <size_t >(ne00 / 32 ), static_cast <size_t >(ne02)};
3654
+ size_t local_work_size[3 ] = {64 , 2 , 1 };
3655
+
3656
+ cl_event evt;
3657
+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
3658
+ CL_CHECK (clWaitForEvents (1 , &evt));
3659
+ CL_CHECK (clReleaseMemObject (data_device));
3660
+ tensor->extra = extra;
3661
+
3662
+ return ;
3663
+ }
3664
+ #endif
3599
3665
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4 ;
3600
3666
3601
3667
CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &data_device));
3602
3668
CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra->q ));
3603
3669
CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra->e ));
3604
3670
3605
- size_t global_work_size[] = {(size_t )ggml_nelements (tensor)/ggml_blck_size (tensor->type ), 1 , 1 };
3606
- size_t local_work_size[] = {64 , 1 , 1 };
3671
+ size_t global_work_size[3 ] = {(size_t )ggml_nelements (tensor)/ggml_blck_size (tensor->type ), 1 , 1 };
3672
+ size_t local_work_size[3 ] = {64 , 1 , 1 };
3607
3673
3608
3674
cl_event evt;
3609
3675
CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
@@ -3619,7 +3685,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
3619
3685
{ extra->q }
3620
3686
};
3621
3687
extra->q_img = clCreateImage (context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL , &err);
3622
-
3623
3688
tensor->extra = extra;
3624
3689
3625
3690
return ;
@@ -7545,6 +7610,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
7545
7610
const int ne21 = src2->ne [1 ];
7546
7611
7547
7612
const cl_ulong nb21 = src2->nb [1 ];
7613
+ const cl_ulong nb20 = src2->nb [0 ];
7548
7614
7549
7615
const int ne0 = dst->ne [0 ];
7550
7616
const int ne1 = dst->ne [1 ];
@@ -7684,6 +7750,103 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
7684
7750
break ;
7685
7751
}
7686
7752
case GGML_TYPE_MXFP4: {
7753
+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
7754
+ if (true ) { // condition todo
7755
+ cl_int status;
7756
+
7757
+ size_t local_size[3 ] = {64 , 4 , 1 };
7758
+ size_t global_size[3 ] = {64 , 4 , 1 };
7759
+
7760
+ cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
7761
+
7762
+ if (ne12 == 1 ) { // for gemv
7763
+ kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32 ;
7764
+
7765
+ // create a sub_buffer for src2
7766
+ cl_buffer_region region;
7767
+ region.origin = offset2;
7768
+ region.size = ne20 * ne21 * sizeof (int );
7769
+ buf_src2 = clCreateSubBuffer (extra2->data_device , 0 , CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
7770
+ CL_CHECK (status);
7771
+
7772
+ // set thread grid
7773
+ global_size[0 ] = static_cast <size_t >(ne01);
7774
+ global_size[2 ] = static_cast <size_t >(ne20);
7775
+ } else { // for gemm
7776
+ kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32 ;
7777
+
7778
+ // preprocess router table
7779
+ int tile_size = 320 ;
7780
+ int num_tiles_per_expert = (ne01 + tile_size - 1 ) / tile_size;
7781
+ void * host_src2_reorder = malloc (ne20 * ne21 * 4 * num_tiles_per_expert * sizeof (short ));
7782
+ void * host_src2 = malloc (ne21 * nb21);
7783
+ CL_CHECK (clEnqueueReadBuffer (backend_ctx->queue , extra2->data_device , CL_TRUE, offset2, ne21 * nb21, host_src2, 0 , NULL , NULL ));
7784
+ int total_experts = nb21 / nb20;
7785
+ int out_idx = 0 ;
7786
+ for (int i_expert = 0 ; i_expert < ne02; i_expert++) {
7787
+ for (int i_tile = 0 ; i_tile < num_tiles_per_expert; i_tile++) {
7788
+ for (int j = 0 ; j < ne21; j++) {
7789
+ for (int i = 0 ; i < ne20; i++) {
7790
+ int expert = ((int *)host_src2)[j * total_experts + i];
7791
+ if (i_expert == expert) {
7792
+ ((short *)host_src2_reorder)[out_idx] = static_cast <short >(expert);
7793
+ ((short *)host_src2_reorder)[out_idx + 1 ] = static_cast <short >(j * ne11 + (i % ne11));
7794
+ ((short *)host_src2_reorder)[out_idx + 2 ] = static_cast <short >(j * ne20 + i);
7795
+ ((short *)host_src2_reorder)[out_idx + 3 ] = static_cast <short >(i_tile);
7796
+ out_idx += 4 ;
7797
+ }
7798
+ }
7799
+ }
7800
+ }
7801
+ }
7802
+ buf_src2 = clCreateBuffer (backend_ctx->context , CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof (short ), host_src2_reorder, &status);
7803
+ CL_CHECK (status);
7804
+
7805
+ // set thread grid
7806
+ global_size[0 ] = static_cast <size_t >(tile_size);
7807
+ global_size[2 ] = static_cast <size_t >(ne20 * ne21 * num_tiles_per_expert);
7808
+ }
7809
+
7810
+ // create a sub_buffer for src1
7811
+ cl_buffer_region region;
7812
+ region.origin = offset1;
7813
+ region.size = ne10 * ne11 * ne12 * sizeof (float );
7814
+ src1_sub_buffer = clCreateSubBuffer (extra1->data_device , 0 , CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
7815
+ CL_CHECK (status);
7816
+
7817
+ // create image for src1
7818
+ cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
7819
+ cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast <size_t >(ne10 * ne11 * ne12 / 4 ), 0 ,0 ,0 ,0 ,0 ,0 ,0 , src1_sub_buffer};
7820
+ buf_src1_image = clCreateImage (backend_ctx->context , CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL , &status);
7821
+ CL_CHECK (status);
7822
+
7823
+ // Set kernel args
7824
+ int arg_idx = 0 ;
7825
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &extra0_mxfp4->q ));
7826
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &extra0_mxfp4->e ));
7827
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &buf_src1_image));
7828
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &buf_src2));
7829
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_mem), &extrad->data_device ));
7830
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (cl_ulong), &offsetd));
7831
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (int ), &ne00));
7832
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (int ), &ne01));
7833
+ if (ne12 == 1 ) {
7834
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (int ), &ne11));
7835
+ } else {
7836
+ CL_CHECK (clSetKernelArg (kernel, arg_idx++, sizeof (int ), &ne02));
7837
+ }
7838
+
7839
+ // launch kernel
7840
+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_size, local_size, dst);
7841
+
7842
+ // deallocate sub buffers and images
7843
+ CL_CHECK (clReleaseMemObject (src1_sub_buffer));
7844
+ CL_CHECK (clReleaseMemObject (buf_src1_image));
7845
+ CL_CHECK (clReleaseMemObject (buf_src2));
7846
+ return ;
7847
+ } // else fallback to generic kernel
7848
+ #endif // GGML_OPENCL_USE_ADRENO_KERNELS
7849
+
7687
7850
#ifdef GGML_OPENCL_SOA_Q
7688
7851
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat ;
7689
7852
0 commit comments