Skip to content

Commit ea1dd18

Browse files
authored
Debug log likelihood evals which are broken for accelerate (#901)
* padding should be better now, also better batch size management to prevent ooms * fix in case bs is 0 * added tests to avoid this padding break mess
1 parent 64f93b0 commit ea1dd18

File tree

7 files changed

+431
-19
lines changed

7 files changed

+431
-19
lines changed

src/lighteval/models/model_input.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class GenerationParameters(BaseModel, extra="forbid"):
4747
top_p: NonNegativeFloat | None = None # vllm, transformers, tgi, litellm, sglang
4848
truncate_prompt: bool | None = None # vllm, tgi
4949

50+
cache_implementation: str | None = None # transformers
51+
5052
# response format to be followed by the model,
5153
# more info here https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
5254
response_format: str | None = None # inference_providers
@@ -192,6 +194,7 @@ def to_transformers_dict(self) -> dict:
192194
"num_blocks": self.num_blocks,
193195
"block_size": self.block_size,
194196
"return_dict_in_generate": True,
197+
"cache_implementation": self.cache_implementation,
195198
}
196199
return {k: v for k, v in args.items() if v is not None}
197200

src/lighteval/models/transformers/transformers_model.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -923,8 +923,12 @@ def _loglikelihood_tokens( # noqa: C901
923923
starting_batch_size=starting_batch_size,
924924
)
925925
starting_batch_size = batch_size * 2
926+
max_num_choices = max(len(d.choices) for d in split)
927+
# We divide the batch size by the number of choices as batch is samples * num choices
928+
# then round up to closest 8 multiple
929+
batch_size = max(1, round(batch_size // max_num_choices / 8) * 8)
926930
logger.warning(
927-
f"batch size is set to {batch_size} however, logliklehood evaluates on n choices per samples so batch size will be muiltiplied by number of choices per sample"
931+
f"batch size is set to {batch_size} (it should be understood as '{batch_size} times the maximum number of choices per sample, {max_num_choices}')"
928932
)
929933

930934
dataloader = DataLoader(split, batch_size=batch_size, collate_fn=lambda batch: batch)
@@ -1026,33 +1030,46 @@ def _loglikelihood_tokens( # noqa: C901
10261030
if self.accelerator:
10271031
# Convert lists to tensors for proper gathering
10281032
# Pad and stack the tensors to make them gatherable
1029-
choices_lengths = [len(choices) for choices in batch_tokenized_continuations_processed]
1030-
choices_lengths_tensor = torch.tensor(choices_lengths, device=self.device)
1031-
gathered_choices_lengths = self.accelerator.gather_for_metrics(choices_lengths_tensor)
1032-
global_max_choices = gathered_choices_lengths.max().item()
1033+
shape_choices = [
1034+
choices.shape for choices in batch_tokenized_continuations_processed
1035+
] # num_choices * max len choices
1036+
num_choices_tensor = torch.tensor([shape[0] for shape in shape_choices], device=self.device)
1037+
len_choices_tensor = torch.tensor([shape[1] for shape in shape_choices], device=self.device)
1038+
gathered_num_choices = self.accelerator.gather_for_metrics(num_choices_tensor)
1039+
gathered_len_choices = self.accelerator.gather_for_metrics(len_choices_tensor)
1040+
max_num_choices = gathered_num_choices.max().item()
1041+
max_len_choices = gathered_len_choices.max().item()
1042+
len_context_tensor = torch.tensor(
1043+
[len(ctx) for ctx in batch_tokenized_contexts_processed], device=self.device
1044+
)
1045+
gathered_len_context = self.accelerator.gather_for_metrics(len_context_tensor)
1046+
max_len_context = gathered_len_context.max().item()
10331047

1034-
# Pad logits_sum_batch to same size
1048+
# 1d - Pad logits_sum and max_equals to same number of choices
10351049
padded_logits_sums = []
10361050
for logits_sum_doc in batch_logits_sums:
1037-
pad_amount = global_max_choices - len(logits_sum_doc)
1051+
pad_amount = max_num_choices - len(logits_sum_doc)
10381052
padded = F.pad(logits_sum_doc, (0, pad_amount), value=-1)
10391053
padded_logits_sums.append(padded)
10401054

10411055
padded_max_equals = []
10421056
for max_equals_doc in batch_max_equals:
1043-
pad_amount = global_max_choices - len(max_equals_doc)
1057+
pad_amount = max_num_choices - len(max_equals_doc)
10441058
padded = F.pad(max_equals_doc, (0, pad_amount), value=False)
10451059
padded_max_equals.append(padded)
10461060

1061+
# 2d - Pad continuations to max number of choice and max length
10471062
padded_continuations = []
10481063
for cont_batch in batch_tokenized_continuations_processed:
1049-
pad_amount = global_max_choices - cont_batch.shape[0]
1050-
padded = F.pad(cont_batch, (0, pad_amount), value=-1)
1064+
pad_amount_num = max_num_choices - cont_batch.shape[0]
1065+
pad_amount_len = max_len_choices - cont_batch.shape[1]
1066+
padded = F.pad(cont_batch, (0, pad_amount_len, 0, pad_amount_num), value=-1)
10511067
padded_continuations.append(padded)
10521068

1069+
# 1d - Pad context to maximum context size
10531070
padded_contexts = []
10541071
for ctx_batch in batch_tokenized_contexts_processed:
1055-
pad_amount = global_max_choices - ctx_batch.shape[0]
1072+
pad_amount = max_len_context - len(ctx_batch)
10561073
padded = F.pad(ctx_batch, (0, pad_amount), value=-1)
10571074
padded_contexts.append(padded)
10581075

@@ -1075,12 +1092,19 @@ def _loglikelihood_tokens( # noqa: C901
10751092
batch_tokenized_contexts_processed = []
10761093

10771094
# Only process if we have gathered results
1078-
for i, actual_count in enumerate(gathered_choices_lengths):
1079-
# Extract non-padded values based on actual counts
1080-
batch_logits_sums.append(gathered_logits_sums[i][:actual_count])
1081-
batch_max_equals.append(gathered_max_equals[i][:actual_count])
1082-
batch_tokenized_continuations_processed.append(gathered_continuations[i][:actual_count])
1083-
batch_tokenized_contexts_processed.append(gathered_contexts[i][:actual_count])
1095+
for i, num_choices in enumerate(gathered_num_choices):
1096+
# Extract non-padded values
1097+
# 1d on num choices
1098+
batch_logits_sums.append(gathered_logits_sums[i][:num_choices])
1099+
batch_max_equals.append(gathered_max_equals[i][:num_choices])
1100+
# 2d on num choices and max len
1101+
len_choice = gathered_len_choices[i]
1102+
batch_tokenized_continuations_processed.append(
1103+
gathered_continuations[i][:num_choices][:len_choice]
1104+
)
1105+
# 1d on max len context
1106+
len_context = gathered_len_context[i]
1107+
batch_tokenized_contexts_processed.append(gathered_contexts[i][:len_context])
10841108

10851109
# Process the gathered results
10861110
for i in range(len(batch_logits_sums)):

src/lighteval/models/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def _parse_args(args: str) -> dict:
136136
for key, value in matches:
137137
key = key.strip()
138138
if key == "generation_parameters":
139-
gen_params = re.sub(r"(\w+):", r'"\1":', value)
139+
# regex by lysandre
140+
gen_params = re.sub(r"(\w+)\s*:\s*([A-Za-z_][\w.-]*)\s*(?=[,}])", r'"\1":"\2"', value)
140141
generation_parameters_dict = json.loads(gen_params)
141142

142143
args = re.sub(r"generation_parameters=\{.*?\},?", "", args).strip(",")

src/lighteval/tasks/default_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7893,7 +7893,7 @@
78937893
version=0,
78947894
)
78957895
gpqa_lighteval = LightevalTaskConfig(
7896-
name="gpqa",
7896+
name="gpqa:mc",
78977897
suite=["lighteval"],
78987898
prompt_function=prompt.gpqa,
78997899
hf_repo="Idavidrein/gpqa",

tests/models/endpoints/test_endpoint_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class TestInferenceEndpointModelConfig:
5454
"generation_parameters": {
5555
"num_blocks": None,
5656
"block_size": None,
57+
"cache_implementation": None,
5758
"early_stopping": None,
5859
"frequency_penalty": None,
5960
"length_penalty": None,

tests/models/endpoints/test_tgi_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class TestTGIModelConfig:
4040
"generation_parameters": {
4141
"block_size": None,
4242
"num_blocks": None,
43+
"cache_implementation": None,
4344
"early_stopping": None,
4445
"frequency_penalty": None,
4546
"length_penalty": None,

0 commit comments

Comments
 (0)