33import json
44import logging
55import os
6- from typing import Dict , List , Optional , Tuple , Type , TypeVar , Union
6+ from typing import TYPE_CHECKING , Any , Iterable , Optional , Type , TypeVar , Union
77
8- import transformers
9- from sqlitedict import SqliteDict
108from tqdm import tqdm
119
1210from lm_eval import utils
1311
1412
13+ if TYPE_CHECKING :
14+ from sqlitedict import SqliteDict
15+
16+ from lm_eval .api .instance import Instance
17+
18+
1519eval_logger = logging .getLogger (__name__ )
1620
1721T = TypeVar ("T" , bound = "LM" )
@@ -27,10 +31,10 @@ def __init__(self) -> None:
2731 # set rank and world size to a single process, by default.
2832 self ._rank = 0
2933 self ._world_size = 1
30- self .cache_hook = CacheHook (None )
34+ self .cache_hook : "CacheHook" = CacheHook (None )
3135
3236 @abc .abstractmethod
33- def loglikelihood (self , requests ) -> List [ Tuple [float , bool ]]:
37+ def loglikelihood (self , requests ) -> list [ tuple [float , bool ]]:
3438 """Compute log-likelihood of generating a continuation from a context.
3539 Downstream tasks should attempt to use loglikelihood instead of other
3640 LM calls whenever possible.
@@ -55,7 +59,7 @@ def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
5559 pass
5660
5761 @abc .abstractmethod
58- def loglikelihood_rolling (self , requests ) -> List [float ]:
62+ def loglikelihood_rolling (self , requests ) -> list [float ]:
5963 """Compute full log-likelihood of a string, with no truncation, for perplexity computation
6064 - We will use the full max context length of the model.
6165 - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
@@ -97,7 +101,7 @@ def loglikelihood_rolling(self, requests) -> List[float]:
97101
98102 # TODO: Add an optional max length
99103 @abc .abstractmethod
100- def generate_until (self , requests ) -> List [str ]:
104+ def generate_until (self , requests ) -> list [str ]:
101105 """Generate greedily until a stopping sequence
102106
103107 :param requests: list[Instance]
@@ -114,7 +118,7 @@ def generate_until(self, requests) -> List[str]:
114118 pass
115119
116120 def apply_chat_template (
117- self , chat_history : List [ Dict [str , str ]], add_generation_prompt = True
121+ self , chat_history : list [ dict [str , str ]], add_generation_prompt = True
118122 ) -> str :
119123 """
120124 Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
@@ -165,8 +169,7 @@ def create_from_arg_obj(
165169 - Instance of the LM class.
166170 """
167171
168- additional_config = {} if additional_config is None else additional_config
169- additional_config = {
172+ additional_config = additional_config or {} | {
170173 k : v for k , v in additional_config .items () if v is not None
171174 }
172175
@@ -204,56 +207,58 @@ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str
204207
205208 return ""
206209
207- def set_cache_hook (self , cache_hook ) -> None :
210+ def set_cache_hook (self , cache_hook : "CacheHook" ) -> None :
208211 self .cache_hook = cache_hook
209212
210213
211214### SQLite-based caching of LM responses
212- def hash_args (attr , args ) :
215+ def hash_args (attr : str , args : Iterable [ Any ]) -> str :
213216 dat = json .dumps ([attr ] + list (args ))
214217 return hashlib .sha256 (dat .encode ("utf-8" )).hexdigest ()
215218
216219
217220class CacheHook :
218- def __init__ (self , cachinglm ) -> None :
221+ def __init__ (self , cachinglm : Optional [ "CachingLM" ] ) -> None :
219222 if cachinglm is None :
220- self .dbdict = None
223+ self .dbdict : Optional [ "SqliteDict" ] = None
221224 return
222225
223226 self .dbdict = cachinglm .dbdict
224227
225- def add_partial (self , attr , req , res ) -> None :
228+ def add_partial (self , attr : str , req : Iterable [ Any ] , res : Any ) -> None :
226229 if self .dbdict is None :
227230 return
228231 hsh = hash_args (attr , req )
229232 self .dbdict [hsh ] = res
230233
231234
232235class CachingLM :
233- def __init__ (self , lm , cache_db ) -> None :
236+ def __init__ (self , lm : LM , cache_db : str ) -> None :
234237 """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
235238
236239 :param lm: LM
237240 Underlying LM
238241 :param cache_db: str
239242 Path to cache db
240243 """
241- self .lm = lm
242- self .cache_db = cache_db
244+ from sqlitedict import SqliteDict
245+
246+ self .lm : LM = lm
247+ self .cache_db : str = cache_db
243248 if os .path .dirname (cache_db ):
244249 os .makedirs (os .path .dirname (cache_db ), exist_ok = True )
245250 self .dbdict = SqliteDict (cache_db , autocommit = True )
246251
247252 # add hook to lm
248253 lm .set_cache_hook (self .get_cache_hook ())
249254
250- def __getattr__ (self , attr : str ):
255+ def __getattr__ (self , attr : str ) -> Any :
251256 lm_attr = getattr (self .lm , attr )
252257 if attr not in ["loglikelihood" , "loglikelihood_rolling" , "generate_until" ]:
253258 eval_logger .debug (f"Passing through attribute '{ attr } ' to underlying LM" )
254259 return lm_attr
255260
256- def fn (requests ) :
261+ def _fn (requests : list [ "Instance" ]) -> list [ "Instance" ] :
257262 res = []
258263 remaining_reqs = []
259264 warned = False
@@ -306,9 +311,9 @@ def fn(requests):
306311
307312 return res
308313
309- return fn
314+ return _fn
310315
311- def get_cache_hook (self ):
316+ def get_cache_hook (self ) -> "CacheHook" :
312317 return CacheHook (self )
313318
314319
@@ -331,19 +336,23 @@ def prefix_token_id(self):
331336 return self .eot_token_id
332337
333338 @abc .abstractmethod
334- def tok_encode (self , string : str , ** kwargs ) -> List [int ]:
339+ def tok_encode (self , string : str , ** kwargs ) -> list [int ]:
335340 """
336341 Tokenize a string using the model's tokenizer and return a list of token IDs.
337342 """
338343 pass
339344
340345 @abc .abstractmethod
341- def _loglikelihood_tokens (self , requests , ** kwargs ) -> List [Tuple [float , bool ]]:
346+ def _loglikelihood_tokens (
347+ self , requests : list ["Instance" ], ** kwargs
348+ ) -> list [tuple [float , bool ]]:
342349 pass
343350
344351 def _encode_pair (
345352 self , context : str , continuation : str
346- ) -> Tuple [List [int ], List [int ]]:
353+ ) -> tuple [list [int ], list [int ]]:
354+ import transformers
355+
347356 n_spaces = len (context ) - len (context .rstrip ())
348357 if n_spaces > 0 :
349358 continuation = context [- n_spaces :] + continuation
@@ -364,8 +373,8 @@ def _encode_pair(
364373 return context_enc , continuation_enc
365374
366375 def loglikelihood (
367- self , requests , disable_tqdm : bool = False
368- ) -> List [ Tuple [float , bool ]]:
376+ self , requests : list [ "Instance" ] , disable_tqdm : bool = False
377+ ) -> list [ tuple [float , bool ]]:
369378 new_reqs = []
370379 for context , continuation in [req .args for req in requests ]:
371380 if context == "" :
@@ -384,11 +393,11 @@ def loglikelihood(
384393 @abc .abstractmethod
385394 def loglikelihood_rolling (
386395 self , requests , disable_tqdm : bool = False
387- ) -> List [float ]:
396+ ) -> list [float ]:
388397 pass
389398
390399 @abc .abstractmethod
391- def generate_until (self , requests , disable_tqdm : bool = False ) -> List [str ]:
400+ def generate_until (self , requests , disable_tqdm : bool = False ) -> list [str ]:
392401 pass
393402
394403 def chat_template (self , chat_template : Union [bool , str ] = False ) -> Optional [str ]:
0 commit comments