1313# limitations under the License.
1414"""Perplexity Metric."""
1515
16+ import numpy as np
1617import torch
18+ from torch .nn import CrossEntropyLoss
1719from transformers import AutoModelForCausalLM , AutoTokenizer
1820
1921import datasets
5355 >>> perplexity = datasets.load_metric("perplexity")
5456 >>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
5557 >>> results = perplexity.compute(model_id='gpt2',
56- ... input_texts=input_texts,
57- ... stride=1)
58- >>> round(results["perplexity"], 1)
59- 78.2
58+ ... add_start_token=False,
59+ ... input_texts=input_texts) # doctest:+ELLIPSIS
60+ >>> print(list(results.keys()))
61+ ['perplexities', 'mean_perplexity']
62+ >>> print(round(results["mean_perplexity"], 2))
63+ 78.22
64+ >>> print(round(results["perplexities"][0], 2))
65+ 11.11
6066
6167 Example 2:
6268 >>> perplexity = datasets.load_metric("perplexity")
6369 >>> input_texts = datasets.load_dataset("wikitext",
6470 ... "wikitext-2-raw-v1",
65- ... split="test")["text"][:10 ] # doctest:+ELLIPSIS
71+ ... split="test")["text"][:50 ] # doctest:+ELLIPSIS
6672 [...]
73+ >>> input_texts = [s for s in input_texts if s!='']
6774 >>> results = perplexity.compute(model_id='gpt2',
68- ... input_texts=input_texts,
69- ... stride=256)
70- >>> round(results["perplexity"], 1)
71- 117.9
72-
75+ ... input_texts=input_texts) # doctest:+ELLIPSIS
76+ >>> print(list(results.keys()))
77+ ['perplexities', 'mean_perplexity']
78+ >>> print(round(results["mean_perplexity"], 2))
79+ 1977.55
80+ >>> print(round(results["perplexities"][0], 2))
81+ 1349.56
7382"""
7483
7584
@@ -88,7 +97,7 @@ def _info(self):
8897 reference_urls = ["https://huggingface.co/docs/transformers/perplexity" ],
8998 )
9099
91- def _compute (self , input_texts , model_id , stride = 512 , device = None ):
100+ def _compute (self , input_texts , model_id , batch_size : int = 16 , add_start_token : bool = True , device = None ):
92101
93102 if device is not None :
94103 assert device in ["gpu" , "cpu" , "cuda" ], "device should be either gpu or cpu."
@@ -100,51 +109,79 @@ def _compute(self, input_texts, model_id, stride=512, device=None):
100109 model = AutoModelForCausalLM .from_pretrained (model_id )
101110 model = model .to (device )
102111
103- tokenizer = AutoTokenizer .from_pretrained (model_id , pad_token = "<PAD>" )
104-
105- encodings = tokenizer (input_texts , padding = True , return_tensors = "pt" , return_special_tokens_mask = True ).to (
106- device
107- )
112+ tokenizer = AutoTokenizer .from_pretrained (model_id )
113+
114+ # if batch_size > 1 (which generally leads to padding being required), and
115+ # if there is not an already assigned pad_token, assign an existing
116+ # special token to also be the padding token
117+ if tokenizer .pad_token is None and batch_size > 1 :
118+ existing_special_tokens = list (tokenizer .special_tokens_map_extended .values ())
119+ # check that the model already has at least one special token defined
120+ assert (
121+ len (existing_special_tokens ) > 0
122+ ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
123+ # assign one of the special tokens to also be the pad token
124+ tokenizer .add_special_tokens ({"pad_token" : existing_special_tokens [0 ]})
125+
126+ if add_start_token :
127+ # leave room for <BOS> token to be added:
128+ assert (
129+ tokenizer .bos_token is not None
130+ ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
131+ max_tokenized_len = model .config .max_length - 1
132+ else :
133+ max_tokenized_len = model .config .max_length
134+
135+ encodings = tokenizer (
136+ input_texts ,
137+ add_special_tokens = False ,
138+ padding = True ,
139+ truncation = True ,
140+ max_length = max_tokenized_len ,
141+ return_tensors = "pt" ,
142+ return_attention_mask = True ,
143+ ).to (device )
108144
109145 encoded_texts = encodings ["input_ids" ]
110- special_tokens_masks = encodings ["special_tokens_mask " ]
146+ attn_masks = encodings ["attention_mask " ]
111147
112- max_model_length = model .config .n_positions
148+ # check that each input is long enough:
149+ if add_start_token :
150+ assert torch .all (torch .ge (attn_masks .sum (1 ), 1 )), "Each input text must be at least one token long."
151+ else :
152+ assert torch .all (
153+ torch .ge (attn_masks .sum (1 ), 2 )
154+ ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
113155
114156 ppls = []
157+ loss_fct = CrossEntropyLoss (reduction = "none" )
115158
116- for text_index in logging .tqdm (range (0 , len (encoded_texts ))):
117- encoded_text = encoded_texts [text_index ]
118- special_tokens_mask = special_tokens_masks [text_index ]
119-
120- encoded_text_length = len (encoded_text ) - special_tokens_mask .sum ()
121-
122- nlls = []
123-
124- target_index = max (1 , min (stride - 1 , encoded_text_length - 1 ))
125-
126- while target_index < encoded_text_length :
127- start_index = max (0 , target_index - (max_model_length - 1 ))
128-
129- input_ids = encoded_text [start_index : target_index + 1 ]
130-
131- target_ids = input_ids .clone ()
132- target_ids [:- 1 ] = - 100
159+ for start_index in logging .tqdm (range (0 , len (encoded_texts ), batch_size )):
160+ end_index = min (start_index + batch_size , len (encoded_texts ))
161+ encoded_batch = encoded_texts [start_index :end_index ]
162+ attn_mask = attn_masks [start_index :end_index ]
133163
134- attn_mask = torch .ones (len (input_ids )).to (device )
135- attn_mask [- 1 ] = 0
164+ if add_start_token :
165+ bos_tokens_tensor = torch .tensor ([[tokenizer .bos_token_id ]] * encoded_batch .size (dim = 0 )).to (device )
166+ encoded_batch = torch .cat ([bos_tokens_tensor , encoded_batch ], dim = 1 )
167+ attn_mask = torch .cat (
168+ [torch .zeros (bos_tokens_tensor .size (), dtype = torch .int64 ).to (device ), attn_mask ], dim = 1
169+ )
136170
137- with torch .no_grad ():
138- outputs = model (input_ids , labels = target_ids , attention_mask = attn_mask )
139- neg_log_likelihood = outputs [0 ]
171+ labels = encoded_batch
140172
141- nlls .append (neg_log_likelihood )
173+ with torch .no_grad ():
174+ out_logits = model (encoded_batch , attention_mask = attn_mask ).logits
142175
143- target_index += stride
176+ shift_logits = out_logits [..., :- 1 , :].contiguous ()
177+ shift_labels = labels [..., 1 :].contiguous ()
178+ shift_attention_mask_batch = attn_mask [..., 1 :].contiguous ()
144179
145- if len (nlls ) > 0 :
146- ppls .append (torch .exp2 (torch .mean (torch .stack (nlls ))))
180+ perplexity_batch = torch .exp2 (
181+ (loss_fct (shift_logits .transpose (1 , 2 ), shift_labels ) * shift_attention_mask_batch ).sum (1 )
182+ / shift_attention_mask_batch .sum (1 )
183+ )
147184
148- ppl = torch . mean ( torch . stack ( ppls ) )
185+ ppls += perplexity_batch . tolist ( )
149186
150- return {"perplexity " : float ( ppl )}
187+ return {"perplexities " : ppls , "mean_perplexity" : np . mean ( ppls )}
0 commit comments