Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
680bb1e
Merge branch 'main' of https://github.com/CDCgov/PyRenew
cdc-mitzimorris Sep 15, 2025
2cb876b
update
cdc-mitzimorris Sep 18, 2025
60db8df
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Sep 22, 2025
32a5314
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 5, 2025
d6213f2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 8, 2025
96f27c9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 17, 2025
1cb6fa2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 24, 2025
f62e1e4
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 4, 2025
0c6785d
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 22, 2025
1ee62b9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Jan 29, 2026
0629461
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 4, 2026
efeadee
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
371ba98
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
0304bed
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 6, 2026
ffeea65
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
50e7261
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
dae6af8
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 10, 2026
6fff9d5
Improve unit tests: replace coverage-driven tests with behavioral tes…
cdc-mitzimorris Feb 10, 2026
ca97e0b
more unit tests
cdc-mitzimorris Feb 10, 2026
50f894b
Merge main (PR 698) into mem_690_unit_test_improvements
cdc-mitzimorris Feb 11, 2026
f02efb0
improve test coverage
cdc-mitzimorris Feb 11, 2026
4f5534a
changes by bot per code review by another bot - see PR conversation
cdc-mitzimorris Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 63 additions & 225 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
import pytest

from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.observation import Counts, NegativeBinomialNoise
from pyrenew.latent import AR1, HierarchicalInfections, RandomWalk
from pyrenew.observation import (
Counts,
HierarchicalNormalNoise,
NegativeBinomialNoise,
VectorizedRV,
)
from pyrenew.randomvariable import DistributionalVariable

# =============================================================================
Expand Down Expand Up @@ -44,32 +50,6 @@ def short_delay_pmf():
return jnp.array([0.5, 0.5])


@pytest.fixture
def medium_delay_pmf():
"""
Medium 4-day delay PMF.

Returns
-------
jnp.ndarray
A 4-element PMF array.
"""
return jnp.array([0.1, 0.3, 0.4, 0.2])


@pytest.fixture
def realistic_delay_pmf():
"""
Realistic 10-day delay PMF (shifted gamma-like).

Returns
-------
jnp.ndarray
A 10-element PMF array with gamma-like shape.
"""
return jnp.array([0.01, 0.05, 0.10, 0.15, 0.20, 0.20, 0.15, 0.08, 0.04, 0.02])


@pytest.fixture
def long_delay_pmf():
"""
Expand All @@ -83,19 +63,6 @@ def long_delay_pmf():
return jnp.array([0.05, 0.1, 0.15, 0.2, 0.2, 0.15, 0.1, 0.03, 0.01, 0.01])


@pytest.fixture
def simple_shedding_pmf():
"""
Simple 1-day shedding PMF (no delay).

Returns
-------
jnp.ndarray
A single-element PMF array representing no shedding delay.
"""
return jnp.array([1.0])


@pytest.fixture
def short_shedding_pmf():
"""
Expand All @@ -109,77 +76,98 @@ def short_shedding_pmf():
return jnp.array([0.3, 0.4, 0.3])


# =============================================================================
# Generation Interval Fixture
# =============================================================================


@pytest.fixture
def medium_shedding_pmf():
def gen_int_rv():
"""
Medium 5-day shedding PMF.
COVID-like generation interval (7-day PMF).

Returns
-------
jnp.ndarray
A 5-element PMF array.
DeterministicPMF
Generation interval random variable.
"""
return jnp.array([0.1, 0.3, 0.3, 0.2, 0.1])
pmf = jnp.array([0.16, 0.32, 0.25, 0.14, 0.07, 0.04, 0.02])
return DeterministicPMF("gen_int", pmf)


# =============================================================================
# Sensor Prior Fixtures
# Noise Fixtures
# =============================================================================


@pytest.fixture
def sensor_mode_rv():
def hierarchical_normal_noise():
"""
Standard normal prior for sensor modes.
Standard HierarchicalNormalNoise with VectorizedRV wrappers.

Returns
-------
DistributionalVariable
A normal prior with standard deviation 0.5.
HierarchicalNormalNoise
Noise model for continuous measurements.
"""
return DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5))
sensor_mode_rv = VectorizedRV(
DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)),
plate_name="sensor_mode",
)
sensor_sd_rv = VectorizedRV(
DistributionalVariable(
"ww_sensor_sd", dist.TruncatedNormal(0.3, 0.15, low=0.10)
),
plate_name="sensor_sd",
)
return HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv)


@pytest.fixture
def sensor_mode_rv_tight():
def hierarchical_normal_noise_tight():
"""
Tight normal prior for deterministic-like behavior.
Tight HierarchicalNormalNoise for near-deterministic testing.

Returns
-------
DistributionalVariable
A normal prior with small standard deviation 0.01.
HierarchicalNormalNoise
Noise model with very small variance.
"""
return DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.01))
sensor_mode_rv = VectorizedRV(
DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.01)),
plate_name="sensor_mode",
)
sensor_sd_rv = VectorizedRV(
DistributionalVariable(
"ww_sensor_sd", dist.TruncatedNormal(0.01, 0.005, low=0.001)
),
plate_name="sensor_sd",
)
return HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv)


@pytest.fixture
def sensor_sd_rv():
"""
Standard truncated normal prior for sensor standard deviations.

Returns
-------
DistributionalVariable
A truncated normal prior for sensor standard deviations.
"""
return DistributionalVariable(
"ww_sensor_sd", dist.TruncatedNormal(0.3, 0.15, low=0.10)
)
# =============================================================================
# Hierarchical Infections Fixture
# =============================================================================


@pytest.fixture
def sensor_sd_rv_tight():
def hierarchical_infections(gen_int_rv):
"""
Tight truncated normal prior for deterministic-like behavior.
Pre-configured HierarchicalInfections instance.

Returns
-------
DistributionalVariable
A truncated normal prior with small scale for tight behavior.
HierarchicalInfections
Configured infection process with realistic parameters.
"""
return DistributionalVariable(
"ww_sensor_sd", dist.TruncatedNormal(0.01, 0.005, low=0.005)
return HierarchicalInfections(
gen_int_rv=gen_int_rv,
I0_rv=DeterministicVariable("I0", 0.001),
initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0),
baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05),
subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025),
n_initialization_points=7,
)


Expand All @@ -206,42 +194,6 @@ def counts_process(simple_delay_pmf):
)


@pytest.fixture
def counts_process_medium_delay(medium_delay_pmf):
"""
Counts observation process with medium delay.

Returns
-------
Counts
A Counts observation process with 4-day delay.
"""
return Counts(
name="test_counts",
ascertainment_rate_rv=DeterministicVariable("ihr", 0.01),
delay_distribution_rv=DeterministicPMF("delay", medium_delay_pmf),
noise=NegativeBinomialNoise(DeterministicVariable("conc", 50.0)),
)


@pytest.fixture
def counts_process_realistic(realistic_delay_pmf):
"""
Counts observation process with realistic delay and ascertainment.

Returns
-------
Counts
A Counts observation process with realistic parameters.
"""
return Counts(
name="test_counts",
ascertainment_rate_rv=DeterministicVariable("ihr", 0.005),
delay_distribution_rv=DeterministicPMF("delay", realistic_delay_pmf),
noise=NegativeBinomialNoise(DeterministicVariable("conc", 100.0)),
)


class CountsProcessFactory:
"""Factory for creating Counts processes with custom parameters."""

Expand Down Expand Up @@ -286,117 +238,3 @@ def counts_factory():
# =============================================================================
# Infection Fixtures
# =============================================================================


@pytest.fixture
def constant_infections():
"""
Constant infections array (30 days, 100 infections/day).

Returns
-------
jnp.ndarray
A 1D array of shape (30,) with constant value 100.
"""
return jnp.ones(30) * 100


@pytest.fixture
def constant_infections_2d():
"""
Constant infections array for 2 subpopulations.

Returns
-------
jnp.ndarray
A 2D array of shape (30, 2) with constant value 100.
"""
return jnp.ones((30, 2)) * 100


def make_infections(n_days, n_subpops=None, value=100.0):
"""
Create infection arrays for testing.

Parameters
----------
n_days : int
Number of days
n_subpops : int, optional
Number of subpopulations (None for 1D array)
value : float
Constant infection value

Returns
-------
jnp.ndarray
Infections array
"""
if n_subpops is None:
return jnp.ones(n_days) * value
return jnp.ones((n_days, n_subpops)) * value


def make_spike_infections(n_days, spike_day, spike_value=1000.0, n_subpops=None):
"""
Create spike infection arrays for testing.

Parameters
----------
n_days : int
Number of days
spike_day : int
Day of the spike
spike_value : float
Value at spike
n_subpops : int, optional
Number of subpopulations

Returns
-------
jnp.ndarray
Infections array with spike
"""
if n_subpops is None:
infections = jnp.zeros(n_days)
return infections.at[spike_day].set(spike_value)
infections = jnp.zeros((n_days, n_subpops))
return infections.at[spike_day, :].set(spike_value)


def create_mock_infections(
n_days: int,
peak_day: int = 10,
peak_value: float = 1000.0,
shape: str = "spike",
) -> jnp.ndarray:
"""
Create mock infection time series for testing.

Parameters
----------
n_days : int
Number of days
peak_day : int
Day of peak infections
peak_value : float
Peak infection value
shape : str
Shape of the curve: "spike", "constant", or "decay"

Returns
-------
jnp.ndarray
Array of infections of shape (n_days,)
"""
if shape == "spike":
infections = jnp.zeros(n_days)
infections = infections.at[peak_day].set(peak_value)
elif shape == "constant":
infections = jnp.ones(n_days) * peak_value
elif shape == "decay":
infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0)
else:
raise ValueError(f"Unknown shape: {shape}")

return infections
Loading
Loading