1
1
"""Tests."""
2
2
3
+ from copy import copy
3
4
from math import ceil # type: ignore
4
5
from datetime import date , datetime # type: ignore
5
6
import pytest # type: ignore
@@ -264,8 +265,10 @@ def test_census_chart():
264
265
build_census_chart ()
265
266
266
267
267
- def test_model (model = MODEL , param = PARAM ):
268
+ def test_model ():
268
269
# test the Model
270
+ param = copy (PARAM )
271
+ model = Model (param )
269
272
270
273
assert round (model .infected , 0 ) == 45810.0
271
274
assert isinstance (model .infected , float ) # based off note in models.py
@@ -280,7 +283,9 @@ def test_model(model=MODEL, param=PARAM):
280
283
assert model .doubling_time_t == 7.764405988534983
281
284
282
285
283
- def test_model_raw_start (model = MODEL , param = PARAM ):
286
+ def test_model_raw_start ():
287
+ param = copy (PARAM )
288
+ model = Model (param )
284
289
raw_df = model .raw_df
285
290
286
291
# test the things n_days creates, which in turn tests sim_sir, sir, and get_dispositions
@@ -302,16 +307,21 @@ def test_model_raw_start(model=MODEL, param=PARAM):
302
307
assert [round (v , 0 ) for v in (d , s , i , r )] == [22 , 1101.0 , 441.0 , 220.0 ]
303
308
304
309
305
- def test_model_raw_end (model = MODEL , param = PARAM ):
310
+ def test_model_raw_end ():
311
+ param = copy (PARAM )
312
+ model = Model (param )
306
313
raw_df = model .raw_df
307
314
308
315
last = raw_df .iloc [- 1 , :]
309
316
assert last .susceptible + last .infected + last .recovered == param .population
310
317
assert round (last .susceptible , 0 ) == 83391.0
311
318
312
319
313
- def test_model_cumulative_census (model = MODEL ):
320
+ def test_model_cumulative_census ():
314
321
# test that census is being properly calculated
322
+ param = copy (PARAM )
323
+ model = Model (param )
324
+
315
325
raw_df = model .raw_df
316
326
admits_df = model .admits_df
317
327
df = pd .DataFrame ({
@@ -334,13 +344,15 @@ def test_growth_rate():
334
344
assert np .round (get_growth_rate (- 4 ) * 100.0 , decimals = 4 ) == - 15.9104
335
345
336
346
337
- def test_build_descriptions (p = PARAM ):
347
+ def test_build_descriptions ():
348
+ param = copy (PARAM )
349
+
338
350
admits_file = 'tests/by_doubling_time/2020-03-28_projected_admits.csv'
339
351
census_file = 'tests/by_doubling_time/2020-03-28_projected_census.csv'
340
352
341
353
admits_df = pd .read_csv (admits_file , parse_dates = ['date' ])
342
354
chart = build_admits_chart (alt = alt , admits_df = admits_df )
343
- description = build_descriptions (chart = chart , labels = p .labels )
355
+ description = build_descriptions (chart = chart , labels = param .labels )
344
356
345
357
hosp , icu , vent = description .split ("\n \n " ) # break out the description into lines
346
358
@@ -349,22 +361,19 @@ def test_build_descriptions(p=PARAM):
349
361
350
362
# TODO add test for asterisk
351
363
352
-
353
364
# test no asterisk
354
- param = PARAM
355
365
param .n_days = 600
356
366
357
367
admits_df = pd .read_csv (admits_file , parse_dates = ['date' ])
358
368
chart = build_admits_chart (alt = alt , admits_df = admits_df )
359
- description = build_descriptions (chart = chart , labels = p .labels )
369
+ description = build_descriptions (chart = chart , labels = param .labels )
360
370
assert "*" not in description
361
371
362
372
363
373
# census chart
364
374
census_df = pd .read_csv (census_file , parse_dates = ['date' ])
365
- PARAM .as_date = True
366
375
chart = build_census_chart (alt = alt , census_df = census_df )
367
- description = build_descriptions (chart = chart , labels = p .labels )
376
+ description = build_descriptions (chart = chart , labels = param .labels )
368
377
369
378
assert str (ceil (chart .data ['ventilated' ].max ())) in description
370
379
assert str (chart .data ['icu' ].idxmax ()) not in description
0 commit comments