Skip to content

Commit 0fb876b

Browse files
committed
Fix default settings for SIR and warn
1 parent dfad40b commit 0fb876b

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

bayesflow/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ def setup():
3030
logger = logging.getLogger(__name__)
3131
logger.setLevel(logging.INFO)
3232

33+
from bayesflow.utils import logging
34+
3335
if keras.backend.backend() == "torch":
3436
# turn off gradients by default
3537
import torch
3638

3739
torch.autograd.set_grad_enabled(False)
3840

39-
from bayesflow.utils import logging
41+
logging.warning("Disabling gradients by default. Use\nwith torch.enable_grad():\nin custom training loops.")
4042

4143
logging.debug(f"Using backend {keras.backend.backend()!r}")
4244

bayesflow/simulators/benchmark_simulators/sir.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(
1111
T: int = 160,
1212
I0: float = 1.0,
1313
R0: float = 0.0,
14-
subsample: int = 10,
14+
subsample: int = None,
1515
total_count: int = 1000,
1616
scale_by_total: bool = True,
1717
rng: np.random.Generator = None,
@@ -87,7 +87,7 @@ def observation_model(self, params: np.ndarray):
8787
8888
Returns
8989
-------
90-
x : np.ndarray of shape (subsample,) or (T,) if subsample=None
90+
x : np.ndarray of shape (subsample, ) or (T, 1) if subsample=None
9191
The time series of simulated infected individuals. A trailing dimension of 1 should
9292
be added by a BayesFlow configurator if the data is (properly) to be treated as time series.
9393
"""
@@ -107,6 +107,8 @@ def observation_model(self, params: np.ndarray):
107107
# Subsample evenly the specified number of points, if specified
108108
if self.subsample is not None:
109109
irt = irt[:: (self.T // self.subsample)]
110+
else:
111+
irt = irt[:, None]
110112

111113
# Truncate irt, so that small underflow below zero becomes zero
112114
irt = np.maximum(irt, 0.0)
@@ -115,4 +117,5 @@ def observation_model(self, params: np.ndarray):
115117
x = self.rng.binomial(n=self.total_count, p=irt / self.N)
116118
if self.scale_by_total:
117119
x = x / self.total_count
120+
118121
return x

0 commit comments

Comments
 (0)