Skip to content

Commit 160d6d7

Browse files
committed
Add refinement of doubling time fit from first hospitalization date
1 parent 8556bad commit 160d6d7

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

src/penn_chime/models.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,31 +85,25 @@ def __init__(self, p: Parameters):
8585
elif p.date_first_hospitalized is not None and p.doubling_time is None:
8686
# Fitting spread parameter to observed hospital census (dates of 1 patient and today)
8787
self.i_day = (p.current_date - p.date_first_hospitalized).days
88+
self.current_hospitalized = p.current_hospitalized
8889
logger.info(
89-
'Using date_first_hospitalized: %s; current_date: %s; i_day: %s',
90+
'Using date_first_hospitalized: %s; current_date: %s; i_day: %s, current_hospitalized: %s',
9091
p.date_first_hospitalized,
9192
p.current_date,
92-
self.i_day)
93-
min_loss = 2.0**99
94-
dts = np.linspace(1, 15, 29)
95-
losses = np.full(dts.shape[0], np.inf)
96-
self.current_hospitalized = p.current_hospitalized
97-
for i, i_dt in enumerate(dts):
98-
intrinsic_growth_rate = get_growth_rate(i_dt)
99-
self.beta = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, 0.0)
100-
self.beta_t = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, p.relative_contact_rate)
93+
self.i_day,
94+
p.current_hospitalized,
95+
)
10196

102-
self.run_projection(p, self.gen_policy(p))
97+
# Make an initial coarse estimate
98+
dts = np.linspace(1, 15, 29)
99+
min_loss = self.get_argmin_doubling_time(p, dts)
103100

104-
# Skip values the would put the fit past peak
105-
peak_admits_day = self.admits_df.hospitalized.argmax()
106-
if peak_admits_day < 0:
107-
continue
101+
# Refine the coarse estimate
102+
dts = np.linspace(dts[min_loss-1], dts[min_loss+1], 29)
103+
min_loss = self.get_argmin_doubling_time(p, dts)
108104

109-
loss = self.get_loss()
110-
losses[i] = loss
105+
p.doubling_time = dts[min_loss]
111106

112-
p.doubling_time = dts[pd.Series(losses).argmin()]
113107
logger.info('Estimated doubling_time: %s', p.doubling_time)
114108
intrinsic_growth_rate = get_growth_rate(p.doubling_time)
115109
self.beta = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, 0.0)
@@ -151,6 +145,26 @@ def __init__(self, p: Parameters):
151145
self.daily_growth_rate = get_growth_rate(p.doubling_time)
152146
self.daily_growth_rate_t = get_growth_rate(self.doubling_time_t)
153147

148+
def get_argmin_doubling_time(self, p: Parameters, dts):
149+
losses = np.full(dts.shape[0], np.inf)
150+
for i, i_dt in enumerate(dts):
151+
intrinsic_growth_rate = get_growth_rate(i_dt)
152+
self.beta = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, 0.0)
153+
self.beta_t = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, p.relative_contact_rate)
154+
155+
self.run_projection(p, self.gen_policy(p))
156+
157+
# Skip values the would put the fit past peak
158+
peak_admits_day = self.admits_df.hospitalized.argmax()
159+
if peak_admits_day < 0:
160+
continue
161+
162+
loss = self.get_loss()
163+
losses[i] = loss
164+
165+
min_loss = pd.Series(losses).argmin()
166+
return min_loss
167+
154168
def gen_policy(self, p: Parameters) -> Sequence[Tuple[float, int]]:
155169
if p.mitigation_date is not None:
156170
mitigation_day = -(p.current_date - p.mitigation_date).days

tests/penn_chime/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ def test_model_first_hosp_fit(param):
111111

112112
my_model = SimSirModel(param)
113113

114-
assert my_model.intrinsic_growth_rate == 0.12246204830937302
114+
assert my_model.intrinsic_growth_rate == 0.1232387953882732
115115
assert abs(my_model.beta - 4.21501347256401e-07) < EPSILON
116-
assert my_model.r_t == 2.307298374881539
117-
assert my_model.r_naught == 2.7144686763312222
118-
assert my_model.doubling_time_t == 7.764405988534983
116+
assert my_model.r_t == 2.316541665120451
117+
assert my_model.r_naught == 2.7253431354358253
118+
assert my_model.doubling_time_t == 7.712255171528787
119119

120120

121121
def test_model_raw_start(model, param):

0 commit comments

Comments
 (0)