Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions bayesflow/simulators/sequential_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,112 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
}

return data

def _single_sample(self, batch_shape_ext, **kwargs) -> dict[str, np.ndarray]:
"""
For single sample used by parallel sampling.
Parameters
----------
**kwargs
Keyword arguments passed to simulators.
Returns
-------
dict
Single sample result.
"""
return self.sample(batch_shape=(1, *tuple(batch_shape_ext)), **kwargs)

def sample_parallel(
self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 0, **kwargs
) -> dict[str, np.ndarray]:
"""
Sample in parallel from the sequential simulator.
Parameters
----------
batch_shape : Shape
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
but it also accepts an int.
n_jobs : int, optional
Number of parallel jobs. -1 uses all available cores. Default is -1.
verbose : int, optional
Verbosity level for joblib. Default is 0 (no output).
**kwargs
Additional keyword arguments passed to each simulator. These may include previously
sampled outputs used as inputs for subsequent simulators.
Returns
-------
data : dict of str to np.ndarray
A dictionary containing the combined outputs from all simulators. Keys are output names
and values are sampled arrays. If `expand_outputs` is True, 1D arrays are expanded to
have shape (..., 1).
"""
try:
from joblib import Parallel, delayed
except ImportError as e:
raise ImportError(
"joblib is required for parallel sampling. Please install it via 'pip install joblib'."
) from e

# normalize batch shape to a tuple
if isinstance(batch_shape, int):
bs = (batch_shape,)
else:
bs = tuple(batch_shape)
if len(bs) == 0:
raise ValueError("batch_shape must be a positive integer or a nonempty tuple")

results = Parallel(n_jobs=n_jobs, verbose=verbose)(
delayed(self._single_sample)(batch_shape_ext=bs[1:], **kwargs) for _ in range(bs[0])
)
return self._combine_results(results)

@staticmethod
def _combine_results(results: list[dict]) -> dict[str, np.ndarray]:
"""
Combine a list of single-sample results into arrays.
Parameters
----------
results : list of dict
List of dictionaries from individual samples.
Returns
-------
dict
Combined results with arrays.
"""
if not results:
return {}

# union of all keys across results
all_keys = set()
for r in results:
all_keys.update(r.keys())

combined_data: dict[str, np.ndarray] = {}

for key in all_keys:
values = []
for result in results:
if key in result:
value = result[key]
if isinstance(value, np.ndarray) and value.shape[:1] == (1,):
values.append(value[0])
else:
values.append(value)
else:
values.append(None)

try:
if all(isinstance(v, np.ndarray) for v in values):
combined_data[key] = np.stack(values, axis=0)
else:
combined_data[key] = np.array(values, dtype=object)
except ValueError:
combined_data[key] = np.array(values, dtype=object)

return combined_data
Loading