Skip to content

Commit 9cd880f

Browse files
committed
add comments
1 parent 1cf070f commit 9cd880f

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def __init__(
9696
into a single vector or passed as separate arguments. If set to False, the subnet
9797
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
9898
and optional 'conditions'. Default is True.
99+
time_power_law_alpha: float, optional
100+
Change the distribution of sampled times during training. Time is sampled from a power law distribution
101+
p(t) ∝ t^(1/(1+α)), where α is the provided value. Default is α=0, which corresponds to uniform sampling.
99102
**kwargs
100103
Additional keyword arguments passed to the subnet and other components.
101104
"""
@@ -107,9 +110,9 @@ def __init__(
107110
self.optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})
108111

109112
self.loss_fn = keras.losses.get(loss_fn)
110-
self.time_sampling_alpha = kwargs.pop("time_sampling_alpha", 0.0) # 0 is uniform, <0 favors smaller t
111-
if self.time_sampling_alpha <= -1.0:
112-
raise ValueError("'time_sampling_alpha' must be greater than -1.0.")
113+
self.time_power_law_alpha = float(kwargs.pop("time_power_law_alpha", 0.0)) # 0 is uniform, <0 favors smaller t
114+
if self.time_power_law_alpha <= -1.0:
115+
raise ValueError("'time_power_law_alpha' must be greater than -1.0.")
113116

114117
self.seed_generator = keras.random.SeedGenerator()
115118

@@ -167,6 +170,7 @@ def get_config(self):
167170
"integrate_kwargs": self.integrate_kwargs,
168171
"optimal_transport_kwargs": self.optimal_transport_kwargs,
169172
"concatenate_subnet_input": self._concatenate_subnet_input,
173+
"time_power_law_alpha": self.time_power_law_alpha,
170174
# we do not need to store subnet_kwargs
171175
}
172176

@@ -312,7 +316,7 @@ def compute_metrics(
312316

313317
u = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
314318
# p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform
315-
t = u ** (1 + self.time_sampling_alpha)
319+
t = u ** (1 + self.time_power_law_alpha)
316320
t = expand_right_as(t, x0)
317321

318322
x = t * x1 + (1 - t) * x0

0 commit comments

Comments
 (0)