Skip to content

Commit 9984cbb

Browse files
authored
opencl: fix boundary handling for mul_mm (ggml-org#16875)
1 parent ce18efe commit 9984cbb

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
7979

8080
for (int block = 0; block < ne00; block += BK) {
8181
for (int l = 0; l < BM; l += loadstride_a) {
82-
if (loadc_a + l < ne01) {
83-
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
82+
if (ir*BM + loadc_a + l < ne01) {
83+
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
8484
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
8585
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
8686
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
@@ -94,7 +94,7 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
9494
}
9595

9696
for (int l = 0; l < BN; l += loadstride_b) {
97-
if (loadc_b + l < ne11) {
97+
if (ic*BN + loadc_b + l < ne11) {
9898
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
9999
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
100100
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;

ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
7979

8080
for (int block = 0; block < ne00; block += BK) {
8181
for (int l = 0; l < BM; l += loadstride_a) {
82-
if (loadc_a + l < ne01) {
82+
if (ir*BM + loadc_a + l < ne01) {
8383
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
8484
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
8585
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
@@ -94,7 +94,7 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
9494
}
9595

9696
for (int l = 0; l < BN; l += loadstride_b) {
97-
if (loadc_b + l < ne11) {
97+
if (ic*BN + loadc_b + l < ne11) {
9898
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
9999
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
100100
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;

ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
7878

7979
for (int block = 0; block < ne00; block += BK) {
8080
for (int l = 0; l < BM; l += loadstride_a) {
81-
if (loadc_a + l < ne01) {
81+
if (ir*BM + loadc_a + l < ne01) {
8282
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
8383
int ib = idx / 8;
8484
int iqs = idx % 8;
@@ -101,7 +101,7 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
101101
}
102102

103103
for (int l = 0; l < BN; l += loadstride_b) {
104-
if (loadc_b + l < ne11) {
104+
if (ic*BN + loadc_b + l < ne11) {
105105
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
106106
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
107107
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;

0 commit comments

Comments
 (0)