File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed
Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -106,13 +106,12 @@ def __init__(self, max_request_num):
106106 "LIGHTLLM_ENABLE_GPU_BUFFER_FOR_OUT_TOKEN_ID_COUNTER"
107107 )
108108 self .vocab_size = get_vocab_size (get_env_start_args ().model_dir )
109- self .req_to_next_token_id = torch .zeros (max_request_num + 1 , dtype = torch .int64 , device = "cuda" )
110109 self .req_to_presence_penalty = torch .zeros (max_request_num + 1 , dtype = torch .float32 , device = "cuda" )
111110 self .req_to_frequency_penalty = torch .zeros (max_request_num + 1 , dtype = torch .float32 , device = "cuda" )
112111 self .req_to_repetition_penalty = torch .zeros (max_request_num + 1 , dtype = torch .float32 , device = "cuda" )
113112 self .req_to_next_token_ids = torch .zeros (
114113 (max_request_num + 1 , 8 ),
115- dtype = torch .int32 ,
114+ dtype = torch .int64 ,
116115 device = "cuda" ,
117116 )
118117 self .req_to_exponential_decay_length_penalty = torch .zeros (
@@ -143,7 +142,7 @@ def init_req_sampling_params(self, req):
143142 req : InferReq = req
144143
145144 shm_param = req .sampling_param .shm_param
146- self .req_to_next_token_id [req .req_idx ].fill_ (req .get_last_gen_token ())
145+ self .req_to_next_token_ids [req .req_idx ][ 0 : 1 ].fill_ (req .get_last_gen_token ())
147146 self .req_to_presence_penalty [req .req_idx ].fill_ (shm_param .presence_penalty )
148147 self .req_to_frequency_penalty [req .req_idx ].fill_ (shm_param .frequency_penalty )
149148 self .req_to_repetition_penalty [req .req_idx ].fill_ (shm_param .repetition_penalty )
You can’t perform that action at this time.
0 commit comments