Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions lightllm/common/basemodel/triton_kernel/gen_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,37 +121,54 @@ def _token_id_counter_update_kernel(
counter_stride_m,
counter_stride_n,
next_token_ids_ptr,
mask_ptr,
batch_size,
HAS_MASK: tl.constexpr,
BLOCK: tl.constexpr,
):

block_start_index = tl.program_id(0) * BLOCK
offs = block_start_index + tl.arange(0, BLOCK)
mask = offs < batch_size

req_idx = tl.load(b_req_idx_ptr + offs, mask=mask, other=0)
token_ids = tl.load(next_token_ids_ptr + offs, mask=mask, other=0)

tl.atomic_add(
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, 1, mask=mask
)
loc_mask = offs < batch_size

req_idx = tl.load(b_req_idx_ptr + offs, mask=loc_mask, other=0)
token_ids = tl.load(next_token_ids_ptr + offs, mask=loc_mask, other=0)

if HAS_MASK:
mask = tl.load(mask_ptr + offs, mask=loc_mask, other=False)
tl.atomic_add(
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n,
1,
mask=loc_mask & mask,
)
else:
tl.atomic_add(
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n,
1,
mask=loc_mask,
)
Comment on lines +137 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and reduce code duplication, this if/else block can be refactored. The tl.atomic_add call is repeated in both branches with only the mask parameter being different. Determine the mask to use first, and then make a single tl.atomic_add call.

    final_mask = loc_mask
    if HAS_MASK:
        mask = tl.load(mask_ptr + offs, mask=loc_mask, other=False)
        final_mask = final_mask & mask

    tl.atomic_add(
        req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n,
        1,
        mask=final_mask,
    )

return


@torch.no_grad()
def update_req_to_token_id_counter(
b_req_idx: torch.Tensor, next_token_ids: torch.Tensor, req_to_out_token_id_counter: torch.Tensor
b_req_idx: torch.Tensor,
next_token_ids: torch.Tensor,
req_to_out_token_id_counter: torch.Tensor,
mask: torch.Tensor = None,
):
batch_size = b_req_idx.shape[0]
BLOCK = 256

has_mask = mask is not None
_token_id_counter_update_kernel[(triton.cdiv(batch_size, BLOCK),)](
b_req_idx_ptr=b_req_idx,
req_to_out_token_id_counter_ptr=req_to_out_token_id_counter,
counter_stride_m=req_to_out_token_id_counter.stride(0),
counter_stride_n=req_to_out_token_id_counter.stride(1),
next_token_ids_ptr=next_token_ids,
mask_ptr=mask,
batch_size=batch_size,
HAS_MASK=has_mask,
BLOCK=BLOCK,
num_warps=1,
)
Expand Down
38 changes: 22 additions & 16 deletions lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,29 +162,35 @@ def init_req_sampling_params(self, req):

return

def update_reqs_out_token_counter_gpu(
self, b_req_idx: torch.Tensor, next_token_ids: torch.Tensor, mask: torch.Tensor = None
):
if self.penalty_counter_mode not in ["gpu_counter", "pin_mem_counter"]:
return

assert b_req_idx.is_cuda and next_token_ids.is_cuda and b_req_idx.shape[0] == next_token_ids.shape[0]

update_req_to_token_id_counter(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
req_to_out_token_id_counter=self.req_to_out_token_id_counter,
mask=mask,
)
return

def update_reqs_token_counter(
self, req_objs: List, next_token_ids: List[int], accept_mark: Optional[List[List[bool]]] = None
):
from lightllm.server.router.model_infer.infer_batch import InferReq

req_objs: List[InferReq] = req_objs

if self.penalty_counter_mode == "cpu_counter":
for req_obj, next_token_id in zip(req_objs, next_token_ids):
if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0:
req_obj.out_token_id_count[next_token_id] += 1
else:
b_req_idx = torch.tensor(
[req.req_idx for req in req_objs], dtype=torch.int32, device="cpu", pin_memory=True
).cuda(non_blocking=True)
next_token_ids = (
torch.tensor(next_token_ids, dtype=torch.int32, device="cpu").pin_memory().cuda(non_blocking=True)
)
update_req_to_token_id_counter(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
req_to_out_token_id_counter=self.req_to_out_token_id_counter,
)
if self.penalty_counter_mode != "cpu_counter":
return

for req_obj, next_token_id in zip(req_objs, next_token_ids):
if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0:
req_obj.out_token_id_count[next_token_id] += 1
return

def gen_cpu_out_token_counter_sampling_params(self, req_objs: List):
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def register(
self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int
):
self.req_manager = req_manager
self.req_sampling_manager = self.req_manager.req_sampling_params_manager
self.radix_cache = radix_cache
self.shm_req_manager = shm_req_manager

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def prefill_normal(
b_mtp_index=model_input.b_mtp_index,
b_has_out=b_has_out,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=model_input.b_req_idx,
next_token_ids=next_token_ids,
mask=b_has_out,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -158,6 +163,10 @@ def decode_normal(
model_input.b_req_idx,
model_input.b_mtp_index,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=model_input.b_req_idx,
next_token_ids=next_token_ids,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -205,6 +214,11 @@ def prefill_mtp(
b_mtp_index=model_input.b_mtp_index,
b_has_out=b_has_out,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=model_input.b_req_idx,
next_token_ids=next_token_ids,
mask=b_has_out,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -305,6 +319,13 @@ def decode_mtp(
b_req_idx=model_input.b_req_idx,
mtp_accept_len=mtp_accept_len,
)

g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=model_input.b_req_idx,
next_token_ids=next_token_ids,
mask=accepted_index == 1,
)

next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down
40 changes: 40 additions & 0 deletions lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def prefill_normal(
b_mtp_index=b_mtp_index,
b_has_out=b_has_out,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
mask=b_has_out,
)

next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -182,6 +188,10 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -254,6 +264,11 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer
b_mtp_index=b_mtp_index,
b_has_out=b_has_out,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
mask=b_has_out,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -318,6 +333,10 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -374,6 +393,11 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
b_mtp_index=b_mtp_index,
b_has_out=b_has_out,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
mask=b_has_out,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -493,6 +517,11 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
b_req_idx=b_req_idx,
mtp_accept_len=mtp_accept_len,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
mask=accepted_index == 1,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down Expand Up @@ -571,6 +600,11 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
b_mtp_index=b_mtp_index,
b_has_out=b_has_out,
)
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
mask=b_has_out,
)

# spec prefill: MTP
draft_micro_input0, draft_micro_input1 = micro_input0, micro_input1
Expand Down Expand Up @@ -733,6 +767,12 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf
b_req_idx=b_req_idx,
mtp_accept_len=mtp_accept_len,
)

g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
b_req_idx=b_req_idx,
next_token_ids=next_token_ids,
mask=accepted_index == 1,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
Expand Down