@@ -80,47 +80,123 @@ kernel void kernel_mul_mv_q8_0_f32_flat(
8080 global float * y = (global float * ) (src1 + offset_src1 );
8181
8282 // pointers to src0 rows
83- global char * ax [N_R0_Q8_0 ];
84- global half * ad [N_R0_Q8_0 ];
85- for (int row = 0 ; row < N_R0_Q8_0 ; ++ row ) {
86- ulong offset_src0 = ((first_row + row )* nb01 + (i12 /r2 )* nb02 + (i13 /r3 )* nb03 )/34 ;
87- ax [row ] = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
88- ad [row ] = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
89- }
83+ uint offset_src0_base = first_row * nb01 + (i12 /r2 )* nb02 + (i13 /r3 )* nb03 ;
84+
85+ global char * ax0 , * ax1 , * ax2 , * ax3 ;
86+ global half * ad0 , * ad1 , * ad2 , * ad3 ;
87+ uint offset_src0 ;
88+
89+ offset_src0 = offset_src0_base + 0 * nb01 ;
90+ offset_src0 = offset_src0 /34 ;
91+ ax0 = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
92+ ad0 = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
9093
91- float yl [NB_Q8_0 ];
92- float sumf [N_R0_Q8_0 ] = { 0.f };
94+ offset_src0 = offset_src0_base + 1 * nb01 ;
95+ offset_src0 = offset_src0 /34 ;
96+ ax1 = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
97+ ad1 = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
98+
99+ offset_src0 = offset_src0_base + 2 * nb01 ;
100+ offset_src0 = offset_src0 /34 ;
101+ ax2 = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
102+ ad2 = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
103+
104+ offset_src0 = offset_src0_base + 3 * nb01 ;
105+ offset_src0 = offset_src0 /34 ;
106+ ax3 = (global char * ) ((global char * ) src0_q + offset_src0 * sizeof (char )* QK8_0 );
107+ ad3 = (global half * ) ((global char * ) src0_d + offset_src0 * sizeof (half ));
93108
94109 const short ix = get_sub_group_local_id ()/4 ;
95110 const short il = get_sub_group_local_id ()%4 ;
96111
97112 global float * yb = y + ix * QK8_0 + il * NB_Q8_0 ;
98113
114+ float8 yl ;
115+ float8 qv ;
116+ float4 sumf = 0.f ;
117+ float sumq = 0.f ;
118+ global char * qs ;
119+
99120 // each thread handles NB_Q8_0 quants at a time
100121 for (int ib = ix ; ib < nb ; ib += N_SIMDWIDTH /4 ) {
101- for (short i = 0 ; i < NB_Q8_0 ; ++ i ) {
102- yl [i ] = yb [i ];
103- }
104-
105- for (short row = 0 ; row < N_R0_Q8_0 ; row ++ ) {
106- global char * qs = ax [row ] + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
107- float sumq = 0.f ;
108- for (short iq = 0 ; iq < NB_Q8_0 ; ++ iq ) {
109- sumq += qs [iq ] * yl [iq ];
110- }
111- sumf [row ] += sumq * ad [row ][ib ];
112- }
122+ yl = vload8 (0 , yb );
123+
124+ qs = ax0 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
125+ qv = convert_float8 (vload8 (0 , qs ));
126+ sumq = 0 ;
127+ sumq += qv .s0 * yl .s0 ;
128+ sumq += qv .s1 * yl .s1 ;
129+ sumq += qv .s2 * yl .s2 ;
130+ sumq += qv .s3 * yl .s3 ;
131+ sumq += qv .s4 * yl .s4 ;
132+ sumq += qv .s5 * yl .s5 ;
133+ sumq += qv .s6 * yl .s6 ;
134+ sumq += qv .s7 * yl .s7 ;
135+ sumf .s0 += sumq * ad0 [ib ];
136+
137+ qs = ax1 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
138+ qv = convert_float8 (vload8 (0 , qs ));
139+ sumq = 0 ;
140+ sumq += qv .s0 * yl .s0 ;
141+ sumq += qv .s1 * yl .s1 ;
142+ sumq += qv .s2 * yl .s2 ;
143+ sumq += qv .s3 * yl .s3 ;
144+ sumq += qv .s4 * yl .s4 ;
145+ sumq += qv .s5 * yl .s5 ;
146+ sumq += qv .s6 * yl .s6 ;
147+ sumq += qv .s7 * yl .s7 ;
148+ sumf .s1 += sumq * ad1 [ib ];
149+
150+ qs = ax2 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
151+ qv = convert_float8 (vload8 (0 , qs ));
152+ sumq = 0 ;
153+ sumq += qv .s0 * yl .s0 ;
154+ sumq += qv .s1 * yl .s1 ;
155+ sumq += qv .s2 * yl .s2 ;
156+ sumq += qv .s3 * yl .s3 ;
157+ sumq += qv .s4 * yl .s4 ;
158+ sumq += qv .s5 * yl .s5 ;
159+ sumq += qv .s6 * yl .s6 ;
160+ sumq += qv .s7 * yl .s7 ;
161+ sumf .s2 += sumq * ad2 [ib ];
162+
163+ qs = ax3 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
164+ qv = convert_float8 (vload8 (0 , qs ));
165+ sumq = 0 ;
166+ sumq += qv .s0 * yl .s0 ;
167+ sumq += qv .s1 * yl .s1 ;
168+ sumq += qv .s2 * yl .s2 ;
169+ sumq += qv .s3 * yl .s3 ;
170+ sumq += qv .s4 * yl .s4 ;
171+ sumq += qv .s5 * yl .s5 ;
172+ sumq += qv .s6 * yl .s6 ;
173+ sumq += qv .s7 * yl .s7 ;
174+ sumf .s3 += sumq * ad3 [ib ];
113175
114176 yb += N_SIMDWIDTH * NB_Q8_0 ;
115177 }
116178
117179 global float * dst_f32 = (global float * ) dst + (ulong )im * ne0 * ne1 + (ulong )r1 * ne0 ;
118180
119- for (int row = 0 ; row < N_R0_Q8_0 ; ++ row ) {
120- float tot = sub_group_reduce_add (sumf [row ]);
181+ float4 tot = (float4 )(
182+ sub_group_reduce_add (sumf .s0 ),
183+ sub_group_reduce_add (sumf .s1 ),
184+ sub_group_reduce_add (sumf .s2 ),
185+ sub_group_reduce_add (sumf .s3 )
186+ );
121187
122- if (get_sub_group_local_id () == 0 && first_row + row < ne01 ) {
123- dst_f32 [first_row + row ] = tot ;
188+ if (get_sub_group_local_id () == 0 ) {
189+ if (first_row + 0 < ne01 ) {
190+ dst_f32 [first_row + 0 ] = tot .s0 ;
191+ }
192+ if (first_row + 1 < ne01 ) {
193+ dst_f32 [first_row + 1 ] = tot .s1 ;
194+ }
195+ if (first_row + 2 < ne01 ) {
196+ dst_f32 [first_row + 2 ] = tot .s2 ;
197+ }
198+ if (first_row + 3 < ne01 ) {
199+ dst_f32 [first_row + 3 ] = tot .s3 ;
124200 }
125201 }
126202}
0 commit comments