From 77ed617fc4f4b1dcf1491ff1241a006db0ecb4e7 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 4 Nov 2025 11:52:47 -0800 Subject: [PATCH] [ET-VK][ez] SDPA don't branch based on whether bounds check needed Title says it all! Why? * The branching path is causing incorrect output on Samsung S24. It's unclear what the exact underlying issue is but the problem is not reproducible on other GPUs and appears to be an issue specific to Adreno 750 architecture. To be safe, always use bounds checking. Differential Revision: [D86226136](https://our.internmc.facebook.com/intern/diff/D86226136/) [ghstack-poisoned] --- .../glsl/sdpa_compute_attn_weights_coop.glsl | 70 ++++++------------ .../glsl/sdpa_compute_attn_weights_tiled.glsl | 71 ++++++------------- 2 files changed, 43 insertions(+), 98 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index a4bf588949b..7dec6c1697f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -119,55 +119,27 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index ef0c3c571c9..2892f74e05f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -130,55 +130,28 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } // Apply scale and mask