Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,3 +1575,110 @@ def get_model_sha(pretrained: str, revision: str) -> str:
if self.delta:
model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
return model_info


class HFLM_Accelerate(HFLM):
"""
Supports custom HF models with accelerate already loaded.
"""

AUTO_MODEL_CLASS = None
_DEFAULT_MAX_LENGTH = 2048

def __init__(
self,
pretrained: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
accelerator: Accelerator,
backend: Literal["default", "causal", "seq2seq"] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision="main",
subfolder: str = "",
truncation: bool | None = False,
logits_cache: bool = True,
max_length: int | None = None,
softmax_dtype: str | torch.dtype | None = None,
mixed_precision_dtype: str | torch.dtype | None = None,
batch_size: int = 1,
max_batch_size: int | None = 64,
trust_remote_code: bool | None = False,
use_fast_tokenizer: bool | None = True,
add_bos_token: bool | None = False,
prefix_token_id: int | None = None,
# PEFT, delta weights and quantization options
peft: str | None = None,
delta: str | None = None,
gguf_file: str | None = None,
# end token for thinking, either the string or int token id.
# splits to get response after this token (if provided).
think_end_token: str | int | None = None,
enable_thinking: bool | None = None,
chat_template_args: dict[str, Any] | None = None,
) -> None:
TemplateLM.__init__(self) # Do not instantiate parent (HFLM), only instantiate grandparent (TemplateLM)
# optionally: take in an already-initialized transformers.PreTrainedModel
self._model = pretrained
self._config = self._model.config

self._get_backend(
config=self.config, backend=backend, trust_remote_code=trust_remote_code
)

# load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
self._create_tokenizer(
pretrained,
tokenizer,
revision=revision,
subfolder=subfolder,
trust_remote_code=trust_remote_code,
use_fast_tokenizer=use_fast_tokenizer,
gguf_file=gguf_file,
add_bos_token=add_bos_token,
)

self.think_end_token = (
int(think_end_token)
if (isinstance(think_end_token, str) and think_end_token.isdigit())
else think_end_token
)
self.truncation = truncation
self.logits_cache = logits_cache
self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
self.chat_template_args = (
chat_template_args or {} | dict(enable_thinking=enable_thinking)
if enable_thinking is not None
else {}
)

self.add_bos_token = add_bos_token
if "gemma" in getattr(self.config, "model_type", ""):
self.add_bos_token = True

self._max_length = max_length
self.pretrained = pretrained
self.delta = delta
self.peft = peft
self.revision = revision
self.batch_schedule = 1
self.batch_sizes = {}
self.max_batch_size = max_batch_size
self.softmax_dtype = (
get_dtype(softmax_dtype) if softmax_dtype is not None else None
)
self.mixed_precision_dtype = (
get_dtype(mixed_precision_dtype)
if mixed_precision_dtype is not None
else None
)
self.batch_size_per_gpu = int(batch_size)
# Accelerate already loaded
self._device = torch.device(f"{accelerator.device}")
self.accelerator = accelerator

self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes

self.custom_prefix_token_id = prefix_token_id
self.model.tie_weights()