@@ -144,13 +144,13 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
144144
145145auto & sdpa_cache () {
146146 static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache (
147- " MLX_CUDA_SDPA_CACHE_SIZE" , /* default_capacity */ 16 );
147+ " MLX_CUDA_SDPA_CACHE_SIZE" , /* default_capacity */ 64 );
148148 return cache;
149149}
150150
151151auto & sdpa_backward_cache () {
152152 static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache (
153- " MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE" , /* default_capacity */ 16 );
153+ " MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE" , /* default_capacity */ 64 );
154154 return cache;
155155}
156156
@@ -207,8 +207,14 @@ fe::graph::Graph build_sdpa_graph(
207207 auto options = fe::graph::SDPA_attributes ()
208208 .set_name (" sdpa_cudnn" )
209209 .set_attn_scale (scale)
210- .set_causal_mask (do_causal)
211210 .set_generate_stats (output_logsumexp);
211+ if (do_causal) {
212+ if (q.shape (2 ) > k.shape (2 )) {
213+ options.set_causal_mask (do_causal);
214+ } else {
215+ options.set_causal_mask_bottom_right (do_causal);
216+ }
217+ }
212218 if (mask_arr) {
213219 auto bias_ = graph.tensor (fe::graph::Tensor_attributes ().set_name (" BIAS" ));
214220 set_tensor_attrs (bias_, BIAS, *mask_arr);
@@ -282,7 +288,14 @@ fe::graph::Graph build_sdpa_backward_graph(
282288 auto options = fe::graph::SDPA_backward_attributes ()
283289 .set_name (" sdpa_backward_cudnn" )
284290 .set_attn_scale (scale)
285- .set_causal_mask (do_causal);
291+ .set_attn_scale (scale);
292+ if (do_causal) {
293+ if (q.shape (2 ) > k.shape (2 )) {
294+ options.set_causal_mask (do_causal);
295+ } else {
296+ options.set_causal_mask_bottom_right (do_causal);
297+ }
298+ }
286299 if (mask_arr) {
287300 auto bias_ = graph.tensor (fe::graph::Tensor_attributes ().set_name (" BIAS" ));
288301 set_tensor_attrs (bias_, BIAS, *mask_arr);
@@ -340,6 +353,7 @@ bool supports_sdpa_cudnn(
340353 const array& q,
341354 const array& k,
342355 const array& v,
356+ bool do_causal,
343357 Stream s) {
344358 static bool enabled = env::get_var (" MLX_CUDA_USE_CUDNN_SPDA" , 1 );
345359 if (!enabled) {
@@ -351,8 +365,8 @@ bool supports_sdpa_cudnn(
351365 return false ;
352366 }
353367
354- // Only use cuDNN for prefilling and training.
355- if (q.shape (2 ) != k.shape (2 )) {
368+ // Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv) .
369+ if (( q.shape (2 ) == 1 ) && (q. shape ( 2 ) != k.shape (2 ) )) {
356370 return false ;
357371 }
358372
@@ -520,7 +534,7 @@ bool ScaledDotProductAttention::use_fallback(
520534
521535 return !supports_sdpa_vector (
522536 q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
523- !supports_sdpa_cudnn (q, k, v, s);
537+ !supports_sdpa_cudnn (q, k, v, do_causal, s);
524538}
525539
526540bool ScaledDotProductAttention::supports_bool_mask () {
0 commit comments