@@ -146,49 +146,4 @@ def mark_multimodal_obj(obj_start_token_ids: torch.Tensor, obj_token_lens: torch
146146 num_warps = 1 ,
147147 num_stages = 1 ,
148148 )
149- return out_mark
150-
151-
152- def test ():
153- S , D = 1024 * 1000 , 128 * 64
154- vob_size = 320000
155- image_size = 10
156- image_token_size = 512
157-
158- text_weight = torch .randn ((vob_size , D ), device = "cuda" , dtype = torch .float16 )
159- img_weight = torch .randn ((image_size * image_token_size , D ), device = "cuda" , dtype = torch .float16 )
160- img_token_lens = torch .full ((image_size ,), image_token_size , device = "cuda" , dtype = torch .long )
161- img_start_token_ids = (
162- (torch .arange (0 , image_size * image_token_size , image_token_size ) + vob_size * 10 ).cuda ().long ()
163- )
164- img_start_locs = torch .arange (0 , image_size * image_token_size , image_token_size ).cuda ().long ()
165-
166- prompt_ids = torch .arange (0 , S , 1 ).cuda ().long ()
167- prompt_ids [0 : image_size * image_token_size ] = (
168- (vob_size * 10 + torch .arange (0 , image_size * image_token_size , 1 )).cuda ().long ()
169- )
170-
171- out = torch .zeros ((S , D ), dtype = torch .float16 , device = "cuda" )
172- print (out .shape )
173-
174- import time
175-
176- multimodal_emb (
177- out , prompt_ids , text_weight , img_weight , img_token_lens , img_start_token_ids , img_start_locs , 0 , vob_size
178- )
179-
180- torch .cuda .synchronize ()
181- iters = 20
182- t1 = time .time ()
183- for _ in range (iters ):
184- multimodal_emb (
185- out , prompt_ids , text_weight , img_weight , img_token_lens , img_start_token_ids , img_start_locs , 0 , vob_size
186- )
187- torch .cuda .synchronize ()
188- t2 = time .time ()
189- print ("Triton time cost" , (t2 - t1 ) / iters )
190- return
191-
192-
193- # if __name__ == "__main__":
194- # test()
149+ return out_mark
0 commit comments