|
7 | 7 | import collections |
8 | 8 |
|
9 | 9 | # %% auto 0 |
10 | | -__all__ = ['GeneratorType', 'Config', 'get_config', 'manual_seed', 'check_pytorch_installed', 'has_pytorch_tensor', |
11 | | - 'check_hf_installed', 'check_tf_installed', 'Generator', 'asnumpy'] |
| 10 | +__all__ = ['Config', 'get_config', 'manual_seed', 'check_pytorch_installed', 'has_pytorch_tensor', 'check_hf_installed', |
| 11 | + 'check_tf_installed', 'Generator', 'asnumpy'] |
12 | 12 |
|
13 | 13 | # %% ../nbs/utils.ipynb 7 |
14 | 14 | @dataclass |
@@ -67,10 +67,12 @@ def check_tf_installed(): |
67 | 67 |
|
68 | 68 | # %% ../nbs/utils.ipynb 21 |
69 | 69 | class Generator: |
| 70 | + """A wrapper around JAX and PyTorch generators. This is used to generate random numbers in a reproducible way.""" |
| 71 | + |
70 | 72 | def __init__( |
71 | 73 | self, |
72 | 74 | *, |
73 | | - generator: jrand.Array | torch.Generator = None, |
| 75 | + generator: jax.Array | torch.Generator = None, # Optional generator |
74 | 76 | ): |
75 | 77 | self._seed = None |
76 | 78 | self._jax_generator = None |
@@ -118,8 +120,6 @@ def torch_generator(self) -> torch.Generator: |
118 | 120 | raise ValueError("Neither pytorch generator or seed is specified.") |
119 | 121 | return self._torch_generator |
120 | 122 |
|
121 | | -GeneratorType = Union[Generator, jax.Array, 'torch.Generator'] |
122 | | - |
123 | 123 | # %% ../nbs/utils.ipynb 26 |
124 | 124 | def asnumpy(x) -> np.ndarray: |
125 | 125 | if isinstance(x, np.ndarray): |
|
0 commit comments