Skip to content

Commit 81e8522

Browse files
committed
opencl: add mm_q4_0_f32_lm
1 parent 300c352 commit 81e8522

File tree

3 files changed

+199
-0
lines changed

3 files changed

+199
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ set(GGML_OPENCL_KERNELS
9393
mul_mv_id_mxfp4_f32_flat
9494
mul_mm_f32_f32_l4_lm
9595
mul_mm_f16_f32_l4_lm
96+
mul_mm_q4_0_f32_l4_lm
9697
mul_mm_q8_0_f32_l4_lm
9798
mul
9899
norm

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
408408
cl_program program_mul_mv_id_mxfp4_f32_flat;
409409
cl_program program_mul_mm_f32_f32_l4_lm;
410410
cl_program program_mul_mm_f16_f32_l4_lm;
411+
cl_program program_mul_mm_q4_0_f32_l4_lm;
411412
cl_program program_mul_mm_q8_0_f32_l4_lm;
412413

413414
cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
@@ -481,6 +482,7 @@ struct ggml_backend_opencl_context {
481482
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
482483
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
483484
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
485+
cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;
484486
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
485487

486488
std::vector<ProfilingInfo> profiling_info;
@@ -1193,6 +1195,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11931195
GGML_LOG_CONT(".");
11941196
}
11951197

1198+
// mul_mm_q4_0_f32_l4_lm
1199+
{
1200+
#ifdef GGML_OPENCL_EMBED_KERNELS
1201+
const std::string kernel_src {
1202+
#include "mul_mm_q4_0_f32_l4_lm.cl.h"
1203+
};
1204+
#else
1205+
const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl");
1206+
#endif
1207+
backend_ctx->program_mul_mm_q4_0_f32_l4_lm =
1208+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1209+
1210+
CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q4_0_f32_l4_lm, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err));
1211+
GGML_LOG_CONT(".");
1212+
}
1213+
11961214
// mul_mm_q8_0_f32_l4_lm
11971215
{
11981216
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -6974,6 +6992,41 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
69746992
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
69756993
return;
69766994
}
6995+
case GGML_TYPE_Q4_0: {
6996+
kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm;
6997+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
6998+
6999+
int batch_stride_a = ne00*ne01;
7000+
int batch_stride_b = ne10*ne11;
7001+
int batch_stride_d = ne0*ne1;
7002+
7003+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q));
7004+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d));
7005+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7006+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7007+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
7008+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
7009+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
7010+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
7011+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
7012+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
7013+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
7014+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
7015+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
7016+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
7017+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
7018+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
7019+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
7020+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
7021+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
7022+
7023+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
7024+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
7025+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
7026+
7027+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
7028+
return;
7029+
}
69777030
case GGML_TYPE_Q8_0: {
69787031
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
69797032
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define LOAD_VEC_A 8
4+
#define LOAD_VEC_B 4
5+
6+
#define BM 64
7+
#define BN 64
8+
#define BK 32
9+
#define TM 4
10+
#define TN 8
11+
12+
kernel void kernel_mul_mm_q4_0_f32_l4_lm(
13+
global uchar4 * src0_q,
14+
global half * src0_d,
15+
global float4 * src1,
16+
ulong offset1,
17+
global float * dst,
18+
ulong offsetd,
19+
20+
int ne00,
21+
int ne01,
22+
int ne02,
23+
int ne11,
24+
int ne12,
25+
26+
int stride_a,
27+
int stride_b,
28+
int stride_d,
29+
30+
int batch_stride_a,
31+
int batch_stride_b,
32+
int batch_stride_d,
33+
34+
int r2,
35+
int r3
36+
) {
37+
src1 = (global float4*)((global char*)src1 + offset1);
38+
dst = (global float *)((global char*)dst + offsetd);
39+
40+
local float buf_a[BM * BK];
41+
local float buf_b[BN * BK];
42+
43+
const int batch_idx = get_global_id(2);
44+
45+
const int i13 = batch_idx / ne12;
46+
const int i12 = batch_idx % ne12;
47+
48+
const int i03 = i13 / r3;
49+
const int i02 = i12 / r2;
50+
51+
const int batch_idx_a = i03 * ne02 + i02;
52+
53+
const int ir = get_group_id(0);
54+
const int ic = get_group_id(1);
55+
56+
const int tid = get_local_id(0);
57+
const int th_r = tid % (BM / TM);
58+
const int th_c = tid / (BM / TM);
59+
60+
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
61+
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
62+
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
63+
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
64+
65+
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
66+
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
67+
68+
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
69+
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
70+
71+
float sums[TM * TN];
72+
float cache_a[TM];
73+
float cache_b[TN];
74+
75+
for (int i = 0; i < TM * TN; i++) {
76+
sums[i] = 0.0f;
77+
}
78+
79+
for (int block = 0; block < ne00; block += BK) {
80+
for (int l = 0; l < BM; l += loadstride_a) {
81+
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
82+
int ib = idx / 4;
83+
int iqs = idx % 4;
84+
85+
float d = (float)src0_d[ib];
86+
global uchar4 * qs = src0_q + ib*4 + iqs;
87+
uchar4 q = *qs;
88+
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
89+
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
90+
91+
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
92+
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
93+
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
94+
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
95+
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
96+
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
97+
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
98+
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
99+
}
100+
101+
for (int l = 0; l < BN; l += loadstride_b) {
102+
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
103+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
104+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
105+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
106+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
107+
}
108+
109+
barrier(CLK_LOCAL_MEM_FENCE);
110+
111+
pos_a += BK / LOAD_VEC_A;
112+
pos_b += BK / LOAD_VEC_B;
113+
114+
for (int i = 0; i < BK; i++) {
115+
for (int j = 0; j < TM; j++) {
116+
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
117+
}
118+
119+
for (int j = 0; j < TN; j++) {
120+
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
121+
}
122+
123+
for (int cc = 0; cc < TN; cc++) {
124+
for (int cr = 0; cr < TM; cr++) {
125+
const int sums_idx = cc*TM + cr;
126+
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
127+
}
128+
}
129+
}
130+
barrier(CLK_LOCAL_MEM_FENCE);
131+
}
132+
133+
const int dr = ir * BM + th_r * TM;
134+
const int dc = ic * BN + th_c * TN;
135+
136+
const int offsets = batch_idx * batch_stride_d;
137+
138+
for (int cc = 0; cc < TN; cc++) {
139+
for (int cr = 0; cr < TM; cr++) {
140+
if (dr + cr < ne01 && dc + cc < ne11) {
141+
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
142+
}
143+
}
144+
}
145+
}

0 commit comments

Comments
 (0)