Skip to content

Commit 695758b

Browse files
committed
opencl: add reference mul_mv_mxfp4_f32
1 parent f4586ee commit 695758b

File tree

3 files changed

+201
-2
lines changed

3 files changed

+201
-2
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ set(GGML_OPENCL_KERNELS
8282
mul_mv_q4_0_f32_1d_8x_flat
8383
mul_mv_q4_0_f32_1d_16x_flat
8484
mul_mv_q6_k
85+
mul_mv_mxfp4_f32
8586
mul_mv_id_q4_0_f32_8x_flat
8687
mul_mm_f32_f32_l4_lm
8788
mul_mm_f16_f32_l4_lm

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ struct ggml_backend_opencl_context {
365365
cl_program program_mul_mv_q4_0_f32_1d_8x_flat;
366366
cl_program program_mul_mv_q4_0_f32_1d_16x_flat;
367367
cl_program program_mul_mv_q6_K;
368+
cl_program program_mul_mv_mxfp4_f32;
368369
cl_program program_mul_mv_f16_f16;
369370
cl_program program_mul_mv_f16_f32_1row;
370371
cl_program program_mul_mv_f16_f32_l4;
@@ -439,6 +440,7 @@ struct ggml_backend_opencl_context {
439440
cl_kernel kernel_convert_block_q4_0_noshuffle;
440441
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
441442
cl_kernel kernel_mul_mv_q6_K_f32;
443+
cl_kernel kernel_mul_mv_mxfp4_f32;
442444
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
443445
cl_kernel kernel_argsort_f32_i32;
444446
cl_kernel kernel_sum_rows_f32;
@@ -971,6 +973,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
971973
GGML_LOG_CONT(".");
972974
}
973975

976+
// mul_mv_mxfp4_f32
977+
{
978+
#ifdef GGML_OPENCL_EMBED_KERNELS
979+
const std::string kernel_src {
980+
#include "mul_mv_mxfp4_f32.cl.h"
981+
};
982+
#else
983+
const std::string kernel_src = read_file("mul_mv_mxfp4_f32.cl");
984+
#endif
985+
backend_ctx->program_mul_mv_mxfp4_f32 =
986+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
987+
988+
CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32, "kernel_mul_mv_mxfp4_f32", &err), err));
989+
GGML_LOG_CONT(".");
990+
}
991+
974992
// mul_mv_f16_f16
975993
{
976994
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2552,7 +2570,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
25522570
return true;
25532571
} else if (op->src[0]->type == GGML_TYPE_F32) {
25542572
return op->src[1]->type == GGML_TYPE_F32;
2555-
} else if (op->src[0]->type == GGML_TYPE_Q4_0 ||
2573+
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
25562574
op->src[0]->type == GGML_TYPE_Q6_K) {
25572575
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
25582576
}
@@ -6254,11 +6272,47 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
62546272
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
62556273
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
62566274
break;
6275+
case GGML_TYPE_MXFP4: {
6276+
kernel = backend_ctx->kernel_mul_mv_mxfp4_f32;
6277+
6278+
if (backend_ctx->gpu_family == INTEL) {
6279+
nth0 = 16;
6280+
nth1 = 2;
6281+
ndst = nth1*2;
6282+
} else if (backend_ctx->gpu_family == ADRENO) {
6283+
nth0 = 64;
6284+
nth1 = 2;
6285+
ndst = nth1*2;
6286+
} else {
6287+
GGML_ASSERT(false && "TODO: Unknown GPU");
6288+
}
6289+
6290+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6291+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6292+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
6293+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6294+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6295+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6296+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
6297+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
6298+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
6299+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
6300+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
6301+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
6302+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12));
6303+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13));
6304+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne0));
6305+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1));
6306+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2));
6307+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3));
6308+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr));
6309+
break;
6310+
}
62576311
default:
62586312
GGML_ASSERT(false && "not implemented");
62596313
}
62606314

6261-
if (src0t == GGML_TYPE_Q4_0 ||
6315+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||
62626316
src0t == GGML_TYPE_Q4_1 ||
62636317
src0t == GGML_TYPE_Q8_0 ||
62646318
src0t == GGML_TYPE_Q2_K) {
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#ifdef cl_intel_subgroups
4+
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5+
#else
6+
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7+
#endif
8+
9+
#ifdef cl_intel_required_subgroup_size
10+
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11+
#define INTEL_GPU 1
12+
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13+
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14+
#elif defined(cl_qcom_reqd_sub_group_size)
15+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16+
#define ADRENO_GPU 1
17+
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19+
#endif
20+
21+
#define QK_MXFP4 32
22+
typedef struct {
23+
uchar e; // E8M0
24+
uchar qs[QK_MXFP4/2];
25+
} block_mxfp4;
26+
27+
constant static float kvalues_mxfp4_f[16] = {
28+
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
29+
};
30+
31+
static inline float e8m0_to_fp32(uchar x) {
32+
int bits;
33+
34+
if (x == 0) {
35+
bits = 0x00400000;
36+
} else {
37+
bits = (uint) x << 23;
38+
}
39+
40+
return as_float(bits);
41+
}
42+
43+
#ifdef INTEL_GPU
44+
#define N_R0_MXFP4 2 // number of rows each subgroup works on
45+
#define N_SG_MXFP4 2 // number of subgroups in a work group
46+
#define N_SIMDWIDTH 16 // subgroup size
47+
#elif defined (ADRENO_GPU)
48+
#define N_R0_MXFP4 2
49+
#define N_SG_MXFP4 2
50+
#define N_SIMDWIDTH 64
51+
#endif
52+
53+
#ifdef INTEL_GPU
54+
REQD_SUBGROUP_SIZE_16
55+
#elif defined (ADRENO_GPU)
56+
REQD_SUBGROUP_SIZE_64
57+
#endif
58+
kernel void kernel_mul_mv_mxfp4_f32(
59+
global char * src0,
60+
ulong offset0,
61+
global char * src1,
62+
ulong offset1,
63+
global char * dst,
64+
ulong offsetd,
65+
int ne00,
66+
ulong nb01,
67+
ulong nb02,
68+
ulong nb03,
69+
int ne12,
70+
ulong nb11,
71+
ulong nb12,
72+
ulong nb13,
73+
int ne0,
74+
int ne1,
75+
int r2,
76+
int r3,
77+
local char * shmem
78+
) {
79+
src0 = (global char*)((global char*)src0 + offset0);
80+
src1 = (global char*)((global char*)src1 + offset1);
81+
dst = (global char*)((global char*)dst + offsetd);
82+
83+
local float * shmem_f32 = (local float *) shmem;
84+
int nb = ne00/QK_MXFP4;
85+
86+
int r0 = get_group_id(0);
87+
int r1 = get_group_id(1);
88+
int im = get_group_id(2);
89+
90+
int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
91+
92+
uint i12 = im%ne12;
93+
uint i13 = im/ne12;
94+
95+
ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
96+
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
97+
98+
global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);
99+
global float * y = (global float *) (src1 + offset_src1);
100+
101+
const short ix = get_sub_group_local_id()/2; // 0...15
102+
const short it = get_sub_group_local_id()%2; // 0 or 1
103+
104+
shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];
105+
barrier(CLK_LOCAL_MEM_FENCE);
106+
107+
float4 yl[4];
108+
float sumf[N_R0_MXFP4] = {0.f};
109+
110+
global float * yb = y + ix * QK_MXFP4 + it * 8;
111+
112+
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
113+
global float4 * y4 = (global float4 *)yb;
114+
yl[0] = y4[0];
115+
yl[1] = y4[4];
116+
yl[2] = y4[1];
117+
yl[3] = y4[5];
118+
119+
for (short row = 0; row < N_R0_MXFP4; row++) {
120+
global block_mxfp4 * xb = x + row*nb + ib;
121+
global uchar * q2 = (global uchar *)(xb->qs + 8*it);
122+
123+
float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
124+
float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
125+
float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
126+
float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
127+
128+
acc1 = (acc1 + acc3) + (acc2 + acc4);
129+
130+
sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
131+
}
132+
133+
yb += (N_SIMDWIDTH/2) * QK_MXFP4;
134+
}
135+
136+
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
137+
138+
for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
139+
float sum_all = sub_group_reduce_add(sumf[row]);
140+
if (get_sub_group_local_id() == 0) {
141+
dst_f32[first_row + row] = sum_all;
142+
}
143+
}
144+
}

0 commit comments

Comments
 (0)