11import torch
22import pytest
3- from lightllm .common .basemodel .triton_kernel .multimodal_emb import mark_multimodal_obj
3+ from lightllm .common .basemodel .triton_kernel .multimodal_emb import mark_multimodal_obj , multimodal_emb
44from lightllm .utils .log_utils import init_logger
55
66logger = init_logger (__name__ )
@@ -18,5 +18,31 @@ def test_mark_mubltimodal_obj():
1818 assert torch .equal (mark_obj , torch .tensor ([1 , 0 , 0 ], device = "cuda" ))
1919
2020
21+ def test_multimodal_emb ():
22+ S , D = 1024 * 1000 , 128 * 64
23+ vob_size = 320000
24+ image_size = 10
25+ image_token_size = 512
26+
27+ text_weight = torch .randn ((vob_size , D ), device = "cuda" , dtype = torch .float16 )
28+ img_weight = torch .randn ((image_size * image_token_size , D ), device = "cuda" , dtype = torch .float16 )
29+ img_token_lens = torch .full ((image_size ,), image_token_size , device = "cuda" , dtype = torch .long )
30+ img_start_token_ids = (
31+ (torch .arange (0 , image_size * image_token_size , image_token_size ) + vob_size * 10 ).cuda ().long ()
32+ )
33+ img_start_locs = torch .arange (0 , image_size * image_token_size , image_token_size ).cuda ().long ()
34+
35+ prompt_ids = torch .arange (0 , S , 1 ).cuda ().long ()
36+ prompt_ids [0 : image_size * image_token_size ] = (
37+ (vob_size * 10 + torch .arange (0 , image_size * image_token_size , 1 )).cuda ().long ()
38+ )
39+
40+ out = torch .zeros ((S , D ), dtype = torch .float16 , device = "cuda" )
41+ multimodal_emb (
42+ out , prompt_ids , text_weight , img_weight , img_token_lens , img_start_token_ids , img_start_locs , 0 , vob_size
43+ )
44+ return
45+
46+
2147if __name__ == "__main__" :
2248 pytest .main ()
0 commit comments