Skip to content

Commit d7e56cb

Browse files
committed
Add sim_sir_df and eliminate comparison float casts
1 parent ce50a79 commit d7e56cb

File tree

2 files changed

+93
-35
lines changed

2 files changed

+93
-35
lines changed

penn_chime/models.py

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,93 @@
1-
from typing import Tuple
1+
from typing import Generator, Tuple
22

33
import numpy as np
4+
import pandas as pd
45
import streamlit as st
56

67

7-
# The SIR model, one time step
88
@st.cache
9-
def sir(y, beta, gamma, N):
10-
S, I, R = y
11-
Sn = (-beta * S * I) + S
12-
In = (beta * S * I - gamma * I) + I
13-
Rn = gamma * I + R
14-
if Sn < 0:
15-
Sn = 0
16-
if In < 0:
17-
In = 0
18-
if Rn < 0:
19-
Rn = 0
20-
21-
scale = N / (Sn + In + Rn)
22-
return Sn * scale, In * scale, Rn * scale
23-
24-
25-
# Run the SIR model forward in time
9+
def sir(
10+
s: float, i: float, r: float,
11+
beta: float, gamma: float, n: float
12+
) -> Tuple[float, float, float]:
13+
"""The SIR model, one time step."""
14+
s_n = (-beta * s * i) + s
15+
i_n = (beta * s * i - gamma * i) + i
16+
r_n = gamma * i + r
17+
if s_n < 0.0:
18+
s_n = 0.0
19+
if i_n < 0.0:
20+
i_n = 0.0
21+
if r_n < 0.0:
22+
r_n = 0.0
23+
24+
scale = n / (s_n + i_n + r_n)
25+
return s_n * scale, i_n * scale, r_n * scale
26+
27+
28+
def gen_sir(
29+
s: float, i: float, r: float,
30+
beta: float, gamma: float, n_days: int, beta_decay: float = 0.0
31+
) -> Generator:
32+
"""Simulate SIR model forward in time yielding tuples."""
33+
s, i, r, beta_decay = (float(v) for v in (s, i, r, beta_decay))
34+
n = s + i + r
35+
for _ in range(n_days + 1):
36+
yield s, i, r
37+
s, i, r = sir(s, i, r, beta, gamma, n)
38+
# okay even if beta_decay is 0.0
39+
beta = beta * (1.0 - beta_decay)
40+
41+
2642
@st.cache
2743
def sim_sir(
28-
S, I, R, beta, gamma, n_days, beta_decay=0
44+
s: float, i: float, r: float,
45+
beta: float, gamma: float, n_days: int, beta_decay: float = 0.0
2946
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
30-
N = S + I + R
31-
s, i, r = [S], [I], [R]
47+
"""Simulate the SIR model forward in time."""
48+
s, i, r, beta_decay = (float(v) for v in (s, i, r, beta_decay))
49+
n = s + i + r
50+
s_v, i_v, r_v = [s], [i], [r]
3251
for day in range(n_days):
33-
y = S, I, R
34-
S, I, R = sir(y, beta, gamma, N)
35-
beta = beta * (1 - beta_decay) # okay even if beta_decay is 0
36-
s.append(S)
37-
i.append(I)
38-
r.append(R)
52+
s, i, r = sir(s, i, r, beta, gamma, n)
53+
# okay even if beta_decay is 0.0
54+
beta = beta * (1.0 - beta_decay)
55+
s_v.append(s)
56+
i_v.append(i)
57+
r_v.append(r)
58+
59+
return (
60+
np.array(s_v),
61+
np.array(i_v),
62+
np.array(r_v),
63+
)
64+
65+
66+
@st.cache
67+
def sim_sir_df(
68+
s: float, i: float, r: float,
69+
beta: float, gamma: float, n_days: int, beta_decay: float = 0.0
70+
) -> pd.DataFrame:
71+
"""Simulate the SIR model forward in time."""
72+
return pd.DataFrame(
73+
data=gen_sir(s, i, r, beta, gamma, n_days, beta_decay),
74+
columns=("S", "I", "R"),
75+
)
76+
77+
78+
@st.cache
79+
def get_hospitalizations2(
80+
infected: np.ndarray, rates: Tuple[float, ...], market_share: float = 1.0
81+
) -> Tuple[np.ndarray, ...]:
82+
"""Get hopitalizations adjusted by rate and market_share."""
83+
return (*(infected * rate * market_share for rate in rates),)
3984

40-
s, i, r = np.array(s), np.array(i), np.array(r)
41-
return s, i, r
4285

4386
@st.cache
4487
def get_hospitalizations(
4588
infected: np.ndarray, rates: Tuple[float, float, float], market_share: float
4689
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
90+
"""Get hopitalizations adjusted by rate and market_share."""
4791
hosp_rate, icu_rate, vent_rate = rates
4892

4993
hosp = infected * hosp_rate * market_share

test_app.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
hosp_rate, icu_rate, vent_rate, hosp_los, icu_los, vent_los, market_share, S, initial_infections,
66
detection_prob, hospitalization_rates, I, R, beta, gamma, n_days, beta_decay,
77
projection_admits, alt)
8-
from penn_chime.models import sir, sim_sir
8+
from penn_chime.models import sir, sim_sir, sim_sir_df
99
from penn_chime.presentation import display_header, new_admissions_chart
1010

1111

@@ -70,12 +70,11 @@ def test_header_fail():
7070

7171
# Test the math
7272

73-
7473
def test_sir():
7574
"""
7675
Someone who is good at testing, help
7776
"""
78-
assert sir((100, 1, 0), 0.2, 0.5, 1) == (
77+
assert sir(100, 1, 0, 0.2, 0.5, 1) == (
7978
0.7920792079207921,
8079
0.20297029702970298,
8180
0.0049504950495049506,
@@ -88,13 +87,28 @@ def test_sim_sir():
8887
"""
8988
s, i, r = sim_sir(S, I, R, beta, gamma, n_days, beta_decay=beta_decay)
9089
assert round(s[0], 0) == 4119405
91-
assert round(s[-1], 2) == 3421436.31
9290
assert round(i[0], 2) == 533.33
91+
assert round(r[0], 0) == 0.0
92+
assert round(s[-1], 2) == 3421436.31
9393
assert round(i[-1], 2) == 418157.62
94-
assert round(r[0], 0) == 0
9594
assert round(r[-1], 2) == 280344.40
9695

9796

97+
def test_sim_sir_df():
98+
"""
99+
Rounding to move fast past decimal place issues
100+
"""
101+
df = sim_sir_df(S, I, R, beta, gamma, n_days, beta_decay=beta_decay)
102+
first = df.iloc[0]
103+
last = df.iloc[-1]
104+
assert round(first[0], 0) == 4119405
105+
assert round(first[1], 2) == 533.33
106+
assert round(first[2], 0) == 0.0
107+
assert round(last[0], 2) == 3421436.31
108+
assert round(last[1], 2) == 418157.62
109+
assert round(last[2], 2) == 280344.40
110+
111+
98112
def test_initial_conditions():
99113
"""
100114
Note: For the rates (ie hosp_rate) - just change the value, leave the "100" alone.

0 commit comments

Comments
 (0)