Skip to content

Commit 1cdccf1

Browse files
author
none
committed
fix
1 parent d8e687e commit 1cdccf1

File tree

1 file changed

+46
-37
lines changed

1 file changed

+46
-37
lines changed

lightllm/server/core/objs/sampling_params.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import ctypes
3-
from typing import List, Tuple, Union
3+
from typing import Optional, List, Tuple, Union
44
from transformers import GenerationConfig
55
from lightllm.server.req_id_generator import MAX_BEST_OF
66

@@ -27,21 +27,24 @@ class StopSequence(ctypes.Structure):
2727
("sequence_str_len", ctypes.c_int),
2828
]
2929

30-
def initialize(self, sequence: List[int], sequence_str: str = ""):
30+
def initialize(self, sequence: List[int], sequence_str: Optional[str] = None):
3131
self.size = len(sequence)
3232
assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long."
3333
assert all(isinstance(e, int) for e in sequence), "all must be int"
3434
self.sequence[: self.size] = sequence[:]
3535

36-
sequence_str_bytes = sequence_str.encode("utf-8")
37-
assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long."
38-
self.sequence_str = sequence_str_bytes
39-
self.sequence_str_len = len(sequence_str_bytes)
36+
if sequence_str is not None:
37+
sequence_str_bytes = sequence_str.encode("utf-8")
38+
assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long."
39+
self.sequence_str = sequence_str_bytes
40+
self.sequence_str_len = len(sequence_str_bytes)
41+
else:
42+
self.sequence_str_len = 0
4043

41-
def to_list(self):
44+
def to_list(self) -> List[int]:
4245
return list(self.sequence[0 : self.size])
4346

44-
def to_string(self):
47+
def to_string(self) -> str:
4548
return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8")
4649

4750

@@ -52,45 +55,51 @@ class StopSequenceGroups(ctypes.Structure):
5255
("size", ctypes.c_int),
5356
]
5457

55-
def initialize(self, stop_sequences: Union[str, List], tokenizer):
58+
def initialize(self, stop_sequences: Union[str, List[Union[List[int], str]]], tokenizer):
59+
if stop_sequences is None:
60+
stop_sequences = []
61+
elif isinstance(stop_sequences, str):
62+
stop_sequences = [stop_sequences]
63+
5664
groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer)
5765
self.size = len(groups)
5866
assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups."
59-
if isinstance(stop_sequences, str):
60-
stop_sequences = [stop_sequences]
61-
for group_idx in range(self.size):
62-
self.groups[group_idx].initialize(groups[group_idx], stop_sequences[group_idx])
6367

64-
def stop_sentences_to_token_ids(self, stop_sequences, tokenizer):
65-
if stop_sequences is None:
66-
stop_sequences = []
67-
else:
68-
if isinstance(stop_sequences, str):
69-
stop_sequences = [stop_sequences]
70-
71-
new_stop_sequences = []
72-
for stop_info in stop_sequences:
73-
if isinstance(stop_info, str):
74-
stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer)
75-
if stop_str_ids is not None and len(stop_str_ids) > 0:
76-
new_stop_sequences.append(stop_str_ids)
77-
if isinstance(stop_info, list):
78-
if all(isinstance(x, int) for x in stop_info):
79-
if len(stop_info) > 0:
80-
new_stop_sequences.append(stop_info)
81-
stop_sequences = new_stop_sequences
82-
return stop_sequences
83-
84-
def _stop_str_to_token_ids(self, stop_str: str, tokenizer):
68+
for group_idx in range(self.size):
69+
if isinstance(stop_sequences[group_idx], str):
70+
self.groups[group_idx].initialize(groups[group_idx], sequence_str=stop_sequences[group_idx])
71+
else:
72+
self.groups[group_idx].initialize(groups[group_idx])
73+
74+
def stop_sentences_to_token_ids(self, stop_sequences: List[Union[List[int], str]], tokenizer) -> List[List[int]]:
75+
new_stop_sequences = []
76+
for stop_info in stop_sequences:
77+
if isinstance(stop_info, str):
78+
stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer)
79+
if stop_str_ids is not None and len(stop_str_ids) > 0:
80+
new_stop_sequences.append(stop_str_ids)
81+
if isinstance(stop_info, list):
82+
if all(isinstance(x, int) for x in stop_info):
83+
if len(stop_info) > 0:
84+
new_stop_sequences.append(stop_info)
85+
else:
86+
assert False, "stop_sequences item must be type List[int] when it is a list."
87+
return new_stop_sequences
88+
89+
def _stop_str_to_token_ids(self, stop_str: str, tokenizer) -> List[int]:
8590
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
8691
return stop_str_ids
8792

88-
def to_list(self):
93+
def to_list(self) -> List[List[int]]:
8994
return [self.groups[i].to_list() for i in range(self.size)]
9095

91-
def to_string(self):
96+
def to_strings(self) -> List[str]:
9297
# 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
93-
return sorted([self.groups[i].to_string() for i in range(self.size)], key=len, reverse=True)
98+
return sorted(
99+
[self.groups[i].to_string() for i in range(self.size) if self.groups[i].sequence_str_len > 0],
100+
key=len,
101+
reverse=True,
102+
)
94103

95104

96105
class RegularConstraint(ctypes.Structure):

0 commit comments

Comments
 (0)