1010#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
1111#endif
1212
13- #ifdef COOPMAT
14- #extension GL_KHR_cooperative_matrix : enable
15- #extension GL_KHR_memory_scope_semantics : enable
16- #extension GL_KHR_shader_subgroup_basic : enable
17- #endif
18-
1913#ifdef MUL_MAT_ID
2014#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
2115#endif
@@ -79,10 +73,6 @@ layout (constant_id = 10) const uint WARP = 32;
7973
8074#define BK 32
8175
82- #ifdef COOPMAT
83- #define SHMEM_STRIDE (BK / 4 + 4)
84- #endif
85-
8676#define MMQ_SHMEM
8777
8878#include "mul_mmq_shmem_types.glsl"
@@ -92,7 +82,7 @@ shared block_a_cache buf_a[BM];
9282shared block_b_cache buf_b[BN];
9383// Register cache
9484block_a_cache cache_a[WMITER * TM];
95- block_b_cache cache_b[TN] ;
85+ block_b_cache cache_b;
9686
9787#define LOAD_VEC_A (4 * QUANT_R_MMQ)
9888#define LOAD_VEC_B 16
@@ -104,10 +94,6 @@ shared u16vec2 row_ids[4096];
10494
10595#define NUM_WARPS (BLOCK_SIZE / WARP)
10696
107- #ifdef COOPMAT
108- shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
109- #endif
110-
11197#include "mul_mmq_funcs.glsl"
11298
11399void main() {
@@ -137,26 +123,12 @@ void main() {
137123 const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
138124 const uint WSUBM = WM / WMITER;
139125 const uint WSUBN = WN / WNITER;
140-
141- #ifdef COOPMAT
142- const uint warp_i = gl_SubgroupID;
143-
144- const uint tiw = gl_SubgroupInvocationID;
145-
146- const uint cms_per_row = WM / TM;
147- const uint cms_per_col = WN / TN;
148-
149- const uint storestride = WARP / TM;
150- const uint store_r = tiw % TM;
151- const uint store_c = tiw / TM;
152- #else
153126 const uint warp_i = gl_LocalInvocationID.x / WARP;
154127
155128 const uint tiw = gl_LocalInvocationID.x % WARP;
156129
157130 const uint tiwr = tiw % (WSUBM / TM);
158131 const uint tiwc = tiw / (WSUBM / TM);
159- #endif
160132
161133 const uint warp_r = warp_i % (BM / WM);
162134 const uint warp_c = warp_i / (BM / WM);
@@ -207,26 +179,11 @@ void main() {
207179 uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
208180#endif
209181
210- #ifdef COOPMAT
211- coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
212- coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
213- coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
214-
215- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
216-
217- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
182+ ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN / 2];
218183
219- [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col ; i++) {
220- sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> (0.0f);
184+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2 ; i++) {
185+ sums[i] = ACC_TYPE_VEC2 (0.0f);
221186 }
222- #else
223-
224- ACC_TYPE sums[WMITER * TM * WNITER * TN];
225-
226- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
227- sums[i] = ACC_TYPE(0.0f);
228- }
229- #endif
230187
231188 for (uint block = start_k; block < end_k; block += BK) {
232189 [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
@@ -267,38 +224,6 @@ void main() {
267224 pos_a_ib += 1;
268225 pos_b_ib += 1;
269226
270- #ifdef COOPMAT
271- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
272- const uint ib_a = warp_r * WM + cm_row * TM;
273- // Load from shared into cache
274- coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
275-
276- // TODO: only cache values that are actually needed
277- [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
278- cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
279- }
280-
281- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
282- const uint ib_b = warp_c * WN + cm_col * TN;
283- coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
284-
285- // TODO: only cache values that are actually needed
286- [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
287- cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
288- }
289-
290- cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
291- cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
292-
293- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
294- coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
295- }
296-
297- coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
298- sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
299- }
300- }
301- #else
302227 // Load from shared into cache
303228 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
304229 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
@@ -312,24 +237,22 @@ void main() {
312237 [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
313238 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
314239 const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
315- cache_b[cc] .ds = buf_b[ib].ds;
240+ cache_b.ds = buf_b[ib].ds;
316241 [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
317- cache_b[cc] .qs[iqs] = buf_b[ib].qs[iqs];
242+ cache_b.qs[iqs] = buf_b[ib].qs[iqs];
318243 }
319- }
320244
321- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
322- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
323- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
324- const uint cache_a_idx = wsir * TM + cr;
325- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
245+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
246+ [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
247+ const uint cache_a_idx = wsir * TM + cr * 2;
248+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
326249
327- sums[sums_idx] += mmq_dot_product(cache_a_idx, cc);
250+ sums[sums_idx].x += mmq_dot_product(cache_a_idx);
251+ sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1);
328252 }
329253 }
330254 }
331255 }
332- #endif
333256
334257 barrier();
335258 }
@@ -341,54 +264,6 @@ void main() {
341264 const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
342265#endif
343266
344- #ifdef COOPMAT
345- #ifdef MUL_MAT_ID
346- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
347- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
348- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
349-
350- [[unroll]] for (uint col = 0; col < BN; col += storestride) {
351- const uint row_i = dc + cm_col * TN + col + store_c;
352- if (row_i >= _ne1) break;
353-
354- const u16vec2 row_idx = row_ids[row_i];
355-
356- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
357- }
358- }
359- }
360- #else
361- const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
362-
363- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
364- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
365- const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
366-
367- if (is_aligned && is_in_bounds) {
368- // Full coopMat is within bounds and stride_d is aligned with 16B
369- coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
370- coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
371- } else if (is_in_bounds) {
372- // Full coopMat is within bounds, but stride_d is not aligned
373- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
374-
375- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
376- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
377- }
378- } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
379- // Partial coopMat is within bounds
380- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
381-
382- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
383- if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
384- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
385- }
386- }
387- }
388- }
389- }
390- #endif // MUL_MAT_ID
391- #else
392267 [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
393268 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
394269
@@ -399,19 +274,27 @@ void main() {
399274 const uint row_i = dc_warp + cc;
400275 if (row_i >= _ne1) break;
401276
402- const u16vec2 row_idx = row_ids[row_i];
277+ const u16vec2 row_idx = row_ids[row_i - ic * BN ];
403278#endif // MUL_MAT_ID
404- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
279+ [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
280+ const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
405281#ifdef MUL_MAT_ID
406- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
282+ if (dr_warp + 2 * cr < p.M) {
283+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
284+ }
285+ if (dr_warp + 2 * cr + 1 < p.M) {
286+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
287+ }
407288#else
408- if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
409- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
289+ if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
290+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
291+ }
292+ if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
293+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
410294 }
411295#endif // MUL_MAT_ID
412296 }
413297 }
414298 }
415299 }
416- #endif // COOPMAT
417300}
0 commit comments