Skip to content

Commit c0dc20d

Browse files
committed
docstring, fix test
1 parent 3a4bfc6 commit c0dc20d

File tree

4 files changed

+69
-46
lines changed

4 files changed

+69
-46
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ _set_transformer.py
1010
__unet.py
1111
cifar10.ipynb
1212
grfs.ipynb
13-
simple.ipynb
13+
simple.ipynb
14+
mnist_clouds.py

data/utils.py

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,42 @@ def loop(
167167
)
168168

169169

170+
def maybe_convert(a):
171+
return np.asarray(a) if isinstance(a, jnp.ndarray) else a
172+
173+
174+
class TensorDataset(torch.utils.data.Dataset):
175+
def __init__(self, tensors, x_transform=None, q_transform=None, a_transform=None):
176+
self.names = ["x", "q", "a"]
177+
self.data = {
178+
name: torch.as_tensor(np.copy(maybe_convert(t))) if exists(t) else None
179+
for name, t in zip(self.names, tensors)
180+
}
181+
182+
self.transforms = {
183+
name: transform if exists(transform) else None
184+
for name, transform in zip(self.names, [x_transform, q_transform, a_transform])
185+
}
186+
187+
# Sanity check: all non-None tensors must have same first dimension
188+
lengths = [v.shape[0] for v in self.data.values() if v is not None]
189+
assert len(set(lengths)) == 1, "All input tensors must have the same length."
190+
191+
def __getitem__(self, index):
192+
output = []
193+
for key in self.names:
194+
tensor = self.data.get(key)
195+
if exists(tensor):
196+
val = tensor[index]
197+
if self.transforms[key]:
198+
val = self.transforms[key](val)
199+
output.append(val)
200+
return tuple(output)
201+
202+
def __len__(self):
203+
return next(v.shape[0] for v in self.data.values() if v is not None)
204+
205+
170206
@jaxtyped(typechecker=typechecker)
171207
@dataclass
172208
class ScalerDataset:
@@ -220,43 +256,6 @@ class ScalerDataset:
220256
]
221257

222258

223-
def maybe_convert(a):
224-
return np.asarray(a) if isinstance(a, jnp.ndarray) else a
225-
226-
227-
class TensorDataset(torch.utils.data.Dataset):
228-
def __init__(self, tensors, x_transform=None, q_transform=None, a_transform=None):
229-
self.names = ["x", "q", "a"]
230-
self.data = {
231-
name: torch.as_tensor(np.copy(maybe_convert(t))) if exists(t) else None
232-
for name, t in zip(self.names, tensors)
233-
}
234-
235-
self.transforms = {
236-
name: transform if exists(transform) else None
237-
for name, transform in zip(self.names, [x_transform, q_transform, a_transform])
238-
}
239-
240-
# Sanity check: all non-None tensors must have same first dimension
241-
lengths = [v.shape[0] for v in self.data.values() if v is not None]
242-
assert len(set(lengths)) == 1, "All input tensors must have the same length."
243-
244-
def __getitem__(self, index):
245-
output = []
246-
for key in self.names:
247-
tensor = self.data.get(key)
248-
if exists(tensor):
249-
val = tensor[index]
250-
if self.transforms[key]:
251-
val = self.transforms[key](val)
252-
val = jnp.asarray(val.numpy())
253-
output.append(val)
254-
return tuple(output)
255-
256-
def __len__(self):
257-
return next(v.shape[0] for v in self.data.values() if v is not None)
258-
259-
260259
@jaxtyped(typechecker=typechecker)
261260
def dataset_from_tensors(
262261
X: Float[Array, "n ..."],

sbgm/_train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import jax.numpy as jnp
99
import jax.random as jr
1010
import equinox as eqx
11-
from jaxtyping import Key, Array, Float, PyTree, jaxtyped
11+
from jaxtyping import Key, Array, Float, Scalar, PyTree, jaxtyped
1212
from beartype import beartype as typechecker
13-
from ml_collections import ConfigDict
1413
import optax
14+
from ml_collections import ConfigDict
1515
from tqdm.auto import trange
1616

1717
from .sde import SDE
@@ -69,7 +69,7 @@ def accumulate_gradients_scan(
6969
n_minibatches: int,
7070
*,
7171
grad_fn: Callable
72-
) -> Tuple[Float[Array, ""], PyTree]:
72+
) -> Tuple[Scalar, PyTree]:
7373
batch_size = xqat[0].shape[0]
7474
minibatch_size = batch_size // n_minibatches
7575

@@ -124,9 +124,9 @@ def single_loss_fn(
124124
x: Float[Array, "..."],
125125
q: Optional[Float[Array, "..."]],
126126
a: Optional[Float[Array, "..."]],
127-
t: Float[Array, ""],
127+
t: Scalar,
128128
key: Key
129-
) -> Float[Array, ""]:
129+
) -> Scalar:
130130
key_noise, key_apply = jr.split(key)
131131
mean, std = sde.marginal_prob(x, t)
132132
noise = jr.normal(key_noise, x.shape)
@@ -183,7 +183,7 @@ def make_step(
183183
sharding: Optional[jax.sharding.NamedSharding] = None,
184184
replicated_sharding: Optional[jax.sharding.NamedSharding] = None
185185
) -> Tuple[
186-
Float[Array, ""], Model, Key[jnp.ndarray, "..."], optax.OptState
186+
Scalar, Model, Key[jnp.ndarray, "..."], optax.OptState
187187
]:
188188
model = eqx.nn.inference_mode(model, False)
189189

sbgm/models/__init__.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import equinox as eqx
33
from jaxtyping import Key
44
import numpy as np
5-
import ml_collections
5+
from ml_collections import ConfigDict
66

77
from ._mixer import Mixer2d
88
from ._mlp import ResidualNetwork
@@ -13,17 +13,40 @@
1313
def get_model(
1414
model_key: Key,
1515
model_type: str,
16-
config: ml_collections.ConfigDict,
16+
config: ConfigDict,
1717
data_shape: Sequence[int],
1818
context_shape: Optional[Sequence[int]] = None,
1919
parameter_dim: Optional[int] = None
2020
) -> eqx.Module:
21+
"""
22+
Get the model based on the specified type and configuration.
23+
24+
Args:
25+
model_key: JAX random key for model initialization.
26+
model_type: Type of the model to create (e.g., "Mixer", "UNet", "mlp", "DiT").
27+
config: Configuration dictionary containing model parameters.
28+
data_shape: Shape of the input data (e.g. image dimensions, channels first).
29+
context_shape: Shape of the context map, if applicable.
30+
parameter_dim: Dimension of the additional conditioning.
31+
Returns:
32+
An initialized instance of the specified model type.
33+
34+
Raises:
35+
ValueError: If the model type is not recognized.
36+
"""
37+
2138
# Grab channel assuming 'q' is a map like x
2239
if context_shape is not None:
2340
context_channels, *_ = context_shape.shape
2441
else:
2542
context_channels = None
2643

44+
if model_type not in ["Mixer", "UNet", "mlp", "DiT"]:
45+
raise ValueError(
46+
f"Model type {model_type} is not recognized. "
47+
"Choose from 'Mixer', 'UNet', 'mlp', or 'DiT'."
48+
)
49+
2750
if model_type == "Mixer":
2851
model = Mixer2d(
2952
data_shape,

0 commit comments

Comments
 (0)