@@ -80,47 +80,123 @@ kernel void kernel_mul_mv_q8_0_f32_flat(
80
80
global float * y = (global float * ) (src1 + offset_src1 );
81
81
82
82
// pointers to src0 rows
83
- global char * ax [N_R0_Q8_0 ];
84
- global half * ad [N_R0_Q8_0 ];
85
- for (int row = 0 ; row < N_R0_Q8_0 ; ++ row ) {
86
- ulong offset_src0 = ((first_row + row )* nb01 + (i12 /r2 )* nb02 + (i13 /r3 )* nb03 )/34 ;
87
- ax [row ] = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
88
- ad [row ] = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
89
- }
83
+ uint offset_src0_base = first_row * nb01 + (i12 /r2 )* nb02 + (i13 /r3 )* nb03 ;
84
+
85
+ global char * ax0 , * ax1 , * ax2 , * ax3 ;
86
+ global half * ad0 , * ad1 , * ad2 , * ad3 ;
87
+ uint offset_src0 ;
88
+
89
+ offset_src0 = offset_src0_base + 0 * nb01 ;
90
+ offset_src0 = offset_src0 /34 ;
91
+ ax0 = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
92
+ ad0 = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
90
93
91
- float yl [NB_Q8_0 ];
92
- float sumf [N_R0_Q8_0 ] = { 0.f };
94
+ offset_src0 = offset_src0_base + 1 * nb01 ;
95
+ offset_src0 = offset_src0 /34 ;
96
+ ax1 = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
97
+ ad1 = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
98
+
99
+ offset_src0 = offset_src0_base + 2 * nb01 ;
100
+ offset_src0 = offset_src0 /34 ;
101
+ ax2 = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
102
+ ad2 = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
103
+
104
+ offset_src0 = offset_src0_base + 3 * nb01 ;
105
+ offset_src0 = offset_src0 /34 ;
106
+ ax3 = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
107
+ ad3 = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
93
108
94
109
const short ix = get_sub_group_local_id ()/4 ;
95
110
const short il = get_sub_group_local_id ()%4 ;
96
111
97
112
global float * yb = y + ix * QK8_0 + il * NB_Q8_0 ;
98
113
114
+ float8 yl ;
115
+ float8 qv ;
116
+ float4 sumf = 0.f ;
117
+ float sumq = 0.f ;
118
+ global char * qs ;
119
+
99
120
// each thread handles NB_Q8_0 quants at a time
100
121
for (int ib = ix ; ib < nb ; ib += N_SIMDWIDTH /4 ) {
101
- for (short i = 0 ; i < NB_Q8_0 ; ++ i ) {
102
- yl [i ] = yb [i ];
103
- }
104
-
105
- for (short row = 0 ; row < N_R0_Q8_0 ; row ++ ) {
106
- global char * qs = ax [row ] + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
107
- float sumq = 0.f ;
108
- for (short iq = 0 ; iq < NB_Q8_0 ; ++ iq ) {
109
- sumq += qs [iq ] * yl [iq ];
110
- }
111
- sumf [row ] += sumq * ad [row ][ib ];
112
- }
122
+ yl = vload8 (0 , yb );
123
+
124
+ qs = ax0 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
125
+ qv = convert_float8 (vload8 (0 , qs ));
126
+ sumq = 0 ;
127
+ sumq += qv .s0 * yl .s0 ;
128
+ sumq += qv .s1 * yl .s1 ;
129
+ sumq += qv .s2 * yl .s2 ;
130
+ sumq += qv .s3 * yl .s3 ;
131
+ sumq += qv .s4 * yl .s4 ;
132
+ sumq += qv .s5 * yl .s5 ;
133
+ sumq += qv .s6 * yl .s6 ;
134
+ sumq += qv .s7 * yl .s7 ;
135
+ sumf .s0 += sumq * ad0 [ib ];
136
+
137
+ qs = ax1 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
138
+ qv = convert_float8 (vload8 (0 , qs ));
139
+ sumq = 0 ;
140
+ sumq += qv .s0 * yl .s0 ;
141
+ sumq += qv .s1 * yl .s1 ;
142
+ sumq += qv .s2 * yl .s2 ;
143
+ sumq += qv .s3 * yl .s3 ;
144
+ sumq += qv .s4 * yl .s4 ;
145
+ sumq += qv .s5 * yl .s5 ;
146
+ sumq += qv .s6 * yl .s6 ;
147
+ sumq += qv .s7 * yl .s7 ;
148
+ sumf .s1 += sumq * ad1 [ib ];
149
+
150
+ qs = ax2 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
151
+ qv = convert_float8 (vload8 (0 , qs ));
152
+ sumq = 0 ;
153
+ sumq += qv .s0 * yl .s0 ;
154
+ sumq += qv .s1 * yl .s1 ;
155
+ sumq += qv .s2 * yl .s2 ;
156
+ sumq += qv .s3 * yl .s3 ;
157
+ sumq += qv .s4 * yl .s4 ;
158
+ sumq += qv .s5 * yl .s5 ;
159
+ sumq += qv .s6 * yl .s6 ;
160
+ sumq += qv .s7 * yl .s7 ;
161
+ sumf .s2 += sumq * ad2 [ib ];
162
+
163
+ qs = ax3 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
164
+ qv = convert_float8 (vload8 (0 , qs ));
165
+ sumq = 0 ;
166
+ sumq += qv .s0 * yl .s0 ;
167
+ sumq += qv .s1 * yl .s1 ;
168
+ sumq += qv .s2 * yl .s2 ;
169
+ sumq += qv .s3 * yl .s3 ;
170
+ sumq += qv .s4 * yl .s4 ;
171
+ sumq += qv .s5 * yl .s5 ;
172
+ sumq += qv .s6 * yl .s6 ;
173
+ sumq += qv .s7 * yl .s7 ;
174
+ sumf .s3 += sumq * ad3 [ib ];
113
175
114
176
yb += N_SIMDWIDTH * NB_Q8_0 ;
115
177
}
116
178
117
179
global float * dst_f32 = (global float * ) dst + (ulong )im * ne0 * ne1 + (ulong )r1 * ne0 ;
118
180
119
- for (int row = 0 ; row < N_R0_Q8_0 ; ++ row ) {
120
- float tot = sub_group_reduce_add (sumf [row ]);
181
+ float4 tot = (float4 )(
182
+ sub_group_reduce_add (sumf .s0 ),
183
+ sub_group_reduce_add (sumf .s1 ),
184
+ sub_group_reduce_add (sumf .s2 ),
185
+ sub_group_reduce_add (sumf .s3 )
186
+ );
121
187
122
- if (get_sub_group_local_id () == 0 && first_row + row < ne01 ) {
123
- dst_f32 [first_row + row ] = tot ;
188
+ if (get_sub_group_local_id () == 0 ) {
189
+ if (first_row + 0 < ne01 ) {
190
+ dst_f32 [first_row + 0 ] = tot .s0 ;
191
+ }
192
+ if (first_row + 1 < ne01 ) {
193
+ dst_f32 [first_row + 1 ] = tot .s1 ;
194
+ }
195
+ if (first_row + 2 < ne01 ) {
196
+ dst_f32 [first_row + 2 ] = tot .s2 ;
197
+ }
198
+ if (first_row + 3 < ne01 ) {
199
+ dst_f32 [first_row + 3 ] = tot .s3 ;
124
200
}
125
201
}
126
202
}
0 commit comments