-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathutils.py
More file actions
134 lines (113 loc) · 4.63 KB
/
utils.py
File metadata and controls
134 lines (113 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/utils.ipynb.
# %% ../nbs/utils.ipynb 4
from __future__ import print_function, division, annotations
from .imports import *
import jax_dataloader as jdl
import collections
# %% auto 0
__all__ = ['Config', 'get_config', 'manual_seed', 'check_pytorch_installed', 'has_pytorch_tensor', 'check_hf_installed',
'check_tf_installed', 'Generator', 'asnumpy']
# %% ../nbs/utils.ipynb 7
@dataclass
class Config:
"""Global configuration for the library"""
rng_reserve_size: int
global_seed: int
@classmethod
def default(cls) -> Config:
return cls(rng_reserve_size=1, global_seed=42)
# %% ../nbs/utils.ipynb 8
main_config = Config.default()
# %% ../nbs/utils.ipynb 9
def get_config() -> Config:
return main_config
# %% ../nbs/utils.ipynb 10
def manual_seed(seed: int):
"""Set the seed for the library"""
main_config.global_seed = seed
# %% ../nbs/utils.ipynb 13
def check_pytorch_installed():
if torch_data is None:
raise ModuleNotFoundError("`pytorch` library needs to be installed. "
"Try `pip install torch`. Please refer to pytorch documentation for details: "
"https://pytorch.org/get-started/.")
# %% ../nbs/utils.ipynb 15
def has_pytorch_tensor(batch) -> bool:
if isinstance(batch[0], torch.Tensor):
return True
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return any([has_pytorch_tensor(samples) for samples in transposed])
else:
return False
# %% ../nbs/utils.ipynb 16
def check_hf_installed():
if hf_datasets is None:
raise ModuleNotFoundError("`datasets` library needs to be installed. "
"Try `pip install datasets`. Please refer to huggingface documentation for details: "
"https://huggingface.co/docs/datasets/installation.html.")
# %% ../nbs/utils.ipynb 18
def check_tf_installed():
if tf is None:
raise ModuleNotFoundError("`tensorflow` library needs to be installed. "
"Try `pip install tensorflow`. Please refer to tensorflow documentation for details: "
"https://www.tensorflow.org/install/pip.")
# %% ../nbs/utils.ipynb 21
class Generator:
def __init__(
self,
*,
generator: jrand.Array | torch.Generator = None,
):
self._seed = None
self._jax_generator = None
self._torch_generator = None
if generator is None:
self._seed = get_config().global_seed
elif isinstance(generator, torch.Generator):
self._torch_generator = generator
elif isinstance(generator, jax.Array):
self._jax_generator = generator
else:
raise ValueError(f"generator=`{generator}` is invalid. Must be either a `jax.random.PRNGKey` or a `torch.Generator`.")
if self._seed is None and self._torch_generator is not None:
self._seed = self._torch_generator.initial_seed()
def seed(self) -> Optional[int]:
"""The initial seed of the generator"""
# TODO: the seed might not be initizalized if the generator is a `jax.random.PRNGKey`
return self._seed
def manual_seed(self, seed: int) -> Generator:
"""Set the seed for the generator. This will override the initial seed and the generator."""
if self._jax_generator is not None:
self._jax_generator = jrand.PRNGKey(seed)
if self._torch_generator is not None:
self._torch_generator = torch.Generator().manual_seed(seed)
self._seed = seed
return self
def jax_generator(self) -> jax.Array:
"""The JAX generator"""
if self._jax_generator is None:
self._jax_generator = jrand.PRNGKey(self._seed)
return self._jax_generator
def torch_generator(self) -> torch.Generator:
"""The PyTorch generator"""
check_pytorch_installed()
if self._torch_generator is None and self._seed is not None:
self._torch_generator = torch.Generator().manual_seed(self._seed)
if self._torch_generator is None:
raise ValueError("Neither pytorch generator or seed is specified.")
return self._torch_generator
# %% ../nbs/utils.ipynb 26
def asnumpy(x) -> np.ndarray:
if isinstance(x, np.ndarray):
return x
elif isinstance(x, jnp.ndarray):
return x.__array__()
elif torch_data and isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
elif tf and isinstance(x, tf.Tensor):
return x.numpy()
elif isinstance(x, (tuple, list)):
return map(asnumpy, x)
else:
raise ValueError(f"Unknown type: {type(x)}")