Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions packages/importation/src/importation/perkins_et_al_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,23 @@ def prob_undetected_infections(
"n_undetected_infections": n_undetected,
"weight": pmf_prob,
}
).with_columns(
(pl.col("weight") / pl.col("weight").sum()).alias("probability")
)

if prob_data.select(pl.sum("weight").eq(0)).item():
return prob_data.with_columns(
pl.lit(1.0 / prob_data.height).alias("probability")
)
else:
return prob_data.with_columns(
(pl.col("weight").log() - pl.sum("weight").log())
.exp()
.alias("probability")
)
else:
raise ValueError(
"Calculating the probability of observing n undetected infections given known cases and deaths requires one parameter set in prop_ascf."
)

return prob_data


def sample_undetected_infections(
known_cases: int,
Expand Down
81 changes: 81 additions & 0 deletions packages/importation/tests/test_perkins_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,70 @@ def test_prob_undetected_infections_list(dummy_values):
)


def test_prob_undetected_infections_zero_probability_value(dummy_values):
prop_ascf = get_prop_ascf(importation_parameters=dummy_values)
known_cases = 500
known_deaths = 2
n_undetected = 1

prob_data = prob_undetected_infections(
n_undetected, known_cases, known_deaths, prop_ascf
)

assert prob_data.shape[0] == 1
assert all(
col in prob_data.columns
for col in ["n_undetected_infections", "probability"]
)
assert prob_data.item(0, "probability") == 1.0
assert prob_data.item(0, "weight") == 0.0


def test_prob_undetected_infections_rounding_zero_probability_list(
dummy_values,
):
# Expect a large number of undetected, such that a low number of undetected is prob 0 but a slightly higher number of undetected p > 0 and p << 1e-6
prop_ascf = get_prop_ascf(
importation_parameters={
"symptomatic_reporting_prob": 0.002,
"case_fatality_ratio": 0.002,
"proportion_asymptomatic": 0.99,
}
)
known_cases = 100
known_deaths = 2
n_undetected = list(range(20_000))

prob_data = prob_undetected_infections(
n_undetected, known_cases, known_deaths, prop_ascf
)

print(prob_data)

assert prob_data.shape[0] == 20_000
assert all(
col in prob_data.columns
for col in ["n_undetected_infections", "probability"]
)

zero_undetected_infections_info = prob_data.filter(
pl.col("n_undetected_infections") == 0
)
max_undetected_infections_info = prob_data.filter(
pl.col("n_undetected_infections") == pl.max("n_undetected_infections")
)

assert zero_undetected_infections_info.select("probability").item() == 0.0
assert max_undetected_infections_info.select("probability").item() > 0.0

assert zero_undetected_infections_info.select("weight").item() == 0.0
assert max_undetected_infections_info.select("weight").item() > 0.0
assert max_undetected_infections_info.select("weight").item() < 1e-12

assert prob_data.select(pl.sum("weight")).item() < 1e-12
assert prob_data.select(pl.sum("probability")).item() == pytest.approx(1.0)


def test_sample_undetected_infections(dummy_values):
prop_ascf = get_prop_ascf(importation_parameters=dummy_values)
known_cases = 5
Expand All @@ -177,6 +241,23 @@ def test_sample_undetected_infections(dummy_values):
)


def test_sample_undetected_infections_zero_handling(dummy_values):
prop_ascf = get_prop_ascf(importation_parameters=dummy_values)
known_cases = 500
known_deaths = 2
max_infections = 502
seed = 42

sampled_data = sample_undetected_infections(
known_cases, known_deaths, prop_ascf, max_infections, seed
)

assert sampled_data.shape[0] == 1
assert all(
col in sampled_data.columns for col in ["n_undetected_infections"]
)


@pytest.fixture
def mock_sample_undetected_infections(dummy_values):
with patch(
Expand Down