@@ -78,27 +78,41 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
78
78
79
79
for (int block = 0 ; block < ne00 ; block += BK ) {
80
80
for (int l = 0 ; l < BM ; l += loadstride_a ) {
81
- int idx = pos_a + (loadc_a + l ) * stride_a / LOAD_VEC_A + loadr_a ;
82
- int ib = idx / 8 ;
83
- int iqs = idx % 8 ;
84
-
85
- float d = (float )src0_d [ib ];
86
- global char4 * qs = src0_q + ib * 8 + iqs ;
87
- char4 q = * qs ;
88
- float4 v = convert_float4 (q )* d ;
89
-
90
- buf_a [(loadr_a * LOAD_VEC_A + 0 ) * BM + loadc_a + l ] = v .s0 ;
91
- buf_a [(loadr_a * LOAD_VEC_A + 1 ) * BM + loadc_a + l ] = v .s1 ;
92
- buf_a [(loadr_a * LOAD_VEC_A + 2 ) * BM + loadc_a + l ] = v .s2 ;
93
- buf_a [(loadr_a * LOAD_VEC_A + 3 ) * BM + loadc_a + l ] = v .s3 ;
81
+ if (loadc_a + l < ne01 ) {
82
+ int idx = pos_a + (loadc_a + l ) * stride_a / LOAD_VEC_A + loadr_a ;
83
+ int ib = idx / 8 ;
84
+ int iqs = idx % 8 ;
85
+
86
+ float d = (float )src0_d [ib ];
87
+ global char4 * qs = src0_q + ib * 8 + iqs ;
88
+ char4 q = * qs ;
89
+ float4 v = convert_float4 (q )* d ;
90
+
91
+ buf_a [(loadr_a * LOAD_VEC_A + 0 ) * BM + loadc_a + l ] = v .s0 ;
92
+ buf_a [(loadr_a * LOAD_VEC_A + 1 ) * BM + loadc_a + l ] = v .s1 ;
93
+ buf_a [(loadr_a * LOAD_VEC_A + 2 ) * BM + loadc_a + l ] = v .s2 ;
94
+ buf_a [(loadr_a * LOAD_VEC_A + 3 ) * BM + loadc_a + l ] = v .s3 ;
95
+ } else {
96
+ buf_a [(loadr_a * LOAD_VEC_A + 0 ) * BM + loadc_a + l ] = 0.0f ;
97
+ buf_a [(loadr_a * LOAD_VEC_A + 1 ) * BM + loadc_a + l ] = 0.0f ;
98
+ buf_a [(loadr_a * LOAD_VEC_A + 2 ) * BM + loadc_a + l ] = 0.0f ;
99
+ buf_a [(loadr_a * LOAD_VEC_A + 3 ) * BM + loadc_a + l ] = 0.0f ;
100
+ }
94
101
}
95
102
96
103
for (int l = 0 ; l < BN ; l += loadstride_b ) {
97
- int idx = pos_b + (loadc_b + l ) * stride_b / LOAD_VEC_B + loadr_b ;
98
- buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = src1 [idx ].s0 ;
99
- buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = src1 [idx ].s1 ;
100
- buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = src1 [idx ].s2 ;
101
- buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = src1 [idx ].s3 ;
104
+ if (loadc_b + l < ne11 ) {
105
+ int idx = pos_b + (loadc_b + l ) * stride_b / LOAD_VEC_B + loadr_b ;
106
+ buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = src1 [idx ].s0 ;
107
+ buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = src1 [idx ].s1 ;
108
+ buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = src1 [idx ].s2 ;
109
+ buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = src1 [idx ].s3 ;
110
+ } else {
111
+ buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = 0.0f ;
112
+ buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = 0.0f ;
113
+ buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = 0.0f ;
114
+ buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = 0.0f ;
115
+ }
102
116
}
103
117
104
118
barrier (CLK_LOCAL_MEM_FENCE );
0 commit comments