Skip to content

Commit 3d47c26

Browse files
committed
vulkan: fix diag_mask_inf
With robustbufferaccess disabled, this shader was showing OOB stores. There is a bounds check in the code, but the workgrouop dimensions were reversed vs CUDA and it was running the wrong number of threads. So fix the workgroup dimensions and disable robustness for this pipeline.
1 parent 9f7add1 commit 3d47c26

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2021,7 +2021,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
20212021
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
20222022
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
20232023

2024-
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
2024+
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
20252025

20262026
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
20272027
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);

ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ layout (push_constant) uniform parameter
1212

1313
#include "types.comp"
1414

15-
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
15+
layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
1616

1717
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
1818
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

0 commit comments

Comments
 (0)