Skip to content

Commit 4a93d21

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into fix-mvnorm-stability
2 parents 9fb0e1b + fb3191b commit 4a93d21

File tree

8 files changed

+361
-80
lines changed

8 files changed

+361
-80
lines changed

bayesflow/__init__.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,64 @@
1-
from . import (
2-
approximators,
3-
adapters,
4-
augmentations,
5-
datasets,
6-
diagnostics,
7-
distributions,
8-
experimental,
9-
networks,
10-
simulators,
11-
utils,
12-
workflows,
13-
wrappers,
14-
)
15-
16-
from .adapters import Adapter
17-
from .approximators import ContinuousApproximator, PointApproximator
18-
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
19-
from .simulators import make_simulator
20-
from .workflows import BasicWorkflow
1+
# ruff: noqa: E402
2+
# disable E402 to allow for setup code before importing any internals (which could import keras)
213

224

235
def setup():
246
# perform any necessary setup without polluting the namespace
7+
import os
8+
from importlib.util import find_spec
9+
10+
issue_url = "https://github.com/bayesflow-org/bayesflow/issues/new?template=bug_report.md"
11+
12+
if "KERAS_BACKEND" not in os.environ:
13+
# check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
14+
class Backend:
15+
def __init__(self, display_name, package_name, env_name, install_url, priority):
16+
self.display_name = display_name
17+
self.package_name = package_name
18+
self.env_name = env_name
19+
self.install_url = install_url
20+
self.priority = priority
21+
22+
backends = [
23+
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation", 0),
24+
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/", 1),
25+
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install", 2),
26+
]
27+
28+
found_backends = []
29+
for backend in backends:
30+
if find_spec(backend.package_name) is not None:
31+
found_backends.append(backend)
32+
33+
if not found_backends:
34+
message = "No suitable backend found. Please install one of the following:\n"
35+
for backend in backends:
36+
message += f"{backend.display_name}\n"
37+
message += "\n"
38+
39+
message += f"If you continue to see this error, please file a bug report at {issue_url}.\n"
40+
message += (
41+
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
42+
)
43+
message += "https://keras.io/getting_started/#configuring-your-backend"
44+
45+
raise ImportError(message)
46+
47+
if len(found_backends) > 1:
48+
import warnings
49+
50+
found_backends.sort(key=lambda b: b.priority)
51+
chosen_backend = found_backends[0]
52+
53+
warnings.warn(
54+
f"Multiple Keras-compatible backends detected ({', '.join(b.display_name for b in found_backends)}).\n"
55+
f"Defaulting to {chosen_backend.display_name}.\n"
56+
"To override, set the KERAS_BACKEND environment variable before importing bayesflow.\n"
57+
"See: https://keras.io/getting_started/#configuring-your-backend"
58+
)
59+
else:
60+
os.environ["KERAS_BACKEND"] = found_backends[0].env_name
61+
2562
import keras
2663
import logging
2764

@@ -60,3 +97,24 @@ def setup():
6097
# call and clean up namespace
6198
setup()
6299
del setup
100+
101+
from . import (
102+
approximators,
103+
adapters,
104+
augmentations,
105+
datasets,
106+
diagnostics,
107+
distributions,
108+
experimental,
109+
networks,
110+
simulators,
111+
utils,
112+
workflows,
113+
wrappers,
114+
)
115+
116+
from .adapters import Adapter
117+
from .approximators import ContinuousApproximator, PointApproximator
118+
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
119+
from .simulators import make_simulator
120+
from .workflows import BasicWorkflow

bayesflow/links/positive_definite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def compute_input_shape(self, output_shape):
4343
4444
There are m nonzero elements of a lower triangular nxn matrix with m = n * (n + 1) / 2.
4545
46-
Example
47-
-------
46+
Examples
47+
--------
4848
>>> PositiveDefinite().compute_output_shape((None, 3, 3))
4949
6
5050
"""

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
7+
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as
88
from bayesflow.utils.serialization import deserialize, serializable, serialize
99

1010
from ..inference_network import InferenceNetwork
@@ -67,6 +67,11 @@ def __init__(
6767
Final number of discretization steps
6868
subnet_kwargs: dict[str, any], optional
6969
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
70+
concatenate_subnet_input: bool, optional
71+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
72+
into a single vector or passed as separate arguments. If set to False, the subnet
73+
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
74+
and optional 'conditions'. Default is True.
7075
**kwargs : dict, optional, default: {}
7176
Additional keyword arguments
7277
"""
@@ -77,6 +82,7 @@ def __init__(
7782
subnet_kwargs = subnet_kwargs or {}
7883
if subnet == "mlp":
7984
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
85+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
8086

8187
self.subnet = find_network(subnet, **subnet_kwargs)
8288
self.output_projector = keras.layers.Dense(
@@ -119,6 +125,7 @@ def get_config(self):
119125
"eps": self.eps,
120126
"s0": self.s0,
121127
"s1": self.s1,
128+
"concatenate_subnet_input": self._concatenate_subnet_input,
122129
# we do not need to store subnet_kwargs
123130
}
124131

@@ -161,18 +168,23 @@ def build(self, xz_shape, conditions_shape=None):
161168

162169
input_shape = list(xz_shape)
163170

164-
# time vector
165-
input_shape[-1] += 1
171+
if self._concatenate_subnet_input:
172+
# construct time vector
173+
input_shape[-1] += 1
174+
if conditions_shape is not None:
175+
input_shape[-1] += conditions_shape[-1]
176+
input_shape = tuple(input_shape)
166177

167-
if conditions_shape is not None:
168-
input_shape[-1] += conditions_shape[-1]
169-
170-
input_shape = tuple(input_shape)
171-
172-
self.subnet.build(input_shape)
173-
174-
input_shape = self.subnet.compute_output_shape(input_shape)
175-
self.output_projector.build(input_shape)
178+
self.subnet.build(input_shape)
179+
out_shape = self.subnet.compute_output_shape(input_shape)
180+
else:
181+
# Multiple separate inputs
182+
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
183+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
184+
out_shape = self.subnet.compute_output_shape(
185+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
186+
)
187+
self.output_projector.build(out_shape)
176188

177189
# Choose coefficient according to [2] Section 3.3
178190
self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1])
@@ -256,6 +268,35 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
256268
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
257269
return x
258270

271+
def _apply_subnet(
272+
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
273+
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
274+
"""
275+
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
276+
the time `t`, and optional conditions or by returning them separately.
277+
278+
Parameters
279+
----------
280+
x : Tensor
281+
The parameter tensor, typically of shape (..., D), but can vary.
282+
t : Tensor
283+
The time tensor, typically of shape (..., 1).
284+
conditions : Tensor, optional
285+
The optional conditioning tensor (e.g. parameters).
286+
training : bool, optional
287+
The training mode flag, which can be used to control behavior during training.
288+
289+
Returns
290+
-------
291+
Tensor
292+
The output tensor from the subnet.
293+
"""
294+
if self._concatenate_subnet_input:
295+
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
296+
return self.subnet(xtc, training=training)
297+
else:
298+
return self.subnet(x=x, t=t, conditions=conditions, training=training)
299+
259300
def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
260301
"""Compute consistency function.
261302
@@ -271,12 +312,8 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
271312
Whether internal layers (e.g., dropout) should behave in train or inference mode.
272313
"""
273314

274-
if conditions is not None:
275-
xtc = ops.concatenate([x, t, conditions], axis=-1)
276-
else:
277-
xtc = ops.concatenate([x, t], axis=-1)
278-
279-
f = self.output_projector(self.subnet(xtc, training=training))
315+
subnet_out = self._apply_subnet(x, t, conditions, training=training)
316+
f = self.output_projector(subnet_out)
280317

281318
# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
282319
# Thus, we can do a cross product with the time vector which is (batch_size, 1) for
@@ -316,8 +353,8 @@ def compute_metrics(
316353

317354
log_p = ops.log(p)
318355
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
319-
t1 = ops.take(discretized_time, times)[..., None]
320-
t2 = ops.take(discretized_time, times + 1)[..., None]
356+
t1 = expand_right_as(ops.take(discretized_time, times), x)
357+
t2 = expand_right_as(ops.take(discretized_time, times + 1), x)
321358

322359
# generate noise vector
323360
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def __init__(
8585
Additional keyword arguments passed to the noise schedule constructor. Default is None.
8686
integrate_kwargs : dict[str, any], optional
8787
Configuration dictionary for integration during training or inference. Default is None.
88+
concatenate_subnet_input: bool, optional
89+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
90+
into a single vector or passed as separate arguments. If set to False, the subnet
91+
must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
92+
and optional 'conditions'. Default is True.
93+
8894
**kwargs
8995
Additional keyword arguments passed to the base class and internal components.
9096
"""
@@ -116,6 +122,7 @@ def __init__(
116122
if subnet == "mlp":
117123
subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs
118124
self.subnet = find_network(subnet, **subnet_kwargs)
125+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
119126

120127
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
121128

@@ -128,15 +135,23 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
128135
self.output_projector.units = xz_shape[-1]
129136
input_shape = list(xz_shape)
130137

131-
# construct time vector
132-
input_shape[-1] += 1
133-
if conditions_shape is not None:
134-
input_shape[-1] += conditions_shape[-1]
138+
if self._concatenate_subnet_input:
139+
# construct time vector
140+
input_shape[-1] += 1
141+
if conditions_shape is not None:
142+
input_shape[-1] += conditions_shape[-1]
143+
input_shape = tuple(input_shape)
135144

136-
input_shape = tuple(input_shape)
145+
self.subnet.build(input_shape)
146+
out_shape = self.subnet.compute_output_shape(input_shape)
147+
else:
148+
# Multiple separate inputs
149+
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
150+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
151+
out_shape = self.subnet.compute_output_shape(
152+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
153+
)
137154

138-
self.subnet.build(input_shape)
139-
out_shape = self.subnet.compute_output_shape(input_shape)
140155
self.output_projector.build(out_shape)
141156

142157
def get_config(self):
@@ -149,6 +164,8 @@ def get_config(self):
149164
"prediction_type": self._prediction_type,
150165
"loss_type": self._loss_type,
151166
"integrate_kwargs": self.integrate_kwargs,
167+
"concatenate_subnet_input": self._concatenate_subnet_input,
168+
# we do not need to store subnet_kwargs
152169
}
153170
return base_config | serialize(config)
154171

@@ -197,6 +214,35 @@ def convert_prediction_to_x(
197214
return (z + sigma_t**2 * pred) / alpha_t
198215
raise ValueError(f"Unknown prediction type {self._prediction_type}.")
199216

217+
def _apply_subnet(
218+
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False
219+
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
220+
"""
221+
Prepares and passes the input to the subnet either by concatenating the latent variable `xz`,
222+
the signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.
223+
224+
Parameters
225+
----------
226+
xz : Tensor
227+
The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary.
228+
log_snr : Tensor
229+
The log signal-to-noise ratio tensor, typically of shape (..., 1).
230+
conditions : Tensor, optional
231+
The optional conditioning tensor (e.g. parameters).
232+
training : bool, optional
233+
The training mode flag, which can be used to control behavior during training.
234+
235+
Returns
236+
-------
237+
Tensor
238+
The output tensor from the subnet.
239+
"""
240+
if self._concatenate_subnet_input:
241+
xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)
242+
return self.subnet(xtc, training=training)
243+
else:
244+
return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training)
245+
200246
def velocity(
201247
self,
202248
xz: Tensor,
@@ -221,7 +267,7 @@ def velocity(
221267
If True, computes the velocity for the stochastic formulation (SDE).
222268
If False, uses the deterministic formulation (ODE).
223269
conditions : Tensor, optional
224-
Optional conditional inputs to the network, such as conditioning variables
270+
Conditional inputs to the network, such as conditioning variables
225271
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
226272
training : bool, optional
227273
Whether the model is in training mode. Affects behavior of dropout, batch norm,
@@ -238,12 +284,10 @@ def velocity(
238284
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
239285
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
240286

241-
if conditions is None:
242-
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1)
243-
else:
244-
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1)
245-
246-
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
287+
subnet_out = self._apply_subnet(
288+
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
289+
)
290+
pred = self.output_projector(subnet_out, training=training)
247291

248292
x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)
249293

@@ -461,11 +505,10 @@ def compute_metrics(
461505
diffused_x = alpha_t * x + sigma_t * eps_t
462506

463507
# calculate output of the network
464-
if conditions is None:
465-
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1)
466-
else:
467-
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1)
468-
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
508+
subnet_out = self._apply_subnet(
509+
diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
510+
)
511+
pred = self.output_projector(subnet_out, training=training)
469512

470513
x_pred = self.convert_prediction_to_x(
471514
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t

0 commit comments

Comments
 (0)