Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions paddlenlp/prompt/prompt_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,15 @@ def _create_max_lengths_from_do_truncate(self, part_text: List[str], part_do_tru
text_length = sum([len(x) for x in part_text])
num_special_token = self.tokenizer.num_special_tokens_to_add()
max_length = self.max_length - num_special_token
max_lengths = [len(part) for part in part_text]
if text_length <= max_length:
return [None] * len(part_text)
max_lengths = [None for _ in range(len(part_text))]
return max_lengths
do_truncate = [int(x) for x in part_do_truncate]

# Remove parts that can not be truncated.
for index, part in enumerate(part_text):
if not part_do_truncate[index]:
max_length -= len(part)
else:
max_lengths[index] = len(part)
if sum(do_truncate) == 0:
logger.warning(
f"Can not truncate the sequence with length {text_length}. Set more `truncate` attributes as True."
Expand All @@ -154,7 +152,6 @@ def _create_max_lengths_from_do_truncate(self, part_text: List[str], part_do_tru
for index, part in enumerate(part_text):
if do_truncate[index] == 1 and len(part) <= avg_max_length:
do_truncate[index] = 0
max_lengths[index] = len(part)
max_length -= len(part)
has_short = True
if max_length < 0:
Expand Down