-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
Currently, the only way to the seed is with mlx.data.core.set_state, but this only controls the seed for .shuffle(). When using .prefetch with num_threads > 1, the samples returned are not deterministic and therefore not reproducible.
Is there a way to set the seed when prefetching with more than one thread?
import mlx.data.core as dmx
from mlx.data.datasets import load_mnist
dmx.set_state(42)
train = load_mnist(root=None, train=True)
dset = (
train.shuffle()
.to_stream()
.key_transform("image", lambda x: x.astype("float32") / 255)
.batch(32)
.prefetch(prefetch_size=4, num_threads=4) # non-deterministic with > 1 thread
)
for i, data in enumerate(dset):
print(data["image"].sum())
if i == 2:
breakBrunoKM, runame and luchungiBrunoKM
Metadata
Metadata
Assignees
Labels
No labels