Skip to content

Commit 010d3e8

Browse files
committed
Speed up by 20x by never copying arrays and only constructing a dataframe after the model has run
1 parent 174a150 commit 010d3e8

File tree

8 files changed

+132
-96
lines changed

8 files changed

+132
-96
lines changed

src/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
st.markdown("Projected number of **daily** COVID-19 admissions. \n\n _NOTE: Now including estimates of prior admissions for comparison._")
3838
admits_chart = build_admits_chart(alt=alt, admits_floor_df=m.admits_floor_df, max_y_axis=p.max_y_axis)
3939
st.altair_chart(admits_chart, use_container_width=True)
40-
st.markdown(build_descriptions(chart=admits_chart, labels=p.labels, suffix=" Admissions"))
40+
st.markdown(build_descriptions(chart=admits_chart, labels=p.labels, prefix="admits_", suffix=" Admissions"))
4141
display_download_link(
4242
st,
4343
filename=f"{p.current_date}_projected_admits.csv",
@@ -59,7 +59,7 @@
5959
st.markdown("Projected **census** of COVID-19 patients, accounting for arrivals and discharges \n\n _NOTE: Now including estimates of prior census for comparison._")
6060
census_chart = build_census_chart(alt=alt, census_floor_df=m.census_floor_df, max_y_axis=p.max_y_axis)
6161
st.altair_chart(census_chart, use_container_width=True)
62-
st.markdown(build_descriptions(chart=census_chart, labels=p.labels, suffix=" Census"))
62+
st.markdown(build_descriptions(chart=census_chart, labels=p.labels, prefix="census_", suffix=" Census"))
6363
display_download_link(
6464
st,
6565
filename=f"{p.current_date}_projected_census.csv",

src/penn_chime/charts.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def build_admits_chart(
2626
# TODO fix the fold to allow any number of dispositions
2727
points = (
2828
alt.Chart()
29-
.transform_fold(fold=["hospitalized", "icu", "ventilated"])
29+
.transform_fold(fold=["admits_hospitalized", "admits_icu", "admits_ventilated"])
3030
.encode(x=alt.X(**x), y=alt.Y(**y), color=color, tooltip=tooltip)
3131
.mark_line(point=True)
3232
.encode(
@@ -65,7 +65,7 @@ def build_census_chart(
6565
# TODO fix the fold to allow any number of dispositions
6666
points = (
6767
alt.Chart()
68-
.transform_fold(fold=["hospitalized", "icu", "ventilated"])
68+
.transform_fold(fold=["census_hospitalized", "census_icu", "census_ventilated"])
6969
.encode(x=alt.X(**x), y=alt.Y(**y), color=color, tooltip=tooltip)
7070
.mark_line(point=True)
7171
.encode(
@@ -128,7 +128,11 @@ def build_sim_sir_w_date_chart(
128128

129129

130130
def build_descriptions(
131-
*, chart: Chart, labels: Dict[str, str], suffix: str = ""
131+
*,
132+
chart: Chart,
133+
labels: Dict[str, str],
134+
prefix: str = "",
135+
suffix: str = ""
132136
) -> str:
133137
"""
134138
@@ -145,17 +149,17 @@ def build_descriptions(
145149
day = "date" if "date" in chart.data.columns else "day"
146150

147151
for col in cols:
148-
if chart.data[col].idxmax() + 1 == len(chart.data):
152+
if chart.data[prefix+col].idxmax() + 1 == len(chart.data):
149153
asterisk = True
150154

151155
# todo: bring this to an optional arg / i18n
152-
on = datetime.strftime(chart.data[day][chart.data[col].idxmax()], "%b %d")
156+
on = datetime.strftime(chart.data[day][chart.data[prefix+col].idxmax()], "%b %d")
153157

154158
messages.append(
155159
"{}{} peaks at {:,} on {}{}".format(
156160
labels[col],
157161
suffix,
158-
ceil(chart.data[col].max()),
162+
ceil(chart.data[prefix+col].max()),
159163
on,
160164
"*" if asterisk else "",
161165
)

src/penn_chime/models.py

Lines changed: 83 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, p: Parameters):
7070

7171
self.i_day = 0 # seed to the full length
7272
self.run_projection(p, [(self.beta, p.n_days)])
73-
self.i_day = i_day = int(get_argmin_ds(self.census_df, p.current_hospitalized))
73+
self.i_day = i_day = int(get_argmin_ds(self.raw["census_hospitalized"], p.current_hospitalized))
7474

7575
self.run_projection(p, self.gen_policy(p))
7676

@@ -120,6 +120,13 @@ def __init__(self, p: Parameters):
120120
)
121121
raise AssertionError('doubling_time or date_first_hospitalized must be provided.')
122122

123+
self.raw["date"] = self.raw["day"].astype("timedelta64[D]") + np.datetime64(p.current_date)
124+
125+
self.raw_df = pd.DataFrame(data=self.raw)
126+
self.dispositions_df = self.raw_df
127+
self.admits_df = self.raw_df
128+
self.census_df = self.raw_df
129+
123130
logger.info('len(np.arange(-i_day, n_days+1)): %s', len(np.arange(-self.i_day, p.n_days+1)))
124131
logger.info('len(raw_df): %s', len(self.raw_df))
125132

@@ -139,9 +146,9 @@ def __init__(self, p: Parameters):
139146

140147
self.sim_sir_w_date_df = build_sim_sir_w_date_df(self.raw_df, p.current_date, self.keys)
141148

142-
self.sim_sir_w_date_floor_df = build_floor_df(self.sim_sir_w_date_df, self.keys)
143-
self.admits_floor_df = build_floor_df(self.admits_df, p.dispositions.keys())
144-
self.census_floor_df = build_floor_df(self.census_df, p.dispositions.keys())
149+
self.sim_sir_w_date_floor_df = build_floor_df(self.sim_sir_w_date_df, self.keys, "")
150+
self.admits_floor_df = build_floor_df(self.admits_df, p.dispositions.keys(), "admits_")
151+
self.census_floor_df = build_floor_df(self.census_df, p.dispositions.keys(), "census_")
145152

146153
self.daily_growth_rate = get_growth_rate(p.doubling_time)
147154
self.daily_growth_rate_t = get_growth_rate(self.doubling_time_t)
@@ -156,7 +163,7 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
156163
self.run_projection(p, self.gen_policy(p))
157164

158165
# Skip values the would put the fit past peak
159-
peak_admits_day = self.admits_df.hospitalized.argmax()
166+
peak_admits_day = self.raw["admits_hospitalized"].argmax()
160167
if peak_admits_day < 0:
161168
continue
162169

@@ -186,7 +193,7 @@ def gen_policy(self, p: Parameters) -> Sequence[Tuple[float, int]]:
186193
]
187194

188195
def run_projection(self, p: Parameters, policy: Sequence[Tuple[float, int]]):
189-
self.raw_df = sim_sir_df(
196+
self.raw = sim_sir(
190197
self.susceptible,
191198
self.infected,
192199
p.recovered,
@@ -195,23 +202,24 @@ def run_projection(self, p: Parameters, policy: Sequence[Tuple[float, int]]):
195202
policy
196203
)
197204

198-
self.dispositions_df = build_dispositions_df(self.raw_df, self.rates, p.market_share, p.current_date)
199-
self.admits_df = build_admits_df(self.dispositions_df)
200-
self.census_df = build_census_df(self.admits_df, self.days)
201-
self.current_infected = self.raw_df.infected.loc[self.i_day]
205+
calculate_dispositions(self.raw, self.rates, p.market_share)
206+
calculate_admits(self.rates, self.raw)
207+
calculate_census(self.raw, self.days)
208+
209+
self.current_infected = self.raw["infected"][self.i_day]
202210

203211
def get_loss(self) -> float:
204212
"""Squared error: predicted vs. actual current hospitalized."""
205-
predicted = self.census_df.hospitalized.loc[self.i_day]
213+
predicted = self.raw["census_hospitalized"][self.i_day]
206214
return (self.current_hospitalized - predicted) ** 2.0
207215

208216

209-
def get_argmin_ds(census_df: pd.DataFrame, current_hospitalized: float) -> float:
217+
def get_argmin_ds(census, current_hospitalized: float) -> float:
210218
# By design, this forbids choosing a day after the peak
211219
# If that's a problem, see #381
212-
peak_day = census_df.hospitalized.argmax()
213-
losses_df = (census_df.hospitalized[:peak_day] - current_hospitalized) ** 2.0
214-
return losses_df.argmin()
220+
peak_day = census.argmax()
221+
losses = (census[:peak_day] - current_hospitalized) ** 2.0
222+
return losses.argmin()
215223

216224

217225
def get_beta(
@@ -259,31 +267,56 @@ def sir(
259267

260268
def gen_sir(
261269
s: float, i: float, r: float, gamma: float, i_day: int, policies: Sequence[Tuple[float, int]]
262-
) -> Generator[Tuple[int, float, float, float], None, None]:
270+
):
263271
"""Simulate SIR model forward in time yielding tuples.
264272
Parameter order has changed to allow multiple (beta, n_days)
265273
to reflect multiple changing social distancing policies.
266274
"""
267275
s, i, r = (float(v) for v in (s, i, r))
268276
n = s + i + r
269277
d = i_day
278+
279+
total_days = 1
280+
for beta, days in policies:
281+
total_days += days
282+
283+
d_a = np.empty(total_days, "int")
284+
s_a = np.empty(total_days, "float")
285+
i_a = np.empty(total_days, "float")
286+
r_a = np.empty(total_days, "float")
287+
288+
index = 0
270289
for beta, n_days in policies:
271290
for _ in range(n_days):
272-
yield d, s, i, r
291+
d_a[index] = d
292+
s_a[index] = s
293+
i_a[index] = i
294+
r_a[index] = r
295+
index += 1
296+
273297
s, i, r = sir(s, i, r, beta, gamma, n)
274298
d += 1
275-
yield d, s, i, r
276299

300+
d_a[index] = d
301+
s_a[index] = s
302+
i_a[index] = i
303+
r_a[index] = r
304+
return {
305+
"day": d_a,
306+
"susceptible": s_a,
307+
"infected": i_a,
308+
"recovered": r_a,
309+
"ever_infected": i_a + r_a
310+
}
277311

278-
def sim_sir_df(
312+
313+
def sim_sir(
279314
s: float, i: float, r: float,
280315
gamma: float, i_day: int, policies: Sequence[Tuple[float, int]]
281316
) -> pd.DataFrame:
282317
"""Simulate the SIR model forward in time."""
283-
return pd.DataFrame(
284-
data=gen_sir(s, i, r, gamma, i_day, policies),
285-
columns=("day", "susceptible", "infected", "recovered"),
286-
)
318+
data = gen_sir(s, i, r, gamma, i_day, policies)
319+
return data
287320

288321

289322
def build_sim_sir_w_date_df(
@@ -302,58 +335,50 @@ def build_sim_sir_w_date_df(
302335
})
303336

304337

305-
def build_floor_df(df, keys):
338+
def build_floor_df(df, keys, prefix):
306339
"""Build floor sim sir w date."""
307340
return pd.DataFrame({
308341
"day": df.day,
309342
"date": df.date,
310343
**{
311-
key: np.floor(df[key])
344+
prefix + key: np.floor(df[prefix+key])
312345
for key in keys
313346
}
314347
})
315348

316349

317-
def build_dispositions_df(
318-
raw_df: pd.DataFrame,
350+
def calculate_dispositions(
351+
raw: Dict,
319352
rates: Dict[str, float],
320353
market_share: float,
321-
current_date: datetime,
322-
) -> pd.DataFrame:
354+
):
323355
"""Build dispositions dataframe of patients adjusted by rate and market_share."""
324-
patients = raw_df.infected + raw_df.recovered
325-
day = raw_df.day
326-
return pd.DataFrame({
327-
"day": day,
328-
"date": day.astype('timedelta64[D]') + np.datetime64(current_date),
329-
**{
330-
key: patients * rate * market_share
331-
for key, rate in rates.items()
332-
}
333-
})
356+
for key, rate in rates.items():
357+
raw["ever_" + key] = raw["ever_infected"] * rate * market_share
358+
raw[key] = raw["ever_infected"] * rate * market_share
334359

335360

336-
def build_admits_df(dispositions_df: pd.DataFrame) -> pd.DataFrame:
361+
def calculate_admits(rates, raw: Dict):
337362
"""Build admits dataframe from dispositions."""
338-
admits_df = dispositions_df - dispositions_df.shift(1)
339-
admits_df.day = dispositions_df.day
340-
admits_df.date = dispositions_df.date
341-
return admits_df
363+
for key in rates.keys():
364+
ever = raw["ever_" + key]
365+
admit = np.empty_like(ever)
366+
admit[0] = np.nan
367+
admit[1:] = ever[1:] - ever[:-1]
368+
raw["admits_"+key] = admit
369+
raw[key] = admit
342370

343371

344-
def build_census_df(
345-
admits_df: pd.DataFrame,
372+
def calculate_census(
373+
raw: Dict,
346374
lengths_of_stay: Dict[str, int],
347-
) -> pd.DataFrame:
375+
):
348376
"""Average Length of Stay for each disposition of COVID-19 case (total guesses)"""
349-
return pd.DataFrame({
350-
'day': admits_df.day,
351-
'date': admits_df.date,
352-
**{
353-
key: (
354-
admits_df[key].cumsum()
355-
- admits_df[key].cumsum().shift(los).fillna(0)
356-
)
357-
for key, los in lengths_of_stay.items()
358-
}
359-
})
377+
n_days = raw["day"].shape[0]
378+
for key, los in lengths_of_stay.items():
379+
cumsum = np.empty(n_days + los)
380+
cumsum[:los+1] = 0.0
381+
cumsum[los+1:] = raw["admits_" + key][1:].cumsum()
382+
383+
census = cumsum[los:] - cumsum[:-los]
384+
raw["census_" + key] = census

tests/by_doubling_time/2020-03-28_projected_admits.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
,day,date,hospitalized,icu,ventilated
1+
,day,date,admits_hospitalized,admits_icu,admits_ventilated
22
0,-4,2020-03-24,,,
33
1,-3,2020-03-25,2.5542297270266676,0.7662689181079996,0.5108459454053333
44
2,-2,2020-03-26,2.8373214956844457,0.8511964487053332,0.5674642991368888

tests/by_doubling_time/2020-03-28_projected_census.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
,day,date,hospitalized,icu,ventilated
1+
,day,date,census_hospitalized,census_icu,census_ventilated
22
0,-4,2020-03-24,,,
33
1,-3,2020-03-25,3.0,1.0,1.0
44
2,-2,2020-03-26,6.0,2.0,2.0

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def admits_df():
112112

113113
@pytest.fixture
114114
def admits_floor_df(param, admits_df):
115-
return build_floor_df(admits_df, param.dispositions.keys())
115+
return build_floor_df(admits_df, param.dispositions.keys(), "admits_")
116116

117117

118118
@pytest.fixture
@@ -123,5 +123,5 @@ def census_df():
123123

124124
@pytest.fixture
125125
def census_floor_df(param, census_df):
126-
return build_floor_df(census_df, param.dispositions.keys())
126+
return build_floor_df(census_df, param.dispositions.keys(), "census_")
127127

tests/penn_chime/test_charts.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
def test_admits_chart(admits_floor_df):
2020
chart = build_admits_chart(alt=alt, admits_floor_df=admits_floor_df)
2121
assert isinstance(chart, (alt.Chart, alt.LayerChart))
22-
assert round(chart.data.iloc[40].icu, 0) == 38
22+
assert round(chart.data.iloc[40].admits_icu, 0) == 38
2323

2424
# test fx call with no params
2525
with pytest.raises(TypeError):
@@ -28,39 +28,39 @@ def test_admits_chart(admits_floor_df):
2828

2929
def test_build_descriptions(admits_floor_df, param):
3030
chart = build_admits_chart(alt=alt, admits_floor_df=admits_floor_df)
31-
description = build_descriptions(chart=chart, labels=param.labels)
31+
description = build_descriptions(chart=chart, labels=param.labels, prefix="admits_")
3232

3333
hosp, icu, vent = description.split("\n\n") # break out the description into lines
3434

35-
max_hosp = chart.data["hospitalized"].max()
35+
max_hosp = chart.data["admits_hospitalized"].max()
3636
assert str(ceil(max_hosp)) in hosp
3737

3838

3939
def test_no_asterisk(admits_floor_df, param):
4040
param.n_days = 600
4141

4242
chart = build_admits_chart(alt=alt, admits_floor_df=admits_floor_df)
43-
description = build_descriptions(chart=chart, labels=param.labels)
43+
description = build_descriptions(chart=chart, labels=param.labels, prefix="admits_")
4444
assert "*" not in description
4545

4646

4747
def test_census(census_floor_df, param):
4848
chart = build_census_chart(alt=alt, census_floor_df=census_floor_df)
49-
description = build_descriptions(chart=chart, labels=param.labels)
49+
description = build_descriptions(chart=chart, labels=param.labels, prefix="census_")
5050

51-
assert str(ceil(chart.data["ventilated"].max())) in description
52-
assert str(chart.data["icu"].idxmax()) not in description
51+
assert str(ceil(chart.data["census_ventilated"].max())) in description
52+
assert str(chart.data["census_icu"].idxmax()) not in description
5353
assert (
54-
datetime.strftime(chart.data.iloc[chart.data["icu"].idxmax()].date, "%b %d")
54+
datetime.strftime(chart.data.iloc[chart.data["census_icu"].idxmax()].date, "%b %d")
5555
in description
5656
)
5757

5858

5959
def test_census_chart(census_floor_df):
6060
chart = build_census_chart(alt=alt, census_floor_df=census_floor_df)
6161
assert isinstance(chart, (alt.Chart, alt.LayerChart))
62-
assert chart.data.iloc[1].hospitalized == 3
63-
assert chart.data.iloc[49].ventilated == 365
62+
assert chart.data.iloc[1].census_hospitalized == 3
63+
assert chart.data.iloc[49].census_ventilated == 365
6464

6565
# test fx call with no params
6666
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)