14
14
import numpy as np
15
15
import pandas as pd
16
16
17
+ from ..constants import PREFIT_ADDITIONAL_DAYS
17
18
from .parameters import Parameters
18
19
19
20
@@ -68,31 +69,42 @@ def __init__(self, p: Parameters):
68
69
69
70
if p .mitigation_date is None :
70
71
self .i_day = 0 # seed to the full length
71
- raw = self .run_projection (p , [(self .beta , p .n_days )])
72
+ raw = self .run_projection (p , [
73
+ (self .beta , p .n_days + PREFIT_ADDITIONAL_DAYS )])
72
74
self .i_day = i_day = int (get_argmin_ds (raw ["census_hospitalized" ], p .current_hospitalized ))
73
75
74
- self .raw = self .run_projection (p , self .gen_policy (p ))
76
+ self .raw = self .run_projection (p , self .get_policies (p ))
75
77
76
78
logger .info ('Set i_day = %s' , i_day )
77
79
else :
78
- projections = {}
79
80
best_i_day = - 1
80
81
best_i_day_loss = float ('inf' )
81
- for i_day in range (p .n_days ):
82
- self .i_day = i_day
83
- raw = self .run_projection (p , self .gen_policy (p ))
82
+ for self .i_day in range (p .n_days + PREFIT_ADDITIONAL_DAYS ):
83
+ mitigation_day = - (p .current_date - p .mitigation_date ).days
84
+ if mitigation_day < - self .i_day :
85
+ mitigation_day = - self .i_day
86
+
87
+ total_days = self .i_day + p .n_days + PREFIT_ADDITIONAL_DAYS
88
+ pre_mitigation_days = self .i_day + mitigation_day
89
+ post_mitigation_days = total_days - pre_mitigation_days
90
+
91
+ raw = self .run_projection (p , [
92
+ (self .beta , pre_mitigation_days ),
93
+ (self .beta_t , post_mitigation_days ),
94
+ ]
95
+ )
84
96
85
97
# Don't fit against results that put the peak before the present day
86
- if raw ["census_hospitalized" ].argmax () < i_day :
98
+ if raw ["census_hospitalized" ].argmax () < self . i_day :
87
99
continue
88
100
89
- loss = get_loss (raw ["census_hospitalized" ][i_day ], p .current_hospitalized )
101
+ loss = get_loss (raw ["census_hospitalized" ][self . i_day ], p .current_hospitalized )
90
102
if loss < best_i_day_loss :
91
103
best_i_day_loss = loss
92
- best_i_day = i_day
93
- self .raw = raw
104
+ best_i_day = self .i_day
94
105
95
106
self .i_day = best_i_day
107
+ self .raw = self .run_projection (p , self .get_policies (p ))
96
108
97
109
logger .info (
98
110
'Estimated date_first_hospitalized: %s; current_date: %s; i_day: %s' ,
@@ -127,7 +139,7 @@ def __init__(self, p: Parameters):
127
139
intrinsic_growth_rate = get_growth_rate (p .doubling_time )
128
140
self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
129
141
self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
130
- self .raw = self .run_projection (p , self .gen_policy (p ))
142
+ self .raw = self .run_projection (p , self .get_policies (p ))
131
143
132
144
self .population = p .population
133
145
else :
@@ -196,7 +208,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
196
208
self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
197
209
self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
198
210
199
- raw = self .run_projection (p , self .gen_policy (p ))
211
+ raw = self .run_projection (p , self .get_policies (p ))
200
212
201
213
# Skip values the would put the fit past peak
202
214
peak_admits_day = raw ["admits_hospitalized" ].argmax ()
@@ -210,7 +222,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
210
222
min_loss = pd .Series (losses ).argmin ()
211
223
return min_loss
212
224
213
- def gen_policy (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
225
+ def get_policies (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
214
226
if p .mitigation_date is not None :
215
227
mitigation_day = - (p .current_date - p .mitigation_date ).days
216
228
else :
0 commit comments