@@ -89,6 +89,7 @@ def sample(
8989 num_samples : int ,
9090 conditions : Mapping [str , np .ndarray ],
9191 split : bool = False ,
92+ keep_conditions : bool = False ,
9293 ** kwargs ,
9394 ) -> dict [str , dict [str , np .ndarray ]]:
9495 """
@@ -107,6 +108,14 @@ def sample(
107108 split : bool, optional
108109 If True, the sampled arrays are split along the last axis, by default False.
109110 Currently not supported for :py:class:`PointApproximator` .
111+ keep_conditions : bool, default=False
112+ If True, the returned dict will include each of the original
113+ conditioning variables, **repeated** along the sample axis so that
114+ they align 1:1 with the generated samples. Each condition array
115+ will have shape ``(num_datasets, num_samples, *condition_variable_shape)``.
116+
117+ By default conditions are not included in the returned dict.
118+
110119 **kwargs
111120 Additional keyword arguments passed to underlying processing functions.
112121
@@ -115,11 +124,11 @@ def sample(
115124 samples : dict[str, np.ndarray or dict[str, np.ndarray]]
116125 Samples for all inference variables and all parametric scoring rules in a nested dictionary.
117126
118- 1. Each first-level key is the name of an inference variable.
127+ 1. Each first-level key is the name of an inference variable or condition .
119128 2. (If there are multiple parametric scores, each second-level key is the name of such a score.)
120129
121130 Each output (i.e., dictionary value that is not itself a dictionary) is an array
122- of shape (num_datasets, num_samples, variable_block_size) .
131+ of shape `` (num_datasets, num_samples, *variable_shape)`` .
123132 """
124133 # Adapt, optionally standardize and convert conditions to tensor.
125134 conditions = self ._prepare_data (conditions , ** kwargs )
@@ -141,6 +150,14 @@ def sample(
141150 # Squeeze sample dictionary if there's only one key-value pair.
142151 samples = self ._squeeze_parametric_score_major_dict (samples )
143152
153+ if keep_conditions :
154+ conditions = keras .tree .map_structure (keras .ops .convert_to_numpy , conditions )
155+ conditions = self .adapter (conditions , inverse = True , strict = False , ** kwargs )
156+ repeated_conditions = keras .tree .map_structure (
157+ lambda value : np .repeat (np .expand_dims (value , axis = 1 ), num_samples , axis = 1 ), conditions
158+ )
159+ samples = repeated_conditions | samples
160+
144161 return samples
145162
146163 def log_prob (
0 commit comments