Skip to content

Commit fe12b20

Browse files
committed
opencl: use original mxfp4 mv for structs
1 parent 7aa67ce commit fe12b20

File tree

3 files changed

+66
-86
lines changed

3 files changed

+66
-86
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7042,6 +7042,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
70427042
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1));
70437043
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2));
70447044
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3));
7045+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr));
70457046
#endif
70467047
break;
70477048
}
@@ -7282,6 +7283,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
72827283
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1));
72837284
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
72847285
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
7286+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,nullptr));
72857287
#endif // GGML_OPENCL_SOA_Q
72867288
break;
72877289
}

ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,19 @@ typedef struct {
2424
uchar qs[QK_MXFP4/2];
2525
} block_mxfp4;
2626

27-
// single ushort contains 4 mxfp4 as input
28-
static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
29-
ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
30-
fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
31-
fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
32-
fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
33-
fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
34-
35-
bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
36-
bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
37-
bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
38-
bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
39-
40-
fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
41-
fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
42-
fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
43-
fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
44-
45-
sign_a.lo = (fp4x4 << 12) & 0x8000;
46-
sign_a.hi = (fp4x4 << 8) & 0x8000;
47-
sign_b.lo = (fp4x4 << 4) & 0x8000;
48-
sign_b.hi = fp4x4 & 0x8000;
49-
50-
fp16_packed_a = sign_a + bias_a + fp16_packed_a;
51-
fp16_packed_b = sign_b + bias_b + fp16_packed_b;
52-
53-
return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
54-
}
27+
constant static float kvalues_mxfp4_f[16] = {
28+
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
29+
};
5530

5631
static inline float e8m0_to_fp32(uchar x) {
5732
int bits;
58-
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
33+
34+
if (x == 0) {
35+
bits = 0x00400000;
36+
} else {
37+
bits = (uint) x << 23;
38+
}
39+
5940
return as_float(bits);
6041
}
6142

@@ -84,8 +65,10 @@ inline void mul_mv_mxfp4_f32(
8465
int ne0,
8566
int ne1,
8667
int r2,
87-
int r3
68+
int r3,
69+
local char * shmem
8870
) {
71+
local float * shmem_f32 = (local float *) shmem;
8972
int nb = ne00/QK_MXFP4;
9073

9174
int r0 = get_group_id(0);
@@ -106,25 +89,31 @@ inline void mul_mv_mxfp4_f32(
10689
const short ix = get_sub_group_local_id()/2; // 0...15
10790
const short it = get_sub_group_local_id()%2; // 0 or 1
10891

92+
shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];
93+
barrier(CLK_LOCAL_MEM_FENCE);
94+
95+
float4 yl[4];
10996
float sumf[N_R0_MXFP4] = {0.f};
11097

11198
global float * yb = y + ix * QK_MXFP4 + it * 8;
11299

113100
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
114101
global float4 * y4 = (global float4 *)yb;
102+
yl[0] = y4[0];
103+
yl[1] = y4[4];
104+
yl[2] = y4[1];
105+
yl[3] = y4[5];
115106

116107
for (short row = 0; row < N_R0_MXFP4; row++) {
117108
global block_mxfp4 * xb = x + row*nb + ib;
118-
global ushort * q2 = (global ushort *)(xb->qs + 8*it);
109+
global uchar * q2 = (global uchar *)(xb->qs + 8*it);
110+
111+
float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
112+
float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
113+
float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
114+
float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
119115

120-
half4 fp16x4_0 = mxfp4_to_fp16_packed(q2[0]);
121-
half4 fp16x4_1 = mxfp4_to_fp16_packed(q2[1]);
122-
float4 acc1 = y4[0]*(float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
123-
acc1 += y4[4]*(float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
124-
fp16x4_0 = mxfp4_to_fp16_packed(q2[2]);
125-
fp16x4_1 = mxfp4_to_fp16_packed(q2[3]);
126-
acc1 += y4[1]*(float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
127-
acc1 += y4[5]*(float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
116+
acc1 = (acc1 + acc3) + (acc2 + acc4);
128117

129118
sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
130119
}
@@ -171,7 +160,8 @@ kernel void kernel_mul_mv_id_mxfp4_f32(
171160
int ne0,
172161
int ne1,
173162
int r2,
174-
int r3
163+
int r3,
164+
local char * shmem
175165
) {
176166
src0 = (global char *)((global char *)src0 + offset0);
177167
src1 = (global char *)((global char *)src1 + offset1);
@@ -195,5 +185,5 @@ kernel void kernel_mul_mv_id_mxfp4_f32(
195185
global char * dst_cur = dst + (i1*ne0 + i2*ne1*ne0)*sizeof(float);
196186

197187
mul_mv_mxfp4_f32(src0_cur, src1_cur, dst_cur,
198-
ne00, nb01, nb02, nb03, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3);
188+
ne00, nb01, nb02, nb03, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shmem);
199189
}

ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,37 +24,19 @@ typedef struct {
2424
uchar qs[QK_MXFP4/2];
2525
} block_mxfp4;
2626

27-
static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
28-
ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
29-
fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
30-
fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
31-
fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
32-
fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
33-
34-
bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
35-
bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
36-
bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
37-
bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
38-
39-
fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
40-
fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
41-
fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
42-
fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
43-
44-
sign_a.lo = (fp4x4 << 12) & 0x8000;
45-
sign_a.hi = (fp4x4 << 8) & 0x8000;
46-
sign_b.lo = (fp4x4 << 4) & 0x8000;
47-
sign_b.hi = fp4x4 & 0x8000;
48-
49-
fp16_packed_a = sign_a + bias_a + fp16_packed_a;
50-
fp16_packed_b = sign_b + bias_b + fp16_packed_b;
51-
52-
return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
53-
}
27+
constant static float kvalues_mxfp4_f[16] = {
28+
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
29+
};
5430

5531
static inline float e8m0_to_fp32(uchar x) {
5632
int bits;
57-
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
33+
34+
if (x == 0) {
35+
bits = 0x00400000;
36+
} else {
37+
bits = (uint) x << 23;
38+
}
39+
5840
return as_float(bits);
5941
}
6042

@@ -91,53 +73,59 @@ kernel void kernel_mul_mv_mxfp4_f32(
9173
int ne0,
9274
int ne1,
9375
int r2,
94-
int r3
76+
int r3,
77+
local char * shmem
9578
) {
9679
src0 = (global char*)((global char*)src0 + offset0);
9780
src1 = (global char*)((global char*)src1 + offset1);
9881
dst = (global char*)((global char*)dst + offsetd);
9982

100-
int nb = ne00 / QK_MXFP4;
83+
local float * shmem_f32 = (local float *) shmem;
84+
int nb = ne00/QK_MXFP4;
10185

10286
int r0 = get_group_id(0);
10387
int r1 = get_group_id(1);
10488
int im = get_group_id(2);
10589

10690
int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
10791

108-
uint i12 = im % ne12;
109-
uint i13 = im / ne12;
92+
uint i12 = im%ne12;
93+
uint i13 = im/ne12;
11094

11195
ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
11296
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
11397

11498
global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);
11599
global float * y = (global float *) (src1 + offset_src1);
116100

117-
const short ix = get_sub_group_local_id() >> 1; // 0...15
118-
const short it = get_sub_group_local_id() & 1; // 0 or 1
101+
const short ix = get_sub_group_local_id()/2; // 0...15
102+
const short it = get_sub_group_local_id()%2; // 0 or 1
119103

104+
shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];
105+
barrier(CLK_LOCAL_MEM_FENCE);
106+
107+
float4 yl[4];
120108
float sumf[N_R0_MXFP4] = {0.f};
121109

122110
global float * yb = y + ix * QK_MXFP4 + it * 8;
123111

124112
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
125113
global float4 * y4 = (global float4 *)yb;
114+
yl[0] = y4[0];
115+
yl[1] = y4[4];
116+
yl[2] = y4[1];
117+
yl[3] = y4[5];
126118

127119
for (short row = 0; row < N_R0_MXFP4; row++) {
128120
global block_mxfp4 * xb = x + row*nb + ib;
129121
global uchar * q2 = (global uchar *)(xb->qs + 8*it);
130-
ushort4 xb_q = ((global ushort4 *)(q2))[0];
131-
132-
half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
133-
half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
134-
float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
135-
acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
136-
137-
fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
138-
fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
139-
acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
140-
acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
122+
123+
float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
124+
float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
125+
float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
126+
float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
127+
128+
acc1 = (acc1 + acc3) + (acc2 + acc4);
141129

142130
sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
143131
}

0 commit comments

Comments
 (0)