Skip to content

Commit 38c2901

Browse files
committed
Copy params for tests
1 parent d2d06f9 commit 38c2901

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

tests/test_app.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests."""
22

3+
from copy import copy
34
from math import ceil # type: ignore
45
from datetime import date, datetime # type: ignore
56
import pytest # type: ignore
@@ -264,8 +265,10 @@ def test_census_chart():
264265
build_census_chart()
265266

266267

267-
def test_model(model=MODEL, param=PARAM):
268+
def test_model():
268269
# test the Model
270+
param = copy(PARAM)
271+
model = Model(param)
269272

270273
assert round(model.infected, 0) == 45810.0
271274
assert isinstance(model.infected, float) # based off note in models.py
@@ -280,7 +283,9 @@ def test_model(model=MODEL, param=PARAM):
280283
assert model.doubling_time_t == 7.764405988534983
281284

282285

283-
def test_model_raw_start(model=MODEL, param=PARAM):
286+
def test_model_raw_start():
287+
param = copy(PARAM)
288+
model = Model(param)
284289
raw_df = model.raw_df
285290

286291
# test the things n_days creates, which in turn tests sim_sir, sir, and get_dispositions
@@ -302,16 +307,21 @@ def test_model_raw_start(model=MODEL, param=PARAM):
302307
assert [round(v, 0) for v in (d, s, i, r)] == [22, 1101.0, 441.0, 220.0]
303308

304309

305-
def test_model_raw_end(model=MODEL, param=PARAM):
310+
def test_model_raw_end():
311+
param = copy(PARAM)
312+
model = Model(param)
306313
raw_df = model.raw_df
307314

308315
last = raw_df.iloc[-1, :]
309316
assert last.susceptible + last.infected + last.recovered == param.population
310317
assert round(last.susceptible, 0) == 83391.0
311318

312319

313-
def test_model_cumulative_census(model=MODEL):
320+
def test_model_cumulative_census():
314321
# test that census is being properly calculated
322+
param = copy(PARAM)
323+
model = Model(param)
324+
315325
raw_df = model.raw_df
316326
admits_df = model.admits_df
317327
df = pd.DataFrame({
@@ -334,13 +344,15 @@ def test_growth_rate():
334344
assert np.round(get_growth_rate(-4) * 100.0, decimals=4) == -15.9104
335345

336346

337-
def test_build_descriptions(p=PARAM):
347+
def test_build_descriptions():
348+
param = copy(PARAM)
349+
338350
admits_file = 'tests/by_doubling_time/2020-03-28_projected_admits.csv'
339351
census_file = 'tests/by_doubling_time/2020-03-28_projected_census.csv'
340352

341353
admits_df = pd.read_csv(admits_file, parse_dates=['date'])
342354
chart = build_admits_chart(alt=alt, admits_df=admits_df)
343-
description = build_descriptions(chart=chart, labels=p.labels)
355+
description = build_descriptions(chart=chart, labels=param.labels)
344356

345357
hosp, icu, vent = description.split("\n\n") # break out the description into lines
346358

@@ -349,22 +361,19 @@ def test_build_descriptions(p=PARAM):
349361

350362
# TODO add test for asterisk
351363

352-
353364
# test no asterisk
354-
param = PARAM
355365
param.n_days = 600
356366

357367
admits_df = pd.read_csv(admits_file, parse_dates=['date'])
358368
chart = build_admits_chart(alt=alt, admits_df=admits_df)
359-
description = build_descriptions(chart=chart, labels=p.labels)
369+
description = build_descriptions(chart=chart, labels=param.labels)
360370
assert "*" not in description
361371

362372

363373
# census chart
364374
census_df = pd.read_csv(census_file, parse_dates=['date'])
365-
PARAM.as_date = True
366375
chart = build_census_chart(alt=alt, census_df=census_df)
367-
description = build_descriptions(chart=chart, labels=p.labels)
376+
description = build_descriptions(chart=chart, labels=param.labels)
368377

369378
assert str(ceil(chart.data['ventilated'].max())) in description
370379
assert str(chart.data['icu'].idxmax()) not in description

0 commit comments

Comments
 (0)