1111from bayesflow .adapters import Adapter
1212from bayesflow .networks import InferenceNetwork , SummaryNetwork
1313from bayesflow .types import Tensor
14- from bayesflow .utils import filter_kwargs , logging , split_arrays
14+ from bayesflow .utils import filter_kwargs , logging , split_arrays , concatenate_valid
1515from .approximator import Approximator
1616
1717
@@ -213,6 +213,7 @@ def get_config(self):
213213 def sample (
214214 self ,
215215 * ,
216+ num_datasets : int ,
216217 num_samples : int ,
217218 conditions : dict [str , np .ndarray ],
218219 split : bool = False ,
@@ -224,6 +225,8 @@ def sample(
224225
225226 Parameters
226227 ----------
228+ num_datasets: int, optional
229+ Number of datasets to generate.
227230 num_samples : int
228231 Number of samples to generate.
229232 conditions : dict[str, np.ndarray]
@@ -239,17 +242,54 @@ def sample(
239242 dict[str, np.ndarray]
240243 Dictionary containing generated samples with the same keys as `conditions`.
241244 """
242- conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
243- # at inference time, inference_variables are estimated by the networks and thus ignored in conditions
244- conditions .pop ("inference_variables" , None )
245- conditions = keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
246- conditions = {"inference_variables" : self ._sample (num_samples = num_samples , ** conditions , ** kwargs )}
247- conditions = keras .tree .map_structure (keras .ops .convert_to_numpy , conditions )
248- conditions = self .adapter (conditions , inverse = True , strict = False , ** kwargs )
245+ conditions = self .adapter (conditions , strict = False )
246+
247+ if "inference_conditions" in conditions :
248+ inference_conditions = keras .ops .convert_to_tensor (conditions ["inference_conditions" ])
249+ else :
250+ inference_conditions = None
251+
252+ if "summary_conditions" in conditions :
253+ # we are directly supplied the summary network output
254+ if self .summary_network is None :
255+ raise ValueError ("Cannot supply direct summary conditions without a summary network." )
256+ summary_conditions = keras .ops .convert_to_tensor (conditions ["summary_conditions" ])
257+ elif "summary_variables" in conditions :
258+ # we are supplied the summary network input
259+ if self .summary_network is None :
260+ raise ValueError ("Cannot supply summary variables without a summary network." )
261+ summary_variables = keras .ops .convert_to_tensor (conditions ["summary_variables" ])
262+ summary_conditions = self .summary_network (summary_variables )
263+ else :
264+ # we are not supplied any summary statistics
265+ if self .summary_network is not None :
266+ raise ValueError ("Summary conditions are required when a summary network is present." )
267+ summary_conditions = None
268+
269+ shape = [num_datasets , num_samples , keras .ops .shape (inference_conditions )[- 1 ]]
270+ inference_conditions = keras .ops .broadcast_to (inference_conditions , shape )
271+
272+ shape = [num_datasets , num_samples , keras .ops .shape (summary_conditions )[- 1 ]]
273+ summary_conditions = keras .ops .broadcast_to (summary_conditions , shape )
274+
275+ samples = self .inference_network .sample (
276+ (num_datasets , num_samples ),
277+ conditions = concatenate_valid ([inference_conditions , summary_conditions ], axis = - 1 ),
278+ ** kwargs ,
279+ )
280+
281+ samples = keras .ops .convert_to_numpy (samples )
282+
283+ samples = {
284+ "inference_variables" : samples ,
285+ ** conditions ,
286+ }
287+ samples = self .adapter (samples , inverse = True , strict = False )
249288
250289 if split :
251- conditions = split_arrays (conditions , axis = - 1 )
252- return conditions
290+ samples = split_arrays (samples , axis = - 1 )
291+
292+ return samples
253293
254294 def _sample (
255295 self ,
0 commit comments