Skip to content

Commit a5ce7e7

Browse files
committed
Improved docstring for sample
1 parent f57af0c commit a5ce7e7

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,14 +439,21 @@ def sample(
439439
Whether to split the output arrays along the last axis and return one column vector per target variable
440440
samples.
441441
keep_conditions : bool, default=False
442-
Whether the output should contain a repeated version of the conditions corresponding to generated samples.
442+
If True, the returned dict will include each of the original
443+
conditioning variables, **repeated** along the sample axis so that
444+
they align 1:1 with the generated samples. Each condition array
445+
will have shape ``(num_datasets, num_samples, *condition_variable_shape)``.
446+
447+
By default conditions are not included in the returned dict.
443448
**kwargs : dict
444449
Additional keyword arguments for the adapter and sampling process.
445450
446451
Returns
447452
-------
448453
dict[str, np.ndarray]
449-
Dictionary containing generated samples with the same keys as `conditions`.
454+
Dictionary containing generated samples and optionally the corresponding conditions.
455+
456+
Dictionary values are arrays of shape ``(num_datasets, num_samples, *variable_shape)``.
450457
"""
451458
# Adapt, optionally standardize and convert conditions to tensor
452459
conditions = self._prepare_data(conditions, **kwargs)
@@ -473,7 +480,7 @@ def sample(
473480
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
474481
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
475482
repeated_conditions = keras.tree.map_structure(
476-
lambda tensor: np.repeat(np.expand_dims(tensor, axis=1), num_samples, axis=1), conditions
483+
lambda value: np.repeat(np.expand_dims(value, axis=1), num_samples, axis=1), conditions
477484
)
478485
samples = repeated_conditions | samples
479486

0 commit comments

Comments
 (0)