Skip to content

Commit 8705d0a

Browse files
committed
fix next token ids.
1 parent 2d81be7 commit 8705d0a

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

lightllm/common/req_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,12 @@ def __init__(self, max_request_num):
106106
"LIGHTLLM_ENABLE_GPU_BUFFER_FOR_OUT_TOKEN_ID_COUNTER"
107107
)
108108
self.vocab_size = get_vocab_size(get_env_start_args().model_dir)
109-
self.req_to_next_token_id = torch.zeros(max_request_num + 1, dtype=torch.int64, device="cuda")
110109
self.req_to_presence_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
111110
self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
112111
self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
113112
self.req_to_next_token_ids = torch.zeros(
114113
(max_request_num + 1, 8),
115-
dtype=torch.int32,
114+
dtype=torch.int64,
116115
device="cuda",
117116
)
118117
self.req_to_exponential_decay_length_penalty = torch.zeros(
@@ -143,7 +142,7 @@ def init_req_sampling_params(self, req):
143142
req: InferReq = req
144143

145144
shm_param = req.sampling_param.shm_param
146-
self.req_to_next_token_id[req.req_idx].fill_(req.get_last_gen_token())
145+
self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token())
147146
self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty)
148147
self.req_to_frequency_penalty[req.req_idx].fill_(shm_param.frequency_penalty)
149148
self.req_to_repetition_penalty[req.req_idx].fill_(shm_param.repetition_penalty)

0 commit comments

Comments
 (0)