Skip to content

Commit 9134ca8

Browse files
Fix 341 (#346)
* split greedy and sampling generative + remove small old helm mechanism * add do_sample to generative tas criteria * Quick fix vllm (#361) * fix max len management in vllm * fixed the maj@n qem being run on the same samples. needed to manage the sort and split * add temperature to vllm config --------- Co-authored-by: Nathan Habib <[email protected]> Co-authored-by: Nathan Habib <[email protected]>
1 parent 635e581 commit 9134ca8

File tree

12 files changed

+59
-127
lines changed

12 files changed

+59
-127
lines changed

src/lighteval/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def init_split_limits(self, num_dataset_splits):
264264
splits_indices = [tuple(e) for e in splits_indices]
265265
return num_dataset_splits, splits_indices
266266

267-
def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, list, int]:
267+
def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, list, int]:
268268
"""
269269
Collate function for generating batches.
270270
@@ -279,7 +279,7 @@ def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, list, in
279279
# The generative task has no limit except the model context
280280
if gen_length is None:
281281
gen_length = 0
282-
return request.use_logits, request.stop_sequence, -(len(toks) + gen_length)
282+
return request.do_sample, request.use_logits, request.stop_sequence, -(len(toks) + gen_length)
283283

284284

285285
class GenerativeTaskDatasetNanotron(GenerativeTaskDataset):

src/lighteval/metrics/__init__.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,21 @@ def apply_generative_metric( # noqa: C901
9090
formatted_docs: list[Doc],
9191
metrics: list[Metric],
9292
output_regex: str = None,
93-
max_num_samples: int = 1,
9493
):
9594
outputs = []
9695

9796
for sample_id, results, formatted_doc in zip(sample_ids, responses, formatted_docs):
9897
output = {}
9998

99+
# Extracting gold
100+
try:
101+
golds = formatted_doc.get_golds()
102+
except (KeyError, IndexError):
103+
golds = None
104+
105+
# Post processing prediction
100106
if len(results) > 1:
107+
# In case of sampling, it's a list of one list of n samples
101108
raise Exception("You returned more than one result for a sample with a generative metric.")
102109
results = results[0]
103110

@@ -112,38 +119,14 @@ def apply_generative_metric( # noqa: C901
112119
pred = pred_raw
113120
preds.append(pred)
114121

115-
# Extracting gold
116-
try:
117-
golds = formatted_doc.get_golds()
118-
except (KeyError, IndexError):
119-
golds = None
120-
121-
# Specific process for HELM like evals # hrm
122-
# if "label_to_choices" in formatted_doc:
123-
if formatted_doc.specific is not None and "label_to_choices" in formatted_doc.specific:
124-
# Helm predicts on labels keys (A/B/C/D), but computes metrics on choices
125-
preds = [formatted_doc.specific["label_to_choices"].get(p) for p in preds]
126-
golds = [formatted_doc.specific["label_to_choices"][g] for g in golds]
127-
128122
for metric in metrics:
129-
if metric.category == MetricCategory.GENERATIVE:
130-
output.update(
131-
metric.compute(
132-
golds=golds,
133-
predictions=as_list(preds[0]) if max_num_samples > 1 else preds,
134-
formatted_doc=formatted_doc,
135-
)
123+
output.update(
124+
metric.compute(
125+
golds=golds,
126+
predictions=preds,
127+
formatted_doc=formatted_doc,
136128
)
137-
if metric.category == MetricCategory.GENERATIVE_LOGPROB:
138-
output.update(
139-
metric.compute(
140-
golds=golds,
141-
predictions=as_list(preds[0]) if max_num_samples > 1 else preds,
142-
formatted_doc=formatted_doc,
143-
)
144-
)
145-
if metric.category == MetricCategory.GENERATIVE_SAMPLING:
146-
output.update(metric.compute(golds=golds, predictions=preds, formatted_doc=formatted_doc))
129+
)
147130
outputs.append(output)
148131

149132
return outputs

src/lighteval/models/base_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def greedy_until(
537537
max_new_tokens = batch[0].generation_size
538538
returns_logits = batch[0].use_logits
539539
num_samples = batch[0].num_samples
540+
do_sample = batch[0].do_sample
540541

541542
context = [c.context for c in batch]
542543

@@ -590,6 +591,7 @@ def greedy_until(
590591
stop_tokens=stop_tokens,
591592
returns_logits=returns_logits,
592593
num_samples=num_samples,
594+
do_sample=do_sample,
593595
)
594596
results.extend(cur_reponses)
595597

@@ -602,6 +604,7 @@ def _generate(
602604
stop_tokens: list[str],
603605
returns_logits: Optional[bool] = False,
604606
num_samples: Optional[int] = 1,
607+
do_sample: Optional[bool] = False,
605608
) -> list[GenerativeResponse]:
606609
"""Contains the actual logic of the generation.
607610
First computes the stop sequences, then generates the predictions, then converts the outputs to GenerativeResponse.
@@ -619,7 +622,7 @@ def _generate(
619622
return_dict_in_generate=True,
620623
output_scores=True,
621624
eos_token_id=self.tokenizer.eos_token_id,
622-
do_sample=num_samples > 1,
625+
do_sample=do_sample,
623626
num_return_sequences=num_samples,
624627
)
625628
if returns_logits:
@@ -660,10 +663,6 @@ def _generate(
660663

661664
decoded_generations.append(decoded_generation)
662665

663-
if num_samples == 1: # We only return one item
664-
result_generations = result_generations[0]
665-
decoded_generations = decoded_generations[0]
666-
667666
cur_response = GenerativeResponse(
668667
result=decoded_generations,
669668
logits=logits[ix][: len_logits[ix]] if returns_logits else None,

src/lighteval/models/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class VLLMModelConfig:
229229
pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together.
230230

231231
subfolder: Optional[str] = None
232+
temperature: float = 0.6 # will be used for multi sampling tasks, for tasks requiring no sampling, this will be ignored and set to 0.
232233

233234

234235
@dataclass

src/lighteval/models/vllm_model.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ def __init__(
7777
self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False
7878
self._tokenizer = self._create_auto_tokenizer(config, env_config)
7979

80-
if config.max_model_length is not None:
81-
self._max_length = int(config.max_model_length)
82-
else:
83-
self._max_length = self.tokenizer.model_max_length or self.tokenizer.max_position_embeddings
80+
self._max_length = int(config.max_model_length) if config.max_model_length is not None else None
8481

8582
# If model_parallel is not set we compare the number of processes with the number of GPUs
8683
self.model = self._create_auto_model(config, env_config)
@@ -152,6 +149,13 @@ def _create_auto_model(self, config: VLLMModelConfig, env_config: EnvConfig) ->
152149
return None
153150

154151
model = LLM(**self.model_args)
152+
153+
# If the max_length can't get extracted from the config, it will be inferred from the model
154+
# Inferring from the tokenizer will cause vllm to bug for models with mismatches between model
155+
# config and tk config, like mistralai/Mistral-7B-v0.1
156+
if self._max_length is None:
157+
self._max_length = model.llm_engine.model_config.max_seq_len_to_capture
158+
155159
return model
156160

157161
def _create_auto_tokenizer(self, config: VLLMModelConfig, env_config: EnvConfig):
@@ -164,36 +168,6 @@ def _create_auto_tokenizer(self, config: VLLMModelConfig, env_config: EnvConfig)
164168
tokenizer.pad_token = tokenizer.eos_token
165169
return tokenizer
166170

167-
def _init_max_length(self, max_length) -> int:
168-
"""Return the maximum sequence length of the model.
169-
NOTE: Different model configurations have different max sequence length
170-
attribute names.
171-
- n_positions: (CTRLConfig)
172-
- max_position_embeddings: (BartConfig, RoFormerConfig)
173-
- n_ctx: (GPT2Config)
174-
NOTE: For relative position encoded models you should specify the max
175-
sequence length of the model in the constructor via `max_length`.
176-
177-
Args:
178-
max_length (Optional[int]): The maximum length of the input sequence. If not provided, it will be determined
179-
based on the model's configuration or tokenizer's model_max_length attribute.
180-
181-
Returns:
182-
int: Max length to use depending on the available args and config
183-
"""
184-
if max_length is not None:
185-
return int(max_length)
186-
# Try to get the sequence length from the model config.
187-
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
188-
189-
for attr in seqlen_config_attrs:
190-
if hasattr(self._config, attr):
191-
return getattr(self._config, attr)
192-
193-
# Default max sequence length setting for when no `max_length` is provided
194-
# or no max length config setting is found in the model or tokenizer.
195-
return 2048
196-
197171
def greedy_until(
198172
self,
199173
requests: list[GreedyUntilRequest],
@@ -300,7 +274,7 @@ def _generate(
300274
"""Contains the actual logic of the generation."""
301275
if generate:
302276
sampling_params = SamplingParams(
303-
temperature=1.0 if num_samples > 1 else 0.0,
277+
temperature=float(self._config.temperature) if num_samples > 1 else 0.0,
304278
n=num_samples,
305279
max_tokens=max_new_tokens,
306280
stop=stop_tokens,

src/lighteval/pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,10 @@ def _compute_metrics(self, sample_id_to_responses):
302302
metric_category_metrics = [metric for metric in task.metrics if metric.category == metric_category]
303303

304304
outputs = metric_function(
305-
sample_ids=sample_ids, responses=responses, formatted_docs=docs, metrics=metric_category_metrics
305+
sample_ids=sample_ids,
306+
responses=responses,
307+
formatted_docs=docs,
308+
metrics=metric_category_metrics,
306309
)
307310

308311
for output, doc, response in zip(outputs, docs, responses):

src/lighteval/tasks/default_prompts.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def hellaswag_harness(line, task_name: str = None):
786786
)
787787

788788

789-
def hellaswag_helm(line, task_name: str = None):
789+
def hellaswag_generative(line, task_name: str = None):
790790
query = "The following are multiple choice questions (with answers) about common sense.\n\n"
791791
query += f"Question: {line['activity_label']}: {line['ctx_a']} {line['ctx_b'].capitalize()}\n"
792792
query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["endings"])])
@@ -800,9 +800,6 @@ def hellaswag_helm(line, task_name: str = None):
800800
gold_index=gold_ix, # -1 for test,
801801
instruction="The following are multiple choice questions (with answers) about common sense.\n\n",
802802
target_for_fewshot_sorting=line["endings"][gold_ix] if gold_ix > -1 else "",
803-
specific={
804-
"label_to_choices": {f" {key}": choice for key, choice in zip(LETTER_INDICES, line["endings"])},
805-
},
806803
)
807804

808805

src/lighteval/tasks/default_tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8708,10 +8708,10 @@
87088708
trust_dataset=True,
87098709
version=0,
87108710
)
8711-
hellaswag_helm = LightevalTaskConfig(
8711+
hellaswag_generative = LightevalTaskConfig(
87128712
name="hellaswag",
87138713
suite=["helm", "helm_general"],
8714-
prompt_function=prompt.hellaswag_helm,
8714+
prompt_function=prompt.hellaswag_generative,
87158715
hf_repo="hellaswag",
87168716
hf_subset="default",
87178717
hf_avail_splits=["train", "test", "validation"],

src/lighteval/tasks/lighteval_task.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,29 @@ def construct_requests(
395395
metric_categories=[MetricCategory.PERPLEXITY],
396396
)
397397
]
398+
if self.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]:
399+
# All the possible sampling tasks require the same generation process - we can do them in one step
400+
# so we select the maximum number of samples and the metrics will select only the
401+
# relevant number of tiems
402+
requests[RequestType.GREEDY_UNTIL] += [
403+
GreedyUntilRequest(
404+
task_name=current_task_name,
405+
sample_index=document_id_seed,
406+
request_index=0,
407+
context=context,
408+
stop_sequence=self.stop_sequence,
409+
generation_size=self.generation_size,
410+
generation_grammar=self.generation_grammar,
411+
num_samples=max(self.num_samples),
412+
do_sample=True,
413+
use_logits=False,
414+
metric_categories=[MetricCategory.GENERATIVE_SAMPLING],
415+
)
416+
]
398417
if (
399-
self.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]
400-
or self.has_metric_category[MetricCategory.GENERATIVE]
418+
self.has_metric_category[MetricCategory.GENERATIVE]
401419
or self.has_metric_category[MetricCategory.GENERATIVE_LOGPROB]
402420
):
403-
# All these tasks require the same generation process - we can do them in one step
404421
use_logits = self.has_metric_category[MetricCategory.GENERATIVE_LOGPROB]
405422
requests[RequestType.GREEDY_UNTIL] += [
406423
GreedyUntilRequest(
@@ -411,12 +428,11 @@ def construct_requests(
411428
stop_sequence=self.stop_sequence,
412429
generation_size=self.generation_size,
413430
generation_grammar=self.generation_grammar,
414-
num_samples=max(self.num_samples), # If we have several samplings to apply, we use the max
431+
num_samples=1,
415432
use_logits=use_logits,
416433
metric_categories=[
417434
c
418435
for c in [
419-
MetricCategory.GENERATIVE_SAMPLING,
420436
MetricCategory.GENERATIVE,
421437
MetricCategory.GENERATIVE_LOGPROB,
422438
]
@@ -443,7 +459,6 @@ def construct_requests(
443459
)
444460
for i, choice in enumerate(formatted_doc.choices)
445461
]
446-
447462
if self.has_metric_category[MetricCategory.MULTICHOICE_PMI]:
448463
assert (
449464
formatted_doc.unconditioned_query is not None

src/lighteval/tasks/requests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class GreedyUntilRequest(Request):
126126
request_type = RequestType.GREEDY_UNTIL
127127
tokenized_context: list[int] = None
128128
num_samples: int = None
129+
do_sample: bool = False
129130
use_logits: bool = False
130131

131132

0 commit comments

Comments
 (0)