Skip to content

Commit 7a15e0e

Browse files
committed
opencl: use broadcast semantic for mul_mv_id_mxfp4_f32_flat
1 parent 7184682 commit 7a15e0e

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7228,20 +7228,21 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
72287228
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
72297229
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
72307230
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
7231-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
7232-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
7233-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11));
7234-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
7235-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
7236-
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));
7237-
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb13));
7238-
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne20));
7239-
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne21));
7240-
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb21));
7241-
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne0));
7242-
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne1));
7243-
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r2));
7244-
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r3));
7231+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
7232+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
7233+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
7234+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
7235+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
7236+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
7237+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
7238+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
7239+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne20));
7240+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne21));
7241+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));
7242+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0));
7243+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1));
7244+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
7245+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
72457246
#else // GGML_OPENCL_SOA_Q
72467247
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32;
72477248

ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ kernel void kernel_mul_mv_id_mxfp4_f32_flat(
7979
global uchar * dst,
8080
ulong offsetd,
8181
int ne00,
82-
int ne01,
83-
int ne02,
82+
ulong nb01,
83+
ulong nb02,
84+
ulong nb03,
8485
int ne11,
8586
int ne12,
8687
ulong nb11,
@@ -105,7 +106,10 @@ kernel void kernel_mul_mv_id_mxfp4_f32_flat(
105106

106107
int nb = ne00 / QK_MXFP4;
107108

108-
src0_e = src0_e + i02 * nb * ne01;
109+
uint src0_off = i02*nb02;
110+
src0_off /= 17; // 17 = sizeof(block_mxfp4)
111+
112+
src0_e = src0_e + src0_off;
109113

110114
dst = dst + (idx * ne0 + iid1 * ne1 * ne0) * sizeof(float);
111115

@@ -114,11 +118,12 @@ kernel void kernel_mul_mv_id_mxfp4_f32_flat(
114118

115119
int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
116120

117-
ulong offset_src0 = first_row * nb;
121+
uint offset_src0 = first_row*nb01;
122+
offset_src0 /= 17; // 17 = sizeof(block_mxfp4)
118123
#ifdef SRC0Q_IMG
119-
ulong offset_q = i02 * nb * ne01 + offset_src0;
124+
ulong offset_q = src0_off + offset_src0;
120125
#else
121-
src0_q = src0_q + i02 * nb * 16 * ne01;
126+
src0_q = src0_q + src0_off*16;
122127
global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
123128
#endif
124129
global uchar * x_e = src0_e + offset_src0;

0 commit comments

Comments
 (0)