Skip to content

Commit 818a2cf

Browse files
authored
fix slow tests (#689)
tests on the aws runner were hanging, culprit was multiporcessing when loading datasets.
1 parent 989f5f5 commit 818a2cf

File tree

5 files changed

+19
-19
lines changed

5 files changed

+19
-19
lines changed

.github/workflows/slow_tests.yaml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@ jobs:
3737
- name: Install the project
3838
run: uv sync --extra dev
3939

40-
- name: Ensure cache directories exist
41-
run: mkdir -p cache/models cache/datasets
4240

4341
- name: Run tests
44-
env:
45-
HF_HOME: "cache/models"
46-
HF_DATASETS_CACHE: "cache/datasets"
47-
run: uv run pytest --disable-pytest-warnings --runslow tests/slow_tests
42+
run: uv run pytest --disable-pytest-warnings -o log_cli=true -o log_cli_level=INFO --runslow tests/slow_tests/

examples/model_configs/transformers_model.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ model_parameters:
44
dtype: "float16"
55
compile: false
66
model_parallel: false
7+
batch_size: 1
78
multichoice_continuations_start_space: null # If true/false, will force multiple choice continuations to start/not start with a space. If none, will do nothing
89
generation_parameters:
910
temperature: 0.2

src/lighteval/models/transformers/transformers_model.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ def __init__(
226226
model_size=str(model_size),
227227
)
228228

229+
def cleanup(self):
230+
"""Clean up operations if needed, such as closing an endpoint."""
231+
del self.model
232+
del self._tokenizer
233+
torch.cuda.empty_cache()
234+
229235
@classmethod
230236
def from_model(
231237
cls,
@@ -543,7 +549,7 @@ def greedy_until(
543549
longest_context_continuation_size_in_split, self.max_length
544550
)
545551
batch_size = self._get_batch_size(
546-
override_bs=self.batch_size,
552+
override_bs=self.config.batch_size,
547553
max_input_length=max_context_continuation_size_allowed,
548554
starting_batch_size=starting_batch_size,
549555
)
@@ -710,7 +716,6 @@ def _generate(
710716
def loglikelihood(
711717
self,
712718
requests: list[LoglikelihoodRequest],
713-
override_bs: Optional[int] = None,
714719
) -> list[LoglikelihoodResponse]:
715720
"""Tokenize the context and continuation and compute the log likelihood of those
716721
tokenized sequences.
@@ -731,12 +736,11 @@ def loglikelihood(
731736
request.context, request.choice, pairwise=self.pairwise_tokenization
732737
)
733738

734-
return self._loglikelihood_tokens(requests, override_bs=override_bs)
739+
return self._loglikelihood_tokens(requests)
735740

736741
def loglikelihood_rolling(
737742
self,
738743
requests: list[LoglikelihoodRollingRequest],
739-
override_bs=None,
740744
) -> list[LoglikelihoodResponse]:
741745
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
742746

@@ -746,7 +750,6 @@ def loglikelihood_rolling(
746750

747751
results = self._loglikelihood_tokens(
748752
requests,
749-
override_bs=override_bs,
750753
return_bool_score=False,
751754
rolling=True,
752755
)
@@ -755,7 +758,6 @@ def loglikelihood_rolling(
755758
def _loglikelihood_tokens(
756759
self,
757760
requests: list[LoglikelihoodRequest],
758-
override_bs: int = -1,
759761
return_bool_score: bool = True,
760762
rolling: bool = False,
761763
) -> list[LoglikelihoodResponse]:
@@ -774,7 +776,7 @@ def _loglikelihood_tokens(
774776
)
775777

776778
batch_size = self._get_batch_size(
777-
override_bs=override_bs,
779+
override_bs=self.config.batch_size,
778780
max_input_length=max_context_continuation_size_allowed,
779781
starting_batch_size=starting_batch_size,
780782
)
@@ -967,7 +969,8 @@ def pad_and_gather(
967969
return output_tensor, length_tensor
968970

969971
def loglikelihood_single_token(
970-
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
972+
self,
973+
requests: list[LoglikelihoodSingleTokenRequest],
971974
) -> list[LoglikelihoodSingleTokenResponse]:
972975
"""Tokenize the context and continuation and compute the log likelihood of those
973976
tokenized sequences.
@@ -996,10 +999,11 @@ def loglikelihood_single_token(
996999
)
9971000
request.tokenized_continuation = continuations_enc
9981001

999-
return self._loglikelihood_single_token(requests, override_bs=override_bs)
1002+
return self._loglikelihood_single_token(requests)
10001003

10011004
def _loglikelihood_single_token(
1002-
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: int = -1
1005+
self,
1006+
requests: list[LoglikelihoodSingleTokenRequest],
10031007
) -> list[LoglikelihoodSingleTokenResponse]:
10041008
dataset = LoglikelihoodSingleTokenDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
10051009
starting_batch_size = STARTING_BATCH_SIZE
@@ -1008,7 +1012,7 @@ def _loglikelihood_single_token(
10081012
for split_start, split_end in tqdm(dataset.splits_start_end_iterator()):
10091013
context_enc = dataset[0].tokenized_context
10101014
max_context = len(context_enc[-self.max_length :])
1011-
batch_size = self._get_batch_size(override_bs=override_bs, max_input_length=max_context)
1015+
batch_size = self._get_batch_size(override_bs=self.config.batch_size, max_input_length=max_context)
10121016
starting_batch_size = batch_size * 2
10131017

10141018
dataloader = DataLoader(dataset, batch_size=starting_batch_size, collate_fn=lambda batch: batch)

tests/slow_tests/test_accelerate_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def run_model(model_name: str, use_chat_template: bool):
5656
tasks=TASKS_PATH,
5757
use_chat_template=use_chat_template,
5858
output_dir="",
59-
dataset_loading_processes=8,
59+
dataset_loading_processes=1,
6060
save_details=False,
6161
max_samples=10,
6262
custom_tasks=CUSTOM_TASKS_PATH,

tests/slow_tests/test_vllm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def run_model(model_name: str, use_chat_template: bool):
5656
tasks=TASKS_PATH,
5757
use_chat_template=use_chat_template,
5858
output_dir="",
59-
dataset_loading_processes=8,
59+
dataset_loading_processes=1,
6060
save_details=False,
6161
max_samples=10,
6262
custom_tasks=CUSTOM_TASKS_PATH,

0 commit comments

Comments
 (0)