@@ -11,7 +11,7 @@ def __init__(
1111 T : int = 160 ,
1212 I0 : float = 1.0 ,
1313 R0 : float = 0.0 ,
14- subsample : int = None ,
14+ subsample : int | str = "original" ,
1515 total_count : int = 1000 ,
1616 scale_by_total : bool = True ,
1717 rng : np .random .Generator = None ,
@@ -27,15 +27,17 @@ def __init__(
2727 The size of the simulated population.
2828 T: int, optional, default: 160
2929 The duration (time horizon) of the simulation.
30+ The last time-point is not included.
3031 I0: float, optional, default: 1.0
3132 The number of initially infected individuals.
3233 R0: float, optional, default: 0.0
3334 The number of initially recovered individuals.
34- subsample: int or None, optional, default: 10
35+ subsample: int or None, optional, default: 'original'
3536 The number of evenly spaced time points to return. If `None`,
3637 no subsampling will be performed, all `T` timepoints will be returned
3738 and a trailing dimension will be added. If an integer is provided,
3839 subsampling is performed and no trailing dimension will be added.
40+ 'original' reproduces the original benchmark task subsampling of 10 points.
3941 total_count: int, optional, default: 1000
4042 The `N` parameter of the binomial noise distribution. Used just
4143 for scaling the data and magnifying the effect of noise, such that
@@ -100,14 +102,16 @@ def observation_model(self, params: np.ndarray):
100102 # Unpack parameter vector into scalars
101103 beta , gamma = params
102104
103- # Prepate time vector between 0 and T of length T
104- t_vec = np .linspace ( 0 , self . T , self .T )
105+ # Prepare time vector between 0 and T of length T
106+ t_vec = np .arange ( 0 , self .T )
105107
106108 # Integrate using scipy and retain only infected (2-nd dimension)
107109 irt = odeint (self ._deriv , x0 , t_vec , args = (self .N , beta , gamma ))[:, 1 ]
108110
109111 # Subsample evenly the specified number of points, if specified
110- if self .subsample is not None :
112+ if self .subsample == "original" :
113+ irt = irt [::17 ]
114+ elif self .subsample is not None :
111115 irt = irt [:: (self .T // self .subsample )]
112116 else :
113117 irt = irt [:, None ]
0 commit comments