Skip to content

Commit 493d794

Browse files
committed
add fm schedule
1 parent 55c18e2 commit 493d794

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def __init__(
107107
self.optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})
108108

109109
self.loss_fn = keras.losses.get(loss_fn)
110+
self.time_sampling_alpha = kwargs.pop("time_sampling_alpha", 0.0) # 0 is uniform, <0 favors smaller t
111+
if self.time_sampling_alpha <= -1.0:
112+
raise ValueError("'time_sampling_alpha' must be greater than -1.0.")
110113

111114
self.seed_generator = keras.random.SeedGenerator()
112115

@@ -307,7 +310,9 @@ def compute_metrics(
307310
# conditions must be resampled along with x1
308311
conditions = keras.ops.take(conditions, assignments, axis=0)
309312

310-
t = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
313+
u = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
314+
# p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform
315+
t = u ** (1 + self.alpha)
311316
t = expand_right_as(t, x0)
312317

313318
x = t * x1 + (1 - t) * x0

0 commit comments

Comments
 (0)