Skip to content

Commit 08b9119

Browse files
committed
Add split arrays as an option to sample
1 parent 6fb8b83 commit 08b9119

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from bayesflow.adapters import Adapter
1212
from bayesflow.networks import InferenceNetwork, SummaryNetwork
1313
from bayesflow.types import Tensor
14-
from bayesflow.utils import logging
14+
from bayesflow.utils import logging, split_arrays
1515
from .approximator import Approximator
1616

1717

@@ -136,6 +136,7 @@ def sample(
136136
*,
137137
num_samples: int,
138138
conditions: dict[str, np.ndarray],
139+
split: bool = False,
139140
**kwargs,
140141
) -> dict[str, np.ndarray]:
141142
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
@@ -144,6 +145,8 @@ def sample(
144145
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
145146
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
146147

148+
if split:
149+
conditions = split_arrays(conditions, axis=-1)
147150
return conditions
148151

149152
def _sample(

bayesflow/utils/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@
33
logging,
44
numpy_utils,
55
)
6-
from .dict_utils import (
7-
convert_args,
8-
convert_kwargs,
9-
filter_kwargs,
10-
keras_kwargs,
11-
split_tensors,
12-
)
6+
from .dict_utils import convert_args, convert_kwargs, filter_kwargs, keras_kwargs, split_tensors, split_arrays
137
from .dispatch import find_distribution, find_network, find_permutation, find_pooling, find_recurrent_net
148
from .ecdf import simultaneous_ecdf_bands, ranks
159
from .functional import batched_call

bayesflow/utils/dict_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,24 @@ def split_tensors(data: Mapping[any, Tensor], axis: int = -1) -> Mapping[any, Te
105105
return result
106106

107107

108+
def split_arrays(data: Mapping[any, np.ndarray], axis: int = -1) -> Mapping[any, np.ndarray]:
109+
"""Split tensors in the dictionary along the given axis."""
110+
result = {}
111+
112+
for key, value in data.items():
113+
if value.shape[axis] == 1:
114+
result[key] = np.squeeze(value, axis=axis)
115+
continue
116+
117+
splits = np.split(value, value.shape[axis], axis=axis)
118+
splits = [np.squeeze(split, axis=axis) for split in splits]
119+
120+
for i, split in enumerate(splits):
121+
result[f"{key}_{i + 1}"] = split
122+
123+
return result
124+
125+
108126
def dicts_to_arrays(
109127
post_variables: dict[str, np.ndarray] | np.ndarray,
110128
prior_variables: dict[str, np.ndarray] | np.ndarray = None,

0 commit comments

Comments
 (0)