@@ -96,6 +96,9 @@ def __init__(
9696 into a single vector or passed as separate arguments. If set to False, the subnet
9797 must accept three separate inputs: 'x' (noisy parameters), 't' (time),
9898 and optional 'conditions'. Default is True.
99+ time_power_law_alpha: float, optional
100+ Change the distribution of sampled times during training. Time is sampled from a power law distribution
101+ p(t) ∝ t^(1/(1+α)), where α is the provided value. Default is α=0, which corresponds to uniform sampling.
99102 **kwargs
100103 Additional keyword arguments passed to the subnet and other components.
101104 """
@@ -107,9 +110,9 @@ def __init__(
107110 self .optimal_transport_kwargs = FlowMatching .OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})
108111
109112 self .loss_fn = keras .losses .get (loss_fn )
110- self .time_sampling_alpha = kwargs .pop ("time_sampling_alpha " , 0.0 ) # 0 is uniform, <0 favors smaller t
111- if self .time_sampling_alpha <= - 1.0 :
112- raise ValueError ("'time_sampling_alpha ' must be greater than -1.0." )
113+ self .time_power_law_alpha = float ( kwargs .pop ("time_power_law_alpha " , 0.0 ) ) # 0 is uniform, <0 favors smaller t
114+ if self .time_power_law_alpha <= - 1.0 :
115+ raise ValueError ("'time_power_law_alpha ' must be greater than -1.0." )
113116
114117 self .seed_generator = keras .random .SeedGenerator ()
115118
@@ -167,6 +170,7 @@ def get_config(self):
167170 "integrate_kwargs" : self .integrate_kwargs ,
168171 "optimal_transport_kwargs" : self .optimal_transport_kwargs ,
169172 "concatenate_subnet_input" : self ._concatenate_subnet_input ,
173+ "time_power_law_alpha" : self .time_power_law_alpha ,
170174 # we do not need to store subnet_kwargs
171175 }
172176
@@ -312,7 +316,7 @@ def compute_metrics(
312316
313317 u = keras .random .uniform ((keras .ops .shape (x0 )[0 ],), seed = self .seed_generator )
314318 # p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform
315- t = u ** (1 + self .time_sampling_alpha )
319+ t = u ** (1 + self .time_power_law_alpha )
316320 t = expand_right_as (t , x0 )
317321
318322 x = t * x1 + (1 - t ) * x0
0 commit comments