Skip to content

Commit a9ba9e6

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into docs-spatial-data
2 parents 5fc65c9 + 58ad41b commit a9ba9e6

File tree

23 files changed

+768
-202
lines changed

23 files changed

+768
-202
lines changed

bayesflow/__init__.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,64 @@
1-
from . import (
2-
approximators,
3-
adapters,
4-
augmentations,
5-
datasets,
6-
diagnostics,
7-
distributions,
8-
experimental,
9-
networks,
10-
simulators,
11-
utils,
12-
workflows,
13-
wrappers,
14-
)
15-
16-
from .adapters import Adapter
17-
from .approximators import ContinuousApproximator, PointApproximator
18-
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
19-
from .simulators import make_simulator
20-
from .workflows import BasicWorkflow
1+
# ruff: noqa: E402
2+
# disable E402 to allow for setup code before importing any internals (which could import keras)
213

224

235
def setup():
246
# perform any necessary setup without polluting the namespace
7+
import os
8+
from importlib.util import find_spec
9+
10+
issue_url = "https://github.com/bayesflow-org/bayesflow/issues/new?template=bug_report.md"
11+
12+
if "KERAS_BACKEND" not in os.environ:
13+
# check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
14+
class Backend:
15+
def __init__(self, display_name, package_name, env_name, install_url, priority):
16+
self.display_name = display_name
17+
self.package_name = package_name
18+
self.env_name = env_name
19+
self.install_url = install_url
20+
self.priority = priority
21+
22+
backends = [
23+
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation", 0),
24+
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/", 1),
25+
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install", 2),
26+
]
27+
28+
found_backends = []
29+
for backend in backends:
30+
if find_spec(backend.package_name) is not None:
31+
found_backends.append(backend)
32+
33+
if not found_backends:
34+
message = "No suitable backend found. Please install one of the following:\n"
35+
for backend in backends:
36+
message += f"{backend.display_name}\n"
37+
message += "\n"
38+
39+
message += f"If you continue to see this error, please file a bug report at {issue_url}.\n"
40+
message += (
41+
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
42+
)
43+
message += "https://keras.io/getting_started/#configuring-your-backend"
44+
45+
raise ImportError(message)
46+
47+
if len(found_backends) > 1:
48+
import warnings
49+
50+
found_backends.sort(key=lambda b: b.priority)
51+
chosen_backend = found_backends[0]
52+
53+
warnings.warn(
54+
f"Multiple Keras-compatible backends detected ({', '.join(b.display_name for b in found_backends)}).\n"
55+
f"Defaulting to {chosen_backend.display_name}.\n"
56+
"To override, set the KERAS_BACKEND environment variable before importing bayesflow.\n"
57+
"See: https://keras.io/getting_started/#configuring-your-backend"
58+
)
59+
else:
60+
os.environ["KERAS_BACKEND"] = found_backends[0].env_name
61+
2562
import keras
2663
import logging
2764

@@ -60,3 +97,24 @@ def setup():
6097
# call and clean up namespace
6198
setup()
6299
del setup
100+
101+
from . import (
102+
approximators,
103+
adapters,
104+
augmentations,
105+
datasets,
106+
diagnostics,
107+
distributions,
108+
experimental,
109+
networks,
110+
simulators,
111+
utils,
112+
workflows,
113+
wrappers,
114+
)
115+
116+
from .adapters import Adapter
117+
from .approximators import ContinuousApproximator, PointApproximator
118+
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
119+
from .simulators import make_simulator
120+
from .workflows import BasicWorkflow

bayesflow/approximators/continuous_approximator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def _sample(
537537
)
538538
batch_shape = keras.ops.shape(inference_conditions)[:-1]
539539
else:
540-
batch_shape = keras.ops.shape(inference_conditions)[1:-1]
540+
batch_shape = (num_samples,)
541541

542542
return self.inference_network.sample(
543543
batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample)

bayesflow/approximators/point_approximator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,7 @@ def sample(
143143

144144
return samples
145145

146-
def log_prob(
147-
self,
148-
*,
149-
data: Mapping[str, np.ndarray],
150-
**kwargs,
151-
) -> np.ndarray | dict[str, np.ndarray]:
146+
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
152147
"""
153148
Computes the log-probability of given data under the parametric distribution(s) for given input conditions.
154149

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,11 @@ def calibration_ecdf(
144144
tq_targets = test_quantity_fn(data=targets)
145145
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)
146146

147-
# # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
147+
# Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
148148
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
149-
flattened_estimates = keras.tree.map_structure(lambda t: np.reshape(t, (-1, *t.shape[2:])), estimates)
149+
flattened_estimates = keras.tree.map_structure(
150+
lambda t: np.reshape(t, (num_conditions * num_samples, *t.shape[2:])), estimates
151+
)
150152
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
151153
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))
152154

bayesflow/links/positive_definite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def compute_input_shape(self, output_shape):
4343
4444
There are m nonzero elements of a lower triangular nxn matrix with m = n * (n + 1) / 2.
4545
46-
Example
47-
-------
46+
Examples
47+
--------
4848
>>> PositiveDefinite().compute_output_shape((None, 3, 3))
4949
6
5050
"""

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
7+
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as
88
from bayesflow.utils.serialization import deserialize, serializable, serialize
99

1010
from ..inference_network import InferenceNetwork
@@ -67,6 +67,11 @@ def __init__(
6767
Final number of discretization steps
6868
subnet_kwargs: dict[str, any], optional
6969
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
70+
concatenate_subnet_input: bool, optional
71+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
72+
into a single vector or passed as separate arguments. If set to False, the subnet
73+
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
74+
and optional 'conditions'. Default is True.
7075
**kwargs : dict, optional, default: {}
7176
Additional keyword arguments
7277
"""
@@ -77,6 +82,7 @@ def __init__(
7782
subnet_kwargs = subnet_kwargs or {}
7883
if subnet == "mlp":
7984
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
85+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
8086

8187
self.subnet = find_network(subnet, **subnet_kwargs)
8288
self.output_projector = keras.layers.Dense(
@@ -119,6 +125,7 @@ def get_config(self):
119125
"eps": self.eps,
120126
"s0": self.s0,
121127
"s1": self.s1,
128+
"concatenate_subnet_input": self._concatenate_subnet_input,
122129
# we do not need to store subnet_kwargs
123130
}
124131

@@ -161,18 +168,23 @@ def build(self, xz_shape, conditions_shape=None):
161168

162169
input_shape = list(xz_shape)
163170

164-
# time vector
165-
input_shape[-1] += 1
171+
if self._concatenate_subnet_input:
172+
# construct time vector
173+
input_shape[-1] += 1
174+
if conditions_shape is not None:
175+
input_shape[-1] += conditions_shape[-1]
176+
input_shape = tuple(input_shape)
166177

167-
if conditions_shape is not None:
168-
input_shape[-1] += conditions_shape[-1]
169-
170-
input_shape = tuple(input_shape)
171-
172-
self.subnet.build(input_shape)
173-
174-
input_shape = self.subnet.compute_output_shape(input_shape)
175-
self.output_projector.build(input_shape)
178+
self.subnet.build(input_shape)
179+
out_shape = self.subnet.compute_output_shape(input_shape)
180+
else:
181+
# Multiple separate inputs
182+
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
183+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
184+
out_shape = self.subnet.compute_output_shape(
185+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
186+
)
187+
self.output_projector.build(out_shape)
176188

177189
# Choose coefficient according to [2] Section 3.3
178190
self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1])
@@ -256,6 +268,35 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
256268
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
257269
return x
258270

271+
def _apply_subnet(
272+
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
273+
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
274+
"""
275+
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
276+
the time `t`, and optional conditions or by returning them separately.
277+
278+
Parameters
279+
----------
280+
x : Tensor
281+
The parameter tensor, typically of shape (..., D), but can vary.
282+
t : Tensor
283+
The time tensor, typically of shape (..., 1).
284+
conditions : Tensor, optional
285+
The optional conditioning tensor (e.g. parameters).
286+
training : bool, optional
287+
The training mode flag, which can be used to control behavior during training.
288+
289+
Returns
290+
-------
291+
Tensor
292+
The output tensor from the subnet.
293+
"""
294+
if self._concatenate_subnet_input:
295+
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
296+
return self.subnet(xtc, training=training)
297+
else:
298+
return self.subnet(x=x, t=t, conditions=conditions, training=training)
299+
259300
def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
260301
"""Compute consistency function.
261302
@@ -271,12 +312,8 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
271312
Whether internal layers (e.g., dropout) should behave in train or inference mode.
272313
"""
273314

274-
if conditions is not None:
275-
xtc = ops.concatenate([x, t, conditions], axis=-1)
276-
else:
277-
xtc = ops.concatenate([x, t], axis=-1)
278-
279-
f = self.output_projector(self.subnet(xtc, training=training))
315+
subnet_out = self._apply_subnet(x, t, conditions, training=training)
316+
f = self.output_projector(subnet_out)
280317

281318
# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
282319
# Thus, we can do a cross product with the time vector which is (batch_size, 1) for
@@ -316,8 +353,8 @@ def compute_metrics(
316353

317354
log_p = ops.log(p)
318355
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
319-
t1 = ops.take(discretized_time, times)[..., None]
320-
t2 = ops.take(discretized_time, times + 1)[..., None]
356+
t1 = expand_right_as(ops.take(discretized_time, times), x)
357+
t2 = expand_right_as(ops.take(discretized_time, times + 1), x)
321358

322359
# generate noise vector
323360
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)

0 commit comments

Comments
 (0)