Skip to content

Commit 7cea06d

Browse files
committed
revert changes to ContinuousApproximator.sample()
1 parent 48bb2d9 commit 7cea06d

File tree

1 file changed

+10
-50
lines changed

1 file changed

+10
-50
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from bayesflow.adapters import Adapter
1212
from bayesflow.networks import InferenceNetwork, SummaryNetwork
1313
from bayesflow.types import Tensor
14-
from bayesflow.utils import filter_kwargs, logging, split_arrays, concatenate_valid
14+
from bayesflow.utils import filter_kwargs, logging, split_arrays
1515
from .approximator import Approximator
1616

1717

@@ -213,7 +213,6 @@ def get_config(self):
213213
def sample(
214214
self,
215215
*,
216-
num_datasets: int,
217216
num_samples: int,
218217
conditions: dict[str, np.ndarray],
219218
split: bool = False,
@@ -225,8 +224,6 @@ def sample(
225224
226225
Parameters
227226
----------
228-
num_datasets: int, optional
229-
Number of datasets to generate.
230227
num_samples : int
231228
Number of samples to generate.
232229
conditions : dict[str, np.ndarray]
@@ -242,54 +239,17 @@ def sample(
242239
dict[str, np.ndarray]
243240
Dictionary containing generated samples with the same keys as `conditions`.
244241
"""
245-
conditions = self.adapter(conditions, strict=False)
246-
247-
if "inference_conditions" in conditions:
248-
inference_conditions = keras.ops.convert_to_tensor(conditions["inference_conditions"])
249-
else:
250-
inference_conditions = None
251-
252-
if "summary_conditions" in conditions:
253-
# we are directly supplied the summary network output
254-
if self.summary_network is None:
255-
raise ValueError("Cannot supply direct summary conditions without a summary network.")
256-
summary_conditions = keras.ops.convert_to_tensor(conditions["summary_conditions"])
257-
elif "summary_variables" in conditions:
258-
# we are supplied the summary network input
259-
if self.summary_network is None:
260-
raise ValueError("Cannot supply summary variables without a summary network.")
261-
summary_variables = keras.ops.convert_to_tensor(conditions["summary_variables"])
262-
summary_conditions = self.summary_network(summary_variables)
263-
else:
264-
# we are not supplied any summary statistics
265-
if self.summary_network is not None:
266-
raise ValueError("Summary conditions are required when a summary network is present.")
267-
summary_conditions = None
268-
269-
shape = [num_datasets, num_samples, keras.ops.shape(inference_conditions)[-1]]
270-
inference_conditions = keras.ops.broadcast_to(inference_conditions, shape)
271-
272-
shape = [num_datasets, num_samples, keras.ops.shape(summary_conditions)[-1]]
273-
summary_conditions = keras.ops.broadcast_to(summary_conditions, shape)
274-
275-
samples = self.inference_network.sample(
276-
(num_datasets, num_samples),
277-
conditions=concatenate_valid([inference_conditions, summary_conditions], axis=-1),
278-
**kwargs,
279-
)
280-
281-
samples = keras.ops.convert_to_numpy(samples)
282-
283-
samples = {
284-
"inference_variables": samples,
285-
**conditions,
286-
}
287-
samples = self.adapter(samples, inverse=True, strict=False)
242+
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
243+
# at inference time, inference_variables are estimated by the networks and thus ignored in conditions
244+
conditions.pop("inference_variables", None)
245+
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
246+
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
247+
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
248+
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
288249

289250
if split:
290-
samples = split_arrays(samples, axis=-1)
291-
292-
return samples
251+
conditions = split_arrays(conditions, axis=-1)
252+
return conditions
293253

294254
def _sample(
295255
self,

0 commit comments

Comments
 (0)