diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index a4bf588949b..7dec6c1697f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -119,55 +119,27 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index ef0c3c571c9..2892f74e05f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -130,55 +130,28 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } // Apply scale and mask