@@ -58,7 +58,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
5858 (max_request_num + 1 , max_sequence_length ), dtype = torch .int32 , device = "cuda"
5959 )
6060 self .mem_manager = mem_manager
61- self .req_sample_parms_manager = None
61+ self .req_sampling_params_manager = ReqSamplingParamsManager ( max_request_num )
6262 self .max_request_num = max_request_num
6363 self .HOLD_REQUEST_ID = max_request_num
6464
@@ -68,8 +68,6 @@ def alloc(self):
6868 def free (self , free_req_indexes : List [int ], free_token_index ):
6969 for req_index in free_req_indexes :
7070 self .req_list .free (req_index )
71- if self .req_sample_parms_manager is not None :
72- self .req_sample_parms_manager .p_token_vocabs [free_req_indexes ] = 0
7371
7472 if self .req_list .is_all_free ():
7573 logger .debug (f"freed all request size { self .req_list .can_alloc_size } " )
@@ -88,3 +86,38 @@ def free_token(self, free_token_index):
8886 def free_all (self ):
8987 self .req_list = _ReqLinkedList (self .max_request_num )
9088 return
89+
90+
91+ class ReqSamplingParamsManager :
92+ """
93+ ReqSamplingParamsManager 将输出采样参数中,确定比较固定的部分,纳入到 gpu buffer中进行管理,这样可以更快捷的
94+ 利用cuda kernel 将采样参数提取为以batch 为单位的采样参数,对于哪些比较动态,或者存在特殊处理的后处理参数,
95+ 则保留从 InferSamplingParams 中进行动态读取和动态组batch, 具体使用可以参考
96+ lightllm/server/router/model_infer/mode_backend/generic_post_process.py 文件中的使用方式。
97+ """
98+
99+ def __init__ (self , max_request_num ):
100+ self .req_to_presence_penalty = torch .zeros (max_request_num + 1 , dtype = torch .float32 , device = "cuda" )
101+ self .req_to_frequency_penalty = torch .zeros (max_request_num + 1 , dtype = torch .float32 , device = "cuda" )
102+ self .req_to_repetition_penalty = torch .zeros (max_request_num + 1 , dtype = torch .float32 , device = "cuda" )
103+ self .req_to_temperature = torch .zeros (max_request_num + 1 , dtype = torch .float32 , device = "cuda" )
104+ self .req_to_exponential_decay_length_penalty = torch .zeros (
105+ max_request_num + 1 , dtype = torch .float32 , device = "cuda"
106+ )
107+
108+ def init_req_sampling_params (self , req ):
109+ # fix cycle loop import
110+ from lightllm .server .router .model_infer .infer_batch import InferReq
111+
112+ req : InferReq = req
113+
114+ shm_param = req .sampling_param .shm_param
115+ self .req_to_presence_penalty [req .req_idx ].fill_ (shm_param .presence_penalty )
116+ self .req_to_frequency_penalty [req .req_idx ].fill_ (shm_param .frequency_penalty )
117+ self .req_to_repetition_penalty [req .req_idx ].fill_ (shm_param .repetition_penalty )
118+ self .req_to_temperature [req .req_idx ].fill_ (shm_param .temperature )
119+ exponential_decay_length_penalty = shm_param .exponential_decay_length_penalty .to_tuple ()
120+ self .req_to_exponential_decay_length_penalty [req .req_id ].fill_ (exponential_decay_length_penalty [1 ])
121+
122+ def get_sampling_batch_params (self , req_idx_list : List [int ]):
123+ pass
0 commit comments