Skip to content

Commit b57126b

Browse files
Make sampling and logprobs attributesa nd add predict method
1 parent 4836907 commit b57126b

File tree

1 file changed

+68
-62
lines changed

1 file changed

+68
-62
lines changed

harmonic/model.py

Lines changed: 68 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -479,66 +479,6 @@ def train_flow(
479479
return train_flow, train_epoch, train_step
480480

481481

482-
def sample_flow_matching(self, n_samples, rng_key, steps=100):
483-
# Prior: standard normal
484-
x0 = jax.random.normal(rng_key, (n_samples, self.ndim))
485-
t0, t1 = 0.0, 1.0
486-
ts = jnp.linspace(t0, t1, steps)
487-
488-
def vector_field(t, x, args):
489-
# x shape: (n_samples, ndim)
490-
t_vec = jnp.full((x.shape[0],), t)
491-
return self.flow.apply({"params": self.state.params}, x, t_vec)
492-
493-
term = diffrax.ODETerm(vector_field)
494-
solver = diffrax.Dopri5()
495-
saveat = diffrax.SaveAt(t1=True)
496-
497-
# Integrate for each sample
498-
def integrate_single(x0):
499-
sol = diffrax.diffeqsolve(
500-
term, solver, t0=t0, t1=t1, dt0=1e-2, y0=x0, saveat=saveat
501-
)
502-
return sol.ys[0]
503-
504-
xs = integrate_single(x0)
505-
return xs
506-
507-
508-
def log_prob_flow_matching(self, x_samples, steps=100):
509-
D = x_samples.shape[1]
510-
t0, t1 = 0.0, 1.0
511-
512-
def reverse_vector_field(t, y, args):
513-
x, log_det = y[:-1], y[-1]
514-
t_val = 1.0 - t # Reverse time
515-
def flow_fn(x_single):
516-
return self.flow.apply({"params": self.state.params}, x_single[None, :], jnp.array([t_val]))[0]
517-
jac = jax.jacobian(flow_fn)(x)
518-
div = jnp.trace(jac)
519-
v = -flow_fn(x)
520-
d_log_det = -div
521-
return jnp.concatenate([v, jnp.array([d_log_det])])
522-
523-
def get_z_and_logdet(x):
524-
y0 = jnp.concatenate([x, jnp.array([0.0])])
525-
term = diffrax.ODETerm(reverse_vector_field)
526-
solver = diffrax.Dopri5()
527-
solution = diffrax.diffeqsolve(
528-
term, solver, t0=t0, t1=t1, dt0=1e-2, y0=y0,
529-
saveat=diffrax.SaveAt(t1=True)
530-
)
531-
z = solution.ys[0][:-1]
532-
log_det = solution.ys[0][-1]
533-
return z, log_det
534-
535-
zs, log_dets = jax.vmap(get_z_and_logdet)(x_samples)
536-
# Prior log density (standard normal)
537-
prior = stats.multivariate_normal(mean=np.zeros(D), cov=np.eye(D))
538-
log_p_zs = prior.logpdf(np.array(zs))
539-
log_densities = log_p_zs + np.array(log_dets)
540-
return jnp.array(log_densities)
541-
542482

543483
class FlowMatchingModel(FlowModel):
544484
"""Flow Matching model using an MLP for v(x, t)."""
@@ -611,9 +551,75 @@ def fit(
611551
self.loss_values = np.array(loss_values)
612552
return
613553

554+
def sample_flow_matching(self, n_samples, rng_key, steps=100):
555+
# Prior: standard normal
556+
x0 = jax.random.normal(rng_key, (n_samples, self.ndim)) * self.temperature
557+
t0, t1 = 0.0, 1.0
558+
ts = jnp.linspace(t0, t1, steps)
559+
560+
def vector_field(t, x, args):
561+
# x shape: (n_samples, ndim)
562+
t_vec = jnp.full((x.shape[0],), t)
563+
return self.flow.apply({"params": self.state.params}, x, t_vec)
564+
565+
term = diffrax.ODETerm(vector_field)
566+
solver = diffrax.Dopri5()
567+
saveat = diffrax.SaveAt(t1=True)
568+
569+
# Integrate for each sample
570+
def integrate_single(x0):
571+
sol = diffrax.diffeqsolve(
572+
term, solver, t0=t0, t1=t1, dt0=1e-2, y0=x0, saveat=saveat
573+
)
574+
return sol.ys[0]
575+
576+
xs = integrate_single(x0)
577+
return xs
578+
579+
580+
def log_prob_flow_matching(self, x_samples, steps=100):
581+
D = x_samples.shape[1]
582+
t0, t1 = 0.0, 1.0
583+
584+
def reverse_vector_field(t, y, args):
585+
x, log_det = y[:-1], y[-1]
586+
t_val = 1.0 - t # Reverse time
587+
def flow_fn(x_single):
588+
return self.flow.apply({"params": self.state.params}, x_single[None, :], jnp.array([t_val]))[0]
589+
jac = jax.jacobian(flow_fn)(x)
590+
div = jnp.trace(jac)
591+
v = -flow_fn(x)
592+
d_log_det = -div
593+
return jnp.concatenate([v, jnp.array([d_log_det])])
594+
595+
def get_z_and_logdet(x):
596+
y0 = jnp.concatenate([x, jnp.array([0.0])])
597+
term = diffrax.ODETerm(reverse_vector_field)
598+
solver = diffrax.Dopri5()
599+
solution = diffrax.diffeqsolve(
600+
term, solver, t0=t0, t1=t1, dt0=1e-2, y0=y0,
601+
saveat=diffrax.SaveAt(t1=True)
602+
)
603+
z = solution.ys[0][:-1]
604+
log_det = solution.ys[0][-1]
605+
return z, log_det
606+
607+
zs, log_dets = jax.vmap(get_z_and_logdet)(x_samples)
608+
# Prior log density (standard normal)
609+
prior = stats.multivariate_normal(mean=np.zeros(D), cov=np.eye(D)*self.temperature)
610+
log_p_zs = prior.logpdf(np.array(zs))
611+
log_densities = log_p_zs + np.array(log_dets)
612+
return jnp.array(log_densities)
614613

615614
def sample(self, n_sample: int, rng_key=jax.random.PRNGKey(0)) -> jnp.ndarray:
616-
return sample_flow_matching(self, n_sample, rng_key)
615+
return self.sample_flow_matching(n_sample, rng_key)
617616

618617
def log_prob(self, x: jnp.ndarray) -> jnp.ndarray:
619-
return log_prob_flow_matching(self, x)
618+
return self.log_prob_flow_matching(x)
619+
620+
621+
def predict(self, x: jnp.ndarray) -> jnp.ndarray:
622+
"""
623+
Predict the log_e posterior for batched input x using flow matching.
624+
"""
625+
return self.log_prob_flow_matching(x)

0 commit comments

Comments
 (0)