@@ -479,66 +479,6 @@ def train_flow(
479479 return train_flow , train_epoch , train_step
480480
481481
482- def sample_flow_matching (self , n_samples , rng_key , steps = 100 ):
483- # Prior: standard normal
484- x0 = jax .random .normal (rng_key , (n_samples , self .ndim ))
485- t0 , t1 = 0.0 , 1.0
486- ts = jnp .linspace (t0 , t1 , steps )
487-
488- def vector_field (t , x , args ):
489- # x shape: (n_samples, ndim)
490- t_vec = jnp .full ((x .shape [0 ],), t )
491- return self .flow .apply ({"params" : self .state .params }, x , t_vec )
492-
493- term = diffrax .ODETerm (vector_field )
494- solver = diffrax .Dopri5 ()
495- saveat = diffrax .SaveAt (t1 = True )
496-
497- # Integrate for each sample
498- def integrate_single (x0 ):
499- sol = diffrax .diffeqsolve (
500- term , solver , t0 = t0 , t1 = t1 , dt0 = 1e-2 , y0 = x0 , saveat = saveat
501- )
502- return sol .ys [0 ]
503-
504- xs = integrate_single (x0 )
505- return xs
506-
507-
508- def log_prob_flow_matching (self , x_samples , steps = 100 ):
509- D = x_samples .shape [1 ]
510- t0 , t1 = 0.0 , 1.0
511-
512- def reverse_vector_field (t , y , args ):
513- x , log_det = y [:- 1 ], y [- 1 ]
514- t_val = 1.0 - t # Reverse time
515- def flow_fn (x_single ):
516- return self .flow .apply ({"params" : self .state .params }, x_single [None , :], jnp .array ([t_val ]))[0 ]
517- jac = jax .jacobian (flow_fn )(x )
518- div = jnp .trace (jac )
519- v = - flow_fn (x )
520- d_log_det = - div
521- return jnp .concatenate ([v , jnp .array ([d_log_det ])])
522-
523- def get_z_and_logdet (x ):
524- y0 = jnp .concatenate ([x , jnp .array ([0.0 ])])
525- term = diffrax .ODETerm (reverse_vector_field )
526- solver = diffrax .Dopri5 ()
527- solution = diffrax .diffeqsolve (
528- term , solver , t0 = t0 , t1 = t1 , dt0 = 1e-2 , y0 = y0 ,
529- saveat = diffrax .SaveAt (t1 = True )
530- )
531- z = solution .ys [0 ][:- 1 ]
532- log_det = solution .ys [0 ][- 1 ]
533- return z , log_det
534-
535- zs , log_dets = jax .vmap (get_z_and_logdet )(x_samples )
536- # Prior log density (standard normal)
537- prior = stats .multivariate_normal (mean = np .zeros (D ), cov = np .eye (D ))
538- log_p_zs = prior .logpdf (np .array (zs ))
539- log_densities = log_p_zs + np .array (log_dets )
540- return jnp .array (log_densities )
541-
542482
543483class FlowMatchingModel (FlowModel ):
544484 """Flow Matching model using an MLP for v(x, t)."""
@@ -611,9 +551,75 @@ def fit(
611551 self .loss_values = np .array (loss_values )
612552 return
613553
554+ def sample_flow_matching (self , n_samples , rng_key , steps = 100 ):
555+ # Prior: standard normal
556+ x0 = jax .random .normal (rng_key , (n_samples , self .ndim )) * self .temperature
557+ t0 , t1 = 0.0 , 1.0
558+ ts = jnp .linspace (t0 , t1 , steps )
559+
560+ def vector_field (t , x , args ):
561+ # x shape: (n_samples, ndim)
562+ t_vec = jnp .full ((x .shape [0 ],), t )
563+ return self .flow .apply ({"params" : self .state .params }, x , t_vec )
564+
565+ term = diffrax .ODETerm (vector_field )
566+ solver = diffrax .Dopri5 ()
567+ saveat = diffrax .SaveAt (t1 = True )
568+
569+ # Integrate for each sample
570+ def integrate_single (x0 ):
571+ sol = diffrax .diffeqsolve (
572+ term , solver , t0 = t0 , t1 = t1 , dt0 = 1e-2 , y0 = x0 , saveat = saveat
573+ )
574+ return sol .ys [0 ]
575+
576+ xs = integrate_single (x0 )
577+ return xs
578+
579+
580+ def log_prob_flow_matching (self , x_samples , steps = 100 ):
581+ D = x_samples .shape [1 ]
582+ t0 , t1 = 0.0 , 1.0
583+
584+ def reverse_vector_field (t , y , args ):
585+ x , log_det = y [:- 1 ], y [- 1 ]
586+ t_val = 1.0 - t # Reverse time
587+ def flow_fn (x_single ):
588+ return self .flow .apply ({"params" : self .state .params }, x_single [None , :], jnp .array ([t_val ]))[0 ]
589+ jac = jax .jacobian (flow_fn )(x )
590+ div = jnp .trace (jac )
591+ v = - flow_fn (x )
592+ d_log_det = - div
593+ return jnp .concatenate ([v , jnp .array ([d_log_det ])])
594+
595+ def get_z_and_logdet (x ):
596+ y0 = jnp .concatenate ([x , jnp .array ([0.0 ])])
597+ term = diffrax .ODETerm (reverse_vector_field )
598+ solver = diffrax .Dopri5 ()
599+ solution = diffrax .diffeqsolve (
600+ term , solver , t0 = t0 , t1 = t1 , dt0 = 1e-2 , y0 = y0 ,
601+ saveat = diffrax .SaveAt (t1 = True )
602+ )
603+ z = solution .ys [0 ][:- 1 ]
604+ log_det = solution .ys [0 ][- 1 ]
605+ return z , log_det
606+
607+ zs , log_dets = jax .vmap (get_z_and_logdet )(x_samples )
608+ # Prior log density (standard normal)
609+ prior = stats .multivariate_normal (mean = np .zeros (D ), cov = np .eye (D )* self .temperature )
610+ log_p_zs = prior .logpdf (np .array (zs ))
611+ log_densities = log_p_zs + np .array (log_dets )
612+ return jnp .array (log_densities )
614613
615614 def sample (self , n_sample : int , rng_key = jax .random .PRNGKey (0 )) -> jnp .ndarray :
616- return sample_flow_matching (self , n_sample , rng_key )
615+ return self . sample_flow_matching (n_sample , rng_key )
617616
618617 def log_prob (self , x : jnp .ndarray ) -> jnp .ndarray :
619- return log_prob_flow_matching (self , x )
618+ return self .log_prob_flow_matching (x )
619+
620+
621+ def predict (self , x : jnp .ndarray ) -> jnp .ndarray :
622+ """
623+ Predict the log_e posterior for batched input x using flow matching.
624+ """
625+ return self .log_prob_flow_matching (x )
0 commit comments