Skip to content

Commit aa9ecb1

Browse files
committed
0813-add-unit-tests
1 parent 48474c5 commit aa9ecb1

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,6 @@ def __init__(
209209
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
210210

211211
self._init_datatype()
212-
self.load_model(kvargs["weight_dir"])
213-
self.cuda()
214212

215213
def _init_datatype(self):
216214
if isinstance(self.data_type, torch.dtype):
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import math
2+
import torch
3+
import pytest
4+
5+
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton
6+
7+
8+
def rotate_half(x):
9+
"""Rotates half the hidden dims of the input."""
10+
x1 = x[..., : x.shape[-1] // 2]
11+
x2 = x[..., x.shape[-1] // 2 :]
12+
return torch.cat((-x2, x1), dim=-1)
13+
14+
15+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
16+
orig_dtype = tensor.dtype
17+
tensor = tensor.float()
18+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
19+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
20+
output = (tensor * cos) + (rotate_half(tensor) * sin)
21+
output = output.to(orig_dtype)
22+
return output
23+
24+
25+
@pytest.mark.parametrize(
26+
"shape",
27+
[
28+
(16, 1296, 64, 80),
29+
(2, 1024, 2, 192),
30+
(1, 1024, 1, 256),
31+
(2, 1024, 3, 160),
32+
],
33+
)
34+
def test_triton_matches_reference(shape):
35+
B, L, H, D = shape
36+
assert D % 2 == 0
37+
38+
torch.manual_seed(0)
39+
40+
freqs = torch.randn(L, D // 2, device="cuda", dtype=torch.bfloat16)
41+
cos = freqs.cos()
42+
sin = freqs.sin()
43+
44+
tensor = torch.randn(B, L, H, D, device="cuda", dtype=torch.bfloat16)
45+
46+
ref = apply_rotary_pos_emb_vision(tensor, cos, sin)
47+
out = apply_rotary_pos_emb_triton(tensor, cos, sin)
48+
49+
assert out.dtype == tensor.dtype, "输出 dtype 应与输入一致"
50+
assert out.shape == tensor.shape, "输出形状应与输入一致"
51+
assert torch.allclose(out, ref, rtol=1e-2, atol=1e-2), "Triton 与参考实现不一致"
52+
53+
54+
if __name__ == "__main__":
55+
pytest.main()

0 commit comments

Comments
 (0)