55from ..imports import *
66from . import BaseDataLoader
77from ..datasets import Dataset , ArrayDataset , JAXDataset
8- from ..utils import check_tf_installed , get_config
8+ from ..utils import check_tf_installed , get_config , Generator
9+ from ..types import GeneratorType
910from ..tests import *
1011from jax .tree_util import tree_map
12+ import warnings
1113
1214# %% auto 0
13- __all__ = ['to_tf_dataset' , 'DataLoaderTensorflow' ]
15+ __all__ = ['to_tf_dataset' , 'get_seed' , ' DataLoaderTensorflow' ]
1416
1517# %% ../../nbs/loader.tf.ipynb 4
1618@dispatch
@@ -26,6 +28,18 @@ def to_tf_dataset(dataset: HFDataset) -> tf.data.Dataset:
2628 return dataset .to_tf_dataset ()
2729
2830# %% ../../nbs/loader.tf.ipynb 5
31+ def get_seed (generator : Optional [Generator | jax .Array | torch .Generator ] = None ) -> int :
32+ if generator is None :
33+ generator = Generator ()
34+
35+ if not isinstance (generator , Generator ):
36+ generator = Generator (generator = generator )
37+
38+ seed = generator .seed ()
39+ if seed is None :
40+ warnings .warn ("No random seed provided. Using default seed which may not guarantee reproducible results." )
41+ return seed
42+
2943class DataLoaderTensorflow (BaseDataLoader ):
3044 """Tensorflow Dataloader"""
3145
@@ -36,13 +50,17 @@ def __init__(
3650 batch_size : int = 1 , # Batch size
3751 shuffle : bool = False , # If true, dataloader shuffles before sampling each batch
3852 drop_last : bool = False , # Drop last batch or not
53+ generator : Optional [GeneratorType ] = None , # Random seed generator
3954 ** kwargs
4055 ):
4156 super ().__init__ (dataset , batch_size , shuffle , drop_last )
4257 check_tf_installed ()
58+ # get random seed from generator
59+ seed = get_seed (generator )
60+
4361 # Convert to tf dataset
4462 ds = to_tf_dataset (dataset )
45- ds = ds .shuffle (buffer_size = len (dataset ), seed = get_config (). global_seed ) if shuffle else ds
63+ ds = ds .shuffle (buffer_size = len (dataset ), seed = seed ) if shuffle else ds
4664 ds = ds .batch (batch_size , drop_remainder = drop_last )
4765 ds = ds .prefetch (tf .data .AUTOTUNE )
4866 self .dataloader = ds
0 commit comments