Skip to content

Commit 2b95d0c

Browse files
authored
[CUDA] Use cuDNN attention when T_q != T_kv (#2843)
1 parent b054838 commit 2b95d0c

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

mlx/backend/cuda/scaled_dot_product_attention.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,13 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
144144

145145
auto& sdpa_cache() {
146146
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
147-
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
147+
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
148148
return cache;
149149
}
150150

151151
auto& sdpa_backward_cache() {
152152
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
153-
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
153+
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
154154
return cache;
155155
}
156156

@@ -207,8 +207,14 @@ fe::graph::Graph build_sdpa_graph(
207207
auto options = fe::graph::SDPA_attributes()
208208
.set_name("sdpa_cudnn")
209209
.set_attn_scale(scale)
210-
.set_causal_mask(do_causal)
211210
.set_generate_stats(output_logsumexp);
211+
if (do_causal) {
212+
if (q.shape(2) > k.shape(2)) {
213+
options.set_causal_mask(do_causal);
214+
} else {
215+
options.set_causal_mask_bottom_right(do_causal);
216+
}
217+
}
212218
if (mask_arr) {
213219
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
214220
set_tensor_attrs(bias_, BIAS, *mask_arr);
@@ -282,7 +288,14 @@ fe::graph::Graph build_sdpa_backward_graph(
282288
auto options = fe::graph::SDPA_backward_attributes()
283289
.set_name("sdpa_backward_cudnn")
284290
.set_attn_scale(scale)
285-
.set_causal_mask(do_causal);
291+
.set_attn_scale(scale);
292+
if (do_causal) {
293+
if (q.shape(2) > k.shape(2)) {
294+
options.set_causal_mask(do_causal);
295+
} else {
296+
options.set_causal_mask_bottom_right(do_causal);
297+
}
298+
}
286299
if (mask_arr) {
287300
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
288301
set_tensor_attrs(bias_, BIAS, *mask_arr);
@@ -340,6 +353,7 @@ bool supports_sdpa_cudnn(
340353
const array& q,
341354
const array& k,
342355
const array& v,
356+
bool do_causal,
343357
Stream s) {
344358
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
345359
if (!enabled) {
@@ -351,8 +365,8 @@ bool supports_sdpa_cudnn(
351365
return false;
352366
}
353367

354-
// Only use cuDNN for prefilling and training.
355-
if (q.shape(2) != k.shape(2)) {
368+
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
369+
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
356370
return false;
357371
}
358372

@@ -520,7 +534,7 @@ bool ScaledDotProductAttention::use_fallback(
520534

521535
return !supports_sdpa_vector(
522536
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
523-
!supports_sdpa_cudnn(q, k, v, s);
537+
!supports_sdpa_cudnn(q, k, v, do_causal, s);
524538
}
525539

526540
bool ScaledDotProductAttention::supports_bool_mask() {

0 commit comments

Comments
 (0)