Skip to content

Commit 1c66b9c

Browse files
committed
Add keep_conditions argument to continuous_approximator.sample
1 parent f916855 commit 1c66b9c

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def sample(
422422
num_samples: int,
423423
conditions: Mapping[str, np.ndarray],
424424
split: bool = False,
425+
keep_conditions: bool = False,
425426
**kwargs,
426427
) -> dict[str, np.ndarray]:
427428
"""
@@ -437,6 +438,8 @@ def sample(
437438
split : bool, default=False
438439
Whether to split the output arrays along the last axis and return one column vector per target variable
439440
samples.
441+
keep_conditions : bool, default=False
442+
Whether the output should contain a repeated version of the conditions corresponding to generated samples.
440443
**kwargs : dict
441444
Additional keyword arguments for the adapter and sampling process.
442445
@@ -465,6 +468,13 @@ def sample(
465468

466469
if split:
467470
samples = split_arrays(samples, axis=-1)
471+
472+
if keep_conditions:
473+
repeated_conditions = keras.tree.map_structure(
474+
lambda tensor: np.repeat(np.expand_dims(tensor, axis=1), num_samples, axis=1), conditions
475+
)
476+
samples = repeated_conditions | samples
477+
468478
return samples
469479

470480
def _prepare_data(

0 commit comments

Comments
 (0)