Skip to content

Commit d2a65ec

Browse files
jeffbolznv0cc4m
andcommitted
0cc4m's fixes for AMD perf
Co-authored-by: 0cc4m <[email protected]>
1 parent e8643c0 commit d2a65ec

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,6 +3099,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
30993099
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
31003100
conv2d_SHMEM_PAD = 0;
31013101
conv2d_UNROLL = false;
3102+
} else if (device->vendor_id == VK_VENDOR_ID_AMD) {
3103+
conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
31023104
}
31033105

31043106
switch (s) {
@@ -3107,6 +3109,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
31073109
conv2d_BS_K = 128;
31083110
conv2d_BS_NPQ = 128;
31093111
conv2d_BS_CRS = 16;
3112+
if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
3113+
conv2d_UNROLL = false;
3114+
}
31103115
break;
31113116
case CONV_SHAPE_64x32:
31123117
conv2d_BS_K = 64;
@@ -3121,13 +3126,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
31213126
break;
31223127
}
31233128

3124-
// Use collectives on pre-Turing NVIDIA GPUs, which had slower integer math.
3129+
// Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
31253130
bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
31263131
device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
3132+
bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
3133+
device->architecture == vk_device_architecture::AMD_GCN;
31273134

31283135
if (device->subgroup_shuffle &&
31293136
device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
3130-
allow_collectives_nv) {
3137+
allow_collectives_nv &&
3138+
allow_collectives_amd) {
31313139
use_collectives = 1;
31323140
conv2d_BS_CRS = std::min(
31333141
device->subgroup_size,

0 commit comments

Comments
 (0)