@@ -73,28 +73,38 @@ def __init__(self, p: Parameters):
73
73
(self .beta , p .n_days + PREFIT_ADDITIONAL_DAYS )])
74
74
self .i_day = i_day = int (get_argmin_ds (raw ["census_hospitalized" ], p .current_hospitalized ))
75
75
76
- self .raw = self .run_projection (p , self .gen_policy (p ))
76
+ self .raw = self .run_projection (p , self .get_policies (p ))
77
77
78
78
logger .info ('Set i_day = %s' , i_day )
79
79
else :
80
- projections = {}
81
80
best_i_day = - 1
82
81
best_i_day_loss = float ('inf' )
83
- for i_day in range (p .n_days ):
84
- self .i_day = i_day
85
- raw = self .run_projection (p , self .gen_policy (p ))
82
+ for self .i_day in range (p .n_days + PREFIT_ADDITIONAL_DAYS ):
83
+ mitigation_day = - (p .current_date - p .mitigation_date ).days
84
+ if mitigation_day < - self .i_day :
85
+ mitigation_day = - self .i_day
86
+
87
+ total_days = self .i_day + p .n_days + PREFIT_ADDITIONAL_DAYS
88
+ pre_mitigation_days = self .i_day + mitigation_day
89
+ post_mitigation_days = total_days - pre_mitigation_days
90
+
91
+ raw = self .run_projection (p , [
92
+ (self .beta , pre_mitigation_days ),
93
+ (self .beta_t , post_mitigation_days ),
94
+ ]
95
+ )
86
96
87
97
# Don't fit against results that put the peak before the present day
88
- if raw ["census_hospitalized" ].argmax () < i_day :
98
+ if raw ["census_hospitalized" ].argmax () < self . i_day :
89
99
continue
90
100
91
- loss = get_loss (raw ["census_hospitalized" ][i_day ], p .current_hospitalized )
101
+ loss = get_loss (raw ["census_hospitalized" ][self . i_day ], p .current_hospitalized )
92
102
if loss < best_i_day_loss :
93
103
best_i_day_loss = loss
94
- best_i_day = i_day
95
- self .raw = raw
104
+ best_i_day = self .i_day
96
105
97
106
self .i_day = best_i_day
107
+ self .raw = self .run_projection (p , self .get_policies (p ))
98
108
99
109
logger .info (
100
110
'Estimated date_first_hospitalized: %s; current_date: %s; i_day: %s' ,
@@ -129,7 +139,7 @@ def __init__(self, p: Parameters):
129
139
intrinsic_growth_rate = get_growth_rate (p .doubling_time )
130
140
self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
131
141
self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
132
- self .raw = self .run_projection (p , self .gen_policy (p ))
142
+ self .raw = self .run_projection (p , self .get_policies (p ))
133
143
134
144
self .population = p .population
135
145
else :
@@ -198,7 +208,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
198
208
self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
199
209
self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
200
210
201
- raw = self .run_projection (p , self .gen_policy (p ))
211
+ raw = self .run_projection (p , self .get_policies (p ))
202
212
203
213
# Skip values the would put the fit past peak
204
214
peak_admits_day = raw ["admits_hospitalized" ].argmax ()
@@ -212,7 +222,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
212
222
min_loss = pd .Series (losses ).argmin ()
213
223
return min_loss
214
224
215
- def gen_policy (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
225
+ def get_policies (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
216
226
if p .mitigation_date is not None :
217
227
mitigation_day = - (p .current_date - p .mitigation_date ).days
218
228
else :
0 commit comments