Skip to content

Commit ab03176

Browse files
committed
fix
1 parent 0e0b08f commit ab03176

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import pytest
3-
from lightllm.common.basemodel.triton_kernel.multimodal_emb import mark_multimodal_obj
3+
from lightllm.common.basemodel.triton_kernel.multimodal_emb import mark_multimodal_obj, multimodal_emb
44
from lightllm.utils.log_utils import init_logger
55

66
logger = init_logger(__name__)
@@ -18,5 +18,31 @@ def test_mark_mubltimodal_obj():
1818
assert torch.equal(mark_obj, torch.tensor([1, 0, 0], device="cuda"))
1919

2020

21+
def test_multimodal_emb():
22+
S, D = 1024 * 1000, 128 * 64
23+
vob_size = 320000
24+
image_size = 10
25+
image_token_size = 512
26+
27+
text_weight = torch.randn((vob_size, D), device="cuda", dtype=torch.float16)
28+
img_weight = torch.randn((image_size * image_token_size, D), device="cuda", dtype=torch.float16)
29+
img_token_lens = torch.full((image_size,), image_token_size, device="cuda", dtype=torch.long)
30+
img_start_token_ids = (
31+
(torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long()
32+
)
33+
img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long()
34+
35+
prompt_ids = torch.arange(0, S, 1).cuda().long()
36+
prompt_ids[0 : image_size * image_token_size] = (
37+
(vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long()
38+
)
39+
40+
out = torch.zeros((S, D), dtype=torch.float16, device="cuda")
41+
multimodal_emb(
42+
out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size
43+
)
44+
return
45+
46+
2147
if __name__ == "__main__":
2248
pytest.main()

0 commit comments

Comments
 (0)