Skip to content

Commit 1729bdf

Browse files
committed
Format test_flowmatching.py for improved readability
1 parent 8a2ced0 commit 1729bdf

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

test/unit/test_flowmatching.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,31 @@
1313
import equinox as eqx
1414
import optax
1515

16+
1617
def get_simple_mlp(n_input, n_hidden, n_output, key):
1718
"""Simple 2-layer MLP for testing."""
18-
shape = [n_input] + ([n_hidden] if isinstance(n_hidden, int) else list(n_hidden)) + [n_output]
19-
return MLP(
20-
shape=shape,
21-
key=key,
22-
activation=jax.nn.swish
19+
shape = (
20+
[n_input]
21+
+ ([n_hidden] if isinstance(n_hidden, int) else list(n_hidden))
22+
+ [n_output]
2323
)
24+
return MLP(shape=shape, key=key, activation=jax.nn.swish)
25+
2426

2527
##############################
2628
# Solver Tests
2729
##############################
2830

31+
2932
class TestSolver:
3033
@pytest.fixture
3134
def solver(self):
3235
key = jax.random.PRNGKey(0)
3336
n_dim = 3
3437
n_hidden = 4
35-
mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key)
38+
mlp = get_simple_mlp(
39+
n_input=n_dim + 1, n_hidden=n_hidden, n_output=n_dim, key=key
40+
)
3641
return Solver(model=mlp, method=Dopri5()), key, n_dim
3742

3843
def test_sample_shape_and_finiteness(self, solver):
@@ -57,10 +62,12 @@ def test_sample_various_dt(self, solver, dt):
5762
assert samples.shape == (3, n_dim)
5863
assert jnp.isfinite(samples).all()
5964

65+
6066
##############################
6167
# Path & Scheduler Tests
6268
##############################
6369

70+
6471
class TestPathAndScheduler:
6572
def test_path_sample_shapes_and_values(self):
6673
n_dim = 2
@@ -82,25 +89,29 @@ def test_condotscheduler_call_output(self, t):
8289
assert len(out) == 4
8390
assert all(isinstance(float(x), float) for x in out)
8491

92+
8593
##############################
8694
# FlowMatchingModel Tests
8795
##############################
8896

97+
8998
class TestFlowMatchingModel:
9099
@pytest.fixture
91100
def model(self):
92101
key = jax.random.PRNGKey(42)
93102
n_dim = 2
94103
n_hidden = 8
95-
mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key)
104+
mlp = get_simple_mlp(
105+
n_input=n_dim + 1, n_hidden=n_hidden, n_output=n_dim, key=key
106+
)
96107
solver = Solver(model=mlp, method=Dopri5())
97108
scheduler = CondOTScheduler()
98109
path = Path(scheduler=scheduler)
99110
model = FlowMatchingModel(
100111
solver=solver,
101112
path=path,
102113
data_mean=jnp.zeros(n_dim),
103-
data_cov=jnp.eye(n_dim)
114+
data_cov=jnp.eye(n_dim),
104115
)
105116
return model, key, n_dim
106117

@@ -130,21 +141,27 @@ def test_log_prob_edge_cases(self, model):
130141
logp = model.log_prob(arr)
131142
logp_arr = jnp.asarray(logp)
132143
assert logp_arr.size == 1
133-
assert jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all() # may be nan for extreme values
144+
assert (
145+
jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all()
146+
) # may be nan for extreme values
134147

135148
def test_save_and_load(self, tmp_path, model):
136149
model, key, n_dim = model
137150
save_path = str(tmp_path / "test_model")
138151
model.save_model(save_path)
139152
loaded = model.load_model(save_path)
140153
x = jax.random.normal(key, (2, n_dim))
141-
assert jnp.allclose(eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x))
154+
assert jnp.allclose(
155+
eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x)
156+
)
142157

143158
def test_properties(self, model):
144159
model, key, n_dim = model
145160
mean = jnp.arange(n_dim)
146161
cov = jnp.eye(n_dim) * 2
147-
model2 = FlowMatchingModel(solver=model.solver, path=model.path, data_mean=mean, data_cov=cov)
162+
model2 = FlowMatchingModel(
163+
solver=model.solver, path=model.path, data_mean=mean, data_cov=cov
164+
)
148165
assert model2.n_features == n_dim
149166
assert jnp.allclose(model2.data_mean, mean)
150167
assert jnp.allclose(model2.data_cov, cov)
@@ -157,7 +174,6 @@ def test_print_parameters_notimplemented(self, model):
157174
def test_train_step_and_epoch(self, model):
158175
model, key, n_dim = model
159176
n_batch = 5
160-
n_hidden = 8
161177
x0 = jax.random.normal(key, (n_batch, n_dim))
162178
x1 = jax.random.normal(key, (n_batch, n_dim))
163179
t = jax.random.uniform(key, (n_batch, 1))
@@ -170,6 +186,8 @@ def test_train_step_and_epoch(self, model):
170186
assert jnp.isfinite(loss)
171187
assert isinstance(model2, FlowMatchingModel)
172188
data = (x0, x1, t)
173-
loss_epoch, model3, state3 = model.train_epoch(key, optim, state, data, batch_size=n_batch)
189+
loss_epoch, model3, state3 = model.train_epoch(
190+
key, optim, state, data, batch_size=n_batch
191+
)
174192
assert jnp.isfinite(loss_epoch)
175193
assert isinstance(model3, FlowMatchingModel)

0 commit comments

Comments
 (0)