Skip to content

Commit a59376c

Browse files
committed
Added covid data loader; use random number seed
1 parent 18c5513 commit a59376c

File tree

3 files changed

+1322
-136
lines changed

3 files changed

+1322
-136
lines changed

examples/case_studies/bayesian_workflow.ipynb

Lines changed: 1218 additions & 128 deletions
Large diffs are not rendered by default.

examples/case_studies/bayesian_workflow.myst.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ import seaborn as sns
4444
warnings.simplefilter("ignore")
4545
4646
sns.set_context("talk")
47-
# plt.style.use('seaborn-whitegrid')
4847
49-
sampler_kwargs = {"chains": 4, "cores": 4, "tune": 2000}
48+
RANDOM_SEED = 8451997
49+
sampler_kwargs = {"chains": 4, "cores": 4, "tune": 2000, "random_seed": RANDOM_SEED}
5050
```
5151

5252
Strengths of Bayesian statistics that are critical here:
@@ -226,7 +226,7 @@ with pm.Model() as model_exp2:
226226

227227
```{code-cell} ipython3
228228
with model_exp2:
229-
prior_pred = pm.sample_prior_predictive()
229+
prior_pred = pm.sample_prior_predictive(random_seed=RANDOM_SEED)
230230
231231
fig, ax = plt.subplots(figsize=(12, 8))
232232
ax.plot(prior_pred.prior_predictive["obs"].values.squeeze().T, color="0.5", alpha=0.1)
@@ -269,7 +269,7 @@ with pm.Model() as model_exp3:
269269
# Likelihood
270270
pm.NegativeBinomial("obs", growth, alpha=alpha, observed=confirmed)
271271
272-
prior_pred = pm.sample_prior_predictive()
272+
prior_pred = pm.sample_prior_predictive(random_seed=RANDOM_SEED)
273273
```
274274

275275
```{code-cell} ipython3
@@ -354,7 +354,7 @@ Similar to the prior predictive, we can also generate new data by repeatedly tak
354354
```{code-cell} ipython3
355355
with model_exp3:
356356
# Draw sampels from posterior predictive
357-
post_pred = pm.sample_posterior_predictive(trace_exp3.posterior)
357+
post_pred = pm.sample_posterior_predictive(trace_exp3.posterior, random_seed=RANDOM_SEED)
358358
```
359359

360360
```{code-cell} ipython3
@@ -426,7 +426,7 @@ with model_exp4:
426426
# the shape.
427427
pm.set_data({"t": np.arange(60), "confirmed": np.zeros(60, dtype="int")})
428428
429-
post_pred = pm.sample_posterior_predictive(trace_exp4.posterior)
429+
post_pred = pm.sample_posterior_predictive(trace_exp4.posterior, random_seed=RANDOM_SEED)
430430
```
431431

432432
As we held data back before, we can now see how the predictions of the model
@@ -482,7 +482,7 @@ with pm.Model() as logistic_model:
482482
"obs", growth, alpha=pm.Gamma("alpha", mu=6, sigma=1), observed=confirmed_data
483483
)
484484
485-
prior_pred = pm.sample_prior_predictive()
485+
prior_pred = pm.sample_prior_predictive(random_seed=RANDOM_SEED)
486486
```
487487

488488
```{code-cell} ipython3
@@ -612,7 +612,9 @@ plt.tight_layout();
612612

613613
```{code-cell} ipython3
614614
with logistic_model:
615-
pm.sample_posterior_predictive(trace_logistic_us, extend_inferencedata=True)
615+
pm.sample_posterior_predictive(
616+
trace_logistic_us, extend_inferencedata=True, random_seed=RANDOM_SEED
617+
)
616618
```
617619

618620
```{code-cell} ipython3
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import pandas as pd
2+
import numpy as np
3+
4+
5+
def load_individual_timeseries(name):
6+
base_url = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series"
7+
url = f"{base_url}/time_series_covid19_{name}_global.csv"
8+
df = pd.read_csv(url, index_col=["Country/Region", "Province/State", "Lat", "Long"])
9+
df["type"] = name.lower()
10+
df.columns.name = "date"
11+
12+
df = (
13+
df.set_index("type", append=True)
14+
.reset_index(["Lat", "Long"], drop=True)
15+
.stack()
16+
.reset_index()
17+
.set_index("date")
18+
)
19+
df.index = pd.to_datetime(df.index)
20+
df.columns = ["country", "state", "type", "cases"]
21+
22+
# Move HK to country level
23+
df.loc[df.state == "Hong Kong", "country"] = "Hong Kong"
24+
df.loc[df.state == "Hong Kong", "state"] = np.nan
25+
26+
# Aggregate large countries split by states
27+
df = pd.concat(
28+
[
29+
df,
30+
(
31+
df.loc[~df.state.isna()]
32+
.groupby(["country", "date", "type"])
33+
.sum()
34+
.rename(index=lambda x: x + " (total)", level=0)
35+
.reset_index(level=["country", "type"])
36+
),
37+
]
38+
)
39+
return df
40+
41+
42+
def load_data(drop_states=False, p_crit=0.05, filter_n_days_100=None):
43+
df = load_individual_timeseries("confirmed")
44+
df = df.rename(columns={"cases": "confirmed"})
45+
if drop_states:
46+
# Drop states for simplicity
47+
df = df.loc[df.state.isnull()]
48+
49+
# Estimated critical cases
50+
df = df.assign(critical_estimate=df.confirmed * p_crit)
51+
52+
# Compute days relative to when 100 confirmed cases was crossed
53+
df.loc[:, "days_since_100"] = np.nan
54+
for country in df.country.unique():
55+
if not df.loc[(df.country == country), "state"].isnull().all():
56+
for state in df.loc[(df.country == country), "state"].unique():
57+
df.loc[(df.country == country) & (df.state == state), "days_since_100"] = np.arange(
58+
-len(
59+
df.loc[(df.country == country) & (df.state == state) & (df.confirmed < 100)]
60+
),
61+
len(
62+
df.loc[
63+
(df.country == country) & (df.state == state) & (df.confirmed >= 100)
64+
]
65+
),
66+
)
67+
else:
68+
df.loc[(df.country == country), "days_since_100"] = np.arange(
69+
-len(df.loc[(df.country == country) & (df.confirmed < 100)]),
70+
len(df.loc[(df.country == country) & (df.confirmed >= 100)]),
71+
)
72+
73+
# Add recovered cases
74+
# df_recovered = load_individual_timeseries('Recovered')
75+
# df_r = df_recovered.set_index(['country', 'state'], append=True)[['cases']]
76+
# df_r.columns = ['recovered']
77+
78+
# Add deaths
79+
df_deaths = load_individual_timeseries("deaths")
80+
df_d = df_deaths.set_index(["country", "state"], append=True)[["cases"]]
81+
df_d.columns = ["deaths"]
82+
83+
df = (
84+
df.set_index(["country", "state"], append=True)
85+
# .join(df_r)
86+
.join(df_d).reset_index(["country", "state"])
87+
)
88+
89+
if filter_n_days_100 is not None:
90+
# Select countries for which we have at least some information
91+
countries = pd.Series(df.loc[df.days_since_100 >= filter_n_days_100].country.unique())
92+
df = df.loc[lambda x: x.country.isin(countries)]
93+
94+
return df

0 commit comments

Comments
 (0)