Skip to content

Commit 22e15d2

Browse files
authored
Various Improvements around Flow Matching (#351)
* make continuous approximator's build_adapter explicit * enable optimal transport by default * add subnet_kwargs explicitly * merge subnet kwargs * merge integrate and optimal transport kwargs * improve naming of time variable * add concatenate_valid and stack_valid * reorder functions * use is_symbolic_tensor in optimal transport * improve keras_kwargs * make fixed euler the default * fix type hint in BasicWorkflow * improve sampling (with explicit num_datasets for now) * fix flow matching default integrate config * fix concatenate_valid in continuous_time_consistency_model.py * add reason to summary network test skips when summary network is None * revert changes to `ContinuousApproximator.sample()`
1 parent 8426dd1 commit 22e15d2

File tree

13 files changed

+541
-908
lines changed

13 files changed

+541
-908
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,20 @@ def build_adapter(
6565
summary_variables : Sequence of str, optional
6666
Names of the summary variables in the data
6767
"""
68-
adapter = Adapter.create_default(inference_variables)
68+
adapter = Adapter()
69+
adapter.to_array()
70+
adapter.convert_dtype("float64", "float32")
71+
adapter.concatenate(inference_variables, into="inference_variables")
6972

7073
if inference_conditions is not None:
71-
adapter = adapter.concatenate(inference_conditions, into="inference_conditions")
74+
adapter.concatenate(inference_conditions, into="inference_conditions")
7275

7376
if summary_variables is not None:
74-
adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables")
77+
adapter.as_set(summary_variables)
78+
adapter.concatenate(summary_variables, into="summary_variables")
7579

76-
adapter = adapter.keep(["inference_variables", "inference_conditions", "summary_variables"]).standardize()
80+
adapter.keep(["inference_variables", "inference_conditions", "summary_variables"])
81+
adapter.standardize()
7782

7883
return adapter
7984

bayesflow/experimental/continuous_time_consistency_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from bayesflow.types import Tensor
1010
from bayesflow.utils import (
1111
jvp,
12-
concatenate,
12+
concatenate_valid,
1313
find_network,
1414
keras_kwargs,
1515
expand_right_as,
@@ -201,7 +201,7 @@ def consistency_function(
201201
**kwargs : dict, optional, default: {}
202202
Additional keyword arguments passed to the inner network.
203203
"""
204-
xtc = concatenate(x / self.sigma_data, self.time_emb(t), conditions, axis=-1)
204+
xtc = concatenate_valid([x / self.sigma_data, self.time_emb(t), conditions], axis=-1)
205205
f = self.subnet_projector(self.subnet(xtc, training=training, **kwargs))
206206
out = ops.cos(t) * x - ops.sin(t) * self.sigma_data * f
207207
return out
@@ -240,7 +240,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
240240
r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here
241241

242242
def f_teacher(x, t):
243-
o = self.subnet(concatenate(x, self.time_emb(t), conditions, axis=-1), training=stage == "training")
243+
o = self.subnet(concatenate_valid([x, self.time_emb(t), conditions], axis=-1), training=stage == "training")
244244
return self.subnet_projector(o)
245245

246246
primals = (xt / self.sigma_data, t)
@@ -254,7 +254,7 @@ def f_teacher(x, t):
254254
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)
255255

256256
# calculate output of the network
257-
xtc = concatenate(xt / self.sigma_data, self.time_emb(t), conditions, axis=-1)
257+
xtc = concatenate_valid([xt / self.sigma_data, self.time_emb(t), conditions], axis=-1)
258258
student_out = self.subnet_projector(self.subnet(xtc, training=stage == "training"))
259259

260260
# calculate the tangent

bayesflow/experimental/free_form_flow/free_form_flow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from bayesflow.utils import (
77
find_network,
88
keras_kwargs,
9-
concatenate,
9+
concatenate_valid,
1010
jacobian,
1111
jvp,
1212
vjp,
@@ -181,7 +181,7 @@ def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, *
181181
if conditions is None:
182182
inp = x
183183
else:
184-
inp = concatenate(x, conditions, axis=-1)
184+
inp = concatenate_valid([x, conditions], axis=-1)
185185
network_out = self.encoder_projector(
186186
self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
187187
)
@@ -191,7 +191,7 @@ def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, *
191191
if conditions is None:
192192
inp = z
193193
else:
194-
inp = concatenate(z, conditions, axis=-1)
194+
inp = concatenate_valid([z, conditions], axis=-1)
195195
network_out = self.decoder_projector(
196196
self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
197197
)

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,19 @@ class FlowMatching(InferenceNetwork):
4545
}
4646

4747
INTEGRATE_DEFAULT_CONFIG = {
48-
"method": "rk45",
49-
"steps": "adaptive",
50-
"tolerance": 1e-3,
51-
"min_steps": 10,
52-
"max_steps": 100,
48+
"method": "euler",
49+
"steps": 100,
5350
}
5451

5552
def __init__(
5653
self,
5754
subnet: str | type = "mlp",
5855
base_distribution: str = "normal",
59-
use_optimal_transport: bool = False,
56+
use_optimal_transport: bool = True,
6057
loss_fn: str = "mse",
6158
integrate_kwargs: dict[str, any] = None,
6259
optimal_transport_kwargs: dict[str, any] = None,
60+
subnet_kwargs: dict[str, any] = None,
6361
**kwargs,
6462
):
6563
"""
@@ -97,23 +95,17 @@ def __init__(
9795

9896
self.use_optimal_transport = use_optimal_transport
9997

100-
new_integrate_kwargs = FlowMatching.INTEGRATE_DEFAULT_CONFIG.copy()
101-
new_integrate_kwargs.update(integrate_kwargs or {})
102-
self.integrate_kwargs = new_integrate_kwargs
103-
104-
new_optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG.copy()
105-
new_optimal_transport_kwargs.update(optimal_transport_kwargs or {})
106-
self.optimal_transport_kwargs = new_optimal_transport_kwargs
98+
self.integrate_kwargs = FlowMatching.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
99+
self.optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})
107100

108101
self.loss_fn = keras.losses.get(loss_fn)
109102

110103
self.seed_generator = keras.random.SeedGenerator()
111104

105+
subnet_kwargs = subnet_kwargs or {}
106+
112107
if subnet == "mlp":
113-
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG.copy()
114-
subnet_kwargs.update(kwargs.get("subnet_kwargs", {}))
115-
else:
116-
subnet_kwargs = kwargs.get("subnet_kwargs", {})
108+
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG | subnet_kwargs
117109

118110
self.subnet = find_network(subnet, **subnet_kwargs)
119111
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")
@@ -154,23 +146,23 @@ def from_config(cls, config):
154146
config = deserialize_value_or_type(config, "subnet")
155147
return cls(**config)
156148

157-
def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
158-
t = keras.ops.convert_to_tensor(t)
159-
t = expand_right_as(t, xz)
160-
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,))
149+
def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
150+
time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz))
151+
time = expand_right_as(time, xz)
152+
time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,))
161153

162154
if conditions is None:
163-
xtc = keras.ops.concatenate([xz, t], axis=-1)
155+
xtc = keras.ops.concatenate([xz, time], axis=-1)
164156
else:
165-
xtc = keras.ops.concatenate([xz, t, conditions], axis=-1)
157+
xtc = keras.ops.concatenate([xz, time, conditions], axis=-1)
166158

167159
return self.output_projector(self.subnet(xtc, training=training), training=training)
168160

169161
def _velocity_trace(
170-
self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
162+
self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
171163
) -> (Tensor, Tensor):
172164
def f(x):
173-
return self.velocity(x, t, conditions=conditions, training=training)
165+
return self.velocity(x, time=time, conditions=conditions, training=training)
174166

175167
v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)
176168

@@ -181,8 +173,8 @@ def _forward(
181173
) -> Tensor | tuple[Tensor, Tensor]:
182174
if density:
183175

184-
def deltas(t, xz):
185-
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
176+
def deltas(time, xz):
177+
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
186178
return {"xz": v, "trace": trace}
187179

188180
state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
@@ -193,8 +185,8 @@ def deltas(t, xz):
193185

194186
return z, log_density
195187

196-
def deltas(t, xz):
197-
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}
188+
def deltas(time, xz):
189+
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
198190

199191
state = {"xz": x}
200192
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
@@ -208,8 +200,8 @@ def _inverse(
208200
) -> Tensor | tuple[Tensor, Tensor]:
209201
if density:
210202

211-
def deltas(t, xz):
212-
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
203+
def deltas(time, xz):
204+
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
213205
return {"xz": v, "trace": trace}
214206

215207
state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
@@ -220,8 +212,8 @@ def deltas(t, xz):
220212

221213
return x, log_density
222214

223-
def deltas(t, xz):
224-
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}
215+
def deltas(time, xz):
216+
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
225217

226218
state = {"xz": z}
227219
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
@@ -258,7 +250,7 @@ def compute_metrics(
258250

259251
base_metrics = super().compute_metrics(x1, conditions, stage)
260252

261-
predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")
253+
predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")
262254

263255
loss = self.loss_fn(target_velocity, predicted_velocity)
264256
loss = keras.ops.mean(loss)

bayesflow/utils/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
)
4949
from .serialization import serialize_value_or_type, deserialize_value_or_type
5050
from .tensor_utils import (
51-
concatenate,
51+
concatenate_valid,
5252
expand,
5353
expand_as,
5454
expand_to,
@@ -59,12 +59,13 @@
5959
expand_right_as,
6060
expand_right_to,
6161
expand_tile,
62+
pad,
63+
searchsorted,
6264
size_of,
65+
stack_valid,
6366
tile_axis,
6467
tree_concatenate,
6568
tree_stack,
66-
pad,
67-
searchsorted,
6869
)
6970
from .validators import check_lengths_same
7071
from .workflow_utils import find_inference_network, find_summary_network

bayesflow/utils/dict_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def convert_kwargs(f: Callable, *args: any, **kwargs: any) -> dict[str, any]:
4949
return parameters
5050

5151

52-
def filter_kwargs(kwargs: Mapping[str, T], f: Callable) -> Mapping[str, T]:
52+
def filter_kwargs(kwargs: dict[str, T], f: Callable) -> dict[str, T]:
5353
"""Filter keyword arguments for f"""
5454
signature = inspect.signature(f)
5555

@@ -63,11 +63,10 @@ def filter_kwargs(kwargs: Mapping[str, T], f: Callable) -> Mapping[str, T]:
6363
return kwargs
6464

6565

66-
def keras_kwargs(kwargs: Mapping[str, T]) -> dict[str, T]:
67-
"""Keep dictionary keys that do not end with _kwargs. Used for propagating
68-
keyword arguments in nested layer classes.
69-
"""
70-
return {key: value for key, value in kwargs.items() if not key.endswith("_kwargs")}
66+
def keras_kwargs(kwargs: dict[str, T]) -> dict[str, T]:
67+
"""Filter keyword arguments for keras.Layer"""
68+
valid_keys = ["dtype", "name", "trainable"]
69+
return {key: value for key, value in kwargs.items() if key in valid_keys}
7170

7271

7372
# TODO: rename and streamline and make protected

bayesflow/utils/integrate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import keras
55

6+
from typing import Literal
7+
68
from bayesflow.types import Tensor
79
from bayesflow.utils import filter_kwargs
810
from . import logging
@@ -238,8 +240,8 @@ def integrate(
238240
stop_time: ArrayLike,
239241
min_steps: int = 10,
240242
max_steps: int = 10_000,
241-
steps: int = "adaptive",
242-
method: str = "rk45",
243+
steps: int | Literal["adaptive"] = 100,
244+
method: str = "euler",
243245
**kwargs,
244246
) -> dict[str, ArrayLike]:
245247
match steps:

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..dispatch import find_cost
77
from .. import logging
88
from ..numpy_utils import softmax
9+
from ..tensor_utils import is_symbolic_tensor
910

1011

1112
def sinkhorn(
@@ -134,8 +135,8 @@ def sinkhorn_indices(
134135
rng = np.random.default_rng(seed)
135136

136137
indices = []
137-
for row in range(cost.shape[0]):
138-
index = rng.choice(cost.shape[1], p=plan[row])
138+
for row in range(plan.shape[0]):
139+
index = rng.choice(plan.shape[1], p=plan[row])
139140
indices.append(index)
140141

141142
indices = np.array(indices)
@@ -190,6 +191,9 @@ def sinkhorn_plan_keras(cost: Tensor, regularization: float, max_steps: int, tol
190191
# initialize the transport plan from a gaussian kernel
191192
plan = keras.ops.exp(-0.5 * cost / regularization)
192193

194+
if is_symbolic_tensor(plan):
195+
return plan
196+
193197
def is_converged(plan):
194198
# check convergence: the plan should be doubly stochastic
195199
marginals = keras.ops.sum(plan, axis=0), keras.ops.sum(plan, axis=1)

0 commit comments

Comments
 (0)