Skip to content

Commit a935eb7

Browse files
author
sangchengmeng
committed
add-unittest-mrope
1 parent ac89450 commit a935eb7

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

unit_tests/models/qwen2_vl/test_mrope.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,23 @@ def test_mrope_triton_correctness(B, H_q, H_k, L, D, mrope_section):
4343

4444
torch.manual_seed(0)
4545
device = "cuda"
46+
HALF = D // 2
4647

4748
q = torch.rand((B, H_q, L, D), dtype=torch.float32, device=device)
4849
k = torch.rand((B, H_k, L, D), dtype=torch.float32, device=device)
49-
cos = torch.rand((3, 1, L, D), dtype=torch.float32, device=device)
50-
sin = torch.rand((3, 1, L, D), dtype=torch.float32, device=device)
5150

52-
ref_q, ref_k = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1)
51+
cos_half = torch.rand((3, L, HALF), dtype=torch.float32, device=device)
52+
sin_half = torch.rand((3, L, HALF), dtype=torch.float32, device=device)
5353

54-
out_q, out_k = mrope_triton(q, k, cos, sin, axis_map)
54+
cos_full = torch.cat([cos_half, cos_half], dim=-1)
55+
sin_full = torch.cat([sin_half, sin_half], dim=-1)
5556

57+
cos_ref = cos_full.unsqueeze(1)
58+
sin_ref = sin_full.unsqueeze(1)
59+
60+
ref_q, ref_k = apply_multimodal_rotary_pos_emb(q, k, cos_ref, sin_ref, mrope_section, unsqueeze_dim=1)
61+
62+
out_q, out_k = mrope_triton(q, k, cos_half, sin_half, axis_map)
5663
assert torch.allclose(out_q, ref_q, rtol=1e-3, atol=1e-3)
5764
assert torch.allclose(out_k, ref_k, rtol=1e-3, atol=1e-3)
5865

0 commit comments

Comments
 (0)