@@ -236,6 +236,7 @@ def cleanup(self):
236236 def from_model (
237237 cls ,
238238 model : Union [AutoModelForCausalLM , LightevalModel ],
239+ config : TransformersModelConfig = None ,
239240 accelerator : "Accelerator" = None ,
240241 tokenizer_name : str = None , # custom tokenizer
241242 trust_remote_code : bool = False ,
@@ -253,16 +254,14 @@ def from_model(
253254
254255 # Instanciate the object without using __init__
255256 self = cls .__new__ (cls )
256- self ._config = model .config
257- self ._max_length = self ._init_max_length (max_length = model .config .max_length )
258- self ._tokenizer = self ._create_auto_tokenizer_with_name (
259- model_name = model .name_or_path ,
260- revision = model .config ._commit_hash ,
261- trust_remote_code = trust_remote_code ,
262- tokenizer_name = tokenizer_name ,
263- )
257+ self .config = config
258+ self .transformers_config = model .config
259+ self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
260+ self ._max_length = self ._init_max_length ()
261+ self ._tokenizer = self ._create_auto_tokenizer ()
262+ self .batch_size = config .batch_size
264263 self .model_name = _simplify_name (model .name_or_path )
265- self .model_sha = model . config ._commit_hash
264+ self .model_sha = config .get_model_sha ()
266265
267266 # If model_parallel is not set we compare the number of processes with the number of GPUs
268267 self .model = model
@@ -274,14 +273,14 @@ def from_model(
274273 self ._device = accelerator .device
275274 self .model = self .accelerator .prepare (self .model .to (accelerator .device ))
276275 else :
277- self ._device = "cpu"
276+ self ._device = self . config . device
278277
279278 self .use_chat_template = use_chat_template
280279 self ._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
281280 self .pairwise_tokenization = pairwise_tokenization
282281 self .multichoice_continuations_start_space = multichoice_continuations_start_space
283282
284- self .precision = _get_dtype (model .dtype , config = self ._config )
283+ self .precision = _get_dtype (model .dtype , config = self .transformers_config )
285284
286285 if is_accelerate_available ():
287286 model_size , _ = calculate_maximum_sizes (self .model )
@@ -450,6 +449,7 @@ def _init_max_length(self) -> int:
450449 Returns:
451450 int: Max length to use depending on the available args and config
452451 """
452+
453453 if self .config .max_length is not None :
454454 return self .config .max_length
455455
0 commit comments