@@ -572,6 +572,7 @@ def testSameConfigDifferentFreeVar(tmp_path, mfma_variant):
572572 output = device_zeros (o_shape , dtype = torch .float32 )
573573 # TODO: Add variant of non-transposed V attention kernel.
574574 non_causal_mb = base_attention (q , k , v .permute ([0 , 2 , 1 ]), output )
575+ torch .cuda .synchronize ()
575576 assert (
576577 cache_manager .cache_misses == 1 and cache_manager .cache_hits == 0
577578 ), "Expected first call to not be cached."
@@ -596,13 +597,13 @@ def testSameConfigDifferentFreeVar(tmp_path, mfma_variant):
596597 )
597598 options = set_default_run_config (options )
598599 causal_attention = wave_compile (options , causal_attention )
599-
600600 q = device_randn (q_shape , dtype = torch .float16 )
601601 k = device_randn (k_shape , dtype = torch .float16 )
602602 v = device_randn (v_shape , dtype = torch .float16 )
603603 output = device_zeros (o_shape , dtype = torch .float32 )
604604 # TODO: Add variant of non-transposed V attention kernel.
605605 causal_mb = causal_attention (q , k , v .permute ([0 , 2 , 1 ]), output )
606+ torch .cuda .synchronize ()
606607 assert (
607608 cache_manager .cache_misses == 2 and cache_manager .cache_hits == 0
608609 ), "Expected to be cached despite same config, since it has different values for is_causal."
@@ -780,6 +781,7 @@ def double_kernel(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
780781@require_e2e
781782@require_cache
782783@require_cdna3
784+ @pytest .mark .skip (reason = "Crashes and/or produces incorrect results." )
783785def testAsmBackendCache (tmp_path ):
784786 """Test that ASM backend caching works correctly."""
785787 reset_cache_manager (tmp_path )
@@ -835,6 +837,7 @@ def simple_copy(
835837 # First compilation - should be a cache miss
836838 kernel1 = wave_compile (options , simple_copy )
837839 kernel1 (a , b )
840+ assert_close (a , b )
838841
839842 assert (
840843 cache_manager .cache_misses == 1 and cache_manager .cache_hits == 0
@@ -851,6 +854,7 @@ def simple_copy(
851854 # Second compilation - should be a cache hit
852855 kernel2 = wave_compile (options , simple_copy )
853856 kernel2 (a , b )
857+ assert_close (a , b )
854858
855859 assert (
856860 cache_manager .cache_misses == 1 and cache_manager .cache_hits == 1
0 commit comments