Skip to content

Commit 28d3073

Browse files
committed
opencl: fix data loading for incomplete tile
1 parent 2b99b5e commit 28d3073

File tree

3 files changed

+79
-37
lines changed

3 files changed

+79
-37
lines changed

ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
7979

8080
for (int block = 0; block < ne00; block += BK) {
8181
for (int l = 0; l < BM; l += loadstride_a) {
82+
if (loadc_a + l < ne01) {
8283
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83-
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
84-
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
85-
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
86-
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
84+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
85+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
86+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
87+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
88+
} else {
89+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h;
90+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h;
91+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h;
92+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h;
93+
}
8794
}
8895

8996
for (int l = 0; l < BN; l += loadstride_b) {
90-
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
91-
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
92-
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
93-
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
94-
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
97+
if (loadc_b + l < ne11) {
98+
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
99+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
100+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
101+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
102+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
103+
} else {
104+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h;
105+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h;
106+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h;
107+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h;
108+
}
95109
}
96110

97111
barrier(CLK_LOCAL_MEM_FENCE);

ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
7979

8080
for (int block = 0; block < ne00; block += BK) {
8181
for (int l = 0; l < BM; l += loadstride_a) {
82-
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83-
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
84-
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
85-
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
86-
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
82+
if (loadc_a + l < ne01) {
83+
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
84+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
85+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
86+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
87+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
88+
} else {
89+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
90+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
91+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
92+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
93+
}
8794
}
8895

8996
for (int l = 0; l < BN; l += loadstride_b) {
90-
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
91-
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
92-
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
93-
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
94-
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
97+
if (loadc_b + l < ne11) {
98+
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
99+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
100+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
101+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
102+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
103+
} else {
104+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
105+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
106+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
107+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
108+
}
95109
}
96110

97111
barrier(CLK_LOCAL_MEM_FENCE);

ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,41 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
7878

7979
for (int block = 0; block < ne00; block += BK) {
8080
for (int l = 0; l < BM; l += loadstride_a) {
81-
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
82-
int ib = idx / 8;
83-
int iqs = idx % 8;
84-
85-
float d = (float)src0_d[ib];
86-
global char4 * qs = src0_q + ib*8 + iqs;
87-
char4 q = *qs;
88-
float4 v = convert_float4(q)*d;
89-
90-
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;
91-
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;
92-
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;
93-
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;
81+
if (loadc_a + l < ne01) {
82+
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83+
int ib = idx / 8;
84+
int iqs = idx % 8;
85+
86+
float d = (float)src0_d[ib];
87+
global char4 * qs = src0_q + ib*8 + iqs;
88+
char4 q = *qs;
89+
float4 v = convert_float4(q)*d;
90+
91+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;
92+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;
93+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;
94+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;
95+
} else {
96+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
97+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
98+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
99+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
100+
}
94101
}
95102

96103
for (int l = 0; l < BN; l += loadstride_b) {
97-
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
98-
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
99-
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
100-
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
101-
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
104+
if (loadc_b + l < ne11) {
105+
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
106+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
107+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
108+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
109+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
110+
} else {
111+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
112+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
113+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
114+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
115+
}
102116
}
103117

104118
barrier(CLK_LOCAL_MEM_FENCE);

0 commit comments

Comments
 (0)