1111from bayesflow .adapters import Adapter
1212from bayesflow .networks import InferenceNetwork , SummaryNetwork
1313from bayesflow .types import Tensor
14- from bayesflow .utils import filter_kwargs , logging , split_arrays , concatenate_valid
14+ from bayesflow .utils import filter_kwargs , logging , split_arrays
1515from .approximator import Approximator
1616
1717
@@ -213,7 +213,6 @@ def get_config(self):
213213 def sample (
214214 self ,
215215 * ,
216- num_datasets : int ,
217216 num_samples : int ,
218217 conditions : dict [str , np .ndarray ],
219218 split : bool = False ,
@@ -225,8 +224,6 @@ def sample(
225224
226225 Parameters
227226 ----------
228- num_datasets: int, optional
229- Number of datasets to generate.
230227 num_samples : int
231228 Number of samples to generate.
232229 conditions : dict[str, np.ndarray]
@@ -242,54 +239,17 @@ def sample(
242239 dict[str, np.ndarray]
243240 Dictionary containing generated samples with the same keys as `conditions`.
244241 """
245- conditions = self .adapter (conditions , strict = False )
246-
247- if "inference_conditions" in conditions :
248- inference_conditions = keras .ops .convert_to_tensor (conditions ["inference_conditions" ])
249- else :
250- inference_conditions = None
251-
252- if "summary_conditions" in conditions :
253- # we are directly supplied the summary network output
254- if self .summary_network is None :
255- raise ValueError ("Cannot supply direct summary conditions without a summary network." )
256- summary_conditions = keras .ops .convert_to_tensor (conditions ["summary_conditions" ])
257- elif "summary_variables" in conditions :
258- # we are supplied the summary network input
259- if self .summary_network is None :
260- raise ValueError ("Cannot supply summary variables without a summary network." )
261- summary_variables = keras .ops .convert_to_tensor (conditions ["summary_variables" ])
262- summary_conditions = self .summary_network (summary_variables )
263- else :
264- # we are not supplied any summary statistics
265- if self .summary_network is not None :
266- raise ValueError ("Summary conditions are required when a summary network is present." )
267- summary_conditions = None
268-
269- shape = [num_datasets , num_samples , keras .ops .shape (inference_conditions )[- 1 ]]
270- inference_conditions = keras .ops .broadcast_to (inference_conditions , shape )
271-
272- shape = [num_datasets , num_samples , keras .ops .shape (summary_conditions )[- 1 ]]
273- summary_conditions = keras .ops .broadcast_to (summary_conditions , shape )
274-
275- samples = self .inference_network .sample (
276- (num_datasets , num_samples ),
277- conditions = concatenate_valid ([inference_conditions , summary_conditions ], axis = - 1 ),
278- ** kwargs ,
279- )
280-
281- samples = keras .ops .convert_to_numpy (samples )
282-
283- samples = {
284- "inference_variables" : samples ,
285- ** conditions ,
286- }
287- samples = self .adapter (samples , inverse = True , strict = False )
242+ conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
243+ # at inference time, inference_variables are estimated by the networks and thus ignored in conditions
244+ conditions .pop ("inference_variables" , None )
245+ conditions = keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
246+ conditions = {"inference_variables" : self ._sample (num_samples = num_samples , ** conditions , ** kwargs )}
247+ conditions = keras .tree .map_structure (keras .ops .convert_to_numpy , conditions )
248+ conditions = self .adapter (conditions , inverse = True , strict = False , ** kwargs )
288249
289250 if split :
290- samples = split_arrays (samples , axis = - 1 )
291-
292- return samples
251+ conditions = split_arrays (conditions , axis = - 1 )
252+ return conditions
293253
294254 def _sample (
295255 self ,
0 commit comments