Skip to content

Commit e90a7d4

Browse files
authored
Perplexity Speedup (#4108)
* perplexity update checkpoint * perplexity speedup code checkpoint * enable batching for faster calculations * assert input strings are long enough * clean up code for readability * fix vocab size, style * update examples * fix padding token issue, which fixes output values * update perplexity examples * fix ppl examples * add example caching fix * update ppl examples * suppress warnings in examples * edit ppl example * edit ppl example * edit ppl example * edit ppl example * fix example output * remove perplexity testing script
1 parent cb6e8e7 commit e90a7d4

File tree

1 file changed

+83
-46
lines changed

1 file changed

+83
-46
lines changed

metrics/perplexity/perplexity.py

Lines changed: 83 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414
"""Perplexity Metric."""
1515

16+
import numpy as np
1617
import torch
18+
from torch.nn import CrossEntropyLoss
1719
from transformers import AutoModelForCausalLM, AutoTokenizer
1820

1921
import datasets
@@ -53,23 +55,30 @@
5355
>>> perplexity = datasets.load_metric("perplexity")
5456
>>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
5557
>>> results = perplexity.compute(model_id='gpt2',
56-
... input_texts=input_texts,
57-
... stride=1)
58-
>>> round(results["perplexity"], 1)
59-
78.2
58+
... add_start_token=False,
59+
... input_texts=input_texts) # doctest:+ELLIPSIS
60+
>>> print(list(results.keys()))
61+
['perplexities', 'mean_perplexity']
62+
>>> print(round(results["mean_perplexity"], 2))
63+
78.22
64+
>>> print(round(results["perplexities"][0], 2))
65+
11.11
6066
6167
Example 2:
6268
>>> perplexity = datasets.load_metric("perplexity")
6369
>>> input_texts = datasets.load_dataset("wikitext",
6470
... "wikitext-2-raw-v1",
65-
... split="test")["text"][:10] # doctest:+ELLIPSIS
71+
... split="test")["text"][:50] # doctest:+ELLIPSIS
6672
[...]
73+
>>> input_texts = [s for s in input_texts if s!='']
6774
>>> results = perplexity.compute(model_id='gpt2',
68-
... input_texts=input_texts,
69-
... stride=256)
70-
>>> round(results["perplexity"], 1)
71-
117.9
72-
75+
... input_texts=input_texts) # doctest:+ELLIPSIS
76+
>>> print(list(results.keys()))
77+
['perplexities', 'mean_perplexity']
78+
>>> print(round(results["mean_perplexity"], 2))
79+
1977.55
80+
>>> print(round(results["perplexities"][0], 2))
81+
1349.56
7382
"""
7483

7584

@@ -88,7 +97,7 @@ def _info(self):
8897
reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
8998
)
9099

91-
def _compute(self, input_texts, model_id, stride=512, device=None):
100+
def _compute(self, input_texts, model_id, batch_size: int = 16, add_start_token: bool = True, device=None):
92101

93102
if device is not None:
94103
assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
@@ -100,51 +109,79 @@ def _compute(self, input_texts, model_id, stride=512, device=None):
100109
model = AutoModelForCausalLM.from_pretrained(model_id)
101110
model = model.to(device)
102111

103-
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>")
104-
105-
encodings = tokenizer(input_texts, padding=True, return_tensors="pt", return_special_tokens_mask=True).to(
106-
device
107-
)
112+
tokenizer = AutoTokenizer.from_pretrained(model_id)
113+
114+
# if batch_size > 1 (which generally leads to padding being required), and
115+
# if there is not an already assigned pad_token, assign an existing
116+
# special token to also be the padding token
117+
if tokenizer.pad_token is None and batch_size > 1:
118+
existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
119+
# check that the model already has at least one special token defined
120+
assert (
121+
len(existing_special_tokens) > 0
122+
), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
123+
# assign one of the special tokens to also be the pad token
124+
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
125+
126+
if add_start_token:
127+
# leave room for <BOS> token to be added:
128+
assert (
129+
tokenizer.bos_token is not None
130+
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
131+
max_tokenized_len = model.config.max_length - 1
132+
else:
133+
max_tokenized_len = model.config.max_length
134+
135+
encodings = tokenizer(
136+
input_texts,
137+
add_special_tokens=False,
138+
padding=True,
139+
truncation=True,
140+
max_length=max_tokenized_len,
141+
return_tensors="pt",
142+
return_attention_mask=True,
143+
).to(device)
108144

109145
encoded_texts = encodings["input_ids"]
110-
special_tokens_masks = encodings["special_tokens_mask"]
146+
attn_masks = encodings["attention_mask"]
111147

112-
max_model_length = model.config.n_positions
148+
# check that each input is long enough:
149+
if add_start_token:
150+
assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
151+
else:
152+
assert torch.all(
153+
torch.ge(attn_masks.sum(1), 2)
154+
), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
113155

114156
ppls = []
157+
loss_fct = CrossEntropyLoss(reduction="none")
115158

116-
for text_index in logging.tqdm(range(0, len(encoded_texts))):
117-
encoded_text = encoded_texts[text_index]
118-
special_tokens_mask = special_tokens_masks[text_index]
119-
120-
encoded_text_length = len(encoded_text) - special_tokens_mask.sum()
121-
122-
nlls = []
123-
124-
target_index = max(1, min(stride - 1, encoded_text_length - 1))
125-
126-
while target_index < encoded_text_length:
127-
start_index = max(0, target_index - (max_model_length - 1))
128-
129-
input_ids = encoded_text[start_index : target_index + 1]
130-
131-
target_ids = input_ids.clone()
132-
target_ids[:-1] = -100
159+
for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
160+
end_index = min(start_index + batch_size, len(encoded_texts))
161+
encoded_batch = encoded_texts[start_index:end_index]
162+
attn_mask = attn_masks[start_index:end_index]
133163

134-
attn_mask = torch.ones(len(input_ids)).to(device)
135-
attn_mask[-1] = 0
164+
if add_start_token:
165+
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
166+
encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
167+
attn_mask = torch.cat(
168+
[torch.zeros(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
169+
)
136170

137-
with torch.no_grad():
138-
outputs = model(input_ids, labels=target_ids, attention_mask=attn_mask)
139-
neg_log_likelihood = outputs[0]
171+
labels = encoded_batch
140172

141-
nlls.append(neg_log_likelihood)
173+
with torch.no_grad():
174+
out_logits = model(encoded_batch, attention_mask=attn_mask).logits
142175

143-
target_index += stride
176+
shift_logits = out_logits[..., :-1, :].contiguous()
177+
shift_labels = labels[..., 1:].contiguous()
178+
shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
144179

145-
if len(nlls) > 0:
146-
ppls.append(torch.exp2(torch.mean(torch.stack(nlls))))
180+
perplexity_batch = torch.exp2(
181+
(loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
182+
/ shift_attention_mask_batch.sum(1)
183+
)
147184

148-
ppl = torch.mean(torch.stack(ppls))
185+
ppls += perplexity_batch.tolist()
149186

150-
return {"perplexity": float(ppl)}
187+
return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

0 commit comments

Comments
 (0)