Skip to content

Commit 59b3ba2

Browse files
authored
move torch/hf utils to utils_hf.py (#3450)
1 parent c3c7f04 commit 59b3ba2

File tree

9 files changed

+314
-324
lines changed

9 files changed

+314
-324
lines changed

lm_eval/models/hf_audiolm.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
from typing import Dict, List, Optional, Tuple, Union
32

43
import torch
54
import transformers
@@ -12,8 +11,8 @@
1211
from lm_eval.models.utils import (
1312
Collator,
1413
replace_placeholders,
15-
stop_sequences_criteria,
1614
)
15+
from lm_eval.models.utils_hf import stop_sequences_criteria
1716

1817

1918
DEFAULT_AUDIO_PLACEHOLDERS = ["<audio>"]
@@ -30,8 +29,8 @@ class HFAUDIOLMQWEN(HFLM):
3029

3130
def __init__(
3231
self,
33-
pretrained: Union[str, transformers.PreTrainedModel],
34-
max_audios: Optional[int] = 5,
32+
pretrained: str | transformers.PreTrainedModel,
33+
max_audios: int | None = 5,
3534
**kwargs,
3635
):
3736
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
@@ -42,15 +41,10 @@ def __init__(
4241

4342
def _create_tokenizer(
4443
self,
45-
pretrained: Union[str, transformers.PreTrainedModel],
46-
tokenizer: Optional[
47-
Union[
48-
str,
49-
transformers.ProcessorMixin,
50-
]
51-
],
52-
revision: Optional[str] = "main",
53-
trust_remote_code: Optional[bool] = False,
44+
pretrained: str | transformers.PreTrainedModel,
45+
tokenizer: str | transformers.ProcessorMixin | None,
46+
revision: str | None = "main",
47+
trust_remote_code: bool | None = False,
5448
**kwargs,
5549
) -> None:
5650
"""
@@ -89,7 +83,7 @@ def _create_tokenizer(
8983
self.tokenizer = self.processor.tokenizer
9084

9185
def apply_chat_template(
92-
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
86+
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
9387
) -> str:
9488
"""
9589
Method to apply a chat template to a list of chat history between user and model.
@@ -103,7 +97,7 @@ def apply_chat_template(
10397

10498
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
10599
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
106-
do_sample = generation_kwargs.get("do_sample", None)
100+
do_sample = generation_kwargs.get("do_sample")
107101

108102
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
109103
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
@@ -129,14 +123,14 @@ def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwar
129123

130124
def tok_batch_multimodal_encode(
131125
self,
132-
strings: List[str], # note that input signature of this fn is different
133-
audios: List[List],
126+
strings: list[str], # note that input signature of this fn is different
127+
audios: list[list],
134128
padding_side: str = "left",
135129
left_truncate_len: int = None,
136130
truncation: bool = False,
137-
) -> Union[
138-
BatchEncoding, Dict[str, torch.Tensor]
139-
]: # note that this return signature differs from HFLM tok_batch_encode.
131+
) -> (
132+
BatchEncoding | dict[str, torch.Tensor]
133+
): # note that this return signature differs from HFLM tok_batch_encode.
140134
# NOTE: here, we replace <audio> tags with our model's corresponding image_token string value.
141135
def _replace_placeholder(placeholder, strings):
142136
return [
@@ -169,8 +163,8 @@ def _replace_placeholder(placeholder, strings):
169163
return encoding
170164

171165
def generate_until(
172-
self, requests: List[Instance], disable_tqdm: bool = False
173-
) -> List[str]:
166+
self, requests: list[Instance], disable_tqdm: bool = False
167+
) -> list[str]:
174168
res = []
175169

176170
def _collate(x):
@@ -204,7 +198,7 @@ def _collate(x):
204198
### Up to here: was identical to non-multimodal HFLM generate_until ###
205199

206200
for chunk in chunks:
207-
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
201+
contexts, all_gen_kwargs, aux_arguments = zip(*chunk, strict=False)
208202

209203
audios = []
210204
for audio_lst_dict in aux_arguments:
@@ -276,7 +270,7 @@ def _collate(x):
276270
### essentially same as HFLM beyond this line!
277271

278272
cont_toks_list = cont.tolist()
279-
for cont_toks, context in zip(cont_toks_list, contexts):
273+
for cont_toks, context in zip(cont_toks_list, contexts, strict=False):
280274
# discard context + left-padding toks if using causal decoder-only VLM
281275
cont_toks = cont_toks[context_enc.shape[1] :]
282276

@@ -293,15 +287,15 @@ def _collate(x):
293287
pbar.close()
294288
return res
295289

296-
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
290+
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
297291
raise NotImplementedError(
298292
"model type `hf-audiolm` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
299293
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
300294
)
301295

302296
def loglikelihood(
303-
self, requests: List[Instance], disable_tqdm: bool = False
304-
) -> List[Tuple[float, bool]]:
297+
self, requests: list[Instance], disable_tqdm: bool = False
298+
) -> list[tuple[float, bool]]:
305299
raise NotImplementedError(
306300
"'loglikelihood' requests for model type `hf-audiolm` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
307301
)

lm_eval/models/hf_vlms.py

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import logging
3-
from typing import Dict, List, Optional, Tuple, Union
43

54
import torch
65
import torch.nn.functional as F
@@ -15,11 +14,10 @@
1514
Collator,
1615
flatten_image_list,
1716
handle_stop_sequences,
18-
pad_and_concat,
1917
replace_placeholders,
2018
resize_image,
21-
stop_sequences_criteria,
2219
)
20+
from lm_eval.models.utils_hf import pad_and_concat, stop_sequences_criteria
2321

2422

2523
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
@@ -39,19 +37,19 @@ class HFMultimodalLM(HFLM):
3937

4038
def __init__(
4139
self,
42-
pretrained: Union[str, transformers.PreTrainedModel],
43-
image_token_id: Optional[int] = None,
44-
image_string: Optional[str] = None,
40+
pretrained: str | transformers.PreTrainedModel,
41+
image_token_id: int | None = None,
42+
image_string: str | None = None,
4543
interleave: bool = True,
4644
# TODO: handle whitespace in image placeholder (replacement)
47-
max_images: Optional[int] = 999,
45+
max_images: int | None = 999,
4846
convert_img_format=False,
4947
# For image resizing
50-
min_pixels: Optional[int] = None,
51-
max_pixels: Optional[int] = None,
52-
image_width: Optional[int] = None,
53-
image_height: Optional[int] = None,
54-
image_max_side: Optional[int] = None,
48+
min_pixels: int | None = None,
49+
max_pixels: int | None = None,
50+
image_width: int | None = None,
51+
image_height: int | None = None,
52+
image_max_side: int | None = None,
5553
**kwargs,
5654
):
5755
self.image_width = image_width
@@ -113,15 +111,10 @@ def __init__(
113111

114112
def _create_tokenizer(
115113
self,
116-
pretrained: Union[str, transformers.PreTrainedModel],
117-
tokenizer: Optional[
118-
Union[
119-
str,
120-
transformers.ProcessorMixin,
121-
]
122-
],
123-
revision: Optional[str] = "main",
124-
trust_remote_code: Optional[bool] = False,
114+
pretrained: str | transformers.PreTrainedModel,
115+
tokenizer: str | transformers.ProcessorMixin | None,
116+
revision: str | None = "main",
117+
trust_remote_code: bool | None = False,
125118
**kwargs,
126119
) -> None:
127120
"""
@@ -223,7 +216,7 @@ def _encode_multimodal_pair(self, context, continuation, images):
223216
return context_enc, continuation_enc, image_enc
224217

225218
def apply_chat_template(
226-
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
219+
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
227220
) -> str:
228221
self.chat_applied = True
229222
if not self.interleave:
@@ -279,7 +272,7 @@ def apply_chat_template(
279272
continue_final_message=not add_generation_prompt,
280273
)
281274

282-
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
275+
def chat_template(self, chat_template: bool | str = False) -> str | None:
283276
if hasattr(self.processor, "apply_chat_template"):
284277
_tokenizer = self.tokenizer
285278
self.tokenizer = self.processor
@@ -293,14 +286,14 @@ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str
293286

294287
def tok_batch_multimodal_encode(
295288
self,
296-
strings: List[str], # note that input signature of this fn is different
297-
images: List[List], # TODO: images are pil.Image at the moment, update typehint
289+
strings: list[str], # note that input signature of this fn is different
290+
images: list[list], # TODO: images are pil.Image at the moment, update typehint
298291
padding_side: str = "left",
299292
left_truncate_len: int = None,
300293
truncation: bool = False,
301-
) -> Union[
302-
BatchEncoding, Dict[str, torch.Tensor]
303-
]: # note that this return signature differs from HFLM tok_batch_encode.
294+
) -> (
295+
BatchEncoding | dict[str, torch.Tensor]
296+
): # note that this return signature differs from HFLM tok_batch_encode.
304297
# NOTE: here, we replace <image> tags with our model's corresponding image_token string value.
305298
if not self.chat_applied:
306299
# TODO<baber>: This still keeps the whitespace in the image placeholder, which is not ideal.
@@ -356,7 +349,7 @@ def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None):
356349

357350
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
358351
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
359-
do_sample = generation_kwargs.get("do_sample", None)
352+
do_sample = generation_kwargs.get("do_sample")
360353

361354
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
362355
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
@@ -398,7 +391,7 @@ def _batch_images(self, image_encs):
398391
)
399392
return batched_imgs
400393

401-
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
394+
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
402395
if requests and len(requests[0].args) < 3:
403396
# Fall back to non-multimodal generation.
404397
return super().loglikelihood_rolling(requests=requests)
@@ -408,8 +401,8 @@ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
408401
)
409402

410403
def loglikelihood(
411-
self, requests: List[Instance], disable_tqdm: bool = False
412-
) -> List[Tuple[float, bool]]:
404+
self, requests: list[Instance], disable_tqdm: bool = False
405+
) -> list[tuple[float, bool]]:
413406
if requests and len(requests[0].args) < 3:
414407
# Fall back to non-multimodal generation.
415408
return super().loglikelihood(requests=requests, disable_tqdm=disable_tqdm)
@@ -445,16 +438,16 @@ def loglikelihood(
445438

446439
def _multimodal_loglikelihood_tokens(
447440
self,
448-
requests: List[
449-
Tuple[Tuple[None, str, str], List[int], List[int], List[int]]
441+
requests: list[
442+
tuple[tuple[None, str, str], list[int], list[int], list[int]]
450443
], # TODO: update typehint to be correct
451444
disable_tqdm: bool = False,
452445
override_bs: int = None,
453-
) -> List[Tuple[float, bool]]:
446+
) -> list[tuple[float, bool]]:
454447
res = []
455448

456449
# TODO: **improve multimodal collation.** We currently ignore image size when ordering docs. ideally we'd take them into account
457-
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
450+
def _collate(req: tuple[tuple[str, str], list[int], list[int]]):
458451
"""Defines the key for the sorted method"""
459452
# the negative sign on len(toks) sorts descending - this has a few advantages:
460453
# - time estimates will always be over not underestimates, which is more useful for planning
@@ -465,7 +458,7 @@ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
465458
toks = req[1] + req[2]
466459
return -len(toks), tuple(toks)
467460

468-
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
461+
def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]):
469462
"""Defines the key to group and lookup one-token continuations"""
470463
# Use with group_by="contexts" (optional)"
471464
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
@@ -477,7 +470,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
477470
requests,
478471
sort_fn=_collate,
479472
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
480-
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
473+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM # noqa: SIM300
481474
and self.logits_cache
482475
else None,
483476
group_fn=_lookup_one_token_cont,
@@ -572,9 +565,9 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
572565
request_str,
573566
ctx_tokens,
574567
_,
575-
image_encs,
568+
_image_encs,
576569
), logits, inplen, cont_toks in zip(
577-
chunk, multi_logits, inplens, cont_toks_list
570+
chunk, multi_logits, inplens, cont_toks_list, strict=False
578571
):
579572
# Slice to original seq length
580573
contlen = len(cont_toks)
@@ -584,7 +577,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
584577
# from prompt/prefix tuning tokens, if applicable
585578
ctx_len = (
586579
inplen + (logits.shape[0] - padding_len_inp)
587-
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
580+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM # noqa: SIM300
588581
else None
589582
)
590583
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
@@ -598,30 +591,30 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
598591
# original args. Otherwise, expands the logits batch dimension and yields each
599592
# batch along with matching continuation tokens and prompt strings.
600593
# logits -> [1, seq, vocab]
601-
for request_str, cont_toks, logits in re_ord.get_cache(
594+
for _request_str, _cont_toks, _logits in re_ord.get_cache(
602595
req_str=request_str,
603596
cxt_toks=ctx_tokens,
604597
cont_toks=cont_toks,
605598
logits=logits,
606599
):
607-
cont_toks = torch.tensor(
608-
cont_toks, dtype=torch.long, device=self.device
600+
_cont_toks = torch.tensor(
601+
_cont_toks, dtype=torch.long, device=self.device
609602
).unsqueeze(0) # [1, seq]
610-
max_equal = (greedy_tokens == cont_toks).all()
603+
max_equal = (greedy_tokens == _cont_toks).all()
611604

612605
# Obtain log-probs at the corresponding continuation token indices
613606
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
614-
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
615-
-1
616-
) # [1, seq]
607+
_logits = torch.gather(
608+
_logits, 2, _cont_toks.unsqueeze(-1)
609+
).squeeze(-1) # [1, seq]
617610

618611
# Answer: (log prob, is-exact-match)
619-
answer = (float(logits.sum()), bool(max_equal))
612+
answer = (float(_logits.sum()), bool(max_equal))
620613

621614
res.append(answer)
622615

623616
self.cache_hook.add_partial(
624-
"loglikelihood", request_str, answer
617+
"loglikelihood", _request_str, answer
625618
) # TODO: choose convention for adding images into the cache key
626619
pbar.update(1)
627620

@@ -630,8 +623,8 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
630623
return re_ord.get_original(res)
631624

632625
def generate_until(
633-
self, requests: List[Instance], disable_tqdm: bool = False
634-
) -> List[str]:
626+
self, requests: list[Instance], disable_tqdm: bool = False
627+
) -> list[str]:
635628
if requests and len(requests[0].args) < 3:
636629
# Fall back to non-multimodal generation.
637630
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
@@ -669,7 +662,7 @@ def _collate(x):
669662
### Up to here: was identical to non-multimodal HFLM generate_until ###
670663
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
671664
for chunk in chunks:
672-
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
665+
contexts, all_gen_kwargs, aux_arguments = zip(*chunk, strict=False)
673666

674667
visuals = [
675668
[
@@ -732,7 +725,7 @@ def _collate(x):
732725
### essentially same as HFLM beyond this line!
733726

734727
cont_toks_list = cont.tolist()
735-
for cont_toks, context in zip(cont_toks_list, contexts):
728+
for cont_toks, context in zip(cont_toks_list, contexts, strict=False):
736729
# discard context + left-padding toks if using causal decoder-only VLM
737730
cont_toks = cont_toks[context_enc.shape[1] :]
738731

0 commit comments

Comments
 (0)