diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index a5dbf12a3..fb2e95a56 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -537,7 +537,7 @@ def _sample( ) batch_shape = keras.ops.shape(inference_conditions)[:-1] else: - batch_shape = keras.ops.shape(inference_conditions)[1:-1] + batch_shape = (num_samples,) return self.inference_network.sample( batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample) diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 3c4d2fd4c..a56802a3e 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -68,6 +68,9 @@ def point_inference_network_with_multiple_parametric_scores(): def point_approximator_with_single_parametric_score(adapter, point_inference_network, summary_network): from bayesflow import PointApproximator + if "-> 'inference_conditions'" not in str(adapter) and "-> 'summary_conditions'" not in str(adapter): + pytest.skip("point approximator does not support unconditional estimation") + return PointApproximator( adapter=adapter, inference_network=point_inference_network, @@ -81,6 +84,9 @@ def point_approximator_with_multiple_parametric_scores( ): from bayesflow import PointApproximator + if "-> 'inference_conditions'" not in str(adapter) and "-> 'summary_conditions'" not in str(adapter): + pytest.skip("point approximator does not support unconditional estimation") + return PointApproximator( adapter=adapter, inference_network=point_inference_network_with_multiple_parametric_scores, @@ -128,7 +134,16 @@ def adapter_with_sample_weight(): ) -@pytest.fixture(params=["adapter_without_sample_weight", "adapter_with_sample_weight"]) +@pytest.fixture() +def adapter_unconditional(): + from bayesflow import ContinuousApproximator + + return ContinuousApproximator.build_adapter( + inference_variables=["mean", "std"], + ) + + +@pytest.fixture(params=["adapter_unconditional", "adapter_without_sample_weight", "adapter_with_sample_weight"]) def adapter(request): return request.getfixturevalue(request.param)