Skip to content

Commit 6f2502c

Browse files
radulescupetruPetru Radulesculhoestq
authored
Sample without replacement option when interleaving datasets (#7786)
* Sample without replacement option * Exit early for non arrow iterable. * Add new stopping strategy * Remove sample_with_replacement argument * fix CyclingMultiSourcesExamplesIterable.shard_data_sources * Add sampling without replacement logic for map style datasets. * Update process.mdx * Update stream.mdx --------- Co-authored-by: Petru Radulescu <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 095c7dc commit 6f2502c

File tree

5 files changed

+113
-39
lines changed

5 files changed

+113
-39
lines changed

docs/source/process.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ In this case, the new dataset is constructed by getting examples one by one from
657657
You can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is a subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples.
658658
You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached.
659659
Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples`.
660+
There is also `stopping_strategy=all_exhausted_without_replacement` to ensure that every sample is seen exactly once.
660661

661662
```py
662663
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})

docs/source/stream.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ Around 80% of the final dataset is made of the `es_dataset`, and 20% of the `fr_
197197
You can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is a subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples.
198198
You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached.
199199
Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples`.
200+
There is also `stopping_strategy=all_exhausted_without_replacement` to ensure that every sample is seen exactly once.
200201

201202
## Rename, remove, and cast
202203

src/datasets/arrow_dataset.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6566,7 +6566,9 @@ def _interleave_map_style_datasets(
65666566
seed: Optional[int] = None,
65676567
info: Optional[DatasetInfo] = None,
65686568
split: Optional[NamedSplit] = None,
6569-
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
6569+
stopping_strategy: Literal[
6570+
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
6571+
] = "first_exhausted",
65706572
**kwargs,
65716573
) -> "Dataset":
65726574
"""
@@ -6586,6 +6588,7 @@ def _interleave_map_style_datasets(
65866588
Two strategies are proposed right now.
65876589
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
65886590
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
6591+
When strategy is `all_exhausted_without_replacement` we make sure that each sample in each dataset is sampled only once.
65896592
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
65906593
- with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples.
65916594
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
@@ -6594,7 +6597,7 @@ def _interleave_map_style_datasets(
65946597
Output:
65956598
:class:`datasets.Dataset`
65966599
"""
6597-
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
6600+
if stopping_strategy not in ["first_exhausted", "all_exhausted", "all_exhausted_without_replacement"]:
65986601
raise ValueError(
65996602
f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}"
66006603
)
@@ -6637,7 +6640,9 @@ def _interleave_map_style_datasets(
66376640

66386641
# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
66396642
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
6640-
bool_strategy_func = np.all if oversampling else np.any
6643+
bool_strategy_func = (
6644+
np.all if (oversampling or stopping_strategy == "all_exhausted_without_replacement") else np.any
6645+
)
66416646

66426647
def iter_random_indices():
66436648
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
@@ -6655,13 +6660,17 @@ def iter_random_indices():
66556660
break
66566661

66576662
# let's add the example at the current index of the `source_idx`-th dataset
6658-
indices.append(current_index[source_idx] + offsets[source_idx])
6659-
current_index[source_idx] += 1
6663+
# For without replacement sampling we additionally need to make sure the current source is not exhausted to not oversample.
6664+
if stopping_strategy != "all_exhausted_without_replacement" or not is_exhausted[source_idx]:
6665+
indices.append(current_index[source_idx] + offsets[source_idx])
6666+
current_index[source_idx] += 1
66606667

66616668
# we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0
66626669
if current_index[source_idx] >= lengths[source_idx]:
66636670
is_exhausted[source_idx] = True
6664-
current_index[source_idx] = 0
6671+
# We don't want to reset the iterator when stopping strategy is without replacement.
6672+
if stopping_strategy != "all_exhausted_without_replacement":
6673+
current_index[source_idx] = 0
66656674

66666675
return concatenated_datasets.select(indices, **kwargs)
66676676

src/datasets/combine.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def interleave_datasets(
2121
seed: Optional[int] = None,
2222
info: Optional[DatasetInfo] = None,
2323
split: Optional[NamedSplit] = None,
24-
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
24+
stopping_strategy: Literal[
25+
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
26+
] = "first_exhausted",
2527
) -> DatasetType:
2628
"""
2729
Interleave several datasets (sources) into a single dataset.
@@ -55,9 +57,10 @@ def interleave_datasets(
5557
Name of the dataset split.
5658
<Added version="2.4.0"/>
5759
stopping_strategy (`str`, defaults to `first_exhausted`):
58-
Two strategies are proposed right now, `first_exhausted` and `all_exhausted`.
60+
Three strategies are proposed right now, `first_exhausted`, `all_exhausted` and `all_exhausted_without_replacement`.
5961
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
6062
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
63+
When strategy is `all_exhausted_without_replacement` we make sure that each sample in each dataset is sampled only once.
6164
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
6265
- with no probabilities, the resulting dataset will have `max_length_datasets*nb_dataset` samples.
6366
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
@@ -143,15 +146,20 @@ def interleave_datasets(
143146
raise ValueError(
144147
f"Unable to interleave a {dataset_type.__name__} (at position 0) with a {other_type.__name__} (at position {i}). Expected a list of Dataset objects or a list of IterableDataset objects."
145148
)
146-
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
149+
if stopping_strategy not in ["first_exhausted", "all_exhausted", "all_exhausted_without_replacement"]:
147150
raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.")
148151
if dataset_type is Dataset:
149152
return _interleave_map_style_datasets(
150153
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
151154
)
152155
else:
153156
return _interleave_iterable_datasets(
154-
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
157+
datasets,
158+
probabilities,
159+
seed,
160+
info=info,
161+
split=split,
162+
stopping_strategy=stopping_strategy,
155163
)
156164

157165

src/datasets/iterable_dataset.py

Lines changed: 84 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -673,15 +673,20 @@ class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable):
673673
def __init__(
674674
self,
675675
ex_iterables: list[_BaseExamplesIterable],
676-
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
676+
stopping_strategy: Literal[
677+
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
678+
] = "first_exhausted",
677679
):
678680
super().__init__()
679681
self.ex_iterables = ex_iterables
680682
self.stopping_strategy = stopping_strategy
681683

682684
# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
683685
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
684-
self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any
686+
# if sampling without replacement ("all_exhausted_without_replacement"), we stop once all samples of every dataset has been visited exactly once.
687+
self.bool_strategy_func = (
688+
np.all if (stopping_strategy in ("all_exhausted", "all_exhausted_without_replacement")) else np.any
689+
)
685690

686691
@property
687692
def is_typed(self):
@@ -734,6 +739,9 @@ def _iter_arrow(self):
734739
# if the stopping criteria is met, break the main for loop
735740
if self.bool_strategy_func(is_exhausted):
736741
break
742+
# Skip exhausted iterators if we sample without replacement
743+
if is_exhausted[i] and self.stopping_strategy in ["all_exhausted_without_replacement"]:
744+
continue
737745
# let's pick one example from the iterator at index i
738746
if nexts[i] is None:
739747
nexts[i] = next(iterators[i], False)
@@ -747,12 +755,13 @@ def _iter_arrow(self):
747755
is_exhausted[i] = True
748756
if self._state_dict:
749757
self._state_dict["is_exhausted"][i] = True
750-
# we reset it in case the stopping crtieria isn't met yet
751-
nexts[i] = None
752-
if self._state_dict:
753-
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
754-
self._state_dict["previous_states"][i] = None
755-
iterators[i] = self.ex_iterables[i].iter_arrow()
758+
# we reset it in case the stopping crtieria isn't met yet and we sample with replacement
759+
if self.stopping_strategy not in ["all_exhausted_without_replacement"]:
760+
nexts[i] = None
761+
if self._state_dict:
762+
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
763+
self._state_dict["previous_states"][i] = None
764+
iterators[i] = self.ex_iterables[i]._iter_arrow()
756765

757766
if result is not False:
758767
yield result
@@ -777,6 +786,8 @@ def __iter__(self):
777786
if self.bool_strategy_func(is_exhausted):
778787
break
779788
# let's pick one example from the iterator at index i
789+
if is_exhausted[i] and self.stopping_strategy in ["all_exhausted_without_replacement"]:
790+
continue
780791
if nexts[i] is None:
781792
nexts[i] = next(iterators[i], False)
782793
result = nexts[i]
@@ -790,12 +801,12 @@ def __iter__(self):
790801
if self._state_dict:
791802
self._state_dict["is_exhausted"][i] = True
792803
# we reset it in case the stopping crtieria isn't met yet
793-
nexts[i] = None
794-
if self._state_dict:
795-
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
796-
self._state_dict["previous_states"][i] = None
797-
iterators[i] = iter(self.ex_iterables[i])
798-
804+
if self.stopping_strategy not in ["all_exhausted_without_replacement"]:
805+
nexts[i] = None
806+
if self._state_dict:
807+
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
808+
self._state_dict["previous_states"][i] = None
809+
iterators[i] = iter(self.ex_iterables[i])
799810
if result is not False:
800811
yield result
801812

@@ -806,16 +817,33 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "CyclingMultiS
806817

807818
@property
808819
def num_shards(self) -> int:
809-
return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
820+
return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) if self.ex_iterables else 0
810821

811822
def shard_data_sources(
812823
self, num_shards: int, index: int, contiguous=True
813824
) -> "CyclingMultiSourcesExamplesIterable":
814825
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
815-
return CyclingMultiSourcesExamplesIterable(
816-
[iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables],
817-
stopping_strategy=self.stopping_strategy,
818-
)
826+
if num_shards < self.num_shards:
827+
return CyclingMultiSourcesExamplesIterable(
828+
[
829+
iterable.shard_data_sources(num_shards, index, contiguous=contiguous)
830+
for iterable in self.ex_iterables
831+
],
832+
stopping_strategy=self.stopping_strategy,
833+
)
834+
elif index < self.num_shards:
835+
return CyclingMultiSourcesExamplesIterable(
836+
[
837+
iterable.shard_data_sources(self.num_shards, index, contiguous=contiguous)
838+
for iterable in self.ex_iterables
839+
],
840+
stopping_strategy=self.stopping_strategy,
841+
)
842+
else:
843+
return CyclingMultiSourcesExamplesIterable(
844+
[],
845+
stopping_strategy=self.stopping_strategy,
846+
)
819847

820848

821849
class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable):
@@ -987,12 +1015,13 @@ def __init__(
9871015
ex_iterables: list[_BaseExamplesIterable],
9881016
generator: np.random.Generator,
9891017
probabilities: Optional[list[float]] = None,
990-
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
1018+
stopping_strategy: Literal[
1019+
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
1020+
] = "first_exhausted",
9911021
):
9921022
super().__init__(ex_iterables, stopping_strategy)
9931023
self.generator = deepcopy(generator)
9941024
self.probabilities = probabilities
995-
# TODO(QL): implement iter_arrow
9961025

9971026
@property
9981027
def is_typed(self):
@@ -1056,12 +1085,33 @@ def shard_data_sources(
10561085
self, num_shards: int, index: int, contiguous=True
10571086
) -> "RandomlyCyclingMultiSourcesExamplesIterable":
10581087
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
1059-
return RandomlyCyclingMultiSourcesExamplesIterable(
1060-
[iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables],
1061-
self.generator,
1062-
self.probabilities,
1063-
self.stopping_strategy,
1064-
)
1088+
if num_shards < self.num_shards:
1089+
return RandomlyCyclingMultiSourcesExamplesIterable(
1090+
[
1091+
iterable.shard_data_sources(num_shards, index, contiguous=contiguous)
1092+
for iterable in self.ex_iterables
1093+
],
1094+
self.generator,
1095+
self.probabilities,
1096+
self.stopping_strategy,
1097+
)
1098+
elif index < self.num_shards:
1099+
return RandomlyCyclingMultiSourcesExamplesIterable(
1100+
[
1101+
iterable.shard_data_sources(self.num_shards, index, contiguous=contiguous)
1102+
for iterable in self.ex_iterables
1103+
],
1104+
self.generator,
1105+
self.probabilities,
1106+
self.stopping_strategy,
1107+
)
1108+
else:
1109+
return RandomlyCyclingMultiSourcesExamplesIterable(
1110+
[],
1111+
self.generator,
1112+
self.probabilities,
1113+
self.stopping_strategy,
1114+
)
10651115

10661116

10671117
def _table_output_to_arrow(output) -> pa.Table:
@@ -4489,7 +4539,9 @@ def _interleave_iterable_datasets(
44894539
seed: Optional[int] = None,
44904540
info: Optional[DatasetInfo] = None,
44914541
split: Optional[NamedSplit] = None,
4492-
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
4542+
stopping_strategy: Literal[
4543+
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
4544+
] = "first_exhausted",
44934545
) -> IterableDataset:
44944546
"""
44954547
Interleave several iterable datasets (sources) into a single iterable dataset.
@@ -4535,7 +4587,10 @@ def _interleave_iterable_datasets(
45354587
else:
45364588
generator = np.random.default_rng(seed)
45374589
ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable(
4538-
ex_iterables, generator=generator, probabilities=probabilities, stopping_strategy=stopping_strategy
4590+
ex_iterables,
4591+
generator=generator,
4592+
probabilities=probabilities,
4593+
stopping_strategy=stopping_strategy,
45394594
)
45404595
# Set new info - we update the features
45414596
# setting the features also ensures to fill missing columns with None

0 commit comments

Comments
 (0)