Skip to content

Commit 2cb1b7b

Browse files
authored
Fix GPU Attention Tests (#1314)
* Fix if statement and ensure config is_supported is run * Add to other test
1 parent 3ae1f24 commit 2cb1b7b

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

axlearn/common/flash_attention/gpu_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ def is_supported(
811811
# key/value to be even.
812812
if not self._check_block_size(input_batch, block_size=2):
813813
return False
814-
if kv_cache_type == KVCache:
814+
elif kv_cache_type == KVCache:
815815
if query.shape[1] > 1:
816816
return self._log_unsupported("multi-step decoding is not supported.")
817817
if not key.shape[1] % 2 == 0:

axlearn/common/flash_attention/gpu_attention_test.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,13 @@ def test_cudnn_dropout_against_xla_dropout(
398398
test_fn = CuDNNGPUFlashAttention.default_config().set(**cfg).instantiate()
399399
ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate()
400400

401+
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
402+
q = jax.random.normal(k1, qkv_shape, dtype=dtype)
403+
k = jax.random.normal(k2, qkv_shape, dtype=dtype)
404+
v = jax.random.normal(k3, qkv_shape, dtype=dtype)
405+
input_batch = dict(query=q, key=k, value=v, bias=bias, logit_sink=None)
406+
chex.assert_equal(test_fn.is_supported(input_batch, kv_cache_type=None), True)
407+
401408
dropout_mask = (
402409
test_fn(
403410
dict(
@@ -416,13 +423,6 @@ def test_cudnn_dropout_against_xla_dropout(
416423
# the same mask.
417424
jax.clear_caches()
418425

419-
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
420-
q = jax.random.normal(k1, qkv_shape, dtype=dtype)
421-
k = jax.random.normal(k2, qkv_shape, dtype=dtype)
422-
v = jax.random.normal(k3, qkv_shape, dtype=dtype)
423-
input_batch = dict(query=q, key=k, value=v, bias=bias, logit_sink=None)
424-
chex.assert_equal(test_fn.is_supported(input_batch, kv_cache_type=None), True)
425-
426426
ref_fn = functools.partial(
427427
ref_fn,
428428
dropout_mask=dropout_mask,
@@ -492,6 +492,7 @@ def test_cudnn_dropout_determinism():
492492
logit_sink=None,
493493
)
494494
fn = CuDNNGPUFlashAttention.default_config().set(dropout_rate=0.1).instantiate()
495+
chex.assert_equal(fn.is_supported(input_batch, kv_cache_type=None), True)
495496

496497
outputs = []
497498
grads = []

0 commit comments

Comments
 (0)