43
43
)
44
44
from transformers .generation .configuration_utils import GenerationConfig
45
45
from transformers .generation .utils import GenerateOutput
46
- from transformers .models .auto .modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
47
46
48
47
from lighteval .data import GenerativeTaskDataset , LoglikelihoodDataset
49
48
from lighteval .models .abstract_model import LightevalModel , ModelConfig
@@ -245,39 +244,34 @@ def cleanup(self):
245
244
@classmethod
246
245
def from_model (
247
246
cls ,
248
- model : Union [AutoModelForCausalLM , LightevalModel ],
249
- config : TransformersModelConfig = None ,
250
- accelerator : "Accelerator" = None ,
251
- tokenizer_name : str = None , # custom tokenizer
252
- trust_remote_code : bool = False ,
253
- add_special_tokens : bool = True ,
254
- skip_special_tokens : bool = True ,
255
- pairwise_tokenization : bool = False ,
256
- multichoice_continuations_start_space : bool = None ,
257
- ):
258
- # Slightly hackish way to test if the model is a AutoModelForCausalLM, since the instances don't
259
- # derive from this class explicitely
260
- assert isinstance (model , LightevalModel ) or type (model ).__name__ in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES .values ()
261
-
262
- if isinstance (model , LightevalModel ):
263
- return model
247
+ model : AutoModelForCausalLM ,
248
+ config : TransformersModelConfig ,
249
+ accelerator : Accelerator | None = None ,
250
+ ) -> "TransformersModel" :
251
+ if config is None :
252
+ raise ValueError ("Config must be provided to initialize the TransformersModel via `from_model` method." )
264
253
265
254
# Instanciate the object without using __init__
266
255
self = cls .__new__ (cls )
256
+
267
257
self .transformers_config = model .config
268
- if isinstance (model , TransformersModel ):
269
- self .config = model .config
270
- else :
271
- self .config = (
272
- config if config is not None else TransformersModelConfig (model_name = model .config .name_or_path )
273
- )
274
- if config is not None :
275
- self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
258
+
259
+ self .config = config
260
+ self .multichoice_continuations_start_space = config .multichoice_continuations_start_space
261
+ self ._add_special_tokens = config .add_special_tokens
262
+ self .skip_special_tokens = config .skip_special_tokens
263
+ self .pairwise_tokenization = config .pairwise_tokenization
264
+ self .batch_size = config .batch_size
265
+ self .continuous_batching = config .continuous_batching
266
+ self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
267
+
268
+ self .model_name = config .model_name
269
+ self .model_sha = config .get_model_sha ()
276
270
self ._max_length = self ._init_max_length ()
277
271
self ._tokenizer = self ._create_auto_tokenizer ()
278
- self .batch_size = getattr ( config , "batch_size" , None )
279
- self .model_name = _simplify_name ( model . name_or_path )
280
- self . model_sha = self . config . get_model_sha ( )
272
+ self .use_chat_template = uses_chat_template (
273
+ tokenizer = self ._tokenizer , override_chat_template = config . override_chat_template
274
+ )
281
275
282
276
# If model_parallel is not set we compare the number of processes with the number of GPUs
283
277
self .model = model
@@ -291,16 +285,6 @@ def from_model(
291
285
else :
292
286
self ._device = self .config .device
293
287
294
- self .use_chat_template = uses_chat_template (
295
- tokenizer = self ._tokenizer , override_chat_template = config .override_chat_template
296
- )
297
- self ._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
298
- self .skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True
299
- self .pairwise_tokenization = pairwise_tokenization
300
- self .multichoice_continuations_start_space = multichoice_continuations_start_space
301
-
302
- self .precision = _get_dtype (model .dtype , config = self .transformers_config )
303
-
304
288
if is_accelerate_available ():
305
289
model_size , _ = calculate_maximum_sizes (self .model )
306
290
model_size = convert_bytes (model_size )
0 commit comments