1313import equinox as eqx
1414import optax
1515
16+
1617def get_simple_mlp (n_input , n_hidden , n_output , key ):
1718 """Simple 2-layer MLP for testing."""
18- shape = [n_input ] + ([n_hidden ] if isinstance (n_hidden , int ) else list (n_hidden )) + [n_output ]
19- return MLP (
20- shape = shape ,
21- key = key ,
22- activation = jax .nn .swish
19+ shape = (
20+ [n_input ]
21+ + ([n_hidden ] if isinstance (n_hidden , int ) else list (n_hidden ))
22+ + [n_output ]
2323 )
24+ return MLP (shape = shape , key = key , activation = jax .nn .swish )
25+
2426
2527##############################
2628# Solver Tests
2729##############################
2830
31+
2932class TestSolver :
3033 @pytest .fixture
3134 def solver (self ):
3235 key = jax .random .PRNGKey (0 )
3336 n_dim = 3
3437 n_hidden = 4
35- mlp = get_simple_mlp (n_input = n_dim + 1 , n_hidden = n_hidden , n_output = n_dim , key = key )
38+ mlp = get_simple_mlp (
39+ n_input = n_dim + 1 , n_hidden = n_hidden , n_output = n_dim , key = key
40+ )
3641 return Solver (model = mlp , method = Dopri5 ()), key , n_dim
3742
3843 def test_sample_shape_and_finiteness (self , solver ):
@@ -57,10 +62,12 @@ def test_sample_various_dt(self, solver, dt):
5762 assert samples .shape == (3 , n_dim )
5863 assert jnp .isfinite (samples ).all ()
5964
65+
6066##############################
6167# Path & Scheduler Tests
6268##############################
6369
70+
6471class TestPathAndScheduler :
6572 def test_path_sample_shapes_and_values (self ):
6673 n_dim = 2
@@ -82,25 +89,29 @@ def test_condotscheduler_call_output(self, t):
8289 assert len (out ) == 4
8390 assert all (isinstance (float (x ), float ) for x in out )
8491
92+
8593##############################
8694# FlowMatchingModel Tests
8795##############################
8896
97+
8998class TestFlowMatchingModel :
9099 @pytest .fixture
91100 def model (self ):
92101 key = jax .random .PRNGKey (42 )
93102 n_dim = 2
94103 n_hidden = 8
95- mlp = get_simple_mlp (n_input = n_dim + 1 , n_hidden = n_hidden , n_output = n_dim , key = key )
104+ mlp = get_simple_mlp (
105+ n_input = n_dim + 1 , n_hidden = n_hidden , n_output = n_dim , key = key
106+ )
96107 solver = Solver (model = mlp , method = Dopri5 ())
97108 scheduler = CondOTScheduler ()
98109 path = Path (scheduler = scheduler )
99110 model = FlowMatchingModel (
100111 solver = solver ,
101112 path = path ,
102113 data_mean = jnp .zeros (n_dim ),
103- data_cov = jnp .eye (n_dim )
114+ data_cov = jnp .eye (n_dim ),
104115 )
105116 return model , key , n_dim
106117
@@ -130,21 +141,27 @@ def test_log_prob_edge_cases(self, model):
130141 logp = model .log_prob (arr )
131142 logp_arr = jnp .asarray (logp )
132143 assert logp_arr .size == 1
133- assert jnp .isfinite (logp_arr ).all () or jnp .isnan (logp_arr ).all () # may be nan for extreme values
144+ assert (
145+ jnp .isfinite (logp_arr ).all () or jnp .isnan (logp_arr ).all ()
146+ ) # may be nan for extreme values
134147
135148 def test_save_and_load (self , tmp_path , model ):
136149 model , key , n_dim = model
137150 save_path = str (tmp_path / "test_model" )
138151 model .save_model (save_path )
139152 loaded = model .load_model (save_path )
140153 x = jax .random .normal (key , (2 , n_dim ))
141- assert jnp .allclose (eqx .filter_vmap (model .log_prob )(x ), eqx .filter_vmap (loaded .log_prob )(x ))
154+ assert jnp .allclose (
155+ eqx .filter_vmap (model .log_prob )(x ), eqx .filter_vmap (loaded .log_prob )(x )
156+ )
142157
143158 def test_properties (self , model ):
144159 model , key , n_dim = model
145160 mean = jnp .arange (n_dim )
146161 cov = jnp .eye (n_dim ) * 2
147- model2 = FlowMatchingModel (solver = model .solver , path = model .path , data_mean = mean , data_cov = cov )
162+ model2 = FlowMatchingModel (
163+ solver = model .solver , path = model .path , data_mean = mean , data_cov = cov
164+ )
148165 assert model2 .n_features == n_dim
149166 assert jnp .allclose (model2 .data_mean , mean )
150167 assert jnp .allclose (model2 .data_cov , cov )
@@ -157,7 +174,6 @@ def test_print_parameters_notimplemented(self, model):
157174 def test_train_step_and_epoch (self , model ):
158175 model , key , n_dim = model
159176 n_batch = 5
160- n_hidden = 8
161177 x0 = jax .random .normal (key , (n_batch , n_dim ))
162178 x1 = jax .random .normal (key , (n_batch , n_dim ))
163179 t = jax .random .uniform (key , (n_batch , 1 ))
@@ -170,6 +186,8 @@ def test_train_step_and_epoch(self, model):
170186 assert jnp .isfinite (loss )
171187 assert isinstance (model2 , FlowMatchingModel )
172188 data = (x0 , x1 , t )
173- loss_epoch , model3 , state3 = model .train_epoch (key , optim , state , data , batch_size = n_batch )
189+ loss_epoch , model3 , state3 = model .train_epoch (
190+ key , optim , state , data , batch_size = n_batch
191+ )
174192 assert jnp .isfinite (loss_epoch )
175193 assert isinstance (model3 , FlowMatchingModel )
0 commit comments