Skip to content

Commit 3427959

Browse files
committed
opencl: add reference mul_mv_id for mxfp4
1 parent 695758b commit 3427959

File tree

3 files changed

+258
-1
lines changed

3 files changed

+258
-1
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ set(GGML_OPENCL_KERNELS
8484
mul_mv_q6_k
8585
mul_mv_mxfp4_f32
8686
mul_mv_id_q4_0_f32_8x_flat
87+
mul_mv_id_mxfp4_f32
8788
mul_mm_f32_f32_l4_lm
8889
mul_mm_f16_f32_l4_lm
8990
mul

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ struct ggml_backend_opencl_context {
399399
cl_program program_conv_2d_f16_f32;
400400
cl_program program_tsembd;
401401
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
402+
cl_program program_mul_mv_id_mxfp4_f32;
402403
cl_program program_mul_mm_f32_f32_l4_lm;
403404
cl_program program_mul_mm_f16_f32_l4_lm;
404405

@@ -457,6 +458,7 @@ struct ggml_backend_opencl_context {
457458
cl_kernel kernel_conv_2d_f16_f32;
458459
cl_kernel kernel_timestep_embedding;
459460
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
461+
cl_kernel kernel_mul_mv_id_mxfp4_f32;
460462
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
461463
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
462464

@@ -1629,6 +1631,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
16291631
GGML_LOG_CONT(".");
16301632
}
16311633

1634+
// mul_mv_id_mxfp4_f32
1635+
{
1636+
#ifdef GGML_OPENCL_EMBED_KERNELS
1637+
const std::string kernel_src {
1638+
#include "mul_mv_id_mxfp4_f32.cl.h"
1639+
};
1640+
#else
1641+
const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32.cl");
1642+
#endif
1643+
backend_ctx->program_mul_mv_id_mxfp4_f32 =
1644+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1645+
1646+
CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32, "kernel_mul_mv_id_mxfp4_f32", &err), err));
1647+
GGML_LOG_CONT(".");
1648+
}
1649+
16321650
// Adreno kernels
16331651
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
16341652
// transpose
@@ -2576,7 +2594,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
25762594
}
25772595
return false;
25782596
case GGML_OP_MUL_MAT_ID:
2579-
if (op->src[0]->type == GGML_TYPE_Q4_0) {
2597+
if (op->src[0]->type == GGML_TYPE_Q4_0 ||
2598+
op->src[0]->type == GGML_TYPE_MXFP4) {
25802599
if (op->src[1]->type == GGML_TYPE_F32) {
25812600
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
25822601
}
@@ -6361,10 +6380,12 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
63616380

63626381
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
63636382

6383+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
63646384
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
63656385
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
63666386
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
63676387

6388+
cl_ulong offset0 = extra0->offset + src0->view_offs;
63686389
cl_ulong offset1 = extra1->offset + src1->view_offs;
63696390
cl_ulong offset2 = extra2->offset + src2->view_offs;
63706391
cl_ulong offsetd = extrad->offset + dst->view_offs;
@@ -6379,7 +6400,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
63796400
const int ne03 = src0->ne[3];
63806401

63816402
const cl_ulong nb00 = src0->nb[0];
6403+
const cl_ulong nb01 = src0->nb[1];
63826404
const cl_ulong nb02 = src0->nb[2];
6405+
const cl_ulong nb03 = src0->nb[3];
63836406

63846407
const int ne10 = src1->ne[0];
63856408
const int ne11 = src1->ne[1];
@@ -6388,6 +6411,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
63886411

63896412
const cl_ulong nb11 = src1->nb[1];
63906413
const cl_ulong nb12 = src1->nb[2];
6414+
const cl_ulong nb13 = src1->nb[3];
63916415

63926416
const int ne20 = src2->ne[0];
63936417
const int ne21 = src2->ne[1];
@@ -6455,6 +6479,49 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
64556479

64566480
break;
64576481
}
6482+
case GGML_TYPE_MXFP4: {
6483+
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32;
6484+
6485+
if (backend_ctx->gpu_family == INTEL) {
6486+
sgs = 16;
6487+
nsg = 2;
6488+
ndst = 2;
6489+
} else if (backend_ctx->gpu_family == ADRENO) {
6490+
sgs = 64;
6491+
nsg = 2;
6492+
ndst = 2;
6493+
} else {
6494+
GGML_ASSERT(false && "TODO: Unknown GPU");
6495+
}
6496+
6497+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6498+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6499+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
6500+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6501+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
6502+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
6503+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
6504+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
6505+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
6506+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
6507+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
6508+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
6509+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
6510+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
6511+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
6512+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
6513+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
6514+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne20));
6515+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne21));
6516+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));
6517+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0));
6518+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1));
6519+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
6520+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
6521+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,nullptr));
6522+
6523+
break;
6524+
}
64586525
default:
64596526
GGML_ASSERT(false && "not implemented");;
64606527
}
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#ifdef cl_intel_subgroups
4+
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5+
#else
6+
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7+
#endif
8+
9+
#ifdef cl_intel_required_subgroup_size
10+
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11+
#define INTEL_GPU 1
12+
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13+
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14+
#elif defined(cl_qcom_reqd_sub_group_size)
15+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16+
#define ADRENO_GPU 1
17+
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19+
#endif
20+
21+
#define QK_MXFP4 32
22+
typedef struct {
23+
uchar e; // E8M0
24+
uchar qs[QK_MXFP4/2];
25+
} block_mxfp4;
26+
27+
constant static float kvalues_mxfp4_f[16] = {
28+
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
29+
};
30+
31+
static inline float e8m0_to_fp32(uchar x) {
32+
int bits;
33+
34+
if (x == 0) {
35+
bits = 0x00400000;
36+
} else {
37+
bits = (uint) x << 23;
38+
}
39+
40+
return as_float(bits);
41+
}
42+
43+
#ifdef INTEL_GPU
44+
#define N_R0_MXFP4 2 // number of rows each subgroup works on
45+
#define N_SG_MXFP4 2 // number of subgroups in a work group
46+
#define N_SIMDWIDTH 16 // subgroup size
47+
#elif defined (ADRENO_GPU)
48+
#define N_R0_MXFP4 2
49+
#define N_SG_MXFP4 2
50+
#define N_SIMDWIDTH 64
51+
#endif
52+
53+
inline void mul_mv_mxfp4_f32(
54+
global char * src0,
55+
global char * src1,
56+
global char * dst,
57+
int ne00,
58+
ulong nb01,
59+
ulong nb02,
60+
ulong nb03,
61+
int ne12,
62+
ulong nb11,
63+
ulong nb12,
64+
ulong nb13,
65+
int ne0,
66+
int ne1,
67+
int r2,
68+
int r3,
69+
local char * shmem
70+
) {
71+
local float * shmem_f32 = (local float *) shmem;
72+
int nb = ne00/QK_MXFP4;
73+
74+
int r0 = get_group_id(0);
75+
int r1 = get_group_id(1);
76+
int im = 0;
77+
78+
int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
79+
80+
uint i12 = im%ne12;
81+
uint i13 = im/ne12;
82+
83+
ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
84+
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
85+
86+
global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);
87+
global float * y = (global float *) (src1 + offset_src1);
88+
89+
const short ix = get_sub_group_local_id()/2; // 0...15
90+
const short it = get_sub_group_local_id()%2; // 0 or 1
91+
92+
shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];
93+
barrier(CLK_LOCAL_MEM_FENCE);
94+
95+
float4 yl[4];
96+
float sumf[N_R0_MXFP4] = {0.f};
97+
98+
global float * yb = y + ix * QK_MXFP4 + it * 8;
99+
100+
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
101+
global float4 * y4 = (global float4 *)yb;
102+
yl[0] = y4[0];
103+
yl[1] = y4[4];
104+
yl[2] = y4[1];
105+
yl[3] = y4[5];
106+
107+
for (short row = 0; row < N_R0_MXFP4; row++) {
108+
global block_mxfp4 * xb = x + row*nb + ib;
109+
global uchar * q2 = (global uchar *)(xb->qs + 8*it);
110+
111+
float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
112+
float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
113+
float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
114+
float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
115+
116+
acc1 = (acc1 + acc3) + (acc2 + acc4);
117+
118+
sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
119+
}
120+
121+
yb += (N_SIMDWIDTH/2) * QK_MXFP4;
122+
}
123+
124+
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
125+
126+
for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
127+
float sum_all = sub_group_reduce_add(sumf[row]);
128+
if (get_sub_group_local_id() == 0) {
129+
dst_f32[first_row + row] = sum_all;
130+
}
131+
}
132+
}
133+
134+
#ifdef INTEL_GPU
135+
REQD_SUBGROUP_SIZE_16
136+
#elif defined (ADRENO_GPU)
137+
REQD_SUBGROUP_SIZE_64
138+
#endif
139+
kernel void kernel_mul_mv_id_mxfp4_f32(
140+
global char * src0,
141+
ulong offset0,
142+
global char * src1,
143+
ulong offset1,
144+
global char * src2,
145+
ulong offset2,
146+
global char * dst,
147+
ulong offsetd,
148+
int ne00,
149+
ulong nb01,
150+
ulong nb02,
151+
ulong nb03,
152+
int ne11,
153+
int ne12,
154+
ulong nb11,
155+
ulong nb12,
156+
ulong nb13,
157+
int ne20,
158+
int ne21,
159+
ulong nb21,
160+
int ne0,
161+
int ne1,
162+
int r2,
163+
int r3,
164+
local char * shmem
165+
) {
166+
src0 = (global char *)((global char *)src0 + offset0);
167+
src1 = (global char *)((global char *)src1 + offset1);
168+
src2 = (global char *)((global char *)src2 + offset2);
169+
dst = (global char *)((global char *)dst + offsetd);
170+
171+
const int iid1 = get_group_id(2)/ne20;
172+
const int idx = get_group_id(2)%ne20;
173+
174+
int i02 = ((global int *) (src2 + iid1*nb21))[idx];
175+
176+
int i11 = idx % ne11;
177+
int i12 = iid1;
178+
179+
int i1 = idx;
180+
int i2 = i12;
181+
182+
global char * src0_cur = src0 + i02*nb02;
183+
global char * src1_cur = src1 + i11*nb11 + i12*nb12;
184+
185+
global char * dst_cur = dst + (i1*ne0 + i2*ne1*ne0)*sizeof(float);
186+
187+
mul_mv_mxfp4_f32(src0_cur, src1_cur, dst_cur,
188+
ne00, nb01, nb02, nb03, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shmem);
189+
}

0 commit comments

Comments
 (0)