1
1
#version 450
2
2
3
3
#extension GL_EXT_control_flow_attributes : enable
4
+ #ifdef COOPMAT2
5
+ #extension GL_NV_cooperative_matrix2 : enable
6
+ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
7
+ #extension GL_KHR_memory_scope_semantics : enable
8
+ #endif
4
9
5
10
#ifdef USE_COLLECTIVES
6
11
# extension GL_KHR_shader_subgroup_shuffle : enable
@@ -91,6 +96,12 @@ uint32_t n_elems_out = K * NPQ;
91
96
// Number of blocktiles per input
92
97
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
93
98
99
+ #ifdef COOPMAT2
100
+ #define SHMEM_TYPE float16_t
101
+ #else
102
+ #define SHMEM_TYPE float
103
+ #endif
104
+
94
105
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
95
106
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
96
107
@@ -100,8 +111,8 @@ const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
100
111
const uint32_t Ash_len = BS_K * Ash_stride;
101
112
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
102
113
103
- shared float Ash[Ash_len]; // K x CRS
104
- shared float Bsh[Bsh_len]; // CRS x NPQ
114
+ shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
115
+ shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
105
116
106
117
// Threadtile sizes
107
118
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
@@ -110,10 +121,6 @@ const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
110
121
const uint32_t NT_K = BS_K / TS_K;
111
122
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
112
123
113
- float regA[TS_K];
114
- float regB[TS_NPQ];
115
- float regC[TS_K][TS_NPQ];
116
-
117
124
/*
118
125
Compute
119
126
KxCRS @ CRSxNPQ = K x NPQ
@@ -145,12 +152,36 @@ uint fastdiv(uint n, uint mp, uint L) {
145
152
return (msbs + n) >> L;
146
153
}
147
154
155
+ #ifdef COOPMAT2
156
+ #define ACC_TYPE float16_t
157
+
158
+ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
159
+ {
160
+ uint32_t K_idx = B_idx_K * BS_K + r;
161
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
162
+ uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
163
+ uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
164
+ uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
165
+ uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
166
+ if (K_idx < K && NPQ_idx < NPQ) {
167
+ dst_data[dst_idx] = D_TYPE(elem);
168
+ }
169
+ return elem;
170
+ }
171
+ #endif
172
+
148
173
void main() {
174
+ #ifdef COOPMAT2
175
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
176
+ matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
177
+ #else
178
+ float regC[TS_K][TS_NPQ];
149
179
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
150
180
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
151
181
regC[T_ly][T_lx] = 0.0;
152
182
}
153
183
}
184
+ #endif
154
185
/* Advance block in CRS dim */
155
186
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
156
187
uint32_t CRS_idx_a;
@@ -199,7 +230,7 @@ void main() {
199
230
if (K_idx >= K || CRS_idx_a >= CRS) {
200
231
val = 0.0;
201
232
}
202
- Ash[B_ly * Ash_stride + B_lx] = val;
233
+ Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE( val) ;
203
234
}
204
235
/* Load input to B_block: (BS_CRS x BS_NPQ) */
205
236
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
@@ -244,11 +275,21 @@ void main() {
244
275
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
245
276
val = 0.0;
246
277
}
247
- Bsh[B_ly * Bsh_stride + B_lx] = val;
278
+ Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE( val) ;
248
279
}
249
280
barrier();
281
+ #ifdef COOPMAT2
282
+ coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
283
+ coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
284
+
285
+ coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
286
+ coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
287
+ matC = coopMatMulAdd(matA, matB, matC);
288
+ #else
250
289
if (T_y * TS_K < K) {
251
290
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
291
+ float regA[TS_K];
292
+ float regB[TS_NPQ];
252
293
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253
294
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
254
295
}
@@ -262,9 +303,13 @@ void main() {
262
303
}
263
304
}
264
305
}
306
+ #endif
265
307
barrier();
266
308
}
267
309
/* Save C* */
310
+ #ifdef COOPMAT2
311
+ coopMatPerElementNV(matC, matC, perElemOpStore);
312
+ #else
268
313
if (T_y * TS_K < K) {
269
314
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
270
315
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
@@ -280,4 +325,5 @@ void main() {
280
325
}
281
326
}
282
327
}
328
+ #endif
283
329
}
0 commit comments