Skip to content

Commit a55612e

Browse files
committed
Fix quantile level serialization and add save/load to notebook
1 parent c99fd01 commit a55612e

File tree

4 files changed

+288
-235
lines changed

4 files changed

+288
-235
lines changed

bayesflow/approximators/point_approximator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ def estimate(
2424
split: bool = False,
2525
**kwargs,
2626
) -> dict[str, dict[str, np.ndarray]]:
27-
if not self.built:
28-
raise AssertionError("PointApproximator needs to be built before predicting with it.")
29-
3027
conditions = self._prepare_conditions(conditions, **kwargs)
3128
estimates = self._estimate(**conditions, **kwargs)
3229
estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs)
@@ -48,9 +45,6 @@ def sample(
4845
split: bool = False,
4946
**kwargs,
5047
) -> dict[str, np.ndarray]:
51-
if not self.built:
52-
raise AssertionError("This model needs to be built before using it for sampling.")
53-
5448
conditions = self._prepare_conditions(conditions, **kwargs)
5549
samples = self._sample(num_samples, **conditions, **kwargs)
5650
samples = self._apply_inverse_adapter_to_samples(samples, **kwargs)

bayesflow/scores/quantile_score.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def __init__(self, q: Sequence[float] = None, links=None, **kwargs):
2323
q = [0.1, 0.5, 0.9]
2424
logging.info(f"QuantileScore was not provided with argument `q`. Using the default quantile levels: {q}.")
2525

26+
# force a conversion to list for proper serialization
27+
q = list(q)
2628
self.q = q
2729
self._q = keras.ops.convert_to_tensor(q, dtype="float32")
2830
self.links = links or {"value": OrderedQuantiles(q=q)}

examples/Lotka_Volterra_point_estimation_and_expert_stats.ipynb

Lines changed: 284 additions & 228 deletions
Large diffs are not rendered by default.

tests/test_networks/test_point_inference_network/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import numpy as np
23

34

45
@pytest.fixture()
@@ -72,5 +73,5 @@ def quantile_point_inference_network():
7273
from bayesflow.scores import QuantileScore
7374

7475
return PointInferenceNetwork(
75-
scores=dict(quantiles=QuantileScore(q=[0.1, 0.4, 0.5, 0.7], subnets=dict(value="mlp"))),
76+
scores=dict(quantiles=QuantileScore(q=np.array([0.1, 0.4, 0.5, 0.7]), subnets=dict(value="mlp"))),
7677
)

0 commit comments

Comments
 (0)