|
2 | 2 | from transformers import PreTrainedTokenizer, LogitsWarper, StoppingCriteria |
3 | 3 | import torch |
4 | 4 |
|
5 | | - |
6 | 5 | class StringStoppingCriteria(StoppingCriteria): |
7 | 6 | def __init__(self, tokenizer: PreTrainedTokenizer, prompt_length: int): |
8 | 7 | self.tokenizer = tokenizer |
@@ -62,27 +61,24 @@ def __call__( |
62 | 61 |
|
63 | 62 | return False |
64 | 63 |
|
65 | | - |
66 | 64 | class OutputNumbersTokens(LogitsWarper): |
67 | 65 | def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): |
68 | | - self.whitelist_tokens = [] |
69 | | - self.tokenized_prompt = tokenizer(prompt, return_tensors="pt") |
70 | 66 | self.tokenizer = tokenizer |
| 67 | + self.tokenized_prompt = tokenizer(prompt, return_tensors="pt") |
| 68 | + vocab_size = len(tokenizer) |
| 69 | + self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool) |
71 | 70 |
|
72 | 71 | for _, token_id in tokenizer.get_vocab().items(): |
73 | | - token_str = tokenizer.decode(token_id) |
74 | | - token_str = token_str.strip() |
| 72 | + token_str = tokenizer.decode(token_id).strip() |
75 | 73 |
|
76 | 74 | if token_str == "" or ( |
77 | 75 | all(c.isdigit() or c == "." for c in token_str) |
78 | 76 | and token_str.count(".") <= 1 |
79 | 77 | ): |
80 | | - self.whitelist_tokens.append(token_id) |
| 78 | + self.allowed_mask[token_id] = True |
81 | 79 |
|
82 | | - def __call__(self, input_ids, scores): |
83 | | - input_ids = input_ids[:, len(self.tokenized_prompt["input_ids"][0]) :] |
| 80 | + def __call__(self, _, scores): |
| 81 | + mask = self.allowed_mask.expand_as(scores) |
| 82 | + scores[~mask] = -float("inf") |
84 | 83 |
|
85 | | - scores[ |
86 | | - :, [i for i in range(len(scores[0])) if i not in self.whitelist_tokens] |
87 | | - ] = -float("inf") |
88 | 84 | return scores |
0 commit comments