Skip to content

Commit 6560184

Browse files
authored
Disable flaky cache test (#1151)
Kernel execution is async, it can outlive the test function if we don't check the result or explicitly synchronize. --------- Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
1 parent 80e64d3 commit 6560184

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tests/kernel/runtime/cache_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ def testSameConfigDifferentFreeVar(tmp_path, mfma_variant):
572572
output = device_zeros(o_shape, dtype=torch.float32)
573573
# TODO: Add variant of non-transposed V attention kernel.
574574
non_causal_mb = base_attention(q, k, v.permute([0, 2, 1]), output)
575+
torch.cuda.synchronize()
575576
assert (
576577
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 0
577578
), "Expected first call to not be cached."
@@ -596,13 +597,13 @@ def testSameConfigDifferentFreeVar(tmp_path, mfma_variant):
596597
)
597598
options = set_default_run_config(options)
598599
causal_attention = wave_compile(options, causal_attention)
599-
600600
q = device_randn(q_shape, dtype=torch.float16)
601601
k = device_randn(k_shape, dtype=torch.float16)
602602
v = device_randn(v_shape, dtype=torch.float16)
603603
output = device_zeros(o_shape, dtype=torch.float32)
604604
# TODO: Add variant of non-transposed V attention kernel.
605605
causal_mb = causal_attention(q, k, v.permute([0, 2, 1]), output)
606+
torch.cuda.synchronize()
606607
assert (
607608
cache_manager.cache_misses == 2 and cache_manager.cache_hits == 0
608609
), "Expected to be cached despite same config, since it has different values for is_causal."
@@ -780,6 +781,7 @@ def double_kernel(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
780781
@require_e2e
781782
@require_cache
782783
@require_cdna3
784+
@pytest.mark.skip(reason="Crashes and/or produces incorrect results.")
783785
def testAsmBackendCache(tmp_path):
784786
"""Test that ASM backend caching works correctly."""
785787
reset_cache_manager(tmp_path)
@@ -835,6 +837,7 @@ def simple_copy(
835837
# First compilation - should be a cache miss
836838
kernel1 = wave_compile(options, simple_copy)
837839
kernel1(a, b)
840+
assert_close(a, b)
838841

839842
assert (
840843
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 0
@@ -851,6 +854,7 @@ def simple_copy(
851854
# Second compilation - should be a cache hit
852855
kernel2 = wave_compile(options, simple_copy)
853856
kernel2(a, b)
857+
assert_close(a, b)
854858

855859
assert (
856860
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 1

0 commit comments

Comments
 (0)