Skip to content

Commit 6b9671b

Browse files
committed
Merge branch 'dev' into compositional_sampling_diffusion
2 parents 5601d20 + bc2bda8 commit 6b9671b

File tree

5 files changed

+114
-98
lines changed

5 files changed

+114
-98
lines changed

.github/workflows/tests.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
name: Multi-Backend Tests
32

43
on:
@@ -16,15 +15,14 @@ defaults:
1615
run:
1716
shell: bash
1817

19-
2018
jobs:
2119
test:
2220
name: Run Multi-Backend Tests
2321

2422
strategy:
2523
matrix:
2624
os: [ubuntu-latest, windows-latest]
27-
python-version: ["3.10"] # we usually only need to test the oldest python version
25+
python-version: ["3.10"] # we usually only need to test the oldest python version
2826
backend: ["jax", "tensorflow", "torch"]
2927

3028
runs-on: ${{ matrix.os }}

CHANGELOG.rst

Lines changed: 0 additions & 92 deletions
This file was deleted.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ complex to be described analytically.
5151

5252
## Install
5353

54-
We currently support Python 3.10 to 3.12. You can install the latest stable version from PyPI using:
54+
We currently support Python 3.10 to 3.13. You can install the latest stable version from PyPI using:
5555

5656
```bash
5757
pip install "bayesflow>=2.0"

bayesflow/simulators/sequential_simulator.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,112 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7070
}
7171

7272
return data
73+
74+
def _single_sample(self, batch_shape_ext, **kwargs) -> dict[str, np.ndarray]:
75+
"""
76+
For single sample used by parallel sampling.
77+
78+
Parameters
79+
----------
80+
**kwargs
81+
Keyword arguments passed to simulators.
82+
83+
Returns
84+
-------
85+
dict
86+
Single sample result.
87+
"""
88+
return self.sample(batch_shape=(1, *tuple(batch_shape_ext)), **kwargs)
89+
90+
def sample_parallel(
91+
self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 0, **kwargs
92+
) -> dict[str, np.ndarray]:
93+
"""
94+
Sample in parallel from the sequential simulator.
95+
96+
Parameters
97+
----------
98+
batch_shape : Shape
99+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
100+
but it also accepts an int.
101+
n_jobs : int, optional
102+
Number of parallel jobs. -1 uses all available cores. Default is -1.
103+
verbose : int, optional
104+
Verbosity level for joblib. Default is 0 (no output).
105+
**kwargs
106+
Additional keyword arguments passed to each simulator. These may include previously
107+
sampled outputs used as inputs for subsequent simulators.
108+
109+
Returns
110+
-------
111+
data : dict of str to np.ndarray
112+
A dictionary containing the combined outputs from all simulators. Keys are output names
113+
and values are sampled arrays. If `expand_outputs` is True, 1D arrays are expanded to
114+
have shape (..., 1).
115+
"""
116+
try:
117+
from joblib import Parallel, delayed
118+
except ImportError as e:
119+
raise ImportError(
120+
"joblib is required for parallel sampling. Please install it via 'pip install joblib'."
121+
) from e
122+
123+
# normalize batch shape to a tuple
124+
if isinstance(batch_shape, int):
125+
bs = (batch_shape,)
126+
else:
127+
bs = tuple(batch_shape)
128+
if len(bs) == 0:
129+
raise ValueError("batch_shape must be a positive integer or a nonempty tuple")
130+
131+
results = Parallel(n_jobs=n_jobs, verbose=verbose)(
132+
delayed(self._single_sample)(batch_shape_ext=bs[1:], **kwargs) for _ in range(bs[0])
133+
)
134+
return self._combine_results(results)
135+
136+
@staticmethod
137+
def _combine_results(results: list[dict]) -> dict[str, np.ndarray]:
138+
"""
139+
Combine a list of single-sample results into arrays.
140+
141+
Parameters
142+
----------
143+
results : list of dict
144+
List of dictionaries from individual samples.
145+
146+
Returns
147+
-------
148+
dict
149+
Combined results with arrays.
150+
"""
151+
if not results:
152+
return {}
153+
154+
# union of all keys across results
155+
all_keys = set()
156+
for r in results:
157+
all_keys.update(r.keys())
158+
159+
combined_data: dict[str, np.ndarray] = {}
160+
161+
for key in all_keys:
162+
values = []
163+
for result in results:
164+
if key in result:
165+
value = result[key]
166+
if isinstance(value, np.ndarray) and value.shape[:1] == (1,):
167+
values.append(value[0])
168+
else:
169+
values.append(value)
170+
else:
171+
values.append(None)
172+
173+
try:
174+
if all(isinstance(v, np.ndarray) for v in values):
175+
combined_data[key] = np.stack(values, axis=0)
176+
else:
177+
combined_data[key] = np.array(values, dtype=object)
178+
except ValueError:
179+
combined_data[key] = np.array(values, dtype=object)
180+
181+
return combined_data

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@ classifiers = [
1717
"Programming Language :: Python :: 3.10",
1818
"Programming Language :: Python :: 3.11",
1919
"Programming Language :: Python :: 3.12",
20+
"Programming Language :: Python :: 3.13",
2021
"Topic :: Scientific/Engineering :: Artificial Intelligence",
2122
]
2223
description = "Amortizing Bayesian Inference With Neural Networks"
2324
readme = { file = "README.md", content-type = "text/markdown" }
2425
license = { file = "LICENSE" }
2526

26-
requires-python = ">= 3.10, < 3.13"
27+
requires-python = ">= 3.10, < 3.14"
2728
dependencies = [
2829
"keras >= 3.9",
2930
"matplotlib",
30-
"numpy >= 1.24, <2.0",
31+
"numpy >= 1.24",
3132
"pandas",
3233
"scipy",
3334
"seaborn",

0 commit comments

Comments
 (0)