11#version 450
22
3- #extension GL_EXT_control_flow_attributes : enable
3+ #define USE_COLLECTIVES
4+
5+ #ifdef USE_COLLECTIVES
6+ #extension GL_KHR_shader_subgroup_shuffle: enable
7+ #endif
48
59#include "types.comp"
610
11+ // Make spec constant
12+ #define SHMEM_PAD 0
13+
714// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
815layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout]
916layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; // src1 - input: [W, H, Cin, N] -- channel_first format
@@ -45,12 +52,16 @@ layout (push_constant) uniform parameter {
4552 uint32_t nb3;
4653} p;
4754
48- #define WG_SIZE 256
49-
50- layout(local_size_x = WG_SIZE, local_size_y = 1, local_size_z = 1) in;
55+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
56+ // Blocktile sizes
57+ layout(constant_id = 1) const uint BS_K = 128;
58+ layout(constant_id = 2) const uint BS_CRS = 16;
59+ layout(constant_id = 3) const uint BS_NPQ = 128;
60+ // Thread-tile sizes
61+ layout(constant_id = 4) const uint TS_K = 8;
5162
5263uint32_t tid = gl_LocalInvocationID.x;
53- const uint32_t bs = gl_WorkGroupSize.x;
64+ const uint32_t WG_SIZE = gl_WorkGroupSize.x;
5465
5566uint splitWork(uint work_size, uint block_size){
5667 return (block_size + work_size -1) / block_size;
@@ -62,16 +73,11 @@ uint32_t NPQ = p.N*p.OH*p.OW;
6273
6374uint32_t n_elems_out = K*NPQ;
6475
65- // Blocktile sizes
66- const uint32_t BS_K = 128;
67- const uint32_t BS_CRS = 16;
68- const uint32_t BS_NPQ = 128;
69-
7076// Number of blocktiles per input
7177uint32_t NB_CRS = splitWork(CRS, BS_CRS);
7278
73- const uint32_t Ash_stride = BS_CRS+1 ;
74- const uint32_t Bsh_stride = BS_NPQ+1 ;
79+ const uint32_t Ash_stride = BS_CRS+SHMEM_PAD ;
80+ const uint32_t Bsh_stride = BS_NPQ+SHMEM_PAD ;
7581
7682const uint32_t Ash_numel = BS_K*BS_CRS;
7783const uint32_t Bsh_numel = BS_CRS*BS_NPQ;
@@ -83,7 +89,6 @@ shared float Ash[Ash_len]; // K x CRS
8389shared float Bsh[Bsh_len]; // CRS x NPQ
8490
8591// Threadtile sizes
86- const uint32_t TS_K = 16;
8792const uint32_t TS_NPQ = BS_K*BS_NPQ / WG_SIZE / TS_K;
8893
8994// Number of threadtiles per blocktile
@@ -111,134 +116,111 @@ uint32_t T_x = tid % NT_NPQ;
111116
112117uint32_t Ar = tid / BS_CRS;
113118uint32_t Ac = tid % BS_CRS;
114- uint32_t ArpWg = WG_SIZE / BS_CRS;
119+ const uint32_t ArpWg = WG_SIZE / BS_CRS;
115120
116121uint32_t Br = tid / BS_NPQ;
117122uint32_t Bc = tid % BS_NPQ;
118- uint32_t BrpWg = WG_SIZE / BS_NPQ;
123+ const uint32_t BrpWg = WG_SIZE / BS_NPQ;
119124
120- void initReg (){
125+ void main (){\
121126 for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
122127 for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
123128 regC[T_ly][T_lx] = 0.0;
124129 }
125130 }
126- }
127-
128- void outProdReg(){
129- for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){
130- for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
131- regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx];
131+ /* Advance block in CRS dim */\
132+ for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){
133+ #ifdef USE_COLLECTIVES
134+ uint32_t cached_CRS_idx = B_idx_CRS*BS_CRS + gl_SubgroupInvocationID;
135+ uint32_t cached_Cin_idx = cached_CRS_idx / (p.KW*p.KH);
136+ uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx*p.KW*p.KH);
137+ uint32_t cached_KH_idx = cached_CRS_remainder / p.KW;
138+ uint32_t cached_KW_idx = cached_CRS_remainder - cached_KH_idx*p.KW;
139+
140+ uint32_t CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
141+ uint32_t Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
142+ uint32_t KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
143+ uint32_t KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
144+ #else
145+ uint32_t CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A)
146+ uint32_t Cin_idx_a = CRS_idx_a / (p.KW*p.KH);
147+ uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH;
148+ uint32_t KH_idx_a = CRS_remainder / p.KW;
149+ uint32_t KW_idx_a = CRS_remainder - KH_idx_a*p.KW;
150+ #endif
151+
152+ /* Load kernel to A_block: (BS_K x BS_CRS)*/
153+ for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){
154+ uint32_t B_ly = r_offset + Ar;
155+ uint32_t B_lx = Ac;
156+ uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/
157+ uint32_t knl_idx = min(KW_idx_a + KH_idx_a*p.nb01 + Cin_idx_a*p.nb02 + K_idx*p.nb03, K*CRS-1);
158+ float val = knl_data[knl_idx];
159+ if(K_idx >= K || CRS_idx_a >= CRS){
160+ val = 0.0;
161+ }
162+ Ash[B_ly * Ash_stride + B_lx] = val;
132163 }
133- for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
134- regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx];
164+ /* Load input to B_block: (BS_CRS x BS_NPQ) */
165+ for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){
166+ uint32_t B_ly = r_offset + Br; /* Row index of B block */
167+ uint32_t B_lx = Bc;
168+ uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
169+ uint32_t N_idx = NPQ_idx / (p.OH*p.OW);
170+ uint32_t NPQ_remainder = NPQ_idx - N_idx*p.OH*p.OW;
171+ uint32_t OH_idx = NPQ_remainder / p.OW;
172+ uint32_t OW_idx = NPQ_remainder - OH_idx*p.OW;
173+
174+ #ifdef USE_COLLECTIVES
175+ uint32_t CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
176+ uint32_t Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
177+ uint32_t KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
178+ uint32_t KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
179+ #else
180+ uint32_t CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */
181+ uint32_t Cin_idx_b = CRS_idx_b / (p.KW*p.KH);
182+ uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b*p.KW*p.KH;
183+ uint32_t KH_idx_b = CRS_remainder / p.KW;
184+ uint32_t KW_idx_b = CRS_remainder - KH_idx_b*p.KW;
185+ #endif
186+
187+ uint32_t H_idx = OH_idx*p.s1 + KH_idx_b*p.d1 - p.p1;
188+ uint32_t W_idx = OW_idx*p.s0 + KW_idx_b*p.d0 - p.p0;
189+ uint32_t src_idx = min(max(W_idx + H_idx*p.nb11 + Cin_idx_b*p.nb12 + N_idx*p.nb13, 0), p.Cin*p.N*p.W*p.H-1);
190+ float val = src_data[src_idx];
191+ if(CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W){
192+ val = 0.0;
193+ }
194+ Bsh[B_ly * Bsh_stride + B_lx] = val;
135195 }
136- for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
196+ barrier();
197+ for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){
198+ for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
199+ regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx];
200+ }
137201 for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
138- regC[T_ly][T_lx] += regA[T_ly] * regB[T_lx];
202+ regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx];
203+ }
204+ for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
205+ for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
206+ regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
207+ }
139208 }
140209 }
210+ barrier();
141211 }
142- }
143-
144- // Generate different functions for computing the sides.
145-
146- #define NOOP()
147-
148- #define DEF_BOUNDARY_CONDITION_A_IF()\
149- if(K_idx < K && CRS_idx < CRS){
150-
151- #define DEF_BOUNDARY_CONDITION_A_ELSE()\
152- }else{\
153- Ash[B_ly * Ash_stride + B_lx] = 0.0;\
154- }
155-
156- #define DEF_BOUNDARY_CONDITION_B_IF()\
157- if(CRS_idx < CRS && NPQ_idx < NPQ){
158-
159- #define DEF_BOUNDARY_CONDITION_B_ELSE()\
160- }else{\
161- Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\
162- }
163-
164- #define MAIN_LOOP(FUNC_NAME_SUFFIX, BOUNDARY_CONDITION_A_IF, BOUNDARY_CONDITION_A_ELSE, BOUNDARY_CONDITION_B_IF, BOUNDARY_CONDITION_B_ELSE)\
165- void mainLoop ## FUNC_NAME_SUFFIX(){\
166- initReg();\
167- /* Advance block in CRS dim */\
168- for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){\
169- /* Load kernel to A_block: (BS_K x BS_CRS)*/\
170- for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){\
171- uint32_t B_ly = r_offset + Ar;\
172- uint32_t B_lx = Ac;\
173- uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/\
174- uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_lx; /* Global CRS_idx (column index of A)*/\
175- BOUNDARY_CONDITION_A_IF()\
176- uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\
177- uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\
178- uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\
179- uint32_t knl_idx = KW_idx + KH_idx*p.nb01 + Cin_idx*p.nb02 + K_idx*p.nb03;\
180- Ash[B_ly * Ash_stride + B_lx] = knl_data[knl_idx];\
181- BOUNDARY_CONDITION_A_ELSE()\
182- }\
183- barrier();\
184- /* Load input to B_block: (BS_CRS x BS_NPQ) */\
185- for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){\
186- uint32_t B_ly = r_offset + Br; /* Row index of B block */\
187- uint32_t B_lx = Bc; /* Column index of B block */\
188- uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */\
189- uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */\
190- BOUNDARY_CONDITION_B_IF()\
191- uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\
192- uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\
193- uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\
194- uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\
195- uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\
196- uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\
197- uint32_t H_idx = OH_idx*p.s1 + KH_idx*p.d1 - p.p1;\
198- uint32_t W_idx = OW_idx*p.s0 + KW_idx*p.d0 - p.p0;\
199- if(H_idx >= 0 && H_idx < p.H && W_idx >= 0 && W_idx < p.W){\
200- uint32_t src_idx = W_idx + H_idx*p.nb11 + Cin_idx*p.nb12 + N_idx*p.nb13;\
201- Bsh[B_ly * Bsh_stride + B_lx] = src_data[src_idx];\
202- }else{\
203- Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\
204- }\
205- BOUNDARY_CONDITION_B_ELSE()\
206- }\
207- barrier();\
208- outProdReg();\
209- barrier();\
210- }\
211- /* Save C* */\
212- for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){\
213- for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){\
214- uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;\
215- uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;\
216- if(K_idx < K && NPQ_idx < NPQ){\
217- uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\
218- uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\
219- uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\
220- uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3;\
221- dst_data[dst_idx] = regC[T_ly][T_lx];\
222- }\
223- }\
224- }\
225- }
226-
227- // Generates mainLoopBoundaryCheck
228- MAIN_LOOP(BoundaryCheck,
229- DEF_BOUNDARY_CONDITION_A_IF,
230- DEF_BOUNDARY_CONDITION_A_ELSE,
231- DEF_BOUNDARY_CONDITION_B_IF,
232- DEF_BOUNDARY_CONDITION_B_ELSE)
233-
234- // Generates mainLoopNoBoundaryCheck
235- MAIN_LOOP(NoBoundaryCheck,
236- NOOP, NOOP, NOOP, NOOP)
237-
238- void main(){
239- if(gl_WorkGroupID.x == gl_NumWorkGroups.x-1 || gl_WorkGroupID.y == gl_NumWorkGroups.y-1){
240- mainLoopBoundaryCheck();
241- }else{
242- mainLoopNoBoundaryCheck();
212+ /* Save C* */
213+ for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
214+ for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
215+ uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
216+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
217+ uint32_t N_idx = NPQ_idx / (p.OH*p.OW);
218+ uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;
219+ uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;
220+ uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3;
221+ if(K_idx < K && NPQ_idx < NPQ){
222+ dst_data[dst_idx] = regC[T_ly][T_lx];
223+ }
224+ }
243225 }
244226}
0 commit comments