Skip to content

Commit 34c5f10

Browse files
committed
Add keep_conditions argument to point_approximator.sample
1 parent 08d5474 commit 34c5f10

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

bayesflow/approximators/point_approximator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def sample(
8989
num_samples: int,
9090
conditions: Mapping[str, np.ndarray],
9191
split: bool = False,
92+
keep_conditions: bool = False,
9293
**kwargs,
9394
) -> dict[str, dict[str, np.ndarray]]:
9495
"""
@@ -107,6 +108,14 @@ def sample(
107108
split : bool, optional
108109
If True, the sampled arrays are split along the last axis, by default False.
109110
Currently not supported for :py:class:`PointApproximator` .
111+
keep_conditions : bool, default=False
112+
If True, the returned dict will include each of the original
113+
conditioning variables, **repeated** along the sample axis so that
114+
they align 1:1 with the generated samples. Each condition array
115+
will have shape ``(num_datasets, num_samples, *condition_variable_shape)``.
116+
117+
By default conditions are not included in the returned dict.
118+
110119
**kwargs
111120
Additional keyword arguments passed to underlying processing functions.
112121
@@ -115,11 +124,11 @@ def sample(
115124
samples : dict[str, np.ndarray or dict[str, np.ndarray]]
116125
Samples for all inference variables and all parametric scoring rules in a nested dictionary.
117126
118-
1. Each first-level key is the name of an inference variable.
127+
1. Each first-level key is the name of an inference variable or condition.
119128
2. (If there are multiple parametric scores, each second-level key is the name of such a score.)
120129
121130
Each output (i.e., dictionary value that is not itself a dictionary) is an array
122-
of shape (num_datasets, num_samples, variable_block_size).
131+
of shape ``(num_datasets, num_samples, *variable_shape)``.
123132
"""
124133
# Adapt, optionally standardize and convert conditions to tensor.
125134
conditions = self._prepare_data(conditions, **kwargs)
@@ -141,6 +150,14 @@ def sample(
141150
# Squeeze sample dictionary if there's only one key-value pair.
142151
samples = self._squeeze_parametric_score_major_dict(samples)
143152

153+
if keep_conditions:
154+
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
155+
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
156+
repeated_conditions = keras.tree.map_structure(
157+
lambda value: np.repeat(np.expand_dims(value, axis=1), num_samples, axis=1), conditions
158+
)
159+
samples = repeated_conditions | samples
160+
144161
return samples
145162

146163
def log_prob(

0 commit comments

Comments
 (0)