Skip to content

Commit 2236e17

Browse files
authored
removed duplicate code, useless function, added stronger deletion of items, plus updated the logic in generation size to respect what the user asks (#1073)
1 parent 99162f1 commit 2236e17

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

src/lighteval/models/transformers/transformers_model.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@
6666

6767
STARTING_BATCH_SIZE = 512
6868

69+
# Thread local param
70+
torch.set_grad_enabled(False)
71+
6972

7073
class TransformersModelConfig(ModelConfig):
7174
"""Configuration class for HuggingFace Transformers models.
@@ -218,12 +221,6 @@ def __init__(
218221
if config.model_parallel is False and self.config.dtype not in ["4bit", "8bit"]:
219222
logger.info(f"Using Data Parallelism, putting model on device {self._device}")
220223
self.model = self.model.to(self._device)
221-
if config.compile:
222-
try:
223-
logger.info("Compiling the model")
224-
self.model.model.compile()
225-
except AttributeError as e:
226-
logger.warning("Could not compile the model because: ", e)
227224

228225
self.model_name = _simplify_name(config.model_name)
229226

@@ -410,7 +407,7 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
410407
)
411408
# model.to(self.device)
412409
model.eval()
413-
torch.set_grad_enabled(False)
410+
414411
if self.continuous_batching:
415412
generation_config = GenerationConfig(
416413
**self.generation_config_dict,
@@ -497,9 +494,6 @@ def _check_continuations_start_space(self, continuation: str) -> str:
497494
continuation = continuation.lstrip()
498495
return continuation
499496

500-
def _model_call(self, inputs: torch.Tensor) -> torch.Tensor:
501-
return self.model(inputs).logits
502-
503497
def _get_batch_size(self, max_input_length: int, override_bs: int | None, starting_batch_size: int = 512) -> int:
504498
if override_bs is not None:
505499
return override_bs
@@ -509,10 +503,18 @@ def _get_batch_size(self, max_input_length: int, override_bs: int | None, starti
509503
starting_batch_size=starting_batch_size
510504
) # if OOM, then halves batch_size and tries again
511505
def forward_batch(batch_size):
512-
test_batch = torch.ones(
513-
(batch_size + int(0.1 * batch_size), max_input_length), device=self.device
514-
).long() # We add 10% for marging :)
515-
F.log_softmax(self._model_call(test_batch).float(), dim=-1).cpu()
506+
fake_batch, fake_output = None, None
507+
with torch.no_grad():
508+
try:
509+
fake_batch = torch.ones((batch_size, max_input_length), device=self.device).int()
510+
fake_output = F.log_softmax(self.model(fake_batch).logits, dim=-1).cpu()
511+
except Exception as e:
512+
for fake_item in [fake_batch, fake_output]:
513+
if fake_item is not None:
514+
fake_item.detach()
515+
del fake_item
516+
517+
raise e
516518
return batch_size
517519

518520
batch_size = forward_batch()
@@ -645,10 +647,14 @@ def _padded_greedy_until(
645647
position=0,
646648
disable=self.disable_tqdm,
647649
):
648-
if split[0].generation_size is None:
650+
if self.generation_config_dict.get("max_new_tokens", None) is not None:
651+
# The user forces a specific generation size
652+
max_context_continuation_size_allowed = self.generation_config_dict["max_new_tokens"]
653+
elif split[0].generation_size is None:
649654
# No constraints on the generation size: max length allowed is the max model context
650655
max_context_continuation_size_allowed = self.max_length
651656
else:
657+
# The task forces a specific generation size
652658
context = self.prompt_manager.prepare_prompt(split[0])
653659
tokenized_context = self.tokenizer(context)
654660

@@ -953,7 +959,7 @@ def _loglikelihood_tokens( # noqa: C901
953959
max_context=None, # computed as model max length in the function
954960
)
955961

956-
model_output = self._model_call(prepared_batch.input_ids)
962+
model_output = self.model(prepared_batch.input_ids).logits
957963
logits = F.log_softmax(model_output, dim=-1) # [batch, sequence_length, vocab]
958964

959965
flat_index = 0

0 commit comments

Comments
 (0)