Skip to content

Commit ea72931

Browse files
VectorrentNathanHB
andauthored
Fix TransformersModel.from_model() method (huggingface#691)
* lift protobuf restriction * fix typos * remove incorrect comment * initial attempt to fix the from_model() method * ensure input tensors are moved to proper device * remove conditional, since device should never be None here --------- Co-authored-by: Nathan Habib <[email protected]>
1 parent e67ed9c commit ea72931

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/lighteval/models/transformers/transformers_model.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def cleanup(self):
236236
def from_model(
237237
cls,
238238
model: Union[AutoModelForCausalLM, LightevalModel],
239+
config: TransformersModelConfig = None,
239240
accelerator: "Accelerator" = None,
240241
tokenizer_name: str = None, # custom tokenizer
241242
trust_remote_code: bool = False,
@@ -253,16 +254,14 @@ def from_model(
253254

254255
# Instanciate the object without using __init__
255256
self = cls.__new__(cls)
256-
self._config = model.config
257-
self._max_length = self._init_max_length(max_length=model.config.max_length)
258-
self._tokenizer = self._create_auto_tokenizer_with_name(
259-
model_name=model.name_or_path,
260-
revision=model.config._commit_hash,
261-
trust_remote_code=trust_remote_code,
262-
tokenizer_name=tokenizer_name,
263-
)
257+
self.config = config
258+
self.transformers_config = model.config
259+
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
260+
self._max_length = self._init_max_length()
261+
self._tokenizer = self._create_auto_tokenizer()
262+
self.batch_size = config.batch_size
264263
self.model_name = _simplify_name(model.name_or_path)
265-
self.model_sha = model.config._commit_hash
264+
self.model_sha = config.get_model_sha()
266265

267266
# If model_parallel is not set we compare the number of processes with the number of GPUs
268267
self.model = model
@@ -274,14 +273,14 @@ def from_model(
274273
self._device = accelerator.device
275274
self.model = self.accelerator.prepare(self.model.to(accelerator.device))
276275
else:
277-
self._device = "cpu"
276+
self._device = self.config.device
278277

279278
self.use_chat_template = use_chat_template
280279
self._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
281280
self.pairwise_tokenization = pairwise_tokenization
282281
self.multichoice_continuations_start_space = multichoice_continuations_start_space
283282

284-
self.precision = _get_dtype(model.dtype, config=self._config)
283+
self.precision = _get_dtype(model.dtype, config=self.transformers_config)
285284

286285
if is_accelerate_available():
287286
model_size, _ = calculate_maximum_sizes(self.model)
@@ -450,6 +449,7 @@ def _init_max_length(self) -> int:
450449
Returns:
451450
int: Max length to use depending on the available args and config
452451
"""
452+
453453
if self.config.max_length is not None:
454454
return self.config.max_length
455455

0 commit comments

Comments
 (0)