@@ -236,6 +236,7 @@ def cleanup(self):
236
236
def from_model (
237
237
cls ,
238
238
model : Union [AutoModelForCausalLM , LightevalModel ],
239
+ config : TransformersModelConfig = None ,
239
240
accelerator : "Accelerator" = None ,
240
241
tokenizer_name : str = None , # custom tokenizer
241
242
trust_remote_code : bool = False ,
@@ -253,16 +254,14 @@ def from_model(
253
254
254
255
# Instanciate the object without using __init__
255
256
self = cls .__new__ (cls )
256
- self ._config = model .config
257
- self ._max_length = self ._init_max_length (max_length = model .config .max_length )
258
- self ._tokenizer = self ._create_auto_tokenizer_with_name (
259
- model_name = model .name_or_path ,
260
- revision = model .config ._commit_hash ,
261
- trust_remote_code = trust_remote_code ,
262
- tokenizer_name = tokenizer_name ,
263
- )
257
+ self .config = config
258
+ self .transformers_config = model .config
259
+ self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
260
+ self ._max_length = self ._init_max_length ()
261
+ self ._tokenizer = self ._create_auto_tokenizer ()
262
+ self .batch_size = config .batch_size
264
263
self .model_name = _simplify_name (model .name_or_path )
265
- self .model_sha = model . config ._commit_hash
264
+ self .model_sha = config .get_model_sha ()
266
265
267
266
# If model_parallel is not set we compare the number of processes with the number of GPUs
268
267
self .model = model
@@ -274,14 +273,14 @@ def from_model(
274
273
self ._device = accelerator .device
275
274
self .model = self .accelerator .prepare (self .model .to (accelerator .device ))
276
275
else :
277
- self ._device = "cpu"
276
+ self ._device = self . config . device
278
277
279
278
self .use_chat_template = use_chat_template
280
279
self ._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
281
280
self .pairwise_tokenization = pairwise_tokenization
282
281
self .multichoice_continuations_start_space = multichoice_continuations_start_space
283
282
284
- self .precision = _get_dtype (model .dtype , config = self ._config )
283
+ self .precision = _get_dtype (model .dtype , config = self .transformers_config )
285
284
286
285
if is_accelerate_available ():
287
286
model_size , _ = calculate_maximum_sizes (self .model )
@@ -450,6 +449,7 @@ def _init_max_length(self) -> int:
450
449
Returns:
451
450
int: Max length to use depending on the available args and config
452
451
"""
452
+
453
453
if self .config .max_length is not None :
454
454
return self .config .max_length
455
455
0 commit comments