diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 0b383185e..59f607a01 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -155,7 +155,7 @@ def init_req_sampling_params(self, req): else: self.req_to_out_token_id_counter[req.req_idx].fill_(0) if req.sampling_param.shm_param.input_penalty and req.need_out_token_id_statistics: - prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids()).pin_memory().cuda(non_blocking=True) + prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids_numpy()).pin_memory().cuda(non_blocking=True) token_id_counter( prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx] ) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index e4b345063..06a728925 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -188,6 +188,9 @@ def link_logprobs_shm_array(self): def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() + def get_prompt_ids_numpy(self): + return self.shm_prompt_ids.arr[: self.input_len] + def to_router_rpc_obj(self): if hasattr(self, "multimodal_params"): return ( diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index abddea356..143f6081c 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -116,15 +116,14 @@ def generate_new_batch(self, current_batch: Batch): if ok_insert: can_run_list.extend(cur_group_reqs) + new_batch = None if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) - for req in abort_req_list: - self.router.shm_req_manager.put_back_req_obj(req) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] - return new_batch - else: - return None + for req in abort_req_list: + self.router.shm_req_manager.put_back_req_obj(req) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch def _add_to_group(self, cur_group_reqs, req: Req): if len(cur_group_reqs) == 0: diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 660271ab6..f1dae4cac 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -91,15 +91,13 @@ def generate_new_batch(self, current_batch: Batch): can_run_list.append(req) else: break - + new_batch = None if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) - for req in abort_req_list: - self.router.shm_req_manager.put_back_req_obj(req) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] - return new_batch - else: - return None + for req in abort_req_list: + self.router.shm_req_manager.put_back_req_obj(req) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): is_busy = self.is_busy() diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py index d8bf36680..e89dda66e 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py @@ -48,15 +48,13 @@ def generate_new_batch(self, current_batch: Batch): can_run_list.append(req) else: break - + new_batch = None if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) - for req in abort_req_list: - self.router.shm_req_manager.put_back_req_obj(req) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] - return new_batch - else: - return None + for req in abort_req_list: + self.router.shm_req_manager.put_back_req_obj(req) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): is_busy = self.is_busy()