11#version 450
22
3+ #extension GL_EXT_control_flow_attributes : enable
4+
35#ifdef USE_COLLECTIVES
46# extension GL_KHR_shader_subgroup_shuffle : enable
57#endif
68
79#include "types.comp"
810
911// Make spec constant
10- #define SHMEM_PAD 0
12+ #define SHMEM_PAD 4
1113
1214// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
1315layout(binding = 0) readonly buffer A {
@@ -56,6 +58,12 @@ layout(push_constant) uniform parameter {
5658 uint32_t nb1;
5759 uint32_t nb2;
5860 uint32_t nb3;
61+
62+ // fastdiv helper values
63+ uint32_t KWmp; uint32_t KWL;
64+ uint32_t KWKHmp; uint32_t KWKHL;
65+ uint32_t OWmp; uint32_t OWL;
66+ uint32_t OWOHmp; uint32_t OWOHL;
5967}
6068
6169p;
@@ -131,6 +139,14 @@ uint32_t Br = tid / BS_NPQ;
131139uint32_t Bc = tid % BS_NPQ;
132140const uint32_t BrpWg = WG_SIZE / BS_NPQ;
133141
142+ // see init_fastdiv_values in ggml-vulkan.cpp
143+ uint fastdiv(uint n, uint mp, uint L) {
144+ uint msbs, lsbs;
145+ // msbs = mulhi(n, mp)
146+ umulExtended(n, mp, msbs, lsbs);
147+ return (msbs + n) >> L;
148+ }
149+
134150void main() {
135151 for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
136152 for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
@@ -151,9 +167,9 @@ void main() {
151167 uint32_t cached_KW_idx;
152168 if (use_collectives == 1) {
153169 cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
154- cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
170+ cached_Cin_idx = fastdiv( cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
155171 uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
156- cached_KH_idx = cached_CRS_remainder / p.KW;
172+ cached_KH_idx = fastdiv( cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
157173 cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
158174
159175 CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
@@ -162,16 +178,16 @@ void main() {
162178 KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
163179 } else {
164180 CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
165- Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
181+ Cin_idx_a = fastdiv( CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
166182 uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
167- KH_idx_a = CRS_remainder / p.KW;
183+ KH_idx_a = fastdiv( CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
168184 KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
169185 }
170186#else
171187 CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
172- Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
188+ Cin_idx_a = fastdiv( CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
173189 CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
174- KH_idx_a = CRS_remainder / p.KW;
190+ KH_idx_a = fastdiv( CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
175191 KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
176192#endif
177193
@@ -188,13 +204,13 @@ void main() {
188204 Ash[B_ly * Ash_stride + B_lx] = val;
189205 }
190206 /* Load input to B_block: (BS_CRS x BS_NPQ) */
191- for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
207+ [[unroll]] for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
192208 uint32_t B_ly = r_offset + Br; /* Row index of B block */
193209 uint32_t B_lx = Bc;
194210 uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
195- uint32_t N_idx = NPQ_idx / ( p.OH * p.OW) ;
211+ uint32_t N_idx = fastdiv( NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
196212 uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
197- uint32_t OH_idx = NPQ_remainder / p.OW;
213+ uint32_t OH_idx = fastdiv( NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
198214 uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
199215
200216 uint32_t CRS_idx_b;
@@ -209,16 +225,16 @@ void main() {
209225 KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
210226 } else {
211227 CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
212- Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
228+ Cin_idx_b = fastdiv( CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
213229 uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
214- KH_idx_b = CRS_remainder / p.KW;
230+ KH_idx_b = fastdiv( CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
215231 KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
216232 }
217233#else
218234 CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
219- Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
235+ Cin_idx_b = fastdiv( CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
220236 uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
221- KH_idx_b = CRS_remainder / p.KW;
237+ KH_idx_b = fastdiv( CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
222238 KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
223239#endif
224240
@@ -233,32 +249,36 @@ void main() {
233249 Bsh[B_ly * Bsh_stride + B_lx] = val;
234250 }
235251 barrier();
236- for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
237- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
238- regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
239- }
240- for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
241- regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
242- }
243- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
252+ if (T_y * TS_K < K) {
253+ [[unroll]] for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
254+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
255+ regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
256+ }
244257 for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
245- regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
258+ regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
259+ }
260+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
261+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
262+ regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
263+ }
246264 }
247265 }
248266 }
249267 barrier();
250268 }
251269 /* Save C* */
252- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253- for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
254- uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
255- uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
256- uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
257- uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
258- uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
259- uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
260- if (K_idx < K && NPQ_idx < NPQ) {
261- dst_data[dst_idx] = regC[T_ly][T_lx];
270+ if (T_y * TS_K < K) {
271+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
272+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
273+ uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
274+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
275+ uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
276+ uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
277+ uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
278+ uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
279+ if (K_idx < K && NPQ_idx < NPQ) {
280+ dst_data[dst_idx] = regC[T_ly][T_lx];
281+ }
262282 }
263283 }
264284 }
0 commit comments