Skip to content

Commit a933752

Browse files
authored
Iteratively refined fitting
Iteratively refined fitting
2 parents 19fe675 + 6e950f8 commit a933752

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

src/penn_chime/models.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,31 +85,26 @@ def __init__(self, p: Parameters):
8585
elif p.date_first_hospitalized is not None and p.doubling_time is None:
8686
# Fitting spread parameter to observed hospital census (dates of 1 patient and today)
8787
self.i_day = (p.current_date - p.date_first_hospitalized).days
88+
self.current_hospitalized = p.current_hospitalized
8889
logger.info(
89-
'Using date_first_hospitalized: %s; current_date: %s; i_day: %s',
90+
'Using date_first_hospitalized: %s; current_date: %s; i_day: %s, current_hospitalized: %s',
9091
p.date_first_hospitalized,
9192
p.current_date,
92-
self.i_day)
93-
min_loss = 2.0**99
94-
dts = np.linspace(1, 15, 29)
95-
losses = np.full(dts.shape[0], np.inf)
96-
self.current_hospitalized = p.current_hospitalized
97-
for i, i_dt in enumerate(dts):
98-
intrinsic_growth_rate = get_growth_rate(i_dt)
99-
self.beta = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, 0.0)
100-
self.beta_t = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, p.relative_contact_rate)
93+
self.i_day,
94+
p.current_hospitalized,
95+
)
10196

102-
self.run_projection(p, self.gen_policy(p))
97+
# Make an initial coarse estimate
98+
dts = np.linspace(1, 15, 15)
99+
min_loss = self.get_argmin_doubling_time(p, dts)
103100

104-
# Skip values the would put the fit past peak
105-
peak_admits_day = self.admits_df.hospitalized.argmax()
106-
if peak_admits_day < 0:
107-
continue
101+
# Refine the coarse estimate
102+
for iteration in range(4):
103+
dts = np.linspace(dts[min_loss-1], dts[min_loss+1], 15)
104+
min_loss = self.get_argmin_doubling_time(p, dts)
108105

109-
loss = self.get_loss()
110-
losses[i] = loss
106+
p.doubling_time = dts[min_loss]
111107

112-
p.doubling_time = dts[pd.Series(losses).argmin()]
113108
logger.info('Estimated doubling_time: %s', p.doubling_time)
114109
intrinsic_growth_rate = get_growth_rate(p.doubling_time)
115110
self.beta = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, 0.0)
@@ -151,6 +146,26 @@ def __init__(self, p: Parameters):
151146
self.daily_growth_rate = get_growth_rate(p.doubling_time)
152147
self.daily_growth_rate_t = get_growth_rate(self.doubling_time_t)
153148

149+
def get_argmin_doubling_time(self, p: Parameters, dts):
150+
losses = np.full(dts.shape[0], np.inf)
151+
for i, i_dt in enumerate(dts):
152+
intrinsic_growth_rate = get_growth_rate(i_dt)
153+
self.beta = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, 0.0)
154+
self.beta_t = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, p.relative_contact_rate)
155+
156+
self.run_projection(p, self.gen_policy(p))
157+
158+
# Skip values the would put the fit past peak
159+
peak_admits_day = self.admits_df.hospitalized.argmax()
160+
if peak_admits_day < 0:
161+
continue
162+
163+
loss = self.get_loss()
164+
losses[i] = loss
165+
166+
min_loss = pd.Series(losses).argmin()
167+
return min_loss
168+
154169
def gen_policy(self, p: Parameters) -> Sequence[Tuple[float, int]]:
155170
if p.mitigation_date is not None:
156171
mitigation_day = -(p.current_date - p.mitigation_date).days

tests/penn_chime/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ def test_model_first_hosp_fit(param):
111111

112112
my_model = SimSirModel(param)
113113

114-
assert my_model.intrinsic_growth_rate == 0.12246204830937302
114+
assert abs(my_model.intrinsic_growth_rate - 0.123) / 0.123 < 0.01
115115
assert abs(my_model.beta - 4.21501347256401e-07) < EPSILON
116-
assert my_model.r_t == 2.307298374881539
117-
assert my_model.r_naught == 2.7144686763312222
118-
assert my_model.doubling_time_t == 7.764405988534983
116+
assert abs(my_model.r_t - 2.32) / 2.32 < 0.01
117+
assert abs(my_model.r_naught - 2.72) / 2.72 < 0.01
118+
assert abs(my_model.doubling_time_t - 7.71)/7.71 < 0.01
119119

120120

121121
def test_model_raw_start(model, param):

0 commit comments

Comments
 (0)