@@ -439,14 +439,21 @@ def sample(
439439 Whether to split the output arrays along the last axis and return one column vector per target variable
440440 samples.
441441 keep_conditions : bool, default=False
442- Whether the output should contain a repeated version of the conditions corresponding to generated samples.
442+ If True, the returned dict will include each of the original
443+ conditioning variables, **repeated** along the sample axis so that
444+ they align 1:1 with the generated samples. Each condition array
445+ will have shape ``(num_datasets, num_samples, *condition_variable_shape)``.
446+
447+ By default conditions are not included in the returned dict.
443448 **kwargs : dict
444449 Additional keyword arguments for the adapter and sampling process.
445450
446451 Returns
447452 -------
448453 dict[str, np.ndarray]
449- Dictionary containing generated samples with the same keys as `conditions`.
454+ Dictionary containing generated samples and optionally the corresponding conditions.
455+
456+ Dictionary values are arrays of shape ``(num_datasets, num_samples, *variable_shape)``.
450457 """
451458 # Adapt, optionally standardize and convert conditions to tensor
452459 conditions = self ._prepare_data (conditions , ** kwargs )
@@ -473,7 +480,7 @@ def sample(
473480 conditions = keras .tree .map_structure (keras .ops .convert_to_numpy , conditions )
474481 conditions = self .adapter (conditions , inverse = True , strict = False , ** kwargs )
475482 repeated_conditions = keras .tree .map_structure (
476- lambda tensor : np .repeat (np .expand_dims (tensor , axis = 1 ), num_samples , axis = 1 ), conditions
483+ lambda value : np .repeat (np .expand_dims (value , axis = 1 ), num_samples , axis = 1 ), conditions
477484 )
478485 samples = repeated_conditions | samples
479486
0 commit comments