Skip to content

Commit 4d2bf42

Browse files
NathanHBCopilot
andauthored
fixes from_model function and adds tests (#921)
* fixes from_model function and adds tests * removes mock accelerator from tests * fixes tests by adding model_configs where needed * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * fixes from review * fixes from review * fixes from review * fixes from review --------- Co-authored-by: Copilot <[email protected]>
1 parent 458a3e9 commit 4d2bf42

File tree

3 files changed

+90
-47
lines changed

3 files changed

+90
-47
lines changed

src/lighteval/models/transformers/transformers_model.py

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
)
4444
from transformers.generation.configuration_utils import GenerationConfig
4545
from transformers.generation.utils import GenerateOutput
46-
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
4746

4847
from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
4948
from lighteval.models.abstract_model import LightevalModel, ModelConfig
@@ -245,39 +244,34 @@ def cleanup(self):
245244
@classmethod
246245
def from_model(
247246
cls,
248-
model: Union[AutoModelForCausalLM, LightevalModel],
249-
config: TransformersModelConfig = None,
250-
accelerator: "Accelerator" = None,
251-
tokenizer_name: str = None, # custom tokenizer
252-
trust_remote_code: bool = False,
253-
add_special_tokens: bool = True,
254-
skip_special_tokens: bool = True,
255-
pairwise_tokenization: bool = False,
256-
multichoice_continuations_start_space: bool = None,
257-
):
258-
# Slightly hackish way to test if the model is a AutoModelForCausalLM, since the instances don't
259-
# derive from this class explicitely
260-
assert isinstance(model, LightevalModel) or type(model).__name__ in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()
261-
262-
if isinstance(model, LightevalModel):
263-
return model
247+
model: AutoModelForCausalLM,
248+
config: TransformersModelConfig,
249+
accelerator: Accelerator | None = None,
250+
) -> "TransformersModel":
251+
if config is None:
252+
raise ValueError("Config must be provided to initialize the TransformersModel via `from_model` method.")
264253

265254
# Instanciate the object without using __init__
266255
self = cls.__new__(cls)
256+
267257
self.transformers_config = model.config
268-
if isinstance(model, TransformersModel):
269-
self.config = model.config
270-
else:
271-
self.config = (
272-
config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path)
273-
)
274-
if config is not None:
275-
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
258+
259+
self.config = config
260+
self.multichoice_continuations_start_space = config.multichoice_continuations_start_space
261+
self._add_special_tokens = config.add_special_tokens
262+
self.skip_special_tokens = config.skip_special_tokens
263+
self.pairwise_tokenization = config.pairwise_tokenization
264+
self.batch_size = config.batch_size
265+
self.continuous_batching = config.continuous_batching
266+
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
267+
268+
self.model_name = config.model_name
269+
self.model_sha = config.get_model_sha()
276270
self._max_length = self._init_max_length()
277271
self._tokenizer = self._create_auto_tokenizer()
278-
self.batch_size = getattr(config, "batch_size", None)
279-
self.model_name = _simplify_name(model.name_or_path)
280-
self.model_sha = self.config.get_model_sha()
272+
self.use_chat_template = uses_chat_template(
273+
tokenizer=self._tokenizer, override_chat_template=config.override_chat_template
274+
)
281275

282276
# If model_parallel is not set we compare the number of processes with the number of GPUs
283277
self.model = model
@@ -291,16 +285,6 @@ def from_model(
291285
else:
292286
self._device = self.config.device
293287

294-
self.use_chat_template = uses_chat_template(
295-
tokenizer=self._tokenizer, override_chat_template=config.override_chat_template
296-
)
297-
self._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
298-
self.skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True
299-
self.pairwise_tokenization = pairwise_tokenization
300-
self.multichoice_continuations_start_space = multichoice_continuations_start_space
301-
302-
self.precision = _get_dtype(model.dtype, config=self.transformers_config)
303-
304288
if is_accelerate_available():
305289
model_size, _ = calculate_maximum_sizes(self.model)
306290
model_size = convert_bytes(model_size)

src/lighteval/pipeline.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from lighteval.logging.evaluation_tracker import EvaluationTracker
3737
from lighteval.metrics import apply_metric
38+
from lighteval.models.abstract_model import LightevalModel, ModelConfig
3839
from lighteval.models.model_loader import TransformersModel, load_model
3940
from lighteval.models.model_output import (
4041
ModelResponse,
@@ -155,7 +156,7 @@ def __init__(
155156
tasks: str,
156157
pipeline_parameters: PipelineParameters,
157158
evaluation_tracker: EvaluationTracker,
158-
model_config=None,
159+
model_config: ModelConfig | None = None,
159160
model=None,
160161
metric_options=None,
161162
):
@@ -205,7 +206,24 @@ def _init_parallelism_manager(self):
205206

206207
def _init_model(self, model_config, model):
207208
logger.info("--- LOADING MODEL ---")
208-
if model_config is not None:
209+
210+
if model is not None and model_config is not None:
211+
if isinstance(model, LightevalModel):
212+
raise ValueError(
213+
"You are trying to provide both a LightevalModel and a model config. Please provide only one of them."
214+
)
215+
return TransformersModel.from_model(
216+
model=model,
217+
config=model_config,
218+
accelerator=self.accelerator,
219+
)
220+
221+
elif model is not None:
222+
if isinstance(model, LightevalModel):
223+
return model
224+
raise ValueError("If not providing a model_config, you need to provide a Lighteval model.")
225+
226+
elif model_config is not None:
209227
if self.parallel_context:
210228
return NanotronLightevalModel(
211229
checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path)
@@ -218,13 +236,6 @@ def _init_model(self, model_config, model):
218236
)
219237
else:
220238
return load_model(config=model_config)
221-
if isinstance(model, TransformersModel):
222-
return model
223-
else:
224-
return TransformersModel.from_model(
225-
model=model,
226-
accelerator=self.accelerator,
227-
)
228239

229240
def _init_tasks_and_requests(self, tasks: str):
230241
with local_ranks_zero_first() if self.launcher_type == ParallelismManager.NANOTRON else nullcontext():

tests/models/test_transformers_model.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,54 @@ def test_model_creation_model(self):
117117
self.assertEqual(str(self.model.model), str(self.reference_model))
118118

119119

120+
class TestTransformersModelCreationFromModel(unittest.TestCase):
121+
def setUp(self):
122+
"""Set up shared model instance for all tests."""
123+
self.reference_model = AutoModelForCausalLM.from_pretrained("gpt2")
124+
self.reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
125+
126+
max_length = 1234
127+
self.reference_tokenizer.model_max_length = max_length
128+
129+
self.config = TransformersModelConfig(model_name="gpt2", max_length=max_length)
130+
131+
# Create full model instance
132+
self.model = TransformersModel.from_model(
133+
model=self.reference_model,
134+
config=self.config,
135+
)
136+
137+
def test_model_creation_tokenizer(self):
138+
for attribute in [
139+
"name_or_path",
140+
"vocab_size",
141+
"model_max_length",
142+
"is_fast",
143+
"clean_up_tokenization_spaces",
144+
"added_tokens_decoder",
145+
]:
146+
with self.subTest(attribute=attribute):
147+
self.assertEqual(
148+
getattr(self.model.tokenizer, attribute), getattr(self.reference_tokenizer, attribute)
149+
)
150+
151+
def test_model_creation_attributes(self):
152+
"""Test that TransformersModel creates and initializes basic attributes correctly."""
153+
# Test attributes are set correctly
154+
self.assertEqual(self.model.config, self.config)
155+
self.assertEqual(self.model.multichoice_continuations_start_space, None)
156+
self.assertTrue(self.model._add_special_tokens)
157+
self.assertFalse(self.model.pairwise_tokenization)
158+
self.assertIsNone(self.model.batch_size)
159+
self.assertFalse(self.model.continuous_batching)
160+
self.assertEqual(self.model.model_name, self.config.model_name)
161+
self.assertEqual(self.model.max_length, self.config.max_length)
162+
163+
def test_model_creation_model(self):
164+
# We can't compare objects directly
165+
self.assertEqual(str(self.model.model), str(self.reference_model))
166+
167+
120168
class TestTransformersModelProcessing(unittest.TestCase):
121169
@patch("lighteval.models.transformers.transformers_model.Accelerator")
122170
def setUp(self, mock_accelerator):

0 commit comments

Comments
 (0)