Skip to content

Commit ff4bad9

Browse files
committed
opencl: fix data loading for incomplete tiles
1 parent e3a0c84 commit ff4bad9

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -78,32 +78,50 @@ kernel void kernel_mul_mm_q4_0_f32_l4_lm(
7878

7979
for (int block = 0; block < ne00; block += BK) {
8080
for (int l = 0; l < BM; l += loadstride_a) {
81-
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
82-
int ib = idx / 4;
83-
int iqs = idx % 4;
84-
85-
float d = (float)src0_d[ib];
86-
global uchar4 * qs = src0_q + ib*4 + iqs;
87-
uchar4 q = *qs;
88-
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
89-
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
90-
91-
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
92-
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
93-
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
94-
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
95-
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
96-
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
97-
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
98-
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
81+
if (loadc_a + l < ne01) {
82+
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83+
int ib = idx / 4;
84+
int iqs = idx % 4;
85+
86+
float d = (float)src0_d[ib];
87+
global uchar4 * qs = src0_q + ib*4 + iqs;
88+
uchar4 q = *qs;
89+
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
90+
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
91+
92+
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
93+
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
94+
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
95+
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
96+
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
97+
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
98+
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
99+
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
100+
} else {
101+
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
102+
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
103+
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
104+
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
105+
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
106+
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
107+
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
108+
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
109+
}
99110
}
100111

101112
for (int l = 0; l < BN; l += loadstride_b) {
102-
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
103-
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
104-
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
105-
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
106-
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
113+
if (loadc_b + l < ne11) {
114+
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
115+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
116+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
117+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
118+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
119+
} else {
120+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
121+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
122+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
123+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
124+
}
107125
}
108126

109127
barrier(CLK_LOCAL_MEM_FENCE);

0 commit comments

Comments
 (0)