Skip to content

Commit fa8b692

Browse files
authored
Merge pull request #471 from CodeForPhilly/PhilMiller/model-performance
Speed up by 20x by never copying arrays and only constructing a dataframe after the model has run
2 parents fe8f41a + c4f218c commit fa8b692

File tree

12 files changed

+227
-119
lines changed

12 files changed

+227
-119
lines changed

src/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
st.markdown("Projected number of **daily** COVID-19 admissions. \n\n _NOTE: Now including estimates of prior admissions for comparison._")
3737
admits_chart = build_admits_chart(alt=alt, admits_floor_df=m.admits_floor_df, max_y_axis=p.max_y_axis)
3838
st.altair_chart(admits_chart, use_container_width=True)
39-
st.markdown(build_descriptions(chart=admits_chart, labels=p.labels, suffix=" Admissions"))
39+
st.markdown(build_descriptions(chart=admits_chart, labels=p.labels, prefix="admits_", suffix=" Admissions"))
4040
display_download_link(
4141
st,
4242
filename=f"{p.current_date}_projected_admits.csv",
@@ -58,7 +58,7 @@
5858
st.markdown("Projected **census** of COVID-19 patients, accounting for arrivals and discharges \n\n _NOTE: Now including estimates of prior census for comparison._")
5959
census_chart = build_census_chart(alt=alt, census_floor_df=m.census_floor_df, max_y_axis=p.max_y_axis)
6060
st.altair_chart(census_chart, use_container_width=True)
61-
st.markdown(build_descriptions(chart=census_chart, labels=p.labels, suffix=" Census"))
61+
st.markdown(build_descriptions(chart=census_chart, labels=p.labels, prefix="census_", suffix=" Census"))
6262
display_download_link(
6363
st,
6464
filename=f"{p.current_date}_projected_census.csv",

src/chime_dash/app/services/callbacks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@ def handle_model_change(i, sidebar_data):
3838
if sidebar_data:
3939
pars = parameters_deserializer(sidebar_data["parameters"])
4040
model = SimSirModel(pars)
41+
vis = i.components.get("visualizations", None) if i else None
42+
vis_content = vis.content if vis else None
43+
4144
viz_kwargs = dict(
4245
labels=pars.labels,
4346
table_mod=7,
4447
max_y_axis=pars.max_y_axis,
48+
content=vis_content
4549
)
4650
result.extend(i.components["intro"].build(model, pars))
4751
for df_key in ["admits_df", "census_df", "sim_sir_w_date_df"]:

src/chime_dash/app/services/plotting.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,17 @@
77
from pandas import DataFrame
88

99

10-
def plot_dataframe(dataframe: DataFrame, max_y_axis: int = None,) -> Dict[str, Any]:
11-
"""
10+
def plot_dataframe(
11+
dataframe: DataFrame,
12+
max_y_axis: int = None,
13+
) -> Dict[str, Any]:
14+
"""Returns dictionary used for plotly graphs
15+
16+
Arguments:
17+
dataframe: The dataframe to plot. Plots all columns as y, index is x.
18+
max_y_axis: Maximal value on y-axis.
1219
"""
20+
1321
if max_y_axis is None:
1422
yaxis = {}
1523
else:

src/chime_dash/app/templates/en/visualizations.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,20 @@ hosp: Hospitalized
1616
icu: ICU
1717
vent: Ventalized
1818

19+
# Admits df column translations
20+
admits_hospitalized: Hospitalized
21+
admits_icu: ICU
22+
admits_ventilated: Ventilated
23+
24+
# Admits df column translations
25+
census_hospitalized: Hospitalized
26+
census_icu: ICU
27+
census_ventilated: Ventilated
28+
29+
# SIR df column translations
30+
susceptible: Susceptible
31+
infected: Infected
32+
recovered: Recovered
33+
1934
# Date Localization
2035
date-format: "%m%d%Y"

src/chime_dash/app/utils/__init__.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,50 @@ def get_n_switch_values(input_value, elements_to_update) -> List[bool]:
116116

117117
def prepare_visualization_group(df: DataFrame = None, **kwargs) -> List[Any]:
118118
"""Creates plot, table and download link for data frame.
119+
120+
Arguments:
121+
df: The Dataframe to plot
122+
content: Dict[str, str]
123+
Mapping for translating columns and index.
124+
max_y_axis: int
125+
Maximal value on y-axis
126+
labels: List[str]
127+
Columns to display
128+
table_mod: int
129+
Displays only each `table_mod` row in table
130+
119131
"""
120132
result = [{}, None, None]
121133
if df is not None and isinstance(df, DataFrame):
134+
135+
date_column = "date"
136+
day_column = "day"
137+
138+
# Translate column and index if specified
139+
content = kwargs.get("content", None)
140+
if content:
141+
columns = {col: content[col] for col in df.columns if col in content}
142+
index = (
143+
{df.index.name: content[df.index.name]}
144+
if df.index.name and df.index.name in content
145+
else None
146+
)
147+
df = df.rename(columns=columns, index=index)
148+
date_column = content.get(date_column, date_column)
149+
day_column = content.get(day_column, day_column)
150+
122151
plot_data = plot_dataframe(
123-
df.dropna().set_index("date").drop(columns=["day"]),
152+
df.dropna().set_index(date_column).drop(columns=[day_column]),
124153
max_y_axis=kwargs.get("max_y_axis", None),
125154
)
126155

156+
157+
# translate back for backwards compability of build_table
158+
column_map = {day_column: "day", date_column: "date"}
127159
table = (
128160
df_to_html_table(
129161
build_table(
130-
df=df,
162+
df=df.rename(columns=column_map),
131163
labels=kwargs.get("labels", df.columns),
132164
modulo=kwargs.get("table_mod", 7),
133165
),
@@ -140,7 +172,9 @@ def prepare_visualization_group(df: DataFrame = None, **kwargs) -> List[Any]:
140172
# else None
141173
)
142174

143-
csv = build_csv_download(df)
175+
# Convert columnnames to lowercase
176+
column_map = {col: col.lower() for col in df.columns}
177+
csv = build_csv_download(df.rename(columns=column_map))
144178
result = [plot_data, table, csv]
145179

146180
return result
@@ -153,4 +187,5 @@ def get_instance(*args, **kwargs):
153187
if class_ not in instances:
154188
instances[class_] = class_(*args, **kwargs)
155189
return instances[class_]
190+
156191
return get_instance

src/penn_chime/charts.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def build_admits_chart(
2626
# TODO fix the fold to allow any number of dispositions
2727
points = (
2828
alt.Chart()
29-
.transform_fold(fold=["hospitalized", "icu", "ventilated"])
29+
.transform_fold(fold=["admits_hospitalized", "admits_icu", "admits_ventilated"])
3030
.encode(x=alt.X(**x), y=alt.Y(**y), color=color, tooltip=tooltip)
3131
.mark_line(point=True)
3232
.encode(
@@ -65,7 +65,7 @@ def build_census_chart(
6565
# TODO fix the fold to allow any number of dispositions
6666
points = (
6767
alt.Chart()
68-
.transform_fold(fold=["hospitalized", "icu", "ventilated"])
68+
.transform_fold(fold=["census_hospitalized", "census_icu", "census_ventilated"])
6969
.encode(x=alt.X(**x), y=alt.Y(**y), color=color, tooltip=tooltip)
7070
.mark_line(point=True)
7171
.encode(
@@ -128,7 +128,11 @@ def build_sim_sir_w_date_chart(
128128

129129

130130
def build_descriptions(
131-
*, chart: Chart, labels: Dict[str, str], suffix: str = ""
131+
*,
132+
chart: Chart,
133+
labels: Dict[str, str],
134+
prefix: str = "",
135+
suffix: str = ""
132136
) -> str:
133137
"""
134138
@@ -145,17 +149,17 @@ def build_descriptions(
145149
day = "date" if "date" in chart.data.columns else "day"
146150

147151
for col in cols:
148-
if chart.data[col].idxmax() + 1 == len(chart.data):
152+
if chart.data[prefix+col].idxmax() + 1 == len(chart.data):
149153
asterisk = True
150154

151155
# todo: bring this to an optional arg / i18n
152-
on = datetime.strftime(chart.data[day][chart.data[col].idxmax()], "%b %d")
156+
on = datetime.strftime(chart.data[day][chart.data[prefix+col].idxmax()], "%b %d")
153157

154158
messages.append(
155159
"{}{} peaks at {:,} on {}{}".format(
156160
labels[col],
157161
suffix,
158-
ceil(chart.data[col].max()),
162+
ceil(chart.data[prefix+col].max()),
159163
on,
160164
"*" if asterisk else "",
161165
)

0 commit comments

Comments
 (0)