Skip to content

Commit c24b666

Browse files
committed
add tiled mul_mat_f16_f32
1 parent 67d1ef2 commit c24b666

File tree

3 files changed

+202
-0
lines changed

3 files changed

+202
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ set(GGML_OPENCL_KERNELS
103103
tanh
104104
pad
105105
repeat
106+
mul_mat_f16_f32
106107
)
107108

108109
foreach (K ${GGML_OPENCL_KERNELS})

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ struct ggml_backend_opencl_context {
367367
cl_program program_mul_mv_f16_f32;
368368
cl_program program_mul_mv_f32_f32;
369369
cl_program program_mul;
370+
cl_program program_mul_mat_f16_f32_tiled;
370371
cl_program program_div;
371372
cl_program program_sub;
372373
cl_program program_norm;
@@ -419,6 +420,7 @@ struct ggml_backend_opencl_context {
419420
cl_kernel kernel_mul_mat_f16_f32_1row;
420421
cl_kernel kernel_mul_mat_f16_f32;
421422
cl_kernel kernel_mul_mat_f16_f32_l4;
423+
cl_kernel kernel_mul_mat_f16_f32_tiled;
422424
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
423425
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
424426
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
@@ -1000,6 +1002,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
10001002
GGML_LOG_CONT(".");
10011003
}
10021004

1005+
// mul_mat_f16_f32_tiled
1006+
{
1007+
#ifdef GGML_OPENCL_EMBED_KERNELS
1008+
const std::string kernel_src {
1009+
#include "mul_mat_f16_f32.cl.h"
1010+
};
1011+
#else
1012+
const std::string kernel_src = read_file("mul_mat_f16_f32.cl");
1013+
#endif
1014+
backend_ctx->program_mul_mat_f16_f32_tiled =
1015+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1016+
1017+
CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err));
1018+
GGML_LOG_CONT(".");
1019+
}
1020+
10031021
// mul
10041022
{
10051023
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -4742,6 +4760,47 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
47424760
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
47434761
}
47444762

4763+
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4764+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4765+
4766+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
4767+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
4768+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4769+
4770+
cl_ulong offset0 = extra0->offset + src0->view_offs;
4771+
cl_ulong offset1 = extra1->offset + src1->view_offs;
4772+
cl_ulong offsetd = extrad->offset + dst->view_offs;
4773+
4774+
const int M = src0->ne[1];
4775+
const int N = src1->ne[1];
4776+
const int K = src0->ne[0];
4777+
4778+
cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled;
4779+
4780+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M));
4781+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N));
4782+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K));
4783+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device));
4784+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0));
4785+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device));
4786+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));
4787+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device));
4788+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
4789+
4790+
const int OPWM = 64;
4791+
const int OPWN = 64;
4792+
const int TPWM = 16;
4793+
const int TPWN = 8;
4794+
4795+
size_t local_work_size[2] = { TPWM, TPWN };
4796+
size_t global_work_size[2] = {
4797+
(size_t) ((M + OPWM - 1) / OPWM) * TPWM,
4798+
(size_t) ((N + OPWN - 1) / OPWN) * TPWN,
4799+
};
4800+
4801+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
4802+
}
4803+
47454804
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
47464805
GGML_ASSERT(src0);
47474806
GGML_ASSERT(src0->extra);
@@ -4755,6 +4814,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
47554814

47564815
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
47574816

4817+
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
4818+
src0->ne[1] > 32 && // M > 32
4819+
src1->ne[1] > 32 && // N > 32
4820+
src0->ne[0] > 32 && // K > 32
4821+
src0->ne[2] == 1 && src0->ne[3] == 1 &&
4822+
src1->ne[2] == 1 && src1->ne[3] == 1 &&
4823+
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
4824+
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
4825+
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
4826+
return;
4827+
}
4828+
47584829
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
47594830
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
47604831
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#if defined(cl_qcom_reqd_sub_group_size)
4+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
5+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
6+
#else
7+
#define REQD_SUBGROUP_SIZE_128
8+
#endif
9+
10+
#define OPWM 64
11+
#define OPWN 64
12+
#define CPWK 8
13+
#define OPTM 4
14+
#define OPTN 8
15+
16+
#define WG_M (OPWM / OPTM)
17+
#define WG_N (OPWN / OPTN)
18+
#define VEC_K (CPWK / 4)
19+
20+
REQD_SUBGROUP_SIZE_128
21+
__kernel void mul_mat_f16_f32(
22+
const int M, const int N, const int K,
23+
__global const void* A_void, ulong A_offset,
24+
__global const void* B_void, ulong B_offset,
25+
__global void* C_void, ulong C_offset) {
26+
27+
__global const half* A = (__global const half* )((__global const char*)A_void + A_offset);
28+
__global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
29+
__global float* C = (__global float*)((__global char*)C_void + C_offset);
30+
31+
const int lidm = get_local_id(0);
32+
const int lidn = get_local_id(1);
33+
const int lid = lidn * WG_M + lidm;
34+
35+
const int offsetM = get_group_id(0) * OPWM;
36+
const int offsetN = get_group_id(1) * OPWN;
37+
38+
__local half4 Alocal[OPWM][VEC_K];
39+
__local float4 Blocal[OPWN][VEC_K];
40+
41+
float sum[OPTM][OPTN];
42+
43+
for (int wm = 0; wm < OPTM; wm++) {
44+
for (int wn = 0; wn < OPTN; wn++) {
45+
sum[wm][wn] = 0.0f;
46+
}
47+
}
48+
49+
const int numTiles = (K + CPWK - 1) / CPWK;
50+
51+
const int load_row_a = lid % OPWM;
52+
const int load_vec_k_a = lid / OPWM;
53+
const int global_row_a = offsetM + load_row_a;
54+
55+
const int load_row_b = lid % OPWN;
56+
const int load_vec_k_b = lid / OPWN;
57+
const int global_row_b = offsetN + load_row_b;
58+
59+
for (int t = 0; t < numTiles; t++) {
60+
const int k_start = t * CPWK;
61+
const int k_vec_start_a = k_start + load_vec_k_a * 4;
62+
const int k_vec_start_b = k_start + load_vec_k_b * 4;
63+
64+
if (global_row_a < M && k_vec_start_a < K) {
65+
if (k_vec_start_a + 3 < K) {
66+
Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
67+
} else {
68+
half4 tempA = (half4)(0.0h);
69+
if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
70+
if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
71+
if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
72+
Alocal[load_row_a][load_vec_k_a] = tempA;
73+
}
74+
} else {
75+
Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
76+
}
77+
78+
if (global_row_b < N && k_vec_start_b < K) {
79+
if (k_vec_start_b + 3 < K) {
80+
Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
81+
} else {
82+
float4 tempB = (float4)(0.0f);
83+
if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
84+
if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
85+
if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
86+
Blocal[load_row_b][load_vec_k_b] = tempB;
87+
}
88+
} else {
89+
Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
90+
}
91+
92+
barrier(CLK_LOCAL_MEM_FENCE);
93+
94+
#pragma unroll
95+
for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
96+
float4 a_fvecs[OPTM];
97+
int current_row_a = lidm;
98+
for (int wm = 0; wm < OPTM; wm++) {
99+
a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
100+
current_row_a += WG_M;
101+
}
102+
103+
float4 b_fvecs[OPTN];
104+
int current_row_b = lidn;
105+
for (int wn = 0; wn < OPTN; wn++) {
106+
b_fvecs[wn] = Blocal[current_row_b][k_vec];
107+
current_row_b += WG_N;
108+
}
109+
110+
for (int wm = 0; wm < OPTM; wm++) {
111+
for (int wn = 0; wn < OPTN; wn++) {
112+
sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
113+
}
114+
}
115+
}
116+
barrier(CLK_LOCAL_MEM_FENCE);
117+
}
118+
119+
for (int wm = 0; wm < OPTM; wm++) {
120+
int globalRow = offsetM + lidm + wm * WG_M;
121+
if (globalRow < M) {
122+
for (int wn = 0; wn < OPTN; wn++) {
123+
int globalCol = offsetN + lidn + wn * WG_N;
124+
if (globalCol < N) {
125+
C[globalCol * M + globalRow] = sum[wm][wn];
126+
}
127+
}
128+
}
129+
}
130+
}

0 commit comments

Comments
 (0)