Skip to content

Commit 23dd166

Browse files
pytorchbotabhinaykukkadapu
authored andcommitted
[ET-VK][ez] SDPA don't branch based on whether bounds check needed (pytorch#15600)
Title says it all! Why? * The branching path is causing incorrect output on Samsung S24. It's unclear what the exact underlying issue is but the problem is not reproducible on other GPUs and appears to be an issue specific to Adreno 750 architecture. To be safe, always use bounds checking. Differential Revision: [D86226136](https://our.internmc.facebook.com/intern/diff/D86226136/)
1 parent 0feac4e commit 23dd166

File tree

2 files changed

+43
-98
lines changed

2 files changed

+43
-98
lines changed

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl

Lines changed: 21 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -119,55 +119,27 @@ void main() {
119119
}
120120
// Otherwise, need to actually compute output tile
121121
else {
122-
const bool dont_check_bounds = (S - s) >= TILE_M &&
123-
(context_len - c) >= TILE_N;
124-
125-
if (dont_check_bounds) {
126-
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
127-
load_q_projected_tile_no_checks(
128-
q_tile,
129-
d4,
130-
s,
131-
q_h,
132-
D4,
133-
Q_H,
134-
S);
135-
136-
load_k_cache_tile_no_checks(
137-
w_tile,
138-
d4,
139-
c,
140-
kv_h,
141-
D4,
142-
context_len,
143-
C,
144-
KV_H);
145-
146-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
147-
}
148-
} else {
149-
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
150-
load_q_projected_tile_with_checks(
151-
q_tile,
152-
d4,
153-
s,
154-
q_h,
155-
D4,
156-
Q_H,
157-
S);
158-
159-
load_k_cache_tile_with_checks(
160-
w_tile,
161-
d4,
162-
c,
163-
kv_h,
164-
D4,
165-
context_len,
166-
C,
167-
KV_H);
168-
169-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
170-
}
122+
for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) {
123+
load_q_projected_tile_with_checks(
124+
q_tile,
125+
d4,
126+
s,
127+
q_h,
128+
D4,
129+
Q_H,
130+
S);
131+
132+
load_k_cache_tile_with_checks(
133+
w_tile,
134+
d4,
135+
c,
136+
kv_h,
137+
D4,
138+
context_len,
139+
C,
140+
KV_H);
141+
142+
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
171143
}
172144
}
173145

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl

Lines changed: 22 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -130,55 +130,28 @@ void main() {
130130
}
131131
// Otherwise, need to actually compute output tile
132132
else {
133-
const bool dont_check_bounds = (S - s) >= TILE_M &&
134-
(context_len - c) >= TILE_N;
135-
136-
if (dont_check_bounds) {
137-
for (int d4 = 0; d4 < D4; d4++) {
138-
load_q_projected_tile_no_checks(
139-
q_tile,
140-
d4,
141-
s,
142-
q_h,
143-
D4,
144-
Q_H,
145-
S);
146-
147-
load_k_cache_tile_no_checks(
148-
w_tile,
149-
d4,
150-
c,
151-
kv_h,
152-
D4,
153-
context_len,
154-
C,
155-
KV_H);
156-
157-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
158-
}
159-
} else {
160-
for (int d4 = 0; d4 < D4; d4++) {
161-
load_q_projected_tile_with_checks(
162-
q_tile,
163-
d4,
164-
s,
165-
q_h,
166-
D4,
167-
Q_H,
168-
S);
169-
170-
load_k_cache_tile_with_checks(
171-
w_tile,
172-
d4,
173-
c,
174-
kv_h,
175-
D4,
176-
context_len,
177-
C,
178-
KV_H);
179-
180-
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
181-
}
133+
for (int d4 = 0; d4 < D4; d4++) {
134+
load_q_projected_tile_with_checks(
135+
q_tile,
136+
d4,
137+
s,
138+
q_h,
139+
D4,
140+
Q_H,
141+
S);
142+
143+
load_k_cache_tile_with_checks(
144+
w_tile,
145+
d4,
146+
c,
147+
kv_h,
148+
D4,
149+
context_len,
150+
C,
151+
KV_H);
152+
153+
154+
fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile);
182155
}
183156

184157
// Apply scale and mask

0 commit comments

Comments
 (0)