Skip to content

Commit cd3e0b5

Browse files
committed
improve postprocess
1 parent 3030b27 commit cd3e0b5

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

lightllm/common/req_manager.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
5858
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
5959
)
6060
self.mem_manager = mem_manager
61-
self.req_sample_parms_manager = None
61+
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
6262
self.max_request_num = max_request_num
6363
self.HOLD_REQUEST_ID = max_request_num
6464

@@ -68,8 +68,6 @@ def alloc(self):
6868
def free(self, free_req_indexes: List[int], free_token_index):
6969
for req_index in free_req_indexes:
7070
self.req_list.free(req_index)
71-
if self.req_sample_parms_manager is not None:
72-
self.req_sample_parms_manager.p_token_vocabs[free_req_indexes] = 0
7371

7472
if self.req_list.is_all_free():
7573
logger.debug(f"freed all request size {self.req_list.can_alloc_size}")
@@ -88,3 +86,38 @@ def free_token(self, free_token_index):
8886
def free_all(self):
8987
self.req_list = _ReqLinkedList(self.max_request_num)
9088
return
89+
90+
91+
class ReqSamplingParamsManager:
92+
"""
93+
ReqSamplingParamsManager 将输出采样参数中,确定比较固定的部分,纳入到 gpu buffer中进行管理,这样可以更快捷的
94+
利用cuda kernel 将采样参数提取为以batch 为单位的采样参数,对于哪些比较动态,或者存在特殊处理的后处理参数,
95+
则保留从 InferSamplingParams 中进行动态读取和动态组batch, 具体使用可以参考
96+
lightllm/server/router/model_infer/mode_backend/generic_post_process.py 文件中的使用方式。
97+
"""
98+
99+
def __init__(self, max_request_num):
100+
self.req_to_presence_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
101+
self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
102+
self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
103+
self.req_to_temperature = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
104+
self.req_to_exponential_decay_length_penalty = torch.zeros(
105+
max_request_num + 1, dtype=torch.float32, device="cuda"
106+
)
107+
108+
def init_req_sampling_params(self, req):
109+
# fix cycle loop import
110+
from lightllm.server.router.model_infer.infer_batch import InferReq
111+
112+
req: InferReq = req
113+
114+
shm_param = req.sampling_param.shm_param
115+
self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty)
116+
self.req_to_frequency_penalty[req.req_idx].fill_(shm_param.frequency_penalty)
117+
self.req_to_repetition_penalty[req.req_idx].fill_(shm_param.repetition_penalty)
118+
self.req_to_temperature[req.req_idx].fill_(shm_param.temperature)
119+
exponential_decay_length_penalty = shm_param.exponential_decay_length_penalty.to_tuple()
120+
self.req_to_exponential_decay_length_penalty[req.req_id].fill_(exponential_decay_length_penalty[1])
121+
122+
def get_sampling_batch_params(self, req_idx_list: List[int]):
123+
pass

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@
2222
logger = init_logger(__name__)
2323

2424

25-
class ReqSampleParmsManager:
26-
def __init__(self, max_request_num, vocab_size):
27-
self.p_token_vocabs = torch.zeros((max_request_num, vocab_size), dtype=torch.int16, device="cuda")
28-
29-
3025
@dataclass
3126
class InferenceContext:
3227
req_manager: ReqManager = None # gpu 请求管理
@@ -42,8 +37,6 @@ class InferenceContext:
4237
def register(
4338
self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int
4439
):
45-
if enable_env_vars("ENABLE_REQ_PARAM_CACHE"):
46-
req_manager.req_sample_parms_manager = ReqSampleParmsManager(req_manager.max_request_num, vocab_size)
4740
self.req_manager = req_manager
4841
self.radix_cache = radix_cache
4942
self.shm_req_manager = shm_req_manager
@@ -272,18 +265,11 @@ def init_all(self):
272265
self.shm_req.link_logprobs_shm_array()
273266
self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size)
274267

275-
if enable_env_vars("ENABLE_REQ_PARAM_CACHE"):
276-
if self.sampling_param.shm_param.input_penalty:
277-
idxs = torch.bincount(self.shm_req.get_prompt_ids())
278-
g_infer_context.req_manager.req_sample_parms_manager.p_token_vocabs[self.req_idx][
279-
: len(idxs)
280-
] = idxs
281-
self.out_token_id_count = None
268+
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
269+
if self.sampling_param.shm_param.input_penalty:
270+
self.out_token_id_count = collections.Counter(self.shm_req.get_prompt_ids())
282271
else:
283-
if self.sampling_param.shm_param.input_penalty:
284-
self.out_token_id_count = collections.Counter(self.shm_req.get_prompt_ids())
285-
else:
286-
self.out_token_id_count = collections.defaultdict(int)
272+
self.out_token_id_count = collections.defaultdict(int)
287273

288274
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
289275
# token healing mode 才被使用的管理对象

0 commit comments

Comments
 (0)