Skip to content

Commit ce1dbb5

Browse files
NathanHBCopilot
andauthored
Allow for model kwargs when loading transformers from pretrained (#754)
## Pull Request Overview This PR introduces support for passing custom keyword arguments when loading pretrained transformer models, enabling more flexible configuration of model loading. It also replaces the fixed "generation_size" parameter with a more general "model_loading_kwargs" field. - Removed the fixed generation_size parameter. - Added a new model_loading_kwargs field to the configuration. - Updated the auto model creation to copy the provided kwargs. Co-authored-by: Copilot <[email protected]> * suggestion from copilot --------- Co-authored-by: Copilot <[email protected]>
1 parent 317cb50 commit ce1dbb5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/lighteval/models/transformers/transformers_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch
2828
import torch.nn.functional as F
2929
import transformers
30-
from pydantic import PositiveInt
30+
from pydantic import Field, PositiveInt
3131
from torch.nn.utils.rnn import pad_sequence
3232
from torch.utils.data import DataLoader
3333
from tqdm import tqdm
@@ -137,8 +137,8 @@ class TransformersModelConfig(ModelConfig):
137137
subfolder: str | None = None
138138
revision: str = "main"
139139
batch_size: PositiveInt | None = None
140-
generation_size: PositiveInt = 256
141140
max_length: PositiveInt | None = None
141+
model_loading_kwargs: dict = Field(default_factory=dict)
142142
add_special_tokens: bool = True
143143
model_parallel: bool | None = None
144144
dtype: str | None = None
@@ -384,7 +384,7 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
384384

385385
pretrained_config = self.transformers_config
386386

387-
kwargs = {}
387+
kwargs = self.config.model_loading_kwargs.copy()
388388
if "quantization_config" not in pretrained_config.to_dict():
389389
kwargs["quantization_config"] = quantization_config
390390

0 commit comments

Comments
 (0)