diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 7e6c843846708..6f6bba55e2805 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -93,6 +93,7 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm + mul_mm_q8_0_f32_l4_lm mul norm relu diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 79d2148744f90..88271926872f3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -408,6 +408,7 @@ struct ggml_backend_opencl_context { cl_program program_mul_mv_id_mxfp4_f32_flat; cl_program program_mul_mm_f32_f32_l4_lm; cl_program program_mul_mm_f16_f32_l4_lm; + cl_program program_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16; cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16; @@ -480,6 +481,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; + cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; std::vector profiling_info; @@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q8_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q8_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl"); +#endif + backend_ctx->program_mul_mm_q8_0_f32_l4_lm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -6956,6 +6974,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q8_0: { + if (ne11 < 32) { + break; + } + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } default: break; } diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl index 9599a0e157262..1a1bfe144f610 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl @@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f16_f32_l4_lm( for (int block = 0; block < ne00; block += BK) { for (int l = 0; l < BM; l += loadstride_a) { + if (loadc_a + l < ne01) { const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; - buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; - buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; - buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; - buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h; + } } for (int l = 0; l < BN; l += loadstride_b) { - const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; - buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; - buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; - buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; - buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + if (loadc_b + l < ne11) { + const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h; + } } barrier(CLK_LOCAL_MEM_FENCE); diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl index 58c5178e39cc8..39a5d4868ffaa 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl @@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f32_f32_l4_lm( for (int block = 0; block < ne00; block += BK) { for (int l = 0; l < BM; l += loadstride_a) { - const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; - buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; - buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; - buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; - buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + if (loadc_a + l < ne01) { + const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } } for (int l = 0; l < BN; l += loadstride_b) { - const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; - buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; - buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; - buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; - buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + if (loadc_b + l < ne11) { + const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } } barrier(CLK_LOCAL_MEM_FENCE); diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl new file mode 100644 index 0000000000000..fd47e8a89dcef --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl @@ -0,0 +1,154 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q8_0_f32_l4_lm( + global char4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 8; + int iqs = idx % 8; + + float d = (float)src0_d[ib]; + global char4 * qs = src0_q + ib*8 + iqs; + char4 q = *qs; + float4 v = convert_float4(q)*d; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2fa16b497a6b7..c4fab9041bd44 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6354,6 +6354,19 @@ static std::vector> make_test_cases_eval() { } } +#if 0 + { + // Test paths in OpenCL + std::vector ns = {32, 64, 128, 256, 512, 1024, 4096}; + std::vector ks = {896, 1536, 4096}; + for (auto n : ns) { + for (auto k : ks) { + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, n, k, {1, 1}, {1, 1})); + } + } + } +#endif + #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {