diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 62acbf107a298..2255f9c168e6e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -345,7 +345,7 @@ void main() { float Lfrcp[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Lfrcp[r] = 1.0 / Lf[r]; + Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 2066a05b34902..8699fa6c9cbb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -380,7 +380,7 @@ void main() { float Lfrcp[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Lfrcp[r] = 1.0 / Lf[r]; + Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 910da1ab0c28f..fcfc60a878544 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -121,7 +121,11 @@ void main() { const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); L = coopmat(0); +#if defined(ACC_TYPE_MAX) + M = coopmat(-ACC_TYPE_MAX / ACC_TYPE(2)); +#else M = coopmat(NEG_FLT_MAX_OVER_2); +#endif coopmat slopeMat = coopmat(1.0); @@ -294,7 +298,7 @@ void main() { [[unroll]] for (int k = 0; k < Ldiag.length(); ++k) { - Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]); } O = Ldiag*O; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 06e83822fe326..4eaddd31a8f58 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -91,7 +91,7 @@ void main() { L = L*ms + vs; } - L = 1.0 / L; + L = (L == 0.0) ? 0.0 : 1.0 / L; // D dimension is split across workgroups in the y dimension uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;