@@ -78,32 +78,50 @@ kernel void kernel_mul_mm_q4_0_f32_l4_lm(
78
78
79
79
for (int block = 0 ; block < ne00 ; block += BK ) {
80
80
for (int l = 0 ; l < BM ; l += loadstride_a ) {
81
- int idx = pos_a + (loadc_a + l ) * stride_a / LOAD_VEC_A + loadr_a ;
82
- int ib = idx / 4 ;
83
- int iqs = idx % 4 ;
84
-
85
- float d = (float )src0_d [ib ];
86
- global uchar4 * qs = src0_q + ib * 4 + iqs ;
87
- uchar4 q = * qs ;
88
- float4 v1 = (convert_float4 ((uchar4 )((q .s0 )& 0x0F , (q .s1 )& 0x0F , (q .s2 )& 0x0F , (q .s3 )& 0x0F )) - 8.0f )* d ;
89
- float4 v2 = (convert_float4 ((uchar4 )((q .s0 >>4 )& 0x0F , (q .s1 >>4 )& 0x0F , (q .s2 >>4 )& 0x0F , (q .s3 >>4 )& 0x0F )) - 8.0f )* d ;
90
-
91
- buf_a [(loadr_a * 4 + 0 ) * BM + loadc_a + l ] = v1 .s0 ;
92
- buf_a [(loadr_a * 4 + 1 ) * BM + loadc_a + l ] = v1 .s1 ;
93
- buf_a [(loadr_a * 4 + 2 ) * BM + loadc_a + l ] = v1 .s2 ;
94
- buf_a [(loadr_a * 4 + 3 ) * BM + loadc_a + l ] = v1 .s3 ;
95
- buf_a [(loadr_a * 4 + 16 ) * BM + loadc_a + l ] = v2 .s0 ;
96
- buf_a [(loadr_a * 4 + 17 ) * BM + loadc_a + l ] = v2 .s1 ;
97
- buf_a [(loadr_a * 4 + 18 ) * BM + loadc_a + l ] = v2 .s2 ;
98
- buf_a [(loadr_a * 4 + 19 ) * BM + loadc_a + l ] = v2 .s3 ;
81
+ if (loadc_a + l < ne01 ) {
82
+ int idx = pos_a + (loadc_a + l ) * stride_a / LOAD_VEC_A + loadr_a ;
83
+ int ib = idx / 4 ;
84
+ int iqs = idx % 4 ;
85
+
86
+ float d = (float )src0_d [ib ];
87
+ global uchar4 * qs = src0_q + ib * 4 + iqs ;
88
+ uchar4 q = * qs ;
89
+ float4 v1 = (convert_float4 ((uchar4 )((q .s0 )& 0x0F , (q .s1 )& 0x0F , (q .s2 )& 0x0F , (q .s3 )& 0x0F )) - 8.0f )* d ;
90
+ float4 v2 = (convert_float4 ((uchar4 )((q .s0 >>4 )& 0x0F , (q .s1 >>4 )& 0x0F , (q .s2 >>4 )& 0x0F , (q .s3 >>4 )& 0x0F )) - 8.0f )* d ;
91
+
92
+ buf_a [(loadr_a * 4 + 0 ) * BM + loadc_a + l ] = v1 .s0 ;
93
+ buf_a [(loadr_a * 4 + 1 ) * BM + loadc_a + l ] = v1 .s1 ;
94
+ buf_a [(loadr_a * 4 + 2 ) * BM + loadc_a + l ] = v1 .s2 ;
95
+ buf_a [(loadr_a * 4 + 3 ) * BM + loadc_a + l ] = v1 .s3 ;
96
+ buf_a [(loadr_a * 4 + 16 ) * BM + loadc_a + l ] = v2 .s0 ;
97
+ buf_a [(loadr_a * 4 + 17 ) * BM + loadc_a + l ] = v2 .s1 ;
98
+ buf_a [(loadr_a * 4 + 18 ) * BM + loadc_a + l ] = v2 .s2 ;
99
+ buf_a [(loadr_a * 4 + 19 ) * BM + loadc_a + l ] = v2 .s3 ;
100
+ } else {
101
+ buf_a [(loadr_a * 4 + 0 ) * BM + loadc_a + l ] = 0.0f ;
102
+ buf_a [(loadr_a * 4 + 1 ) * BM + loadc_a + l ] = 0.0f ;
103
+ buf_a [(loadr_a * 4 + 2 ) * BM + loadc_a + l ] = 0.0f ;
104
+ buf_a [(loadr_a * 4 + 3 ) * BM + loadc_a + l ] = 0.0f ;
105
+ buf_a [(loadr_a * 4 + 16 ) * BM + loadc_a + l ] = 0.0f ;
106
+ buf_a [(loadr_a * 4 + 17 ) * BM + loadc_a + l ] = 0.0f ;
107
+ buf_a [(loadr_a * 4 + 18 ) * BM + loadc_a + l ] = 0.0f ;
108
+ buf_a [(loadr_a * 4 + 19 ) * BM + loadc_a + l ] = 0.0f ;
109
+ }
99
110
}
100
111
101
112
for (int l = 0 ; l < BN ; l += loadstride_b ) {
102
- int idx = pos_b + (loadc_b + l ) * stride_b / LOAD_VEC_B + loadr_b ;
103
- buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = src1 [idx ].s0 ;
104
- buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = src1 [idx ].s1 ;
105
- buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = src1 [idx ].s2 ;
106
- buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = src1 [idx ].s3 ;
113
+ if (loadc_b + l < ne11 ) {
114
+ int idx = pos_b + (loadc_b + l ) * stride_b / LOAD_VEC_B + loadr_b ;
115
+ buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = src1 [idx ].s0 ;
116
+ buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = src1 [idx ].s1 ;
117
+ buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = src1 [idx ].s2 ;
118
+ buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = src1 [idx ].s3 ;
119
+ } else {
120
+ buf_b [(loadr_b * LOAD_VEC_B + 0 ) * BN + loadc_b + l ] = 0.0f ;
121
+ buf_b [(loadr_b * LOAD_VEC_B + 1 ) * BN + loadc_b + l ] = 0.0f ;
122
+ buf_b [(loadr_b * LOAD_VEC_B + 2 ) * BN + loadc_b + l ] = 0.0f ;
123
+ buf_b [(loadr_b * LOAD_VEC_B + 3 ) * BN + loadc_b + l ] = 0.0f ;
124
+ }
107
125
}
108
126
109
127
barrier (CLK_LOCAL_MEM_FENCE );
0 commit comments