Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2492,9 +2492,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;

l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
l_warptile = { 128, 128, 128, (device->coopmat_support ? 16u : 32u), subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile = { 128, 64, 64, (device->coopmat_support ? 16u : 32u), subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile = { subgroup_size_16, 32, 32, (device->coopmat_support ? 16u : 32u), 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };

l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
Expand Down
20 changes: 13 additions & 7 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ void main() {
}
#else
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
FLOAT_TYPE_VEC2 cache_b;
FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
FLOAT_TYPE_VEC4 cache_b;

[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
Expand Down Expand Up @@ -353,24 +353,30 @@ void main() {
}
}
#else
[[unroll]] for (uint i = 0; i < BK / 2; i++) {
[[unroll]] for (uint i = 0; i < BK / 4; i++) {
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
const uint base_a = (warp_r * WM + wsir * WSUBM + tiwr * TM) * SHMEM_STRIDE + 2 * i;
[[unroll]] for (uint j = 0; j < TM; j++) {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
cache_a[wsir * TM + j].xy = buf_a[base_a + j * SHMEM_STRIDE ];
cache_a[wsir * TM + j].zw = buf_a[base_a + j * SHMEM_STRIDE + 1];
}
}

[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
const uint base_b = (warp_c * WN + wsic * WSUBN + tiwc * TN) * SHMEM_STRIDE + 2 * i;
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
cache_b.xy = buf_b[base_b + cc * SHMEM_STRIDE ];
cache_b.zw = buf_b[base_b + cc * SHMEM_STRIDE + 1];

[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
}
}
}
Expand Down
Loading