Skip to content

Commit 5fd6c55

Browse files
committed
add parallel simulator
1 parent c21fa4a commit 5fd6c55

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

bayesflow/simulators/sequential_simulator.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,112 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7070
}
7171

7272
return data
73+
74+
def _single_sample(self, batch_shape_ext, **kwargs) -> dict[str, np.ndarray]:
75+
"""
76+
For single sample used by parallel sampling.
77+
78+
Parameters
79+
----------
80+
**kwargs
81+
Keyword arguments passed to simulators.
82+
83+
Returns
84+
-------
85+
dict
86+
Single sample result.
87+
"""
88+
return self.sample(batch_shape=(1, *tuple(batch_shape_ext)), **kwargs)
89+
90+
def sample_parallel(
91+
self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 0, **kwargs
92+
) -> dict[str, np.ndarray]:
93+
"""
94+
Sample in parallel from the sequential simulator.
95+
96+
Parameters
97+
----------
98+
batch_shape : Shape
99+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
100+
but it also accepts an int.
101+
n_jobs : int, optional
102+
Number of parallel jobs. -1 uses all available cores. Default is -1.
103+
verbose : int, optional
104+
Verbosity level for joblib. Default is 0 (no output).
105+
**kwargs
106+
Additional keyword arguments passed to each simulator. These may include previously
107+
sampled outputs used as inputs for subsequent simulators.
108+
109+
Returns
110+
-------
111+
data : dict of str to np.ndarray
112+
A dictionary containing the combined outputs from all simulators. Keys are output names
113+
and values are sampled arrays. If `expand_outputs` is True, 1D arrays are expanded to
114+
have shape (..., 1).
115+
"""
116+
try:
117+
from joblib import Parallel, delayed
118+
except ImportError as e:
119+
raise ImportError(
120+
"joblib is required for parallel sampling. Please install it via 'pip install joblib'."
121+
) from e
122+
123+
# normalize batch shape to a tuple
124+
if isinstance(batch_shape, int):
125+
bs = (batch_shape,)
126+
else:
127+
bs = tuple(batch_shape)
128+
if len(bs) == 0:
129+
raise ValueError("batch_shape must be a positive integer or a nonempty tuple")
130+
131+
results = Parallel(n_jobs=n_jobs, verbose=verbose)(
132+
delayed(self._single_sample)(batch_shape_ext=bs[1:], **kwargs) for _ in range(bs[0])
133+
)
134+
return self._combine_results(results)
135+
136+
@staticmethod
137+
def _combine_results(results: list[dict]) -> dict[str, np.ndarray]:
138+
"""
139+
Combine a list of single-sample results into arrays.
140+
141+
Parameters
142+
----------
143+
results : list of dict
144+
List of dictionaries from individual samples.
145+
146+
Returns
147+
-------
148+
dict
149+
Combined results with arrays.
150+
"""
151+
if not results:
152+
return {}
153+
154+
# union of all keys across results
155+
all_keys = set()
156+
for r in results:
157+
all_keys.update(r.keys())
158+
159+
combined_data: dict[str, np.ndarray] = {}
160+
161+
for key in all_keys:
162+
values = []
163+
for result in results:
164+
if key in result:
165+
value = result[key]
166+
if isinstance(value, np.ndarray) and value.shape[:1] == (1,):
167+
values.append(value[0])
168+
else:
169+
values.append(value)
170+
else:
171+
values.append(None)
172+
173+
try:
174+
if all(isinstance(v, np.ndarray) for v in values):
175+
combined_data[key] = np.stack(values, axis=0)
176+
else:
177+
combined_data[key] = np.array(values, dtype=object)
178+
except ValueError:
179+
combined_data[key] = np.array(values, dtype=object)
180+
181+
return combined_data

0 commit comments

Comments
 (0)