diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 021de6843..e45cc11c7 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -5,6 +5,7 @@ from typing import Tuple, Any, Optional from .triton_kernel.gen_prefill_params import gen_prefill_params from .triton_kernel.gen_decode_params import gen_decode_params +from .triton_kernel.multimodal_emb import mark_multimodal_obj class InferStateInfo: @@ -98,3 +99,24 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): attr_.copy_(attr_value, non_blocking=True) return + + def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor): + """ + 功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。 + 因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。 + """ + multi_objs = [] + for _, p in enumerate(self.multimodal_params): + for obj in p["images"] + p["audios"]: + multi_objs.append(obj) + + if multi_objs: + obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda") + obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda") + marks = mark_multimodal_obj( + obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids + ) + marks_array = marks.detach().cpu().numpy() + for mark, obj in zip(marks_array, multi_objs): + obj["_prefill_"] = mark > 0 + return diff --git a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py index 8b66827a5..0f9279c55 100644 --- a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py +++ b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py @@ -5,48 +5,78 @@ @triton.jit def _fwd_kernel( - Prompt_ids, + Prompt_ids, Text_weight_embs, Img_embs, Out, Img_token_lens, Img_start_token_ids, Img_start_locs, - stride_text_emb_s, stride_text_emb_d, # text_stride - stride_img_emb_s, stride_img_emb_d, # img_stride - stride_out_s, stride_out_d, + stride_text_emb_s, + stride_text_emb_d, # text_stride + stride_img_emb_s, + stride_img_emb_d, # img_stride + stride_out_s, + stride_out_d, tp_text_start_token_id, tp_text_end_token_id, hidden_size, - BLOCK_HIDDEN_DIM: tl.constexpr - ): + BLOCK_HIDDEN_DIM: tl.constexpr, +): seq_index = tl.program_id(0).to(tl.int64) img_handle_id = tl.program_id(1) token_id = tl.load(Prompt_ids + seq_index) off_d = tl.arange(0, BLOCK_HIDDEN_DIM) - + # load store text emb - for _ in range(0, tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), 1): - load_emb = tl.load(Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, mask=off_d < hidden_size, other=0) + for _ in range( + 0, + tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), + 1, + ): + load_emb = tl.load( + Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, + mask=off_d < hidden_size, + other=0, + ) tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) - + img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0) img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0) img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0) # load store img emb - for _ in range(0, tl.where((img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), 1, 0), 1): - load_emb = tl.load(Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, mask=off_d < hidden_size, other=0) + for _ in range( + 0, + tl.where( + (img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), + 1, + 0, + ), + 1, + ): + load_emb = tl.load( + Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, + mask=off_d < hidden_size, + other=0, + ) tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) return @torch.no_grad() -def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs: torch.Tensor, img_embs: torch.Tensor, - img_token_lens: torch.Tensor, img_start_token_ids: torch.Tensor, img_start_locs: torch.Tensor, - tp_text_start_token_id, - tp_text_end_token_id): +def multimodal_emb( + out: torch.Tensor, + prompt_ids: torch.Tensor, + text_weight_embs: torch.Tensor, + img_embs: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs: torch.Tensor, + tp_text_start_token_id, + tp_text_end_token_id, +): total_len = prompt_ids.shape[0] BLOCK = triton.next_power_of_2(out.shape[1]) # print(len(img_token_lens)) @@ -60,9 +90,12 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs img_token_lens, img_start_token_ids, img_start_locs, - text_weight_embs.stride(0), text_weight_embs.stride(1), - img_embs.stride(0), img_embs.stride(1), - out.stride(0), out.stride(1), + text_weight_embs.stride(0), + text_weight_embs.stride(1), + img_embs.stride(0), + img_embs.stride(1), + out.stride(0), + out.stride(1), tp_text_start_token_id, tp_text_end_token_id, hidden_size=out.shape[1], @@ -73,40 +106,44 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs return - -def test(): - S, D = 1024 * 1000, 128 * 64 - vob_size = 320000 - image_size = 10 - image_token_size = 512 - - text_weight = torch.randn((vob_size, D), device='cuda', dtype=torch.float16) - img_weight = torch.randn((image_size * image_token_size, D), device='cuda', dtype=torch.float16) - img_token_lens = torch.full((image_size,), image_token_size, device='cuda', dtype=torch.long) - img_start_token_ids = (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long() - img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long() - - prompt_ids = torch.arange(0, S, 1).cuda().long() - prompt_ids[0: image_size * image_token_size] = (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long() - - out = torch.zeros((S, D), dtype=torch.float16, device="cuda") - print(out.shape) - - import time - - triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size) - - torch.cuda.synchronize() - iters = 20 - t1 = time.time() - for _ in range(iters): - triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size) - torch.cuda.synchronize() - t2 = time.time() - print("Triton time cost", (t2 - t1) / iters) +@triton.jit +def _mark_multimodal_obj_need_kernel( + obj_start_token_ids_ptr, + obj_token_lens_ptr, + obj_marks_ptr, + input_ids_ptr, + input_size, + BLOCK_SIZE: tl.constexpr, +): + + obj_index = tl.program_id(0) + start_id = tl.load(obj_start_token_ids_ptr + obj_index) + token_len = tl.load(obj_token_lens_ptr + obj_index) + + for block_start in range(0, input_size, BLOCK_SIZE): + block_range = block_start + tl.arange(0, BLOCK_SIZE) + cur_input_ids = tl.load(input_ids_ptr + block_range, mask=block_range < input_size, other=0) + mark = tl.where((cur_input_ids >= start_id) & (cur_input_ids < start_id + token_len), 1, 0) + mark = tl.sum(mark) + tl.store(obj_marks_ptr + obj_index, 1, mask=mark > 0) return -# if __name__ == "__main__": -# test() - +@torch.no_grad() +def mark_multimodal_obj(obj_start_token_ids: torch.Tensor, obj_token_lens: torch.Tensor, input_ids: torch.Tensor): + out_mark = torch.empty_like(obj_start_token_ids) + out_mark.fill_(0) + assert obj_start_token_ids.shape == obj_token_lens.shape + BLOCK = 512 + grid = (obj_start_token_ids.shape[0],) + _mark_multimodal_obj_need_kernel[grid]( + obj_start_token_ids_ptr=obj_start_token_ids, + obj_token_lens_ptr=obj_token_lens, + obj_marks_ptr=out_mark, + input_ids_ptr=input_ids, + input_size=input_ids.shape[0], + BLOCK_SIZE=BLOCK, + num_warps=1, + num_stages=1, + ) + return out_mark diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 89c8e0d8d..46b782879 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -35,10 +35,12 @@ def context_forward(self, input_ids, infer_state, layer_weight): else: weight_mask[idx] = scale + infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"]: # skip the same image - if img["token_id"] in img_start_token_ids: + if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 60c9e0564..b5b31a413 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -42,10 +42,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei device = layer_weight.wte_weight_.device dtype = layer_weight.wte_weight_.dtype hidden_size = layer_weight.wte_weight_.shape[1] + + infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image - if img["token_id"] in img_start_token_ids: + if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) diff --git a/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py new file mode 100755 index 000000000..438eaa157 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py @@ -0,0 +1,48 @@ +import torch +import pytest +from lightllm.common.basemodel.triton_kernel.multimodal_emb import mark_multimodal_obj, multimodal_emb +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def test_mark_mubltimodal_obj(): + obj_start_ids = torch.tensor([1, 4, 100], device="cuda", dtype=torch.int64) + obj_token_lens = torch.tensor([1, 3, 2], device="cuda", dtype=torch.int64) + input_ids = torch.tensor([1, 7, 9, 333], device="cuda", dtype=torch.int64) + + mark_obj = mark_multimodal_obj( + obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids + ) + + assert torch.equal(mark_obj, torch.tensor([1, 0, 0], device="cuda")) + + +def test_multimodal_emb(): + S, D = 1024 * 1000, 128 * 64 + vob_size = 320000 + image_size = 10 + image_token_size = 512 + + text_weight = torch.randn((vob_size, D), device="cuda", dtype=torch.float16) + img_weight = torch.randn((image_size * image_token_size, D), device="cuda", dtype=torch.float16) + img_token_lens = torch.full((image_size,), image_token_size, device="cuda", dtype=torch.long) + img_start_token_ids = ( + (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long() + ) + img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long() + + prompt_ids = torch.arange(0, S, 1).cuda().long() + prompt_ids[0 : image_size * image_token_size] = ( + (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long() + ) + + out = torch.zeros((S, D), dtype=torch.float16, device="cuda") + multimodal_emb( + out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size + ) + return + + +if __name__ == "__main__": + pytest.main()