@@ -78,27 +78,41 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
7878
7979 for (int block = 0 ; block < ne00 ; block += BK ) {
8080 for (int l = 0 ; l < BM ; l += loadstride_a ) {
81- int idx = pos_a + (loadc_a + l ) * stride_a / LOAD_VEC_A + loadr_a ;
82- int ib = idx / 8 ;
83- int iqs = idx % 8 ;
84-
85- float d = (float )src0_d [ib ];
86- global char4 * qs = src0_q + ib * 8 + iqs ;
87- char4 q = * qs ;
88- float4 v = convert_float4 (q )* d ;
89-
90- buf_a [(loadr_a * LOAD_VEC_A + 0 ) * BM + loadc_a + l ] = v .s0 ;
91- buf_a [(loadr_a * LOAD_VEC_A + 1 ) * BM + loadc_a + l ] = v .s1 ;
92- buf_a [(loadr_a * LOAD_VEC_A + 2 ) * BM + loadc_a + l ] = v .s2 ;
93- buf_a [(loadr_a * LOAD_VEC_A + 3 ) * BM + loadc_a + l ] = v .s3 ;
81+ if (loadc_a + l < ne01 ) {
82+ int idx = pos_a + (loadc_a + l ) * stride_a / LOAD_VEC_A + loadr_a ;
83+ int ib = idx / 8 ;
84+ int iqs = idx % 8 ;
85+
86+ float d = (float )src0_d [ib ];
87+ global char4 * qs = src0_q + ib * 8 + iqs ;
88+ char4 q = * qs ;
89+ float4 v = convert_float4 (q )* d ;
90+
91+ buf_a [(loadr_a * LOAD_VEC_A + 0 ) * BM + loadc_a + l ] = v .s0 ;
92+ buf_a [(loadr_a * LOAD_VEC_A + 1 ) * BM + loadc_a + l ] = v .s1 ;
93+ buf_a [(loadr_a * LOAD_VEC_A + 2 ) * BM + loadc_a + l ] = v .s2 ;
94+ buf_a [(loadr_a * LOAD_VEC_A + 3 ) * BM + loadc_a + l ] = v .s3 ;
95+ } else {
96+ buf_a [(loadr_a * LOAD_VEC_A + 0 ) * BM + loadc_a + l ] = 0.0f ;
97+ buf_a [(loadr_a * LOAD_VEC_A + 1 ) * BM + loadc_a + l ] = 0.0f ;
98+ buf_a [(loadr_a * LOAD_VEC_A + 2 ) * BM + loadc_a + l ] = 0.0f ;
99+ buf_a [(loadr_a * LOAD_VEC_A + 3 ) * BM + loadc_a + l ] = 0.0f ;
100+ }
94101 }
95102
96103 for (int l = 0 ; l < BN ; l += loadstride_b ) {
97- int idx = pos_b + (loadc_b + l ) * stride_b / LOAD_VEC_B + loadr_b ;
98- buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = src1 [idx ].s0 ;
99- buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = src1 [idx ].s1 ;
100- buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = src1 [idx ].s2 ;
101- buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = src1 [idx ].s3 ;
104+ if (loadc_b + l < ne11 ) {
105+ int idx = pos_b + (loadc_b + l ) * stride_b / LOAD_VEC_B + loadr_b ;
106+ buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = src1 [idx ].s0 ;
107+ buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = src1 [idx ].s1 ;
108+ buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = src1 [idx ].s2 ;
109+ buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = src1 [idx ].s3 ;
110+ } else {
111+ buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = 0.0f ;
112+ buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = 0.0f ;
113+ buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = 0.0f ;
114+ buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = 0.0f ;
115+ }
102116 }
103117
104118 barrier (CLK_LOCAL_MEM_FENCE );
0 commit comments