Skip to content

Commit 99c84a3

Browse files
committed
Utilize pytest fixtures for reused test objects
1 parent 81e3762 commit 99c84a3

File tree

1 file changed

+80
-97
lines changed

1 file changed

+80
-97
lines changed

tests/test_app.py

Lines changed: 80 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Tests."""
22

3-
from copy import copy
43
from math import ceil # type: ignore
54
from datetime import date, datetime # type: ignore
65
import pytest # type: ignore
@@ -29,8 +28,29 @@
2928

3029
EPSILON = 1.e-7
3130

31+
# set up
32+
33+
# we just want to verify that st _attempted_ to render the right stuff
34+
# so we store the input, and make sure that it matches what we expect
35+
class MockStreamlit:
36+
def __init__(self):
37+
self.render_store = []
38+
self.markdown = self.just_store_instead_of_rendering
39+
self.latex = self.just_store_instead_of_rendering
40+
self.subheader = self.just_store_instead_of_rendering
41+
42+
def just_store_instead_of_rendering(self, inp, *args, **kwargs):
43+
self.render_store.append(inp)
44+
return None
45+
46+
@pytest.fixture
47+
def mock_st():
48+
return MockStreamlit()
49+
3250
# The defaults in settings will change and break the tests
33-
DEFAULTS = Parameters(
51+
@pytest.fixture
52+
def DEFAULTS():
53+
return Parameters(
3454
region=Regions(
3555
delaware=564696,
3656
chester=519293,
@@ -50,7 +70,9 @@
5070
ventilated=Disposition(0.005, 10),
5171
)
5272

53-
PARAM = Parameters(
73+
@pytest.fixture
74+
def param():
75+
return Parameters(
5476
current_date=datetime(year=2020, month=3, day=28),
5577
current_hospitalized=100,
5678
doubling_time=6.0,
@@ -63,7 +85,9 @@
6385
n_days=60,
6486
)
6587

66-
HALVING_PARAM = Parameters(
88+
@pytest.fixture
89+
def halving_param():
90+
return Parameters(
6791
current_date=datetime(year=2020, month=3, day=28),
6892
current_hospitalized=100,
6993
doubling_time=6.0,
@@ -76,93 +100,78 @@
76100
n_days=60,
77101
)
78102

79-
MODEL = Model(copy(PARAM))
80-
HALVING_MODEL = Model(copy(HALVING_PARAM))
81-
82-
83-
# set up
84-
85-
# we just want to verify that st _attempted_ to render the right stuff
86-
# so we store the input, and make sure that it matches what we expect
87-
class MockStreamlit:
88-
def __init__(self):
89-
self.render_store = []
90-
self.markdown = self.just_store_instead_of_rendering
91-
self.latex = self.just_store_instead_of_rendering
92-
self.subheader = self.just_store_instead_of_rendering
93-
94-
def just_store_instead_of_rendering(self, inp, *args, **kwargs):
95-
self.render_store.append(inp)
96-
return None
103+
@pytest.fixture
104+
def model(param):
105+
return Model(param)
97106

98-
def cleanup(self):
99-
"""
100-
Call this after every test, unless you intentionally want to accumulate stuff-to-render
101-
"""
102-
self.render_store = []
107+
@pytest.fixture
108+
def halving_model(halving_param):
109+
return Model(halving_param)
103110

111+
@pytest.fixture
112+
def admits_df():
113+
return pd.read_csv('tests/by_doubling_time/2020-03-28_projected_admits.csv', parse_dates=['date'])
104114

105-
st = MockStreamlit()
115+
@pytest.fixture
116+
def census_df():
117+
return pd.read_csv('tests/by_doubling_time/2020-03-28_projected_census.csv', parse_dates=['date'])
106118

107119

108120
# test presentation
121+
def header_test_helper(expected_str, model, param, mock_st):
122+
display_header(mock_st, model, param)
123+
assert [s for s in mock_st.render_store if expected_str in s],\
124+
f"Expected the string '{expected_str}' in the display header"
109125

110-
111-
def header_test_helper(expected_str, model, param):
112-
st.cleanup()
113-
display_header(st, model, param)
114-
assert [s for s in st.render_store if expected_str in s],\
115-
"Expected the string '{expected}' in the display header".format(expected=expected_str)
116-
st.cleanup()
117-
118-
119-
def test_penn_logo_in_header():
126+
def test_penn_logo_in_header(model, param, mock_st):
120127
penn_css = '<link rel="stylesheet" href="https://www1.pennmedicine.org/styles/shared/penn-medicine-header.css">'
121-
header_test_helper(penn_css, MODEL, PARAM)
128+
header_test_helper(penn_css, model, param, mock_st)
122129

123130

124-
def test_the_rest_of_header_shows_up():
131+
def test_the_rest_of_header_shows_up(model, param, mock_st):
125132
random_part_of_header = "implying an effective $R_t$ of"
126-
header_test_helper(random_part_of_header, MODEL, PARAM)
133+
header_test_helper(random_part_of_header, model, param, mock_st)
127134

128135

129-
def test_mitigation_statement():
136+
def test_mitigation_statement(model, param, mock_st):
130137
expected_doubling = "outbreak **reduces the doubling time to 7.8** days"
138+
header_test_helper(expected_doubling, model, param, mock_st)
139+
140+
def test_mitigation_statement_halving(halving_model, halving_param, mock_st):
131141
expected_halving = "outbreak **halves the infections every 51.9** days"
132-
header_test_helper(expected_doubling, MODEL, PARAM)
133-
header_test_helper(expected_halving, HALVING_MODEL, HALVING_PARAM)
142+
header_test_helper(expected_halving, halving_model, halving_param, mock_st)
134143

135144

136-
def test_growth_rate():
145+
def test_growth_rate(model, param, mock_st):
137146
initial_growth = "and daily growth rate of **12.25%**."
147+
header_test_helper(initial_growth, model, param, mock_st)
148+
138149
mitigated_growth = "and daily growth rate of **9.34%**."
139-
mitigated_halving = "and daily growth rate of **-1.33%**."
140-
header_test_helper(initial_growth, MODEL, PARAM)
141-
header_test_helper(mitigated_growth, MODEL, PARAM)
142-
header_test_helper(mitigated_halving, HALVING_MODEL, HALVING_PARAM)
150+
header_test_helper(mitigated_growth, model, param, mock_st)
143151

144-
145-
st.cleanup()
152+
def test_growth_rate_halving(halving_model, halving_param, mock_st):
153+
mitigated_halving = "and daily growth rate of **-1.33%**."
154+
header_test_helper(mitigated_halving, halving_model, halving_param, mock_st)
146155

147156

148157
@pytest.mark.xfail()
149-
def test_header_fail():
158+
def test_header_fail(mock_st):
150159
"""
151160
Just proving to myself that these tests work
152161
"""
153162
some_garbage = "ajskhlaeHFPIQONOI8QH34TRNAOP8ESYAW4"
154-
display_header(st, PARAM)
163+
display_header(mock_st, param)
155164
assert len(
156-
list(filter(lambda s: some_garbage in s, st.render_store))
165+
list(filter(lambda s: some_garbage in s, mock_st.render_store))
157166
), "This should fail"
158-
st.cleanup()
159167

160168

161169
def test_defaults_repr():
162170
"""
163171
Test DEFAULTS.repr
164172
"""
165173
repr(DEFAULTS)
174+
# TODO: Add assertions here
166175

167176

168177
# Test the math
@@ -240,9 +249,7 @@ def test_sim_sir():
240249

241250
assert isinstance(raw_df, pd.DataFrame)
242251

243-
244-
def test_admits_chart():
245-
admits_df = pd.read_csv("tests/by_doubling_time/2020-03-28_projected_admits.csv")
252+
def test_admits_chart(admits_df):
246253
chart = build_admits_chart(alt=alt, admits_df=admits_df)
247254
assert isinstance(chart, (alt.Chart, alt.LayerChart))
248255
assert round(chart.data.iloc[40].icu, 0) == 39
@@ -251,9 +258,7 @@ def test_admits_chart():
251258
with pytest.raises(TypeError):
252259
build_admits_chart()
253260

254-
255-
def test_census_chart():
256-
census_df = pd.read_csv("tests/by_doubling_time/2020-03-28_projected_census.csv")
261+
def test_census_chart(census_df):
257262
chart = build_census_chart(alt=alt, census_df=census_df)
258263
assert isinstance(chart, (alt.Chart, alt.LayerChart))
259264
assert chart.data.iloc[1].hospitalized == 3
@@ -264,10 +269,8 @@ def test_census_chart():
264269
build_census_chart()
265270

266271

267-
def test_model():
272+
def test_model(model, param):
268273
# test the Model
269-
param = copy(PARAM)
270-
model = Model(param)
271274

272275
assert round(model.infected, 0) == 45810.0
273276
assert isinstance(model.infected, float) # based off note in models.py
@@ -282,9 +285,7 @@ def test_model():
282285
assert model.doubling_time_t == 7.764405988534983
283286

284287

285-
def test_model_raw_start():
286-
param = copy(PARAM)
287-
model = Model(param)
288+
def test_model_raw_start(model, param):
288289
raw_df = model.raw_df
289290

290291
# test the things n_days creates, which in turn tests sim_sir, sir, and get_dispositions
@@ -305,47 +306,37 @@ def test_model_raw_start():
305306
assert [round(v, 0) for v in (d, s, i, r)] == [17, 549.0, 220.0, 110.0]
306307

307308

308-
def test_model_conservation():
309-
p = copy(PARAM)
310-
m = Model(p)
311-
raw_df = m.raw_df
309+
def test_model_conservation(param, model):
310+
raw_df = model.raw_df
312311

313312
assert (0.0 <= raw_df.susceptible).all()
314313
assert (0.0 <= raw_df.infected).all()
315314
assert (0.0 <= raw_df.recovered).all()
316315

317-
diff = raw_df.susceptible + raw_df.infected + raw_df.recovered - p.population
316+
diff = raw_df.susceptible + raw_df.infected + raw_df.recovered - param.population
318317
assert (diff < 0.1).all()
319318

320-
assert (raw_df.susceptible <= p.population).all()
321-
assert (raw_df.infected <= p.population).all()
322-
assert (raw_df.recovered <= p.population).all()
319+
assert (raw_df.susceptible <= param.population).all()
320+
assert (raw_df.infected <= param.population).all()
321+
assert (raw_df.recovered <= param.population).all()
323322

324323

325-
def test_model_raw_end():
326-
param = copy(PARAM)
327-
model = Model(param)
324+
def test_model_raw_end(param, model):
328325
raw_df = model.raw_df
329-
330326
last = raw_df.iloc[-1, :]
331327
assert round(last.susceptible, 0) == 83391.0
332328

333329

334-
def test_model_monotonicity():
335-
param = copy(PARAM)
336-
model = Model(param)
330+
def test_model_monotonicity(param, model):
337331
raw_df = model.raw_df
338332

339333
# Susceptible population should be non-increasing, and Recovered non-decreasing
340334
assert (raw_df.susceptible[1:] - raw_df.susceptible.shift(1)[1:] <= 0).all()
341335
assert (raw_df.recovered [1:] - raw_df.recovered. shift(1)[1:] >= 0).all()
342336

343337

344-
def test_model_cumulative_census():
338+
def test_model_cumulative_census(param, model):
345339
# test that census is being properly calculated
346-
param = copy(PARAM)
347-
model = Model(param)
348-
349340
raw_df = model.raw_df
350341
admits_df = model.admits_df
351342
df = pd.DataFrame({
@@ -368,13 +359,7 @@ def test_growth_rate():
368359
assert np.round(get_growth_rate(-4) * 100.0, decimals=4) == -15.9104
369360

370361

371-
def test_build_descriptions():
372-
param = copy(PARAM)
373-
374-
admits_file = 'tests/by_doubling_time/2020-03-28_projected_admits.csv'
375-
census_file = 'tests/by_doubling_time/2020-03-28_projected_census.csv'
376-
377-
admits_df = pd.read_csv(admits_file, parse_dates=['date'])
362+
def test_build_descriptions(admits_df, param):
378363
chart = build_admits_chart(alt=alt, admits_df=admits_df)
379364
description = build_descriptions(chart=chart, labels=param.labels)
380365

@@ -385,17 +370,15 @@ def test_build_descriptions():
385370

386371
# TODO add test for asterisk
387372

388-
# test no asterisk
373+
def test_no_asterisk(admits_df, param):
389374
param.n_days = 600
390375

391-
admits_df = pd.read_csv(admits_file, parse_dates=['date'])
392376
chart = build_admits_chart(alt=alt, admits_df=admits_df)
393377
description = build_descriptions(chart=chart, labels=param.labels)
394378
assert "*" not in description
395379

396380

397-
# census chart
398-
census_df = pd.read_csv(census_file, parse_dates=['date'])
381+
def test_census(census_df, param):
399382
chart = build_census_chart(alt=alt, census_df=census_df)
400383
description = build_descriptions(chart=chart, labels=param.labels)
401384

0 commit comments

Comments
 (0)