Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Any, Optional
from .triton_kernel.gen_prefill_params import gen_prefill_params
from .triton_kernel.gen_decode_params import gen_decode_params
from .triton_kernel.multimodal_emb import mark_multimodal_obj


class InferStateInfo:
Expand Down Expand Up @@ -98,3 +99,24 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr():
attr_.copy_(attr_value, non_blocking=True)
return

def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
"""
功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。
因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。
"""
multi_objs = []
for _, p in enumerate(self.multimodal_params):
for obj in p["images"] + p["audios"]:
multi_objs.append(obj)

if multi_objs:
obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda")
obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda")
marks = mark_multimodal_obj(
obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids
)
marks_array = marks.detach().cpu().numpy()
for mark, obj in zip(marks_array, multi_objs):
obj["_prefill_"] = mark > 0
return
143 changes: 90 additions & 53 deletions lightllm/common/basemodel/triton_kernel/multimodal_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,78 @@

@triton.jit
def _fwd_kernel(
Prompt_ids,
Prompt_ids,
Text_weight_embs,
Img_embs,
Out,
Img_token_lens,
Img_start_token_ids,
Img_start_locs,
stride_text_emb_s, stride_text_emb_d, # text_stride
stride_img_emb_s, stride_img_emb_d, # img_stride
stride_out_s, stride_out_d,
stride_text_emb_s,
stride_text_emb_d, # text_stride
stride_img_emb_s,
stride_img_emb_d, # img_stride
stride_out_s,
stride_out_d,
tp_text_start_token_id,
tp_text_end_token_id,
hidden_size,
BLOCK_HIDDEN_DIM: tl.constexpr
):
BLOCK_HIDDEN_DIM: tl.constexpr,
):

seq_index = tl.program_id(0).to(tl.int64)
img_handle_id = tl.program_id(1)

token_id = tl.load(Prompt_ids + seq_index)
off_d = tl.arange(0, BLOCK_HIDDEN_DIM)

# load store text emb
for _ in range(0, tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), 1):
load_emb = tl.load(Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, mask=off_d < hidden_size, other=0)
for _ in range(
0,
tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0),
1,
):
load_emb = tl.load(
Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d,
mask=off_d < hidden_size,
other=0,
)
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)

img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
# load store img emb
for _ in range(0, tl.where((img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), 1, 0), 1):
load_emb = tl.load(Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, mask=off_d < hidden_size, other=0)
for _ in range(
0,
tl.where(
(img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len),
1,
0,
),
1,
):
load_emb = tl.load(
Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d,
mask=off_d < hidden_size,
other=0,
)
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
return


@torch.no_grad()
def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs: torch.Tensor, img_embs: torch.Tensor,
img_token_lens: torch.Tensor, img_start_token_ids: torch.Tensor, img_start_locs: torch.Tensor,
tp_text_start_token_id,
tp_text_end_token_id):
def multimodal_emb(
out: torch.Tensor,
prompt_ids: torch.Tensor,
text_weight_embs: torch.Tensor,
img_embs: torch.Tensor,
img_token_lens: torch.Tensor,
img_start_token_ids: torch.Tensor,
img_start_locs: torch.Tensor,
tp_text_start_token_id,
tp_text_end_token_id,
):
total_len = prompt_ids.shape[0]
BLOCK = triton.next_power_of_2(out.shape[1])
# print(len(img_token_lens))
Expand All @@ -60,9 +90,12 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs
img_token_lens,
img_start_token_ids,
img_start_locs,
text_weight_embs.stride(0), text_weight_embs.stride(1),
img_embs.stride(0), img_embs.stride(1),
out.stride(0), out.stride(1),
text_weight_embs.stride(0),
text_weight_embs.stride(1),
img_embs.stride(0),
img_embs.stride(1),
out.stride(0),
out.stride(1),
tp_text_start_token_id,
tp_text_end_token_id,
hidden_size=out.shape[1],
Expand All @@ -73,40 +106,44 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs
return



def test():
S, D = 1024 * 1000, 128 * 64
vob_size = 320000
image_size = 10
image_token_size = 512

text_weight = torch.randn((vob_size, D), device='cuda', dtype=torch.float16)
img_weight = torch.randn((image_size * image_token_size, D), device='cuda', dtype=torch.float16)
img_token_lens = torch.full((image_size,), image_token_size, device='cuda', dtype=torch.long)
img_start_token_ids = (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long()
img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long()

prompt_ids = torch.arange(0, S, 1).cuda().long()
prompt_ids[0: image_size * image_token_size] = (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long()

out = torch.zeros((S, D), dtype=torch.float16, device="cuda")
print(out.shape)

import time

triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size)

torch.cuda.synchronize()
iters = 20
t1 = time.time()
for _ in range(iters):
triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size)
torch.cuda.synchronize()
t2 = time.time()
print("Triton time cost", (t2 - t1) / iters)
@triton.jit
def _mark_multimodal_obj_need_kernel(
obj_start_token_ids_ptr,
obj_token_lens_ptr,
obj_marks_ptr,
input_ids_ptr,
input_size,
BLOCK_SIZE: tl.constexpr,
):

obj_index = tl.program_id(0)
start_id = tl.load(obj_start_token_ids_ptr + obj_index)
token_len = tl.load(obj_token_lens_ptr + obj_index)

for block_start in range(0, input_size, BLOCK_SIZE):
block_range = block_start + tl.arange(0, BLOCK_SIZE)
cur_input_ids = tl.load(input_ids_ptr + block_range, mask=block_range < input_size, other=0)
mark = tl.where((cur_input_ids >= start_id) & (cur_input_ids < start_id + token_len), 1, 0)
mark = tl.sum(mark)
tl.store(obj_marks_ptr + obj_index, 1, mask=mark > 0)
return


# if __name__ == "__main__":
# test()

@torch.no_grad()
def mark_multimodal_obj(obj_start_token_ids: torch.Tensor, obj_token_lens: torch.Tensor, input_ids: torch.Tensor):
out_mark = torch.empty_like(obj_start_token_ids)
out_mark.fill_(0)
assert obj_start_token_ids.shape == obj_token_lens.shape
BLOCK = 512
grid = (obj_start_token_ids.shape[0],)
_mark_multimodal_obj_need_kernel[grid](
obj_start_token_ids_ptr=obj_start_token_ids,
obj_token_lens_ptr=obj_token_lens,
obj_marks_ptr=out_mark,
input_ids_ptr=input_ids,
input_size=input_ids.shape[0],
BLOCK_SIZE=BLOCK,
num_warps=1,
num_stages=1,
)
return out_mark
4 changes: 3 additions & 1 deletion lightllm/models/gemma3/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def context_forward(self, input_ids, infer_state, layer_weight):
else:
weight_mask[idx] = scale

infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)

for batch_id, p in enumerate(infer_state.multimodal_params):
for img in p["images"]:
# skip the same image
if img["token_id"] in img_start_token_ids:
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
Expand Down
5 changes: 4 additions & 1 deletion lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
device = layer_weight.wte_weight_.device
dtype = layer_weight.wte_weight_.dtype
hidden_size = layer_weight.wte_weight_.shape[1]

infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)

for batch_id, p in enumerate(infer_state.multimodal_params):
for img in p["images"] + p["audios"]:
# skip the same image
if img["token_id"] in img_start_token_ids:
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
Expand Down
48 changes: 48 additions & 0 deletions unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import pytest
from lightllm.common.basemodel.triton_kernel.multimodal_emb import mark_multimodal_obj, multimodal_emb
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


def test_mark_mubltimodal_obj():
obj_start_ids = torch.tensor([1, 4, 100], device="cuda", dtype=torch.int64)
obj_token_lens = torch.tensor([1, 3, 2], device="cuda", dtype=torch.int64)
input_ids = torch.tensor([1, 7, 9, 333], device="cuda", dtype=torch.int64)

mark_obj = mark_multimodal_obj(
obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids
)

assert torch.equal(mark_obj, torch.tensor([1, 0, 0], device="cuda"))


def test_multimodal_emb():
S, D = 1024 * 1000, 128 * 64
vob_size = 320000
image_size = 10
image_token_size = 512

text_weight = torch.randn((vob_size, D), device="cuda", dtype=torch.float16)
img_weight = torch.randn((image_size * image_token_size, D), device="cuda", dtype=torch.float16)
img_token_lens = torch.full((image_size,), image_token_size, device="cuda", dtype=torch.long)
img_start_token_ids = (
(torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long()
)
img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long()

prompt_ids = torch.arange(0, S, 1).cuda().long()
prompt_ids[0 : image_size * image_token_size] = (
(vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long()
)

out = torch.zeros((S, D), dtype=torch.float16, device="cuda")
multimodal_emb(
out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size
)
return


if __name__ == "__main__":
pytest.main()