Skip to content

Commit d146f9b

Browse files
qubvelNathanHB
andauthored
Update split iteration for DynamicBatchingDataset (huggingface#684)
This PR aims to make iterating over splits a bit more intuitive, at least in my opinion. Open to feedback though! If the current behavior was intentional, feel free to close. --------- Co-authored-by: Nathan Habib <[email protected]>
1 parent ea72931 commit d146f9b

File tree

8 files changed

+78
-99
lines changed

8 files changed

+78
-99
lines changed

src/lighteval/data.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222

2323
import logging
2424
import math
25-
from typing import Iterator, Tuple
25+
from typing import Iterator
2626

2727
import torch
2828
from packaging import version
29-
from torch.utils.data import Dataset
29+
from torch.utils.data import Dataset, Subset
3030

3131

3232
if version.parse(torch.__version__) >= version.parse("2.5.0"):
@@ -82,8 +82,6 @@ def __init__(
8282

8383
self.num_dataset_splits, self.splits = self.init_split_limits(num_dataset_splits)
8484

85-
self.split_start, self.split_end = self.splits[0]
86-
8785
def init_split_limits(self, num_dataset_splits):
8886
if num_dataset_splits >= self.total_size:
8987
logger.warning(
@@ -121,48 +119,31 @@ def get_original_order(self, new_arr: list) -> list:
121119

122120
return original_order
123121

124-
def get_split_start_end(self, split_id: int) -> Tuple[int, int]:
125-
"""
126-
Get the start and end indices of a dataset split.
127-
128-
Args:
129-
split_id (int): The ID of the split.
130-
131-
Returns:
132-
tuple: A tuple containing the start and end indices of the split.
133-
"""
134-
self.split_start, self.split_end = self.splits[split_id]
135-
return self.split_start, self.split_end
136-
137-
def splits_start_end_iterator(self) -> Iterator[Tuple[int, int]]:
122+
def splits_iterator(self) -> Iterator[Subset]:
138123
"""
139-
Iterator that yields the start and end indices of each dataset split.
140-
Also updates the starting batch size for each split (trying to double
141-
the batch every time we move to a new split).
124+
Iterator that yields the dataset splits based on the split limits.
142125
143126
Yields:
144-
tuple: A tuple containing the start and end indices of a split.
127+
Subset: A subset of the dataset.
145128
"""
146129
split_range = self.num_dataset_splits
147130
if self.total_size == 0:
148131
split_range = 0
149-
for split_id in range(split_range):
150-
yield self.get_split_start_end(split_id)
132+
for i in range(split_range):
133+
split_start, split_end = self.splits[i]
134+
yield Subset(self, range(split_start, split_end))
151135

152136
def __getitem__(self, index) -> Request:
153137
"""
154-
Get an item from the dataset depending on the split we are currently in.
155-
For instance, if we are in split 0, we will get the item at index 0, if
156-
we are in split 1, we will get the item at index self.split_size, etc.
157-
Used for dynamic batching.
138+
Get an item from the dataset.
158139
159140
Args:
160141
index (int): The index of the item.
161142
162143
Returns:
163144
Any: The item at the specified index.
164145
"""
165-
return self.sorted_data[index + self.split_start]
146+
return self.sorted_data[index]
166147

167148
def __len__(self) -> int:
168149
"""
@@ -173,7 +154,7 @@ def __len__(self) -> int:
173154
Returns:
174155
int: The length of the dataset.
175156
"""
176-
return self.split_end - self.split_start
157+
return len(self.sorted_data)
177158

178159
def __iter__(self) -> Iterator[Request]:
179160
"""
@@ -186,7 +167,7 @@ def __iter__(self) -> Iterator[Request]:
186167
Yields:
187168
Any: The items of the dataset.
188169
"""
189-
for i in range(self.split_start, self.split_end):
170+
for i in range(len(self)):
190171
yield self.sorted_data[i]
191172

192173
def _sorting_criteria(self, request) -> int:

src/lighteval/models/endpoints/endpoint_model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,14 @@ def greedy_until(
463463
batch_size = override_bs if override_bs is not None else BATCH_SIZE
464464
results: List[str] = []
465465

466-
for _, _ in tqdm(
467-
dataset.splits_start_end_iterator(),
466+
for split in tqdm(
467+
dataset.splits_iterator(),
468468
total=dataset.num_dataset_splits,
469469
desc="Splits",
470470
position=0,
471471
disable=self.disable_tqdm,
472472
):
473-
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
473+
dataloader = DataLoader(split, batch_size=batch_size, collate_fn=lambda batch: batch)
474474

475475
for batch in tqdm(
476476
dataloader, desc="Greedy generation", position=1, leave=False, disable=self.disable_tqdm
@@ -512,14 +512,14 @@ def loglikelihood(
512512
batch_size = override_bs if override_bs is not None else BATCH_SIZE
513513
results: List[str] = []
514514

515-
for _, _ in tqdm(
516-
dataset.splits_start_end_iterator(),
515+
for split in tqdm(
516+
dataset.splits_iterator(),
517517
total=dataset.num_dataset_splits,
518518
desc="Splits",
519519
position=0,
520520
disable=self.disable_tqdm,
521521
):
522-
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
522+
dataloader = DataLoader(split, batch_size=batch_size, collate_fn=lambda batch: batch)
523523

524524
for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm):
525525
if self.use_async:
@@ -563,14 +563,14 @@ def loglikelihood_rolling(
563563
batch_size = override_bs if override_bs is not None else BATCH_SIZE
564564
results: List[str] = []
565565

566-
for _, _ in tqdm(
567-
dataset.splits_start_end_iterator(),
566+
for split in tqdm(
567+
dataset.splits_iterator(),
568568
total=dataset.num_dataset_splits,
569569
desc="Splits",
570570
position=0,
571571
disable=self.disable_tqdm,
572572
):
573-
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
573+
dataloader = DataLoader(split, batch_size=batch_size, collate_fn=lambda batch: batch)
574574

575575
for batch in tqdm(
576576
dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm

src/lighteval/models/endpoints/inference_providers_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,15 @@ def greedy_until(
210210
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
211211
results = []
212212

213-
for _ in tqdm(
214-
dataset.splits_start_end_iterator(),
213+
for split in tqdm(
214+
dataset.splits_iterator(),
215215
total=dataset.num_dataset_splits,
216216
desc="Splits",
217217
position=0,
218218
disable=False, # self.disable_tqdm,
219219
):
220-
contexts = [c.context for c in dataset]
221-
num_samples = dataset[0].num_samples
220+
contexts = [sample.context for sample in split]
221+
num_samples = split[0].num_samples
222222

223223
responses = asyncio.run(self.__call_api_parallel(contexts, num_samples))
224224

src/lighteval/models/endpoints/openai_model.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -184,17 +184,17 @@ def greedy_until(
184184
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
185185
results = []
186186

187-
for _ in tqdm(
188-
dataset.splits_start_end_iterator(),
187+
for split in tqdm(
188+
dataset.splits_iterator(),
189189
total=dataset.num_dataset_splits,
190190
desc="Splits",
191191
position=0,
192192
disable=False, # self.disable_tqdm,
193193
):
194-
max_new_tokens = dataset[0].generation_size # could be none
195-
return_logits = dataset[0].use_logits
196-
num_samples = dataset[0].num_samples
197-
contexts = [c.context for c in dataset]
194+
max_new_tokens = split[0].generation_size # could be none
195+
return_logits = split[0].use_logits
196+
num_samples = split[0].num_samples
197+
contexts = [sample.context for sample in split]
198198

199199
responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples)
200200

@@ -251,24 +251,22 @@ def _loglikelihood_tokens(
251251
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=1)
252252
results = []
253253

254-
for _ in tqdm(dataset.splits_start_end_iterator()):
255-
inputs = [dataset[i].context for i in range(len(dataset))]
256-
logit_biass = []
257-
max_new_tokens = [len(dataset[i].tokenized_continuation) for i in range(len(dataset))]
254+
for split in tqdm(dataset.splits_iterator()):
255+
inputs = [sample.context for sample in split]
256+
max_new_tokens = [len(sample.tokenized_continuation) for sample in split]
258257

259258
assert all(
260259
new_tokens == 1 for new_tokens in max_new_tokens
261260
), "Only single token continuations are supported when using openai API."
262261

263-
for i in range(len(dataset)):
264-
logit_bias = {tok: 100 for tok in dataset[i].tokenized_continuation}
265-
logit_biass.append(logit_bias)
262+
logit_biases = [{tok: 100 for tok in sample.tokenized_continuation} for sample in split]
266263

267264
outputs = self.__call_api_parallel(
268-
inputs, return_logits=True, max_new_tokens=max_new_tokens, num_samples=1, logit_bias=logit_biass
265+
inputs, return_logits=True, max_new_tokens=max_new_tokens, num_samples=1, logit_bias=logit_biases
269266
)
270267

271-
for output, input in zip(outputs, dataset):
268+
for i, output in enumerate(outputs):
269+
input = split[i]
272270
continuation_logprobs = [content.logprob for content in output.choices[0].logprobs.content]
273271
answer = LoglikelihoodResponse(
274272
input_tokens=input.tokenized_context + input.tokenized_continuation,

src/lighteval/models/litellm_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,17 +227,17 @@ def greedy_until(
227227
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
228228
results = []
229229

230-
for _ in tqdm(
231-
dataset.splits_start_end_iterator(),
230+
for split in tqdm(
231+
dataset.splits_iterator(),
232232
total=dataset.num_dataset_splits,
233233
desc="Splits",
234234
position=0,
235-
disable=False, # self.disable_tqdm,
235+
disable=self.disable_tqdm,
236236
):
237-
contexts = [c.context for c in dataset]
238-
max_new_tokens = dataset[0].generation_size # could be none
239-
return_logits = dataset[0].use_logits
240-
num_samples = dataset[0].num_samples
237+
contexts = [sample.context for sample in split]
238+
max_new_tokens = split[0].generation_size # could be none
239+
return_logits = split[0].use_logits
240+
num_samples = split[0].num_samples
241241
stop_sequence = requests[0].stop_sequence
242242

243243
responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples, stop_sequence)

src/lighteval/models/sglang/sglang_model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def greedy_until(
177177
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
178178
results = []
179179

180-
for _ in tqdm(
181-
dataset.splits_start_end_iterator(),
180+
for split in tqdm(
181+
dataset.splits_iterator(),
182182
total=dataset.num_dataset_splits,
183183
desc="Splits",
184184
position=0,
@@ -187,12 +187,12 @@ def greedy_until(
187187
if self.use_chat_template:
188188
stop_tokens = []
189189
else:
190-
stop_tokens = dataset[0].stop_sequence
190+
stop_tokens = split[0].stop_sequence
191191

192-
max_new_tokens = dataset[0].generation_size # could be none
193-
num_samples = dataset[0].num_samples
192+
max_new_tokens = split[0].generation_size # could be none
193+
num_samples = split[0].num_samples
194194

195-
context = [c.context for c in dataset]
195+
context = [sample.context for sample in split]
196196
tokenized = self.tokenizer(context, add_special_tokens=self.add_special_tokens)
197197

198198
# The main question for this step is the following:
@@ -298,14 +298,15 @@ def _loglikelihood_tokens(
298298
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=1)
299299
res = []
300300

301-
for _ in tqdm(dataset.splits_start_end_iterator(), disable=False):
301+
for split in tqdm(dataset.splits_iterator(), disable=False):
302302
# the last token is an eos token, so we don't need to add it
303-
inputs = [dataset[i].tokenized_context + dataset[i].tokenized_continuation for i in range(len(dataset))]
303+
inputs = [sample.tokenized_context + sample.tokenized_continuation for sample in split]
304304
# Left truncate the inputs to the maximum length
305305
inputs = [input[-self.max_length :] for input in inputs]
306306
outputs = self._generate(inputs, generate=False)
307307

308-
for output, input in zip(outputs, dataset):
308+
for i, output in enumerate(outputs):
309+
input = split[i]
309310
continuation_logprobs = []
310311
meta_info = output["meta_info"]
311312
input_token_logprobs = meta_info["input_token_logprobs"][::-1]

src/lighteval/models/transformers/transformers_model.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -530,21 +530,19 @@ def greedy_until(
530530
starting_batch_size = STARTING_BATCH_SIZE
531531
results = []
532532

533-
for split_start, split_end in tqdm(
534-
dataset.splits_start_end_iterator(),
533+
for split in tqdm(
534+
dataset.splits_iterator(),
535535
total=dataset.num_dataset_splits,
536536
desc="Splits",
537537
position=0,
538538
disable=self.disable_tqdm,
539539
):
540-
if dataset[0].generation_size is None:
540+
if split[0].generation_size is None:
541541
# No constraints on the generation size: max length allowed is the max model context
542542
max_context_continuation_size_allowed = self.max_length
543543
else:
544544
# Longest context in the current split is the first item (since we sort reversed)
545-
longest_context_continuation_size_in_split = (
546-
len(dataset[0].tokenized_context) + dataset[0].generation_size
547-
)
545+
longest_context_continuation_size_in_split = len(split[0].tokenized_context) + split[0].generation_size
548546
max_context_continuation_size_allowed = min(
549547
longest_context_continuation_size_in_split, self.max_length
550548
)
@@ -556,7 +554,7 @@ def greedy_until(
556554
# For next iteration, since the batch will be smaller, we'll test a bigger batch size
557555
starting_batch_size = batch_size * 2
558556

559-
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
557+
dataloader = DataLoader(split, batch_size=batch_size, collate_fn=lambda batch: batch)
560558
if self.accelerator:
561559
dataloader = self.accelerator.prepare(dataloader)
562560

@@ -765,9 +763,9 @@ def _loglikelihood_tokens(
765763
starting_batch_size = STARTING_BATCH_SIZE
766764
res = []
767765

768-
for split_start, split_end in tqdm(dataset.splits_start_end_iterator()):
769-
context_enc = dataset[0].tokenized_context
770-
continuation_enc = dataset[0].tokenized_continuation
766+
for split in tqdm(dataset.splits_iterator()):
767+
context_enc = split[0].tokenized_context
768+
continuation_enc = split[0].tokenized_continuation
771769
if rolling: # we take all the sequence in rolling mode
772770
max_context_continuation_size_allowed = len(context_enc + continuation_enc)
773771
else: # in normal mode, we left cut the context if needed
@@ -782,7 +780,7 @@ def _loglikelihood_tokens(
782780
)
783781
starting_batch_size = batch_size * 2
784782

785-
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
783+
dataloader = DataLoader(split, batch_size=batch_size, collate_fn=lambda batch: batch)
786784
if self.accelerator:
787785
dataloader = self.accelerator.prepare(dataloader)
788786

@@ -1009,13 +1007,13 @@ def _loglikelihood_single_token(
10091007
starting_batch_size = STARTING_BATCH_SIZE
10101008
res = []
10111009

1012-
for split_start, split_end in tqdm(dataset.splits_start_end_iterator()):
1013-
context_enc = dataset[0].tokenized_context
1010+
for split in tqdm(dataset.splits_iterator()):
1011+
context_enc = split[0].tokenized_context
10141012
max_context = len(context_enc[-self.max_length :])
10151013
batch_size = self._get_batch_size(override_bs=self.config.batch_size, max_input_length=max_context)
10161014
starting_batch_size = batch_size * 2
10171015

1018-
dataloader = DataLoader(dataset, batch_size=starting_batch_size, collate_fn=lambda batch: batch)
1016+
dataloader = DataLoader(split, batch_size=starting_batch_size, collate_fn=lambda batch: batch)
10191017
if self.accelerator is not None:
10201018
dataloader = self.accelerator.prepare(dataloader)
10211019

0 commit comments

Comments
 (0)