11import copy
22import logging
3- from typing import Dict , List , Optional , Tuple , Union
43
54import torch
65import torch .nn .functional as F
1514 Collator ,
1615 flatten_image_list ,
1716 handle_stop_sequences ,
18- pad_and_concat ,
1917 replace_placeholders ,
2018 resize_image ,
21- stop_sequences_criteria ,
2219)
20+ from lm_eval .models .utils_hf import pad_and_concat , stop_sequences_criteria
2321
2422
2523DEFAULT_IMAGE_PLACEHOLDER = "<image>"
@@ -39,19 +37,19 @@ class HFMultimodalLM(HFLM):
3937
4038 def __init__ (
4139 self ,
42- pretrained : Union [ str , transformers .PreTrainedModel ] ,
43- image_token_id : Optional [ int ] = None ,
44- image_string : Optional [ str ] = None ,
40+ pretrained : str | transformers .PreTrainedModel ,
41+ image_token_id : int | None = None ,
42+ image_string : str | None = None ,
4543 interleave : bool = True ,
4644 # TODO: handle whitespace in image placeholder (replacement)
47- max_images : Optional [ int ] = 999 ,
45+ max_images : int | None = 999 ,
4846 convert_img_format = False ,
4947 # For image resizing
50- min_pixels : Optional [ int ] = None ,
51- max_pixels : Optional [ int ] = None ,
52- image_width : Optional [ int ] = None ,
53- image_height : Optional [ int ] = None ,
54- image_max_side : Optional [ int ] = None ,
48+ min_pixels : int | None = None ,
49+ max_pixels : int | None = None ,
50+ image_width : int | None = None ,
51+ image_height : int | None = None ,
52+ image_max_side : int | None = None ,
5553 ** kwargs ,
5654 ):
5755 self .image_width = image_width
@@ -113,15 +111,10 @@ def __init__(
113111
114112 def _create_tokenizer (
115113 self ,
116- pretrained : Union [str , transformers .PreTrainedModel ],
117- tokenizer : Optional [
118- Union [
119- str ,
120- transformers .ProcessorMixin ,
121- ]
122- ],
123- revision : Optional [str ] = "main" ,
124- trust_remote_code : Optional [bool ] = False ,
114+ pretrained : str | transformers .PreTrainedModel ,
115+ tokenizer : str | transformers .ProcessorMixin | None ,
116+ revision : str | None = "main" ,
117+ trust_remote_code : bool | None = False ,
125118 ** kwargs ,
126119 ) -> None :
127120 """
@@ -223,7 +216,7 @@ def _encode_multimodal_pair(self, context, continuation, images):
223216 return context_enc , continuation_enc , image_enc
224217
225218 def apply_chat_template (
226- self , chat_history : List [ Dict [str , str ]], add_generation_prompt : bool = True
219+ self , chat_history : list [ dict [str , str ]], add_generation_prompt : bool = True
227220 ) -> str :
228221 self .chat_applied = True
229222 if not self .interleave :
@@ -279,7 +272,7 @@ def apply_chat_template(
279272 continue_final_message = not add_generation_prompt ,
280273 )
281274
282- def chat_template (self , chat_template : Union [ bool , str ] = False ) -> Optional [ str ] :
275+ def chat_template (self , chat_template : bool | str = False ) -> str | None :
283276 if hasattr (self .processor , "apply_chat_template" ):
284277 _tokenizer = self .tokenizer
285278 self .tokenizer = self .processor
@@ -293,14 +286,14 @@ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str
293286
294287 def tok_batch_multimodal_encode (
295288 self ,
296- strings : List [str ], # note that input signature of this fn is different
297- images : List [ List ], # TODO: images are pil.Image at the moment, update typehint
289+ strings : list [str ], # note that input signature of this fn is different
290+ images : list [ list ], # TODO: images are pil.Image at the moment, update typehint
298291 padding_side : str = "left" ,
299292 left_truncate_len : int = None ,
300293 truncation : bool = False ,
301- ) -> Union [
302- BatchEncoding , Dict [str , torch .Tensor ]
303- ] : # note that this return signature differs from HFLM tok_batch_encode.
294+ ) -> (
295+ BatchEncoding | dict [str , torch .Tensor ]
296+ ) : # note that this return signature differs from HFLM tok_batch_encode.
304297 # NOTE: here, we replace <image> tags with our model's corresponding image_token string value.
305298 if not self .chat_applied :
306299 # TODO<baber>: This still keeps the whitespace in the image placeholder, which is not ideal.
@@ -356,7 +349,7 @@ def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None):
356349
357350 def _model_multimodal_generate (self , inputs , max_length , stop , ** generation_kwargs ):
358351 generation_kwargs ["temperature" ] = generation_kwargs .get ("temperature" , 0.0 )
359- do_sample = generation_kwargs .get ("do_sample" , None )
352+ do_sample = generation_kwargs .get ("do_sample" )
360353
361354 # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
362355 if generation_kwargs .get ("temperature" ) == 0.0 and do_sample is None :
@@ -398,7 +391,7 @@ def _batch_images(self, image_encs):
398391 )
399392 return batched_imgs
400393
401- def loglikelihood_rolling (self , requests : List [Instance ]) -> List [float ]:
394+ def loglikelihood_rolling (self , requests : list [Instance ]) -> list [float ]:
402395 if requests and len (requests [0 ].args ) < 3 :
403396 # Fall back to non-multimodal generation.
404397 return super ().loglikelihood_rolling (requests = requests )
@@ -408,8 +401,8 @@ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
408401 )
409402
410403 def loglikelihood (
411- self , requests : List [Instance ], disable_tqdm : bool = False
412- ) -> List [ Tuple [float , bool ]]:
404+ self , requests : list [Instance ], disable_tqdm : bool = False
405+ ) -> list [ tuple [float , bool ]]:
413406 if requests and len (requests [0 ].args ) < 3 :
414407 # Fall back to non-multimodal generation.
415408 return super ().loglikelihood (requests = requests , disable_tqdm = disable_tqdm )
@@ -445,16 +438,16 @@ def loglikelihood(
445438
446439 def _multimodal_loglikelihood_tokens (
447440 self ,
448- requests : List [
449- Tuple [ Tuple [None , str , str ], List [int ], List [int ], List [int ]]
441+ requests : list [
442+ tuple [ tuple [None , str , str ], list [int ], list [int ], list [int ]]
450443 ], # TODO: update typehint to be correct
451444 disable_tqdm : bool = False ,
452445 override_bs : int = None ,
453- ) -> List [ Tuple [float , bool ]]:
446+ ) -> list [ tuple [float , bool ]]:
454447 res = []
455448
456449 # TODO: **improve multimodal collation.** We currently ignore image size when ordering docs. ideally we'd take them into account
457- def _collate (req : Tuple [ Tuple [str , str ], List [int ], List [int ]]):
450+ def _collate (req : tuple [ tuple [str , str ], list [int ], list [int ]]):
458451 """Defines the key for the sorted method"""
459452 # the negative sign on len(toks) sorts descending - this has a few advantages:
460453 # - time estimates will always be over not underestimates, which is more useful for planning
@@ -465,7 +458,7 @@ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
465458 toks = req [1 ] + req [2 ]
466459 return - len (toks ), tuple (toks )
467460
468- def _lookup_one_token_cont (req : Tuple [ Tuple [str , str ], List [int ], List [int ]]):
461+ def _lookup_one_token_cont (req : tuple [ tuple [str , str ], list [int ], list [int ]]):
469462 """Defines the key to group and lookup one-token continuations"""
470463 # Use with group_by="contexts" (optional)"
471464 # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
@@ -477,7 +470,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
477470 requests ,
478471 sort_fn = _collate ,
479472 group_by = "contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
480- if self .AUTO_MODEL_CLASS == transformers .AutoModelForCausalLM
473+ if self .AUTO_MODEL_CLASS == transformers .AutoModelForCausalLM # noqa: SIM300
481474 and self .logits_cache
482475 else None ,
483476 group_fn = _lookup_one_token_cont ,
@@ -572,9 +565,9 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
572565 request_str ,
573566 ctx_tokens ,
574567 _ ,
575- image_encs ,
568+ _image_encs ,
576569 ), logits , inplen , cont_toks in zip (
577- chunk , multi_logits , inplens , cont_toks_list
570+ chunk , multi_logits , inplens , cont_toks_list , strict = False
578571 ):
579572 # Slice to original seq length
580573 contlen = len (cont_toks )
@@ -584,7 +577,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
584577 # from prompt/prefix tuning tokens, if applicable
585578 ctx_len = (
586579 inplen + (logits .shape [0 ] - padding_len_inp )
587- if self .AUTO_MODEL_CLASS == transformers .AutoModelForCausalLM
580+ if self .AUTO_MODEL_CLASS == transformers .AutoModelForCausalLM # noqa: SIM300
588581 else None
589582 )
590583 logits = self ._select_cont_toks (logits , contlen = contlen , inplen = ctx_len )
@@ -598,30 +591,30 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
598591 # original args. Otherwise, expands the logits batch dimension and yields each
599592 # batch along with matching continuation tokens and prompt strings.
600593 # logits -> [1, seq, vocab]
601- for request_str , cont_toks , logits in re_ord .get_cache (
594+ for _request_str , _cont_toks , _logits in re_ord .get_cache (
602595 req_str = request_str ,
603596 cxt_toks = ctx_tokens ,
604597 cont_toks = cont_toks ,
605598 logits = logits ,
606599 ):
607- cont_toks = torch .tensor (
608- cont_toks , dtype = torch .long , device = self .device
600+ _cont_toks = torch .tensor (
601+ _cont_toks , dtype = torch .long , device = self .device
609602 ).unsqueeze (0 ) # [1, seq]
610- max_equal = (greedy_tokens == cont_toks ).all ()
603+ max_equal = (greedy_tokens == _cont_toks ).all ()
611604
612605 # Obtain log-probs at the corresponding continuation token indices
613606 # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
614- logits = torch .gather ( logits , 2 , cont_toks . unsqueeze ( - 1 )). squeeze (
615- - 1
616- ) # [1, seq]
607+ _logits = torch .gather (
608+ _logits , 2 , _cont_toks . unsqueeze ( - 1 )
609+ ). squeeze ( - 1 ) # [1, seq]
617610
618611 # Answer: (log prob, is-exact-match)
619- answer = (float (logits .sum ()), bool (max_equal ))
612+ answer = (float (_logits .sum ()), bool (max_equal ))
620613
621614 res .append (answer )
622615
623616 self .cache_hook .add_partial (
624- "loglikelihood" , request_str , answer
617+ "loglikelihood" , _request_str , answer
625618 ) # TODO: choose convention for adding images into the cache key
626619 pbar .update (1 )
627620
@@ -630,8 +623,8 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
630623 return re_ord .get_original (res )
631624
632625 def generate_until (
633- self , requests : List [Instance ], disable_tqdm : bool = False
634- ) -> List [str ]:
626+ self , requests : list [Instance ], disable_tqdm : bool = False
627+ ) -> list [str ]:
635628 if requests and len (requests [0 ].args ) < 3 :
636629 # Fall back to non-multimodal generation.
637630 return super ().generate_until (requests = requests , disable_tqdm = disable_tqdm )
@@ -669,7 +662,7 @@ def _collate(x):
669662 ### Up to here: was identical to non-multimodal HFLM generate_until ###
670663 eos = self .tok_decode (self .eot_token_id , skip_special_tokens = False )
671664 for chunk in chunks :
672- contexts , all_gen_kwargs , aux_arguments = zip (* chunk )
665+ contexts , all_gen_kwargs , aux_arguments = zip (* chunk , strict = False )
673666
674667 visuals = [
675668 [
@@ -732,7 +725,7 @@ def _collate(x):
732725 ### essentially same as HFLM beyond this line!
733726
734727 cont_toks_list = cont .tolist ()
735- for cont_toks , context in zip (cont_toks_list , contexts ):
728+ for cont_toks , context in zip (cont_toks_list , contexts , strict = False ):
736729 # discard context + left-padding toks if using causal decoder-only VLM
737730 cont_toks = cont_toks [context_enc .shape [1 ] :]
738731
0 commit comments