Skip to content

Commit c9f46bd

Browse files
committed
vectorize OutputNumbersTokens logit warper
1 parent 4fe34dc commit c9f46bd

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

jsonformer/logits_processors.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from transformers import PreTrainedTokenizer, LogitsWarper, StoppingCriteria
33
import torch
44

5-
65
class StringStoppingCriteria(StoppingCriteria):
76
def __init__(self, tokenizer: PreTrainedTokenizer, prompt_length: int):
87
self.tokenizer = tokenizer
@@ -62,27 +61,24 @@ def __call__(
6261

6362
return False
6463

65-
6664
class OutputNumbersTokens(LogitsWarper):
6765
def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str):
68-
self.whitelist_tokens = []
69-
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")
7066
self.tokenizer = tokenizer
67+
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")
68+
vocab_size = len(tokenizer)
69+
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool)
7170

7271
for _, token_id in tokenizer.get_vocab().items():
73-
token_str = tokenizer.decode(token_id)
74-
token_str = token_str.strip()
72+
token_str = tokenizer.decode(token_id).strip()
7573

7674
if token_str == "" or (
7775
all(c.isdigit() or c == "." for c in token_str)
7876
and token_str.count(".") <= 1
7977
):
80-
self.whitelist_tokens.append(token_id)
78+
self.allowed_mask[token_id] = True
8179

82-
def __call__(self, input_ids, scores):
83-
input_ids = input_ids[:, len(self.tokenized_prompt["input_ids"][0]) :]
80+
def __call__(self, _, scores):
81+
mask = self.allowed_mask.expand_as(scores)
82+
scores[~mask] = -float("inf")
8483

85-
scores[
86-
:, [i for i in range(len(scores[0])) if i not in self.whitelist_tokens]
87-
] = -float("inf")
8884
return scores

0 commit comments

Comments
 (0)