File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed
bayesflow/networks/flow_matching Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -107,6 +107,9 @@ def __init__(
107107 self .optimal_transport_kwargs = FlowMatching .OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})
108108
109109 self .loss_fn = keras .losses .get (loss_fn )
110+ self .time_sampling_alpha = kwargs .pop ("time_sampling_alpha" , 0.0 ) # 0 is uniform, <0 favors smaller t
111+ if self .time_sampling_alpha <= - 1.0 :
112+ raise ValueError ("'time_sampling_alpha' must be greater than -1.0." )
110113
111114 self .seed_generator = keras .random .SeedGenerator ()
112115
@@ -307,7 +310,9 @@ def compute_metrics(
307310 # conditions must be resampled along with x1
308311 conditions = keras .ops .take (conditions , assignments , axis = 0 )
309312
310- t = keras .random .uniform ((keras .ops .shape (x0 )[0 ],), seed = self .seed_generator )
313+ u = keras .random .uniform ((keras .ops .shape (x0 )[0 ],), seed = self .seed_generator )
314+ # p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform
315+ t = u ** (1 + self .alpha )
311316 t = expand_right_as (t , x0 )
312317
313318 x = t * x1 + (1 - t ) * x0
You can’t perform that action at this time.
0 commit comments