File tree Expand file tree Collapse file tree 2 files changed +8
-3
lines changed
simulators/benchmark_simulators Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -30,13 +30,15 @@ def setup():
3030 logger = logging .getLogger (__name__ )
3131 logger .setLevel (logging .INFO )
3232
33+ from bayesflow .utils import logging
34+
3335 if keras .backend .backend () == "torch" :
3436 # turn off gradients by default
3537 import torch
3638
3739 torch .autograd .set_grad_enabled (False )
3840
39- from bayesflow . utils import logging
41+ logging . warning ( "Disabling gradients by default. Use \n with torch.enable_grad(): \n in custom training loops." )
4042
4143 logging .debug (f"Using backend { keras .backend .backend ()!r} " )
4244
Original file line number Diff line number Diff line change @@ -11,7 +11,7 @@ def __init__(
1111 T : int = 160 ,
1212 I0 : float = 1.0 ,
1313 R0 : float = 0.0 ,
14- subsample : int = 10 ,
14+ subsample : int = None ,
1515 total_count : int = 1000 ,
1616 scale_by_total : bool = True ,
1717 rng : np .random .Generator = None ,
@@ -87,7 +87,7 @@ def observation_model(self, params: np.ndarray):
8787
8888 Returns
8989 -------
90- x : np.ndarray of shape (subsample,) or (T,) if subsample=None
90+ x : np.ndarray of shape (subsample, ) or (T, 1 ) if subsample=None
9191 The time series of simulated infected individuals. A trailing dimension of 1 should
9292 be added by a BayesFlow configurator if the data is (properly) to be treated as time series.
9393 """
@@ -107,6 +107,8 @@ def observation_model(self, params: np.ndarray):
107107 # Subsample evenly the specified number of points, if specified
108108 if self .subsample is not None :
109109 irt = irt [:: (self .T // self .subsample )]
110+ else :
111+ irt = irt [:, None ]
110112
111113 # Truncate irt, so that small underflow below zero becomes zero
112114 irt = np .maximum (irt , 0.0 )
@@ -115,4 +117,5 @@ def observation_model(self, params: np.ndarray):
115117 x = self .rng .binomial (n = self .total_count , p = irt / self .N )
116118 if self .scale_by_total :
117119 x = x / self .total_count
120+
118121 return x
You can’t perform that action at this time.
0 commit comments