@@ -73,28 +73,38 @@ def __init__(self, p: Parameters):
7373 (self .beta , p .n_days + PREFIT_ADDITIONAL_DAYS )])
7474 self .i_day = i_day = int (get_argmin_ds (raw ["census_hospitalized" ], p .current_hospitalized ))
7575
76- self .raw = self .run_projection (p , self .gen_policy (p ))
76+ self .raw = self .run_projection (p , self .get_policies (p ))
7777
7878 logger .info ('Set i_day = %s' , i_day )
7979 else :
80- projections = {}
8180 best_i_day = - 1
8281 best_i_day_loss = float ('inf' )
83- for i_day in range (p .n_days ):
84- self .i_day = i_day
85- raw = self .run_projection (p , self .gen_policy (p ))
82+ for self .i_day in range (p .n_days + PREFIT_ADDITIONAL_DAYS ):
83+ mitigation_day = - (p .current_date - p .mitigation_date ).days
84+ if mitigation_day < - self .i_day :
85+ mitigation_day = - self .i_day
86+
87+ total_days = self .i_day + p .n_days + PREFIT_ADDITIONAL_DAYS
88+ pre_mitigation_days = self .i_day + mitigation_day
89+ post_mitigation_days = total_days - pre_mitigation_days
90+
91+ raw = self .run_projection (p , [
92+ (self .beta , pre_mitigation_days ),
93+ (self .beta_t , post_mitigation_days ),
94+ ]
95+ )
8696
8797 # Don't fit against results that put the peak before the present day
88- if raw ["census_hospitalized" ].argmax () < i_day :
98+ if raw ["census_hospitalized" ].argmax () < self . i_day :
8999 continue
90100
91- loss = get_loss (raw ["census_hospitalized" ][i_day ], p .current_hospitalized )
101+ loss = get_loss (raw ["census_hospitalized" ][self . i_day ], p .current_hospitalized )
92102 if loss < best_i_day_loss :
93103 best_i_day_loss = loss
94- best_i_day = i_day
95- self .raw = raw
104+ best_i_day = self .i_day
96105
97106 self .i_day = best_i_day
107+ self .raw = self .run_projection (p , self .get_policies (p ))
98108
99109 logger .info (
100110 'Estimated date_first_hospitalized: %s; current_date: %s; i_day: %s' ,
@@ -129,7 +139,7 @@ def __init__(self, p: Parameters):
129139 intrinsic_growth_rate = get_growth_rate (p .doubling_time )
130140 self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
131141 self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
132- self .raw = self .run_projection (p , self .gen_policy (p ))
142+ self .raw = self .run_projection (p , self .get_policies (p ))
133143
134144 self .population = p .population
135145 else :
@@ -198,7 +208,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
198208 self .beta = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , 0.0 )
199209 self .beta_t = get_beta (intrinsic_growth_rate , self .gamma , self .susceptible , p .relative_contact_rate )
200210
201- raw = self .run_projection (p , self .gen_policy (p ))
211+ raw = self .run_projection (p , self .get_policies (p ))
202212
203213 # Skip values the would put the fit past peak
204214 peak_admits_day = raw ["admits_hospitalized" ].argmax ()
@@ -212,7 +222,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
212222 min_loss = pd .Series (losses ).argmin ()
213223 return min_loss
214224
215- def gen_policy (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
225+ def get_policies (self , p : Parameters ) -> Sequence [Tuple [float , int ]]:
216226 if p .mitigation_date is not None :
217227 mitigation_day = - (p .current_date - p .mitigation_date ).days
218228 else :
0 commit comments