Skip to content

Commit e3fc42d

Browse files
RUTHLESS-BOTgemini-code-assist[bot]
authored andcommitted
[Misc] parametrize 'dtype' in test_flash_mla (vllm-project#22641)
Signed-off-by: RUTHLESS-BOT <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 25e5d42 commit e3fc42d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tests/kernels/attention/test_flashmla.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
3535
@pytest.mark.parametrize("block_size", [64])
3636
@pytest.mark.parametrize("causal", [True])
3737
@pytest.mark.parametrize("varlen", [False, True])
38+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
3839
@torch.inference_mode()
3940
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
40-
varlen):
41-
# TODO: parametrize using pytest
42-
dtype = torch.bfloat16
41+
varlen, dtype):
4342
device = torch.device("cuda:0")
4443
torch.set_default_dtype(dtype)
4544
torch.set_default_device(device)
@@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
4847
random.seed(0)
4948

5049
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
51-
f"{d=}, {dv=}, {causal=}, {varlen=}")
50+
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
5251

5352
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
5453
if varlen:

0 commit comments

Comments
 (0)