@@ -399,6 +399,7 @@ struct ggml_backend_opencl_context {
399
399
cl_program program_conv_2d_f16_f32;
400
400
cl_program program_tsembd;
401
401
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
402
+ cl_program program_mul_mv_id_mxfp4_f32;
402
403
cl_program program_mul_mm_f32_f32_l4_lm;
403
404
cl_program program_mul_mm_f16_f32_l4_lm;
404
405
@@ -457,6 +458,7 @@ struct ggml_backend_opencl_context {
457
458
cl_kernel kernel_conv_2d_f16_f32;
458
459
cl_kernel kernel_timestep_embedding;
459
460
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
461
+ cl_kernel kernel_mul_mv_id_mxfp4_f32;
460
462
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
461
463
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
462
464
@@ -1629,6 +1631,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1629
1631
GGML_LOG_CONT (" ." );
1630
1632
}
1631
1633
1634
+ // mul_mv_id_mxfp4_f32
1635
+ {
1636
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1637
+ const std::string kernel_src {
1638
+ #include " mul_mv_id_mxfp4_f32.cl.h"
1639
+ };
1640
+ #else
1641
+ const std::string kernel_src = read_file (" mul_mv_id_mxfp4_f32.cl" );
1642
+ #endif
1643
+ backend_ctx->program_mul_mv_id_mxfp4_f32 =
1644
+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1645
+
1646
+ CL_CHECK ((backend_ctx->kernel_mul_mv_id_mxfp4_f32 = clCreateKernel (backend_ctx->program_mul_mv_id_mxfp4_f32 , " kernel_mul_mv_id_mxfp4_f32" , &err), err));
1647
+ GGML_LOG_CONT (" ." );
1648
+ }
1649
+
1632
1650
// Adreno kernels
1633
1651
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
1634
1652
// transpose
@@ -2576,7 +2594,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2576
2594
}
2577
2595
return false ;
2578
2596
case GGML_OP_MUL_MAT_ID:
2579
- if (op->src [0 ]->type == GGML_TYPE_Q4_0) {
2597
+ if (op->src [0 ]->type == GGML_TYPE_Q4_0 ||
2598
+ op->src [0 ]->type == GGML_TYPE_MXFP4) {
2580
2599
if (op->src [1 ]->type == GGML_TYPE_F32) {
2581
2600
return ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]);
2582
2601
}
@@ -6361,10 +6380,12 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
6361
6380
6362
6381
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
6363
6382
6383
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
6364
6384
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
6365
6385
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra ;
6366
6386
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
6367
6387
6388
+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
6368
6389
cl_ulong offset1 = extra1->offset + src1->view_offs ;
6369
6390
cl_ulong offset2 = extra2->offset + src2->view_offs ;
6370
6391
cl_ulong offsetd = extrad->offset + dst->view_offs ;
@@ -6379,7 +6400,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
6379
6400
const int ne03 = src0->ne [3 ];
6380
6401
6381
6402
const cl_ulong nb00 = src0->nb [0 ];
6403
+ const cl_ulong nb01 = src0->nb [1 ];
6382
6404
const cl_ulong nb02 = src0->nb [2 ];
6405
+ const cl_ulong nb03 = src0->nb [3 ];
6383
6406
6384
6407
const int ne10 = src1->ne [0 ];
6385
6408
const int ne11 = src1->ne [1 ];
@@ -6388,6 +6411,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
6388
6411
6389
6412
const cl_ulong nb11 = src1->nb [1 ];
6390
6413
const cl_ulong nb12 = src1->nb [2 ];
6414
+ const cl_ulong nb13 = src1->nb [3 ];
6391
6415
6392
6416
const int ne20 = src2->ne [0 ];
6393
6417
const int ne21 = src2->ne [1 ];
@@ -6455,6 +6479,49 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
6455
6479
6456
6480
break ;
6457
6481
}
6482
+ case GGML_TYPE_MXFP4: {
6483
+ kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32 ;
6484
+
6485
+ if (backend_ctx->gpu_family == INTEL) {
6486
+ sgs = 16 ;
6487
+ nsg = 2 ;
6488
+ ndst = 2 ;
6489
+ } else if (backend_ctx->gpu_family == ADRENO) {
6490
+ sgs = 64 ;
6491
+ nsg = 2 ;
6492
+ ndst = 2 ;
6493
+ } else {
6494
+ GGML_ASSERT (false && " TODO: Unknown GPU" );
6495
+ }
6496
+
6497
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
6498
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
6499
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
6500
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
6501
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extra2->data_device ));
6502
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
6503
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
6504
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
6505
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00));
6506
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb01));
6507
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb02));
6508
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb03));
6509
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne11));
6510
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne12));
6511
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb11));
6512
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb12));
6513
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb13));
6514
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne20));
6515
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &ne21));
6516
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb21));
6517
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (int ), &ne0));
6518
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (int ), &ne1));
6519
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &r2));
6520
+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &r3));
6521
+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (float )*sgs,nullptr ));
6522
+
6523
+ break ;
6524
+ }
6458
6525
default :
6459
6526
GGML_ASSERT (false && " not implemented" );;
6460
6527
}
0 commit comments