Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
alternative to multiprocessing prefetch in free-threading Python.
* Adds experimental support for static `{Map|Iter}Dataset` element
specification inference.
* Adds support for changing `IterDataset.mix` components and weights after a checkpoint.

* Breaking changes:

Expand Down
14 changes: 10 additions & 4 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

import abc
import builtins
from collections.abc import Awaitable, Callable, Iterable, Iterator, Sequence
from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence
import functools
import json
from typing import Any, Generic, TypeVar, Union, cast, overload
Expand Down Expand Up @@ -918,15 +918,20 @@ class IterDatasetMeta(abc.ABCMeta):

def mix(
cls,
datasets: Sequence[IterDataset[T]],
weights: Sequence[float] | None = None,
datasets: Sequence[IterDataset[T]] | Mapping[str, IterDataset[T]],
weights: Sequence[float] | Mapping[str, float] | None = None,
) -> IterDataset[T]:
"""Returns a dataset that mixes input datasets with the given weights.

NOTE: Stops producing elements once *any* input dataset is exhausted. If
you need an infinite mixed dateset consider repeating the input datasets
before mixing.

If `datasets` is a mapping, it is possible to recover from a checkpoint with
different components and/or weights. Component states will be recoverd from
the checkpoint by key. If a component state is not found, the component will
start from the beginning.

Example usage::

ds1 = MapDataset.range(5).to_iter_dataset()
Expand All @@ -937,7 +942,8 @@ def mix(
Args:
datasets: The datasets to mix.
weights: The weights to use for mixing. Defaults to uniform weights if not
specified.
specified. If `datasets` is a mapping, `weights` must be a mapping with
the same keys.

Returns:
A dataset that represents a mixture of the input datasets according to the
Expand Down
54 changes: 38 additions & 16 deletions grain/_src/python/dataset/transformations/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import bisect
from collections.abc import Sequence
import sys
from typing import Any, TypeVar
from typing import Any, Mapping, TypeVar

from grain._src.core import exceptions
from grain._src.core import tree_lib
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats


Element = Any
T = TypeVar("T") # pylint: disable=invalid-name

Expand Down Expand Up @@ -148,10 +150,15 @@ class _MixedDatasetIterator(dataset.DatasetIterator[T]):

def __init__(
self,
parents: Sequence[dataset.DatasetIterator[T]],
proportions: Sequence[int] | None,
parents: (
Sequence[dataset.DatasetIterator[T]]
| Mapping[str, dataset.DatasetIterator[T]]
),
proportions: Sequence[int],
):
super().__init__(parents)
flat_parents = tree_lib.flatten(parents)
super().__init__(flat_parents)
self._parents_structure = parents
self._proportions = tuple(proportions)
self._index = 0
self._stop = False
Expand Down Expand Up @@ -181,14 +188,23 @@ def __next__(self):

def get_state(self):
return {
"parents": [parent.get_state() for parent in self._parents],
"parents": tree_lib.map_structure(
lambda p: p.get_state(), self._parents_structure
),
"index": self._index,
"stop": self._stop,
}

def set_state(self, state):
for parent, parent_state in zip(self._parents, state["parents"]):
parent.set_state(parent_state)
parents_state = state["parents"]
if isinstance(self._parents_structure, Sequence):
for parent, parent_state in zip(self._parents, parents_state):
parent.set_state(parent_state)
else:
for key, parent in self._parents_structure.items():
if (parent_state := parents_state.get(key)) is not None:
parent.set_state(parent_state)

self._index = state["index"]
self._stop = state["stop"]

Expand All @@ -204,24 +220,30 @@ class MixedIterDataset(dataset.IterDataset[T]):

def __init__(
self,
parents: Sequence[dataset.IterDataset],
proportions: Sequence[float] | None = None,
parents: (
Sequence[dataset.IterDataset] | Mapping[str, dataset.IterDataset]
),
proportions: Sequence[float] | Mapping[str, float] | None = None,
):
super().__init__(parents)
flat_parents = tree_lib.flatten(parents)
super().__init__(flat_parents)
self._parents_structure = parents
# Normalize proportions
if proportions is None:
proportions = [1] * len(parents)
elif 0 in proportions:
proportions = [1] * len(flat_parents)
else:
proportions = tree_lib.flatten(proportions)

if 0 in proportions:
raise ValueError("Must specify all non-zero proportions for mixing.")
else:
proportions = _float_to_int_proportions(proportions)
assert len(parents) == len(proportions)
assert len(flat_parents) == len(proportions)
self._proportions = proportions

def __iter__(self) -> _MixedDatasetIterator[T]:
parent_iters = [parent.__iter__() for parent in self.parents]
def __iter__(self) -> dataset.DatasetIterator[T]:
return _MixedDatasetIterator(
parent_iters,
tree_lib.map_structure(lambda p: p.__iter__(), self._parents_structure),
proportions=self._proportions,
)

Expand Down
95 changes: 95 additions & 0 deletions grain/_src/python/dataset/transformations/mix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,101 @@ def test_checkpointing(self):
next(ds_iter), values_without_interruption[i]
)

def test_checkpoint_recovery_with_changed_weights(self):
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds3 = dataset.MapDataset.range(20, 30).to_iter_dataset()
ds = dataset.IterDataset.mix(
{"ds1": ds1, "ds2": ds2, "ds3": ds3},
{"ds1": 0.2, "ds2": 0.3, "ds3": 0.5},
)
ds_iter = ds.__iter__()
values_before_checkpoint = [next(ds_iter) for _ in range(10)]
self.assertEqual(
values_before_checkpoint, [0, 10, 20, 11, 21, 1, 22, 12, 23, 24]
)
checkpoint = ds_iter.get_state()
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds3 = dataset.MapDataset.range(20, 30).to_iter_dataset()
ds = dataset.IterDataset.mix(
{"ds1": ds1, "ds2": ds2, "ds3": ds3},
{"ds1": 0.2, "ds2": 0.3, "ds3": 0.5},
)
ds_iter = ds.__iter__()
ds_iter.set_state(checkpoint)
self.assertEqual(
list(ds_iter), [2, 13, 25, 14, 26, 3, 27, 15, 28, 29, 4, 16]
)

def test_checkpoint_recovery_with_fewer_datasets(self):
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds3 = dataset.MapDataset.range(20, 30).to_iter_dataset()
ds = dataset.IterDataset.mix(
{"ds1": ds1, "ds2": ds2, "ds3": ds3},
{"ds1": 0.2, "ds2": 0.3, "ds3": 0.5},
)
ds_iter = ds.__iter__()
values_before_checkpoint = [next(ds_iter) for _ in range(10)]
self.assertEqual(
values_before_checkpoint, [0, 10, 20, 11, 21, 1, 22, 12, 23, 24]
)
checkpoint = ds_iter.get_state()
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds = dataset.IterDataset.mix(
{"ds1": ds1, "ds2": ds2},
{"ds1": 0.2, "ds2": 0.3},
)
ds_iter = ds.__iter__()
ds_iter.set_state(checkpoint)
self.assertEqual(
list(ds_iter), [2, 13, 3, 14, 15, 4, 16, 5, 17, 18, 6, 19, 7]
)

def test_checkpoint_recovery_with_more_datasets(self):
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds = dataset.IterDataset.mix({"ds1": ds1, "ds2": ds2})
ds_iter = ds.__iter__()
values_before_checkpoint = [next(ds_iter) for _ in range(10)]
self.assertEqual(
values_before_checkpoint, [0, 10, 1, 11, 2, 12, 3, 13, 4, 14]
)
checkpoint = ds_iter.get_state()
ds3 = dataset.MapDataset.range(20, 30).to_iter_dataset()
ds = dataset.IterDataset.mix({"ds1": ds1, "ds2": ds2, "ds3": ds3})
ds_iter = ds.__iter__()
ds_iter.set_state(checkpoint)
self.assertEqual(
list(ds_iter), [15, 20, 5, 16, 21, 6, 17, 22, 7, 18, 23, 8, 19, 24, 9]
)

def test_checkpoint_recovery_with_different_mixture(self):
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds3 = dataset.MapDataset.range(20, 30).to_iter_dataset()
ds = dataset.IterDataset.mix(
{"ds1": ds1, "ds2": ds2, "ds3": ds3},
{"ds1": 1, "ds2": 1, "ds3": 1},
)
ds_iter = ds.__iter__()
values_before_checkpoint = [next(ds_iter) for _ in range(10)]
self.assertEqual(
values_before_checkpoint, [0, 10, 20, 1, 11, 21, 2, 12, 22, 3]
)
checkpoint = ds_iter.get_state()
ds4 = dataset.MapDataset.range(30, 40).to_iter_dataset()
ds = dataset.IterDataset.mix(
{"ds2": ds2, "ds4": ds4}, {"ds2": 0.6, "ds4": 0.4}
)
ds_iter = ds.__iter__()
ds_iter.set_state(checkpoint)
self.assertEqual(
list(ds_iter), [13, 14, 30, 15, 31, 16, 17, 32, 18, 33, 19]
)


class ConcatenateLazyMapTest(absltest.TestCase):

Expand Down
Loading