diff --git a/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py b/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py index bfff291cd..638ad92d3 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py @@ -121,37 +121,54 @@ def _token_id_counter_update_kernel( counter_stride_m, counter_stride_n, next_token_ids_ptr, + mask_ptr, batch_size, + HAS_MASK: tl.constexpr, BLOCK: tl.constexpr, ): block_start_index = tl.program_id(0) * BLOCK offs = block_start_index + tl.arange(0, BLOCK) - mask = offs < batch_size - - req_idx = tl.load(b_req_idx_ptr + offs, mask=mask, other=0) - token_ids = tl.load(next_token_ids_ptr + offs, mask=mask, other=0) - - tl.atomic_add( - req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, 1, mask=mask - ) + loc_mask = offs < batch_size + + req_idx = tl.load(b_req_idx_ptr + offs, mask=loc_mask, other=0) + token_ids = tl.load(next_token_ids_ptr + offs, mask=loc_mask, other=0) + + if HAS_MASK: + mask = tl.load(mask_ptr + offs, mask=loc_mask, other=False) + tl.atomic_add( + req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, + 1, + mask=loc_mask & mask, + ) + else: + tl.atomic_add( + req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, + 1, + mask=loc_mask, + ) return @torch.no_grad() def update_req_to_token_id_counter( - b_req_idx: torch.Tensor, next_token_ids: torch.Tensor, req_to_out_token_id_counter: torch.Tensor + b_req_idx: torch.Tensor, + next_token_ids: torch.Tensor, + req_to_out_token_id_counter: torch.Tensor, + mask: torch.Tensor = None, ): batch_size = b_req_idx.shape[0] BLOCK = 256 - + has_mask = mask is not None _token_id_counter_update_kernel[(triton.cdiv(batch_size, BLOCK),)]( b_req_idx_ptr=b_req_idx, req_to_out_token_id_counter_ptr=req_to_out_token_id_counter, counter_stride_m=req_to_out_token_id_counter.stride(0), counter_stride_n=req_to_out_token_id_counter.stride(1), next_token_ids_ptr=next_token_ids, + mask_ptr=mask, batch_size=batch_size, + HAS_MASK=has_mask, BLOCK=BLOCK, num_warps=1, ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index d4cf24f75..0b383185e 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -162,6 +162,22 @@ def init_req_sampling_params(self, req): return + def update_reqs_out_token_counter_gpu( + self, b_req_idx: torch.Tensor, next_token_ids: torch.Tensor, mask: torch.Tensor = None + ): + if self.penalty_counter_mode not in ["gpu_counter", "pin_mem_counter"]: + return + + assert b_req_idx.is_cuda and next_token_ids.is_cuda and b_req_idx.shape[0] == next_token_ids.shape[0] + + update_req_to_token_id_counter( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + req_to_out_token_id_counter=self.req_to_out_token_id_counter, + mask=mask, + ) + return + def update_reqs_token_counter( self, req_objs: List, next_token_ids: List[int], accept_mark: Optional[List[List[bool]]] = None ): @@ -169,22 +185,12 @@ def update_reqs_token_counter( req_objs: List[InferReq] = req_objs - if self.penalty_counter_mode == "cpu_counter": - for req_obj, next_token_id in zip(req_objs, next_token_ids): - if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0: - req_obj.out_token_id_count[next_token_id] += 1 - else: - b_req_idx = torch.tensor( - [req.req_idx for req in req_objs], dtype=torch.int32, device="cpu", pin_memory=True - ).cuda(non_blocking=True) - next_token_ids = ( - torch.tensor(next_token_ids, dtype=torch.int32, device="cpu").pin_memory().cuda(non_blocking=True) - ) - update_req_to_token_id_counter( - b_req_idx=b_req_idx, - next_token_ids=next_token_ids, - req_to_out_token_id_counter=self.req_to_out_token_id_counter, - ) + if self.penalty_counter_mode != "cpu_counter": + return + + for req_obj, next_token_id in zip(req_objs, next_token_ids): + if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0: + req_obj.out_token_id_count[next_token_id] += 1 return def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 0b930c72d..67d69aa38 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -38,6 +38,7 @@ def register( self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int ): self.req_manager = req_manager + self.req_sampling_manager = self.req_manager.req_sampling_params_manager self.radix_cache = radix_cache self.shm_req_manager = shm_req_manager diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 0b34ba4e7..39d345ff5 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -115,6 +115,11 @@ def prefill_normal( b_mtp_index=model_input.b_mtp_index, b_has_out=b_has_out, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=model_input.b_req_idx, + next_token_ids=next_token_ids, + mask=b_has_out, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -158,6 +163,10 @@ def decode_normal( model_input.b_req_idx, model_input.b_mtp_index, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=model_input.b_req_idx, + next_token_ids=next_token_ids, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -205,6 +214,11 @@ def prefill_mtp( b_mtp_index=model_input.b_mtp_index, b_has_out=b_has_out, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=model_input.b_req_idx, + next_token_ids=next_token_ids, + mask=b_has_out, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -305,6 +319,13 @@ def decode_mtp( b_req_idx=model_input.b_req_idx, mtp_accept_len=mtp_accept_len, ) + + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=model_input.b_req_idx, + next_token_ids=next_token_ids, + mask=accepted_index == 1, + ) + next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 90d65642d..a90a946fd 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -134,6 +134,12 @@ def prefill_normal( b_mtp_index=b_mtp_index, b_has_out=b_has_out, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + mask=b_has_out, + ) + next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -182,6 +188,10 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -254,6 +264,11 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer b_mtp_index=b_mtp_index, b_has_out=b_has_out, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + mask=b_has_out, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -318,6 +333,10 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -374,6 +393,11 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] b_mtp_index=b_mtp_index, b_has_out=b_has_out, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + mask=b_has_out, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -493,6 +517,11 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): b_req_idx=b_req_idx, mtp_accept_len=mtp_accept_len, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + mask=accepted_index == 1, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs ) @@ -571,6 +600,11 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I b_mtp_index=b_mtp_index, b_has_out=b_has_out, ) + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + mask=b_has_out, + ) # spec prefill: MTP draft_micro_input0, draft_micro_input1 = micro_input0, micro_input1 @@ -733,6 +767,12 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_req_idx=b_req_idx, mtp_accept_len=mtp_accept_len, ) + + g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu( + b_req_idx=b_req_idx, + next_token_ids=next_token_ids, + mask=accepted_index == 1, + ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( next_token_ids, next_token_logprobs )