@@ -398,6 +398,13 @@ def test_cudnn_dropout_against_xla_dropout(
398
398
test_fn = CuDNNGPUFlashAttention .default_config ().set (** cfg ).instantiate ()
399
399
ref_fn = ReferenceMHA .default_config ().set (** cfg ).instantiate ()
400
400
401
+ k1 , k2 , k3 = jax .random .split (jax .random .PRNGKey (0 ), 3 )
402
+ q = jax .random .normal (k1 , qkv_shape , dtype = dtype )
403
+ k = jax .random .normal (k2 , qkv_shape , dtype = dtype )
404
+ v = jax .random .normal (k3 , qkv_shape , dtype = dtype )
405
+ input_batch = dict (query = q , key = k , value = v , bias = bias , logit_sink = None )
406
+ chex .assert_equal (test_fn .is_supported (input_batch , kv_cache_type = None ), True )
407
+
401
408
dropout_mask = (
402
409
test_fn (
403
410
dict (
@@ -416,13 +423,6 @@ def test_cudnn_dropout_against_xla_dropout(
416
423
# the same mask.
417
424
jax .clear_caches ()
418
425
419
- k1 , k2 , k3 = jax .random .split (jax .random .PRNGKey (0 ), 3 )
420
- q = jax .random .normal (k1 , qkv_shape , dtype = dtype )
421
- k = jax .random .normal (k2 , qkv_shape , dtype = dtype )
422
- v = jax .random .normal (k3 , qkv_shape , dtype = dtype )
423
- input_batch = dict (query = q , key = k , value = v , bias = bias , logit_sink = None )
424
- chex .assert_equal (test_fn .is_supported (input_batch , kv_cache_type = None ), True )
425
-
426
426
ref_fn = functools .partial (
427
427
ref_fn ,
428
428
dropout_mask = dropout_mask ,
@@ -492,6 +492,7 @@ def test_cudnn_dropout_determinism():
492
492
logit_sink = None ,
493
493
)
494
494
fn = CuDNNGPUFlashAttention .default_config ().set (dropout_rate = 0.1 ).instantiate ()
495
+ chex .assert_equal (fn .is_supported (input_batch , kv_cache_type = None ), True )
495
496
496
497
outputs = []
497
498
grads = []
0 commit comments