@@ -54,7 +54,7 @@ def compute_loss(bijection, draw, grad, logp):
5454 return cost
5555
5656 costs = jax .vmap (compute_loss , [None , 0 , 0 , 0 ])(
57- flow . bijection ,
57+ flow ,
5858 draws ,
5959 grads ,
6060 logps ,
@@ -69,7 +69,7 @@ def compute_loss(bijection, draw, grad, logp):
6969 return cost
7070
7171 costs = jax .vmap (compute_loss , [None , 0 , 0 , 0 ])(
72- flow . bijection ,
72+ flow ,
7373 draws ,
7474 grads ,
7575 logps ,
@@ -86,7 +86,7 @@ def compute_loss(bijection, draw, grad, logp):
8686 else :
8787
8888 def transform (draw , grad , logp ):
89- return flow .bijection . inverse_gradient_and_val_ (draw , grad , logp )
89+ return flow .inverse_gradient_and_val_ (draw , grad , logp )
9090
9191 draws , grads , logps = jax .vmap (transform , [0 , 0 , 0 ], (0 , 0 , 0 ))(
9292 draws , grads , logps
@@ -98,9 +98,7 @@ def transform(draw, grad, logp):
9898
9999
100100def fit_flow (key , bijection , loss_fn , draws , grads , logps , ** kwargs ):
101- flow = flowjax .flows .Transformed (
102- flowjax .distributions .StandardNormal (bijection .shape ), bijection
103- )
101+ flow = bijection
104102
105103 key , train_key = jax .random .split (key )
106104
@@ -113,7 +111,7 @@ def fit_flow(key, bijection, loss_fn, draws, grads, logps, **kwargs):
113111 return_best = True ,
114112 ** kwargs ,
115113 )
116- return fit . bijection , losses , losses ["opt_state" ]
114+ return fit , losses , losses ["opt_state" ]
117115
118116
119117@eqx .filter_jit
@@ -298,9 +296,7 @@ def update(self, seed, positions, gradients, logps):
298296
299297 fit = self ._make_flow_fn (seed , positions , gradients , n_layers = 0 )
300298
301- flow = flowjax .flows .Transformed (
302- flowjax .distributions .StandardNormal (fit .shape ), fit
303- )
299+ flow = fit
304300 params , static = eqx .partition (flow , eqx .is_inexact_array )
305301 new_loss = self ._loss_fn (params , static , positions , gradients , logps )
306302
@@ -341,9 +337,7 @@ def update(self, seed, positions, gradients, logps):
341337 untransformed_dim = self ._untransformed_dim ,
342338 zero_init = self ._zero_init ,
343339 )
344- flow = flowjax .flows .Transformed (
345- flowjax .distributions .StandardNormal (base .shape ), base
346- )
340+ flow = base
347341 params , static = eqx .partition (flow , eqx .is_inexact_array )
348342 if self ._verbose :
349343 print (
@@ -356,9 +350,9 @@ def update(self, seed, positions, gradients, logps):
356350 self ._loss_fn (
357351 params ,
358352 static ,
359- positions [- 100 :],
360- gradients [- 100 :],
361- logps [- 100 :],
353+ positions [- 128 :],
354+ gradients [- 128 :],
355+ logps [- 128 :],
362356 ),
363357 )
364358 else :
@@ -392,10 +386,7 @@ def update(self, seed, positions, gradients, logps):
392386 self ._opt_state = None
393387 return
394388
395- flow = flowjax .flows .Transformed (
396- flowjax .distributions .StandardNormal (self ._bijection .shape ),
397- self ._bijection ,
398- )
389+ flow = self ._bijection
399390 params , static = eqx .partition (flow , eqx .is_inexact_array )
400391 old_loss = self ._loss_fn (
401392 params , static , positions [- 128 :], gradients [- 128 :], logps [- 128 :]
@@ -420,9 +411,7 @@ def update(self, seed, positions, gradients, logps):
420411 max_patience = self ._max_patience ,
421412 )
422413
423- flow = flowjax .flows .Transformed (
424- flowjax .distributions .StandardNormal (fit .shape ), fit
425- )
414+ flow = fit
426415 params , static = eqx .partition (flow , eqx .is_inexact_array )
427416 new_loss = self ._loss_fn (
428417 params , static , positions [- 128 :], gradients [- 128 :], logps [- 128 :]
@@ -432,10 +421,7 @@ def update(self, seed, positions, gradients, logps):
432421 print (f"Chain { self ._chain } : New loss { new_loss } , old loss { old_loss } " )
433422
434423 if not np .isfinite (old_loss ):
435- flow = flowjax .flows .Transformed (
436- flowjax .distributions .StandardNormal (self ._bijection .shape ),
437- self ._bijection ,
438- )
424+ flow = self ._bijection
439425 params , static = eqx .partition (flow , eqx .is_inexact_array )
440426 print (
441427 self ._loss_fn (
@@ -449,9 +435,7 @@ def update(self, seed, positions, gradients, logps):
449435 )
450436
451437 if not np .isfinite (new_loss ):
452- flow = flowjax .flows .Transformed (
453- flowjax .distributions .StandardNormal (fit .shape ), fit
454- )
438+ flow = fit
455439 params , static = eqx .partition (flow , eqx .is_inexact_array )
456440 print (
457441 self ._loss_fn (
0 commit comments