Skip to content

Commit ff700e3

Browse files
author
Stefan Savic
committed
Using this only for f32 and f16
Signed-off-by: Stefan Savic <[email protected]>
1 parent ba67c4a commit ff700e3

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2606,9 +2606,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
26062606
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
26072607
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
26082608

2609-
l_warptile = { 128, 128, 128, (device->coopmat_support ? 16u : 32u), subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
2610-
m_warptile = { 128, 64, 64, (device->coopmat_support ? 16u : 32u), subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
2611-
s_warptile = { subgroup_size_16, 32, 32, (device->coopmat_support ? 16u : 32u), 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
2609+
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
2610+
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
2611+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
26122612

26132613
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
26142614
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ layout (push_constant) uniform parameter
100100
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
101101
layout (constant_id = 1) const uint BM = 64;
102102
layout (constant_id = 2) const uint BN = 64;
103-
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
104103
layout (constant_id = 4) const uint WM = 32;
105104
layout (constant_id = 5) const uint WN = 32;
106105
layout (constant_id = 6) const uint WMITER = 2;
@@ -109,6 +108,14 @@ layout (constant_id = 8) const uint TN = 2;
109108
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
110109
layout (constant_id = 10) const uint WARP = 32;
111110

111+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
112+
#define BK 32
113+
#define BK_STEP 4
114+
#else
115+
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
116+
#define BK_STEP 2
117+
#endif
118+
112119
#ifdef COOPMAT
113120
#define SHMEM_STRIDE (BK / 2 + 4)
114121
#else
@@ -244,8 +251,13 @@ void main() {
244251
}
245252
#else
246253
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
254+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
247255
FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
248256
FLOAT_TYPE_VEC4 cache_b;
257+
#else
258+
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
259+
FLOAT_TYPE_VEC2 cache_b;
260+
#endif
249261

250262
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
251263
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
@@ -283,30 +295,41 @@ void main() {
283295
}
284296
}
285297
#else
286-
[[unroll]] for (uint i = 0; i < BK / 4; i++) {
298+
[[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) {
287299
// Load from shared into cache
288300
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
289-
const uint base_a = (warp_r * WM + wsir * WSUBM + tiwr * TM) * SHMEM_STRIDE + 2 * i;
290301
[[unroll]] for (uint j = 0; j < TM; j++) {
291-
cache_a[wsir * TM + j].xy = buf_a[base_a + j * SHMEM_STRIDE ];
292-
cache_a[wsir * TM + j].zw = buf_a[base_a + j * SHMEM_STRIDE + 1];
302+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
303+
cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i ];
304+
cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1];
305+
#else
306+
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
307+
#endif
293308
}
294309
}
295310

296311
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
297-
const uint base_b = (warp_c * WN + wsic * WSUBN + tiwc * TN) * SHMEM_STRIDE + 2 * i;
298312
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
299-
cache_b.xy = buf_b[base_b + cc * SHMEM_STRIDE ];
300-
cache_b.zw = buf_b[base_b + cc * SHMEM_STRIDE + 1];
313+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
314+
cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i ];
315+
cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1];
316+
#else
317+
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
318+
#endif
301319

302320
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
303321
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
304322
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
305323
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
324+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
306325
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
307326
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
308327
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
309328
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
329+
#else
330+
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
331+
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
332+
#endif
310333
}
311334
}
312335
}

0 commit comments

Comments
 (0)