@@ -35,11 +35,10 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
35
35
@pytest .mark .parametrize ("block_size" , [64 ])
36
36
@pytest .mark .parametrize ("causal" , [True ])
37
37
@pytest .mark .parametrize ("varlen" , [False , True ])
38
+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
38
39
@torch .inference_mode ()
39
40
def test_flash_mla (b , s_q , mean_sk , h_q , h_kv , d , dv , block_size , causal ,
40
- varlen ):
41
- # TODO: parametrize using pytest
42
- dtype = torch .bfloat16
41
+ varlen , dtype ):
43
42
device = torch .device ("cuda:0" )
44
43
torch .set_default_dtype (dtype )
45
44
torch .set_default_device (device )
@@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
48
47
random .seed (0 )
49
48
50
49
print (f"{ b = } , { s_q = } , { mean_sk = } , { h_q = } , { h_kv = } , "
51
- f"{ d = } , { dv = } , { causal = } , { varlen = } " )
50
+ f"{ d = } , { dv = } , { causal = } , { varlen = } , { dtype = } " )
52
51
53
52
cache_seqlens = torch .full ((b , ), mean_sk , dtype = torch .int32 )
54
53
if varlen :
0 commit comments