Skip to content

Commit ee1214f

Browse files
committed
Move GeneratorType to types.py
1 parent f519bf2 commit ee1214f

File tree

15 files changed

+44
-18
lines changed

15 files changed

+44
-18
lines changed

jax_dataloader/_modidx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
'jax_dataloader/tests.py'),
119119
'jax_dataloader.tests.test_shuffle_reproducible': ( 'tests.html#test_shuffle_reproducible',
120120
'jax_dataloader/tests.py')},
121+
'jax_dataloader.types': {},
121122
'jax_dataloader.utils': { 'jax_dataloader.utils.Config': ('utils.html#config', 'jax_dataloader/utils.py'),
122123
'jax_dataloader.utils.Config.default': ('utils.html#config.default', 'jax_dataloader/utils.py'),
123124
'jax_dataloader.utils.Generator': ('utils.html#generator', 'jax_dataloader/utils.py'),

jax_dataloader/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .utils import *
99
from .datasets import *
1010
from .loaders import *
11+
from .types import *
1112

1213
# %% auto 0
1314
__all__ = ['SUPPORTED_DATASETS', 'DataloaderBackends', 'get_backend_compatibilities', 'DataLoader']

jax_dataloader/loaders/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# %% ../../nbs/loader.base.ipynb 3
44
from __future__ import print_function, division, annotations
55
from ..imports import *
6-
from ..utils import Generator, GeneratorType
6+
from ..utils import Generator
7+
from ..types import GeneratorType
78

89
# %% auto 0
910
__all__ = ['BaseDataLoader']

jax_dataloader/loaders/jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from ..imports import *
66
from ..datasets import ArrayDataset, JAXDataset
77
from . import BaseDataLoader
8-
from ..utils import get_config, asnumpy, Generator, GeneratorType
8+
from ..utils import get_config, asnumpy, Generator
9+
from ..types import GeneratorType
910
from ..tests import *
1011
import jax_dataloader as jdl
1112
from threading import Thread, Event

jax_dataloader/loaders/tensorflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from ..imports import *
66
from . import BaseDataLoader
77
from ..datasets import Dataset, ArrayDataset, JAXDataset
8-
from ..utils import check_tf_installed, get_config, Generator, GeneratorType
8+
from ..utils import check_tf_installed, get_config, Generator
9+
from ..types import GeneratorType
910
from ..tests import *
1011
from jax.tree_util import tree_map
1112
import warnings

jax_dataloader/loaders/torch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from ..imports import *
66
from . import BaseDataLoader
77
from ..datasets import Dataset, ArrayDataset, JAXDataset
8-
from ..utils import check_pytorch_installed, get_config, Generator, GeneratorType
8+
from ..utils import check_pytorch_installed, get_config, Generator
9+
from ..types import GeneratorType
910
from ..tests import *
1011
from jax.tree_util import tree_map
1112
import warnings

jax_dataloader/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .imports import *
2+
from .utils import Generator
3+
4+
GeneratorType = Union[Generator, jax.Array, 'torch.Generator']

jax_dataloader/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import collections
88

99
# %% auto 0
10-
__all__ = ['GeneratorType', 'Config', 'get_config', 'manual_seed', 'check_pytorch_installed', 'has_pytorch_tensor',
11-
'check_hf_installed', 'check_tf_installed', 'Generator', 'asnumpy']
10+
__all__ = ['Config', 'get_config', 'manual_seed', 'check_pytorch_installed', 'has_pytorch_tensor', 'check_hf_installed',
11+
'check_tf_installed', 'Generator', 'asnumpy']
1212

1313
# %% ../nbs/utils.ipynb 7
1414
@dataclass
@@ -67,10 +67,12 @@ def check_tf_installed():
6767

6868
# %% ../nbs/utils.ipynb 21
6969
class Generator:
70+
"""A wrapper around JAX and PyTorch generators. This is used to generate random numbers in a reproducible way."""
71+
7072
def __init__(
7173
self,
7274
*,
73-
generator: jrand.Array | torch.Generator = None,
75+
generator: jax.Array | torch.Generator = None, # Optional generator
7476
):
7577
self._seed = None
7678
self._jax_generator = None
@@ -118,8 +120,6 @@ def torch_generator(self) -> torch.Generator:
118120
raise ValueError("Neither pytorch generator or seed is specified.")
119121
return self._torch_generator
120122

121-
GeneratorType = Union[Generator, jax.Array, 'torch.Generator']
122-
123123
# %% ../nbs/utils.ipynb 26
124124
def asnumpy(x) -> np.ndarray:
125125
if isinstance(x, np.ndarray):

nbs/core.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545
"from jax_dataloader.imports import *\n",
4646
"from jax_dataloader.utils import *\n",
4747
"from jax_dataloader.datasets import *\n",
48-
"from jax_dataloader.loaders import *"
48+
"from jax_dataloader.loaders import *\n",
49+
"from jax_dataloader.types import *"
4950
]
5051
},
5152
{

nbs/loader.base.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
"#| export\n",
4040
"from __future__ import print_function, division, annotations\n",
4141
"from jax_dataloader.imports import *\n",
42-
"from jax_dataloader.utils import Generator, GeneratorType"
42+
"from jax_dataloader.utils import Generator\n",
43+
"from jax_dataloader.types import GeneratorType"
4344
]
4445
},
4546
{

0 commit comments

Comments
 (0)