Skip to content

Commit b3e6d4e

Browse files
committed
fix
1 parent ab03176 commit b3e6d4e

File tree

1 file changed

+1
-46
lines changed

1 file changed

+1
-46
lines changed

lightllm/common/basemodel/triton_kernel/multimodal_emb.py

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -146,49 +146,4 @@ def mark_multimodal_obj(obj_start_token_ids: torch.Tensor, obj_token_lens: torch
146146
num_warps=1,
147147
num_stages=1,
148148
)
149-
return out_mark
150-
151-
152-
def test():
153-
S, D = 1024 * 1000, 128 * 64
154-
vob_size = 320000
155-
image_size = 10
156-
image_token_size = 512
157-
158-
text_weight = torch.randn((vob_size, D), device="cuda", dtype=torch.float16)
159-
img_weight = torch.randn((image_size * image_token_size, D), device="cuda", dtype=torch.float16)
160-
img_token_lens = torch.full((image_size,), image_token_size, device="cuda", dtype=torch.long)
161-
img_start_token_ids = (
162-
(torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long()
163-
)
164-
img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long()
165-
166-
prompt_ids = torch.arange(0, S, 1).cuda().long()
167-
prompt_ids[0 : image_size * image_token_size] = (
168-
(vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long()
169-
)
170-
171-
out = torch.zeros((S, D), dtype=torch.float16, device="cuda")
172-
print(out.shape)
173-
174-
import time
175-
176-
multimodal_emb(
177-
out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size
178-
)
179-
180-
torch.cuda.synchronize()
181-
iters = 20
182-
t1 = time.time()
183-
for _ in range(iters):
184-
multimodal_emb(
185-
out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size
186-
)
187-
torch.cuda.synchronize()
188-
t2 = time.time()
189-
print("Triton time cost", (t2 - t1) / iters)
190-
return
191-
192-
193-
# if __name__ == "__main__":
194-
# test()
149+
return out_mark

0 commit comments

Comments
 (0)