Skip to content

Commit 90c7b1b

Browse files
committed
Factor out mitigation policy generation
1 parent 90d8288 commit 90c7b1b

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

src/penn_chime/models.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@
2828

2929
class SimSirModel:
3030

31+
def gen_policy(self, p: Parameters) -> List[Tuple[float, int]]:
32+
if p.mitigation_date is not None:
33+
mitigation_day = -(p.current_date - p.mitigation_date).days
34+
else:
35+
mitigation_day = 0
36+
37+
total_days = self.i_day + p.n_days
38+
39+
if mitigation_day < -self.i_day:
40+
mitigation_day = -self.i_day
41+
42+
pre_mitigation_days = self.i_day + mitigation_day
43+
post_mitigation_days = total_days - pre_mitigation_days
44+
45+
return [
46+
(self.beta, pre_mitigation_days),
47+
(self.beta_t, post_mitigation_days),
48+
]
49+
3150
def __init__(self, p: Parameters):
3251

3352
self.rates = {
@@ -66,14 +85,13 @@ def __init__(self, p: Parameters):
6685
intrinsic_growth_rate = get_growth_rate(p.doubling_time)
6786

6887
self.beta = get_beta(intrinsic_growth_rate, gamma, self.susceptible, 0.0)
88+
self.beta_t = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, p.relative_contact_rate)
6989

7090
self.i_day = 0 # seed to the full length
71-
self.beta_t = self.beta
72-
self.run_projection(p)
91+
self.run_projection(p, [(self.beta, p.n_days)])
7392
self.i_day = i_day = int(get_argmin_ds(self.census_df, p.current_hospitalized))
7493

75-
self.beta_t = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, p.relative_contact_rate)
76-
self.run_projection(p)
94+
self.run_projection(p, self.gen_policy(p))
7795

7896
logger.info('Set i_day = %s', i_day)
7997
p.date_first_hospitalized = p.current_date - timedelta(days=i_day)
@@ -100,7 +118,7 @@ def __init__(self, p: Parameters):
100118
self.beta = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, 0.0)
101119
self.beta_t = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, p.relative_contact_rate)
102120

103-
self.run_projection(p)
121+
self.run_projection(p, self.gen_policy(p))
104122
loss = self.get_loss()
105123
losses[i] = loss
106124

@@ -109,7 +127,7 @@ def __init__(self, p: Parameters):
109127
intrinsic_growth_rate = get_growth_rate(p.doubling_time)
110128
self.beta = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, 0.0)
111129
self.beta_t = get_beta(intrinsic_growth_rate, self.gamma, self.susceptible, p.relative_contact_rate)
112-
self.run_projection(p)
130+
self.run_projection(p, self.gen_policy(p))
113131

114132
self.population = p.population
115133
else:
@@ -146,30 +164,14 @@ def __init__(self, p: Parameters):
146164
self.daily_growth_rate = get_growth_rate(p.doubling_time)
147165
self.daily_growth_rate_t = get_growth_rate(self.doubling_time_t)
148166

149-
def run_projection(self, p):
150-
if p.mitigation_date is not None:
151-
mitigation_day = -(p.current_date - p.mitigation_date).days
152-
else:
153-
mitigation_day = 0
154-
155-
total_days = self.i_day + p.n_days
156-
157-
if mitigation_day < -self.i_day:
158-
mitigation_day = -self.i_day
159-
160-
pre_mitigation_days = self.i_day + mitigation_day
161-
post_mitigation_days = total_days - pre_mitigation_days
162-
167+
def run_projection(self, p: Parameters, policy: List[Tuple[float, int]]):
163168
self.raw_df = sim_sir_df(
164169
self.susceptible,
165170
self.infected,
166171
p.recovered,
167172
self.gamma,
168173
-self.i_day,
169-
[
170-
(self.beta, pre_mitigation_days),
171-
(self.beta_t, post_mitigation_days),
172-
]
174+
policy
173175
)
174176

175177
self.dispositions_df = build_dispositions_df(self.raw_df, self.rates, p.market_share, p.current_date)

0 commit comments

Comments
 (0)