diff --git a/paddlenlp/prompt/prompt_tokenizer.py b/paddlenlp/prompt/prompt_tokenizer.py index 8e41162c5ab6..1e50c92c74dd 100644 --- a/paddlenlp/prompt/prompt_tokenizer.py +++ b/paddlenlp/prompt/prompt_tokenizer.py @@ -129,17 +129,15 @@ def _create_max_lengths_from_do_truncate(self, part_text: List[str], part_do_tru text_length = sum([len(x) for x in part_text]) num_special_token = self.tokenizer.num_special_tokens_to_add() max_length = self.max_length - num_special_token + max_lengths = [len(part) for part in part_text] if text_length <= max_length: - return [None] * len(part_text) - max_lengths = [None for _ in range(len(part_text))] + return max_lengths do_truncate = [int(x) for x in part_do_truncate] # Remove parts that can not be truncated. for index, part in enumerate(part_text): if not part_do_truncate[index]: max_length -= len(part) - else: - max_lengths[index] = len(part) if sum(do_truncate) == 0: logger.warning( f"Can not truncate the sequence with length {text_length}. Set more `truncate` attributes as True." @@ -154,7 +152,6 @@ def _create_max_lengths_from_do_truncate(self, part_text: List[str], part_do_tru for index, part in enumerate(part_text): if do_truncate[index] == 1 and len(part) <= avg_max_length: do_truncate[index] = 0 - max_lengths[index] = len(part) max_length -= len(part) has_short = True if max_length < 0: