Skip to content

Commit 300c352

Browse files
committed
opencl: add mm_q8_0_f32
1 parent e74c92e commit 300c352

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ set(GGML_OPENCL_KERNELS
9393
mul_mv_id_mxfp4_f32_flat
9494
mul_mm_f32_f32_l4_lm
9595
mul_mm_f16_f32_l4_lm
96+
mul_mm_q8_0_f32_l4_lm
9697
mul
9798
norm
9899
relu

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
408408
cl_program program_mul_mv_id_mxfp4_f32_flat;
409409
cl_program program_mul_mm_f32_f32_l4_lm;
410410
cl_program program_mul_mm_f16_f32_l4_lm;
411+
cl_program program_mul_mm_q8_0_f32_l4_lm;
411412

412413
cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
413414
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
@@ -480,6 +481,7 @@ struct ggml_backend_opencl_context {
480481
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
481482
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
482483
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
484+
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
483485

484486
std::vector<ProfilingInfo> profiling_info;
485487

@@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11911193
GGML_LOG_CONT(".");
11921194
}
11931195

1196+
// mul_mm_q8_0_f32_l4_lm
1197+
{
1198+
#ifdef GGML_OPENCL_EMBED_KERNELS
1199+
const std::string kernel_src {
1200+
#include "mul_mm_q8_0_f32_l4_lm.cl.h"
1201+
};
1202+
#else
1203+
const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl");
1204+
#endif
1205+
backend_ctx->program_mul_mm_q8_0_f32_l4_lm =
1206+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1207+
1208+
CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err));
1209+
GGML_LOG_CONT(".");
1210+
}
1211+
11941212
// mul
11951213
{
11961214
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -6956,6 +6974,41 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
69566974
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
69576975
return;
69586976
}
6977+
case GGML_TYPE_Q8_0: {
6978+
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
6979+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
6980+
6981+
int batch_stride_a = ne00*ne01;
6982+
int batch_stride_b = ne10*ne11;
6983+
int batch_stride_d = ne0*ne1;
6984+
6985+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q));
6986+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d));
6987+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
6988+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6989+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6990+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6991+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
6992+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
6993+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
6994+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
6995+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
6996+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
6997+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
6998+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
6999+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
7000+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
7001+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
7002+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
7003+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
7004+
7005+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
7006+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
7007+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
7008+
7009+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
7010+
return;
7011+
}
69597012
default:
69607013
break;
69617014
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define LOAD_VEC_A 4
4+
#define LOAD_VEC_B 4
5+
6+
#define BM 64
7+
#define BN 64
8+
#define BK 32
9+
#define TM 4
10+
#define TN 8
11+
12+
kernel void kernel_mul_mm_q8_0_f32_l4_lm(
13+
global char4 * src0_q,
14+
global half * src0_d,
15+
global float4 * src1,
16+
ulong offset1,
17+
global float * dst,
18+
ulong offsetd,
19+
20+
int ne00,
21+
int ne01,
22+
int ne02,
23+
int ne11,
24+
int ne12,
25+
26+
int stride_a,
27+
int stride_b,
28+
int stride_d,
29+
30+
int batch_stride_a,
31+
int batch_stride_b,
32+
int batch_stride_d,
33+
34+
int r2,
35+
int r3
36+
) {
37+
src1 = (global float4*)((global char*)src1 + offset1);
38+
dst = (global float *)((global char*)dst + offsetd);
39+
40+
local float buf_a[BM * BK];
41+
local float buf_b[BN * BK];
42+
43+
const int batch_idx = get_global_id(2);
44+
45+
const int i13 = batch_idx / ne12;
46+
const int i12 = batch_idx % ne12;
47+
48+
const int i03 = i13 / r3;
49+
const int i02 = i12 / r2;
50+
51+
const int batch_idx_a = i03 * ne02 + i02;
52+
53+
const int ir = get_group_id(0);
54+
const int ic = get_group_id(1);
55+
56+
const int tid = get_local_id(0);
57+
const int th_r = tid % (BM / TM);
58+
const int th_c = tid / (BM / TM);
59+
60+
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
61+
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
62+
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
63+
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
64+
65+
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
66+
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
67+
68+
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
69+
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
70+
71+
float sums[TM * TN];
72+
float cache_a[TM];
73+
float cache_b[TN];
74+
75+
for (int i = 0; i < TM * TN; i++) {
76+
sums[i] = 0.0f;
77+
}
78+
79+
for (int block = 0; block < ne00; block += BK) {
80+
for (int l = 0; l < BM; l += loadstride_a) {
81+
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
82+
int ib = idx / 8;
83+
int iqs = idx % 8;
84+
85+
float d = (float)src0_d[ib];
86+
global char4 * qs = src0_q + ib*8 + iqs;
87+
char4 q = *qs;
88+
float4 v = convert_float4(q)*d;
89+
90+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;
91+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;
92+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;
93+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;
94+
}
95+
96+
for (int l = 0; l < BN; l += loadstride_b) {
97+
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
98+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
99+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
100+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
101+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
102+
}
103+
104+
barrier(CLK_LOCAL_MEM_FENCE);
105+
106+
pos_a += BK / LOAD_VEC_A;
107+
pos_b += BK / LOAD_VEC_B;
108+
109+
for (int i = 0; i < BK; i++) {
110+
for (int j = 0; j < TM; j++) {
111+
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
112+
}
113+
114+
for (int j = 0; j < TN; j++) {
115+
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
116+
}
117+
118+
for (int cc = 0; cc < TN; cc++) {
119+
for (int cr = 0; cr < TM; cr++) {
120+
const int sums_idx = cc*TM + cr;
121+
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
122+
}
123+
}
124+
}
125+
barrier(CLK_LOCAL_MEM_FENCE);
126+
}
127+
128+
const int dr = ir * BM + th_r * TM;
129+
const int dc = ic * BN + th_c * TN;
130+
131+
const int offsets = batch_idx * batch_stride_d;
132+
133+
for (int cc = 0; cc < TN; cc++) {
134+
for (int cr = 0; cr < TM; cr++) {
135+
if (dr + cr < ne01 && dc + cc < ne11) {
136+
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
137+
}
138+
}
139+
}
140+
}

0 commit comments

Comments
 (0)