From 6fff9d5eed309e3faeba31bdbda7931075536b0d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 10 Feb 2026 14:17:17 -0500 Subject: [PATCH 1/4] Improve unit tests: replace coverage-driven tests with behavioral tests (#690) Remove 28 low-value tests (repr checks, no-op validates, shape-only assertions, private method tests) and add 12 behavioral tests that validate real-world behavior. Consolidate shared fixtures in conftest.py and extract reusable test helpers to test/test_helpers.py. Co-Authored-By: Claude Opus 4.6 --- test/conftest.py | 251 +++++------------- test/test_base_latent_process.py | 77 +----- test/test_helpers.py | 97 +++++++ test/test_hierarchical_infections.py | 150 +++++------ test/test_hierarchical_priors.py | 58 ++--- test/test_observation_counts.py | 341 ++++++------------------- test/test_observation_measurements.py | 352 ++++++++------------------ test/test_pyrenew_builder.py | 199 +++------------ test/test_temporal_processes.py | 211 ++++++--------- 9 files changed, 543 insertions(+), 1193 deletions(-) create mode 100644 test/test_helpers.py diff --git a/test/conftest.py b/test/conftest.py index 381ee056..22318c62 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,7 +10,13 @@ import pytest from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.observation import Counts, NegativeBinomialNoise +from pyrenew.latent import AR1, HierarchicalInfections, RandomWalk +from pyrenew.observation import ( + Counts, + HierarchicalNormalNoise, + NegativeBinomialNoise, + VectorizedRV, +) from pyrenew.randomvariable import DistributionalVariable # ============================================================================= @@ -44,19 +50,6 @@ def short_delay_pmf(): return jnp.array([0.5, 0.5]) -@pytest.fixture -def medium_delay_pmf(): - """ - Medium 4-day delay PMF. - - Returns - ------- - jnp.ndarray - A 4-element PMF array. - """ - return jnp.array([0.1, 0.3, 0.4, 0.2]) - - @pytest.fixture def realistic_delay_pmf(): """ @@ -83,19 +76,6 @@ def long_delay_pmf(): return jnp.array([0.05, 0.1, 0.15, 0.2, 0.2, 0.15, 0.1, 0.03, 0.01, 0.01]) -@pytest.fixture -def simple_shedding_pmf(): - """ - Simple 1-day shedding PMF (no delay). - - Returns - ------- - jnp.ndarray - A single-element PMF array representing no shedding delay. - """ - return jnp.array([1.0]) - - @pytest.fixture def short_shedding_pmf(): """ @@ -109,77 +89,98 @@ def short_shedding_pmf(): return jnp.array([0.3, 0.4, 0.3]) +# ============================================================================= +# Generation Interval Fixture +# ============================================================================= + + @pytest.fixture -def medium_shedding_pmf(): +def gen_int_rv(): """ - Medium 5-day shedding PMF. + COVID-like generation interval (7-day PMF). Returns ------- - jnp.ndarray - A 5-element PMF array. + DeterministicPMF + Generation interval random variable. """ - return jnp.array([0.1, 0.3, 0.3, 0.2, 0.1]) + pmf = jnp.array([0.16, 0.32, 0.25, 0.14, 0.07, 0.04, 0.02]) + return DeterministicPMF("gen_int", pmf) # ============================================================================= -# Sensor Prior Fixtures +# Noise Fixtures # ============================================================================= @pytest.fixture -def sensor_mode_rv(): +def hierarchical_normal_noise(): """ - Standard normal prior for sensor modes. + Standard HierarchicalNormalNoise with VectorizedRV wrappers. Returns ------- - DistributionalVariable - A normal prior with standard deviation 0.5. + HierarchicalNormalNoise + Noise model for continuous measurements. """ - return DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)) + sensor_mode_rv = VectorizedRV( + DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)), + plate_name="sensor_mode", + ) + sensor_sd_rv = VectorizedRV( + DistributionalVariable( + "ww_sensor_sd", dist.TruncatedNormal(0.3, 0.15, low=0.10) + ), + plate_name="sensor_sd", + ) + return HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) @pytest.fixture -def sensor_mode_rv_tight(): +def hierarchical_normal_noise_tight(): """ - Tight normal prior for deterministic-like behavior. + Tight HierarchicalNormalNoise for near-deterministic testing. Returns ------- - DistributionalVariable - A normal prior with small standard deviation 0.01. + HierarchicalNormalNoise + Noise model with very small variance. """ - return DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.01)) - + sensor_mode_rv = VectorizedRV( + DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.01)), + plate_name="sensor_mode", + ) + sensor_sd_rv = VectorizedRV( + DistributionalVariable( + "ww_sensor_sd", dist.TruncatedNormal(0.01, 0.005, low=0.001) + ), + plate_name="sensor_sd", + ) + return HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) -@pytest.fixture -def sensor_sd_rv(): - """ - Standard truncated normal prior for sensor standard deviations. - Returns - ------- - DistributionalVariable - A truncated normal prior for sensor standard deviations. - """ - return DistributionalVariable( - "ww_sensor_sd", dist.TruncatedNormal(0.3, 0.15, low=0.10) - ) +# ============================================================================= +# Hierarchical Infections Fixture +# ============================================================================= @pytest.fixture -def sensor_sd_rv_tight(): +def hierarchical_infections(gen_int_rv): """ - Tight truncated normal prior for deterministic-like behavior. + Pre-configured HierarchicalInfections instance. Returns ------- - DistributionalVariable - A truncated normal prior with small scale for tight behavior. - """ - return DistributionalVariable( - "ww_sensor_sd", dist.TruncatedNormal(0.01, 0.005, low=0.005) + HierarchicalInfections + Configured infection process with realistic parameters. + """ + return HierarchicalInfections( + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), + n_initialization_points=7, ) @@ -206,42 +207,6 @@ def counts_process(simple_delay_pmf): ) -@pytest.fixture -def counts_process_medium_delay(medium_delay_pmf): - """ - Counts observation process with medium delay. - - Returns - ------- - Counts - A Counts observation process with 4-day delay. - """ - return Counts( - name="test_counts", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", medium_delay_pmf), - noise=NegativeBinomialNoise(DeterministicVariable("conc", 50.0)), - ) - - -@pytest.fixture -def counts_process_realistic(realistic_delay_pmf): - """ - Counts observation process with realistic delay and ascertainment. - - Returns - ------- - Counts - A Counts observation process with realistic parameters. - """ - return Counts( - name="test_counts", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.005), - delay_distribution_rv=DeterministicPMF("delay", realistic_delay_pmf), - noise=NegativeBinomialNoise(DeterministicVariable("conc", 100.0)), - ) - - class CountsProcessFactory: """Factory for creating Counts processes with custom parameters.""" @@ -312,91 +277,3 @@ def constant_infections_2d(): A 2D array of shape (30, 2) with constant value 100. """ return jnp.ones((30, 2)) * 100 - - -def make_infections(n_days, n_subpops=None, value=100.0): - """ - Create infection arrays for testing. - - Parameters - ---------- - n_days : int - Number of days - n_subpops : int, optional - Number of subpopulations (None for 1D array) - value : float - Constant infection value - - Returns - ------- - jnp.ndarray - Infections array - """ - if n_subpops is None: - return jnp.ones(n_days) * value - return jnp.ones((n_days, n_subpops)) * value - - -def make_spike_infections(n_days, spike_day, spike_value=1000.0, n_subpops=None): - """ - Create spike infection arrays for testing. - - Parameters - ---------- - n_days : int - Number of days - spike_day : int - Day of the spike - spike_value : float - Value at spike - n_subpops : int, optional - Number of subpopulations - - Returns - ------- - jnp.ndarray - Infections array with spike - """ - if n_subpops is None: - infections = jnp.zeros(n_days) - return infections.at[spike_day].set(spike_value) - infections = jnp.zeros((n_days, n_subpops)) - return infections.at[spike_day, :].set(spike_value) - - -def create_mock_infections( - n_days: int, - peak_day: int = 10, - peak_value: float = 1000.0, - shape: str = "spike", -) -> jnp.ndarray: - """ - Create mock infection time series for testing. - - Parameters - ---------- - n_days : int - Number of days - peak_day : int - Day of peak infections - peak_value : float - Peak infection value - shape : str - Shape of the curve: "spike", "constant", or "decay" - - Returns - ------- - jnp.ndarray - Array of infections of shape (n_days,) - """ - if shape == "spike": - infections = jnp.zeros(n_days) - infections = infections.at[peak_day].set(peak_value) - elif shape == "constant": - infections = jnp.ones(n_days) * peak_value - elif shape == "decay": - infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) - else: - raise ValueError(f"Unknown shape: {shape}") - - return infections diff --git a/test/test_base_latent_process.py b/test/test_base_latent_process.py index 56a5ea4c..7e6a9461 100644 --- a/test/test_base_latent_process.py +++ b/test/test_base_latent_process.py @@ -5,9 +5,8 @@ import jax.numpy as jnp import pytest -from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.latent import HierarchicalInfections, RandomWalk -from pyrenew.latent.base import BaseLatentInfectionProcess, LatentSample +from pyrenew.deterministic import DeterministicPMF +from pyrenew.latent.base import BaseLatentInfectionProcess class TestPopulationStructureParsing: @@ -26,7 +25,7 @@ def test_rejects_fractions_not_summing_to_one(self): """Test that fractions not summing to 1 raises error.""" with pytest.raises(ValueError, match="must sum to 1.0"): BaseLatentInfectionProcess._parse_and_validate_fractions( - subpop_fractions=jnp.array([0.3, 0.25, 0.40]), # Sum is 0.95 + subpop_fractions=jnp.array([0.3, 0.25, 0.40]), ) def test_rejects_negative_fractions(self): @@ -45,7 +44,7 @@ def test_rejects_2d_fractions(self): """Test that 2D fraction arrays raise error.""" with pytest.raises(ValueError, match="must be a 1D array"): BaseLatentInfectionProcess._parse_and_validate_fractions( - subpop_fractions=jnp.array([[0.3, 0.25, 0.45]]), # 2D array + subpop_fractions=jnp.array([[0.3, 0.25, 0.45]]), ) def test_rejects_empty_subpopulations(self): @@ -62,7 +61,7 @@ class TestBaseLatentInfectionProcessInit: def test_rejects_missing_gen_int_rv(self): """Test that None gen_int_rv is rejected.""" with pytest.raises(ValueError, match="gen_int_rv is required"): - # Create a minimal concrete subclass for testing + class ConcreteLatent(BaseLatentInfectionProcess): def validate(self): pass @@ -87,49 +86,7 @@ def sample(self, n_days_post_init, **kwargs): ConcreteLatent( gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), - n_initialization_points=2, # gen_int has length 3 - ) - - -class TestGetRequiredLookback: - """Test get_required_lookback method.""" - - def test_get_required_lookback_returns_gen_int_length(self): - """Test that get_required_lookback returns generation interval length.""" - gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) - process = HierarchicalInfections( - gen_int_rv=gen_int, - I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - baseline_rt_process=RandomWalk(), - subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=3, - ) - - assert process.get_required_lookback() == 3 - - -class TestValidateOutputShapes: - """Test _validate_output_shapes method.""" - - def test_validate_output_shapes_raises_on_mismatch(self): - """Test that shape validation raises on incorrect shapes.""" - from pyrenew.latent.base import PopulationStructure - - pop = PopulationStructure( - fractions=jnp.array([0.5, 0.5]), - ) - - # Create arrays with wrong shapes - infections_aggregate = jnp.ones(10) # Correct: (10,) - infections_all = jnp.ones((10, 3)) # WRONG: should be (10, 2) - - with pytest.raises(ValueError, match="has incorrect shape"): - BaseLatentInfectionProcess._validate_output_shapes( - infections_aggregate, - infections_all, - n_total_days=10, - pop=pop, + n_initialization_points=2, ) @@ -146,28 +103,6 @@ def test_validate_I0_rejects_zero(self): with pytest.raises(ValueError, match="I0 must be positive"): BaseLatentInfectionProcess._validate_I0(jnp.array(0.0)) - def test_validate_I0_accepts_valid_values(self): - """Test that valid I0 values are accepted.""" - # Should not raise - BaseLatentInfectionProcess._validate_I0(jnp.array(0.001)) - BaseLatentInfectionProcess._validate_I0(jnp.array(1.0)) - BaseLatentInfectionProcess._validate_I0(jnp.array([0.001, 0.002])) - - -class TestLatentSample: - """Test LatentSample named tuple.""" - - def test_latent_sample_unpacking(self): - """Test that LatentSample can be unpacked correctly.""" - sample = LatentSample( - aggregate=jnp.ones(10), - all_subpops=jnp.ones((10, 2)), - ) - - agg, all_s = sample - assert agg.shape == (10,) - assert all_s.shape == (10, 2) - if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_helpers.py b/test/test_helpers.py new file mode 100644 index 00000000..f8748d39 --- /dev/null +++ b/test/test_helpers.py @@ -0,0 +1,97 @@ +""" +Shared test helper classes and functions. + +This module provides reusable classes and functions that are not pytest fixtures. +For pytest fixtures, see conftest.py. +""" + +import jax +import jax.numpy as jnp + +from pyrenew.observation import Measurements + + +class ConcreteMeasurements(Measurements): + """Concrete implementation of Measurements for testing.""" + + def __init__(self, name, temporal_pmf_rv, noise, log10_scale=9.0): + """Initialize the concrete measurements for testing.""" + super().__init__(name=name, temporal_pmf_rv=temporal_pmf_rv, noise=noise) + self.log10_scale = log10_scale + + def validate(self) -> None: + """Validate parameters.""" + pmf = self.temporal_pmf_rv() + self._validate_pmf(pmf, "temporal_pmf_rv") + + def lookback_days(self) -> int: + """ + Return required lookback days for this observation. + + Returns + ------- + int + Length of temporal PMF minus 1. + """ + return len(self.temporal_pmf_rv()) - 1 + + def _predicted_obs(self, infections): + """ + Simple predicted signal: log(convolution * scale). + + Returns + ------- + jnp.ndarray + Log-transformed predicted signal. + """ + pmf = self.temporal_pmf_rv() + + if infections.ndim == 1: + infections = infections[:, jnp.newaxis] + + def convolve_col(col): # numpydoc ignore=GL08 + return self._convolve_with_alignment(col, pmf, 1.0)[0] + + predicted = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + + log_predicted = jnp.log(predicted + 1e-10) + self.log10_scale * jnp.log(10) + + return log_predicted + + +def create_mock_infections( + n_days: int, + peak_day: int = 10, + peak_value: float = 1000.0, + shape: str = "spike", +) -> jnp.ndarray: + """ + Create mock infection time series for testing. + + Parameters + ---------- + n_days : int + Number of days + peak_day : int + Day of peak infections + peak_value : float + Peak infection value + shape : str + Shape of the curve: "spike", "constant", or "decay" + + Returns + ------- + jnp.ndarray + Array of infections of shape (n_days,) + """ + if shape == "spike": + infections = jnp.zeros(n_days) + infections = infections.at[peak_day].set(peak_value) + elif shape == "constant": + infections = jnp.ones(n_days) * peak_value + elif shape == "decay": + infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + else: + raise ValueError(f"Unknown shape: {shape}") + + return infections diff --git a/test/test_hierarchical_infections.py b/test/test_hierarchical_infections.py index 92056ef2..574f5b13 100644 --- a/test/test_hierarchical_infections.py +++ b/test/test_hierarchical_infections.py @@ -6,85 +6,19 @@ import numpyro import pytest -from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import HierarchicalInfections, RandomWalk -@pytest.fixture -def gen_int_rv(): - """ - Create a generation interval random variable. - - Returns - ------- - DeterministicPMF - Generation interval PMF. - """ - return DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) - - -@pytest.fixture -def process(gen_int_rv): - """ - Create a HierarchicalInfections with the new keyword-only API. - - Returns - ------- - HierarchicalInfections - Configured infection process. - """ - return HierarchicalInfections( - gen_int_rv=gen_int_rv, - I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - baseline_rt_process=RandomWalk(), - subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=3, - ) - - class TestHierarchicalInfectionsSample: """Test sample method with population structure at sample time.""" - def test_sample_returns_correct_shapes(self, process): - """Test that sample returns correct output shapes.""" - with numpyro.handlers.seed(rng_seed=42): - result = process.sample( - n_days_post_init=30, - subpop_fractions=jnp.array([0.3, 0.25, 0.45]), - ) - - inf_juris, inf_all = result - n_total = process.n_initialization_points + 30 - - assert inf_juris.shape == (n_total,) - assert inf_all.shape == (n_total, 3) - - def test_same_model_different_jurisdictions(self, process): - """Test that one model can fit different population structures.""" - # Jurisdiction A: 3 subpopulations - with numpyro.handlers.seed(rng_seed=42): - _, inf_all_a = process.sample( - n_days_post_init=30, - subpop_fractions=jnp.array([0.3, 0.25, 0.45]), - ) - - # Jurisdiction B: 5 subpopulations (different structure) - with numpyro.handlers.seed(rng_seed=42): - _, inf_all_b = process.sample( - n_days_post_init=30, - subpop_fractions=jnp.array([0.15, 0.20, 0.25, 0.10, 0.30]), - ) - - assert inf_all_a.shape[1] == 3 # K=3 - assert inf_all_b.shape[1] == 5 # K=5 - - def test_jurisdiction_total_is_weighted_sum(self, process): + def test_jurisdiction_total_is_weighted_sum(self, hierarchical_infections): """Test that jurisdiction total equals weighted sum of subpopulations.""" fractions = jnp.array([0.3, 0.25, 0.45]) with numpyro.handlers.seed(rng_seed=42): - inf_juris, inf_all = process.sample( + inf_juris, inf_all = hierarchical_infections.sample( n_days_post_init=30, subpop_fractions=fractions, ) @@ -93,11 +27,11 @@ def test_jurisdiction_total_is_weighted_sum(self, process): assert jnp.allclose(inf_juris, expected, atol=1e-6) - def test_deviations_sum_to_zero(self, process): + def test_deviations_sum_to_zero(self, hierarchical_infections): """Test that subpopulation deviations sum to zero (identifiability).""" with numpyro.handlers.seed(rng_seed=42): with numpyro.handlers.trace() as trace: - process.sample( + hierarchical_infections.sample( n_days_post_init=30, subpop_fractions=jnp.array([0.3, 0.25, 0.45]), ) @@ -107,6 +41,49 @@ def test_deviations_sum_to_zero(self, process): assert jnp.allclose(deviation_sums, 0.0, atol=1e-6) + def test_infections_are_positive(self, hierarchical_infections): + """Test that all infections are positive (epidemiological invariant).""" + with numpyro.handlers.seed(rng_seed=42): + inf_juris, inf_all = hierarchical_infections.sample( + n_days_post_init=30, + subpop_fractions=jnp.array([0.3, 0.25, 0.45]), + ) + + assert jnp.all(inf_juris > 0) + assert jnp.all(inf_all > 0) + + @pytest.mark.parametrize( + "fractions", + [ + jnp.array([1.0]), + jnp.array([0.3, 0.25, 0.45]), + jnp.array([0.10, 0.14, 0.21, 0.22, 0.07, 0.26]), + ], + ids=["K=1", "K=3", "K=6"], + ) + def test_shape_and_positivity_across_subpop_counts( + self, hierarchical_infections, fractions + ): + """Test correct shapes and positivity for varying numbers of subpops.""" + n_days_post_init = 30 + n_total = hierarchical_infections.n_initialization_points + n_days_post_init + n_subpops = len(fractions) + + with numpyro.handlers.seed(rng_seed=42): + inf_juris, inf_all = hierarchical_infections.sample( + n_days_post_init=n_days_post_init, + subpop_fractions=fractions, + ) + + assert inf_juris.shape == (n_total,) + assert inf_all.shape == (n_total, n_subpops) + assert jnp.all(inf_juris > 0) + assert jnp.all(inf_all > 0) + + # Weighted sum property + expected = jnp.sum(inf_all * fractions[jnp.newaxis, :], axis=1) + assert jnp.allclose(inf_juris, expected, atol=1e-6) + class TestHierarchicalInfectionsValidation: """Test validation of inputs.""" @@ -120,7 +97,7 @@ def test_rejects_missing_I0_rv(self, gen_int_rv): initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=3, + n_initialization_points=7, ) def test_rejects_missing_initial_log_rt_rv(self, gen_int_rv): @@ -132,7 +109,7 @@ def test_rejects_missing_initial_log_rt_rv(self, gen_int_rv): initial_log_rt_rv=None, baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=3, + n_initialization_points=7, ) def test_rejects_missing_baseline_rt_process(self, gen_int_rv): @@ -144,7 +121,7 @@ def test_rejects_missing_baseline_rt_process(self, gen_int_rv): initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=None, subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=3, + n_initialization_points=7, ) def test_rejects_missing_subpop_rt_deviation_process(self, gen_int_rv): @@ -156,7 +133,7 @@ def test_rejects_missing_subpop_rt_deviation_process(self, gen_int_rv): initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=None, - n_initialization_points=3, + n_initialization_points=7, ) def test_rejects_invalid_I0(self, gen_int_rv): @@ -168,7 +145,7 @@ def test_rejects_invalid_I0(self, gen_int_rv): initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=3, + n_initialization_points=7, ) def test_rejects_insufficient_n_initialization_points(self, gen_int_rv): @@ -182,48 +159,45 @@ def test_rejects_insufficient_n_initialization_points(self, gen_int_rv): initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=2, # gen_int has length 3 + n_initialization_points=2, ) - def test_rejects_fractions_not_summing_to_one(self, process): + def test_rejects_fractions_not_summing_to_one(self, hierarchical_infections): """Test that fractions not summing to 1 raises error at sample time.""" with pytest.raises(ValueError, match="must sum to 1.0"): with numpyro.handlers.seed(rng_seed=42): - process.sample( + hierarchical_infections.sample( n_days_post_init=30, - subpop_fractions=jnp.array([0.3, 0.25, 0.40]), # Sum is 0.95 + subpop_fractions=jnp.array([0.3, 0.25, 0.40]), ) - def test_validate_method(self, process): - """Test that validate() method runs without error.""" - process.validate() - class TestHierarchicalInfectionsPerSubpopI0: """Test per-subpopulation I0 values.""" def test_per_subpop_I0_array(self, gen_int_rv): - """Test with per-subpopulation I0 values (array instead of scalar).""" + """Test with per-subpopulation I0 values and verify positivity.""" process = HierarchicalInfections( gen_int_rv=gen_int_rv, I0_rv=DeterministicVariable("I0", jnp.array([0.001, 0.002, 0.0015])), initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=3, + n_initialization_points=7, ) with numpyro.handlers.seed(rng_seed=42): - result = process.sample( + inf_juris, inf_all = process.sample( n_days_post_init=30, subpop_fractions=jnp.array([0.3, 0.25, 0.45]), ) - inf_juris, inf_all = result n_total = process.n_initialization_points + 30 assert inf_juris.shape == (n_total,) assert inf_all.shape == (n_total, 3) + assert jnp.all(inf_juris > 0) + assert jnp.all(inf_all > 0) if __name__ == "__main__": diff --git a/test/test_hierarchical_priors.py b/test/test_hierarchical_priors.py index 70f9a50f..fb99ca0e 100644 --- a/test/test_hierarchical_priors.py +++ b/test/test_hierarchical_priors.py @@ -17,16 +17,17 @@ class TestHierarchicalNormalPrior: """Test HierarchicalNormalPrior.""" - def test_sample_shape(self): - """Test that sample returns correct shape.""" + def test_sample_shape_and_centering(self): + """Test that sample returns correct shape and is centered near zero.""" prior = HierarchicalNormalPrior( "effect", sd_rv=DeterministicVariable("sd", 1.0) ) with numpyro.handlers.seed(rng_seed=42): - samples = prior.sample(n_groups=5) + samples = prior.sample(n_groups=1000) - assert samples.shape == (5,) + assert samples.shape == (1000,) + assert jnp.abs(jnp.mean(samples)) < 0.2 def test_smaller_sd_produces_tighter_distribution(self): """Test that smaller sd produces samples closer to zero.""" @@ -43,7 +44,6 @@ def test_smaller_sd_produces_tighter_distribution(self): with numpyro.handlers.seed(rng_seed=43): samples_wide = prior_wide.sample(n_groups=n_samples) - # Tight prior should have smaller standard deviation assert jnp.std(samples_tight) < jnp.std(samples_wide) def test_rejects_non_random_variable_sd(self): @@ -61,20 +61,12 @@ def test_accepts_distributional_variable_for_sd(self): assert samples.shape == (5,) - def test_validate_method(self): - """Test that validate() method runs without error.""" - prior = HierarchicalNormalPrior( - "effect", sd_rv=DeterministicVariable("sd", 1.0) - ) - # Should not raise - validate is a no-op - prior.validate() - class TestGammaGroupSdPrior: """Test GammaGroupSdPrior.""" - def test_sample_shape(self): - """Test that sample returns correct shape.""" + def test_sample_shape_and_positivity(self): + """Test that sample returns correct shape with positive values.""" prior = GammaGroupSdPrior( "sd", sd_mean_rv=DeterministicVariable("sd_mean", 0.5), @@ -82,9 +74,10 @@ def test_sample_shape(self): ) with numpyro.handlers.seed(rng_seed=42): - samples = prior.sample(n_groups=5) + samples = prior.sample(n_groups=100) - assert samples.shape == (5,) + assert samples.shape == (100,) + assert jnp.all(samples > 0) def test_respects_sd_min(self): """Test that sd_min is enforced as lower bound.""" @@ -128,22 +121,12 @@ def test_rejects_negative_sd_min(self): sd_min=-0.1, ) - def test_validate_method(self): - """Test that validate() method runs without error.""" - prior = GammaGroupSdPrior( - "sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.5), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), - ) - # Should not raise - validate is a no-op - prior.validate() - class TestStudentTGroupModePrior: """Test StudentTGroupModePrior.""" - def test_sample_shape(self): - """Test that sample returns correct shape.""" + def test_sample_shape_and_centering(self): + """Test that sample returns correct shape and is centered near zero.""" prior = StudentTGroupModePrior( "mode", sd_rv=DeterministicVariable("sd", 1.0), @@ -151,13 +134,13 @@ def test_sample_shape(self): ) with numpyro.handlers.seed(rng_seed=42): - samples = prior.sample(n_groups=5) + samples = prior.sample(n_groups=1000) - assert samples.shape == (5,) + assert samples.shape == (1000,) + assert jnp.abs(jnp.mean(samples)) < 0.3 def test_heavier_tails_than_normal(self): """Test Student-t produces more extreme values than Normal.""" - # df=2 gives very heavy tails student_prior = StudentTGroupModePrior( "s", sd_rv=DeterministicVariable("sd_s", 1.0), @@ -173,7 +156,6 @@ def test_heavier_tails_than_normal(self): with numpyro.handlers.seed(rng_seed=42): normal_samples = normal_prior.sample(n_groups=n_samples) - # Student-t should have more extreme values (higher max absolute value) assert jnp.max(jnp.abs(student_samples)) > jnp.max(jnp.abs(normal_samples)) def test_rejects_non_random_variable_params(self): @@ -192,16 +174,6 @@ def test_rejects_non_random_variable_params(self): df_rv=4.0, ) - def test_validate_method(self): - """Test that validate() method runs without error.""" - prior = StudentTGroupModePrior( - "mode", - sd_rv=DeterministicVariable("sd", 1.0), - df_rv=DeterministicVariable("df", 4.0), - ) - # Should not raise - validate is a no-op - prior.validate() - if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index e43b5272..4ae3e0b1 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -16,53 +16,15 @@ NegativeBinomialNoise, PoissonNoise, ) -from pyrenew.observation.count_observations import _CountBase from pyrenew.randomvariable import DistributionalVariable - - -def create_mock_infections( - n_days: int, - peak_day: int = 10, - peak_value: float = 1000.0, - shape: str = "spike", -) -> jnp.ndarray: - """ - Create mock infection time series for testing. - - Parameters - ---------- - n_days : int - Number of days - peak_day : int - Day of peak infections - peak_value : float - Peak infection value - shape : str - Shape of the curve: "spike", "constant", or "decay" - - Returns - ------- - jnp.ndarray - Array of infections of shape (n_days,) - """ - if shape == "spike": - infections = jnp.zeros(n_days) - infections = infections.at[peak_day].set(peak_value) - elif shape == "constant": - infections = jnp.ones(n_days) * peak_value - elif shape == "decay": - infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) - else: - raise ValueError(f"Unknown shape: {shape}") - - return infections +from test.test_helpers import create_mock_infections class TestCountsBasics: """Test basic functionality of aggregated count observation process.""" - def test_sample_returns_correct_shape(self, counts_process): - """Test that sample returns correct shape.""" + def test_sample_returns_correct_shape_with_value_checks(self, counts_process): + """Test that sample returns correct shape with non-negative predicted counts.""" infections = jnp.ones(30) * 100 with numpyro.handlers.seed(rng_seed=42): @@ -74,6 +36,10 @@ def test_sample_returns_correct_shape(self, counts_process): assert result.observed.shape[0] > 0 assert result.observed.ndim == 1 assert result.predicted.shape == infections.shape + # Predicted counts must be non-negative + assert jnp.all(result.predicted >= 0) + # Observed counts must be non-negative (count data) + assert jnp.all(result.observed >= 0) def test_delay_convolution(self, counts_factory, short_delay_pmf): """Test that delay is properly applied.""" @@ -88,13 +54,11 @@ def test_delay_convolution(self, counts_factory, short_delay_pmf): obs=None, ) - # Timeline alignment: both predicted and observed have same length as input assert result.predicted.shape[0] == len(infections) assert result.observed.shape[0] == len(infections) # First len(pmf) - 1 entries in predicted are NaN (initialization period) assert jnp.all(jnp.isnan(result.predicted[:1])) assert jnp.all(~jnp.isnan(result.predicted[1:])) - # Observed is sampled for all entries (masked entries don't affect likelihood) assert jnp.all(result.observed >= 0) def test_ascertainment_scaling(self, counts_factory, simple_delay_pmf): @@ -115,12 +79,11 @@ def test_ascertainment_scaling(self, counts_factory, simple_delay_pmf): ) results.append(jnp.mean(result.observed)) - # Higher ascertainment rate should lead to more counts assert results[1] > results[0] assert results[2] > results[1] def test_negative_binomial_observation(self, counts_factory, simple_delay_pmf): - """Test that negative binomial observation is used.""" + """Test that negative binomial observation produces variability.""" process = counts_factory.create( delay_pmf=simple_delay_pmf, concentration=5.0, @@ -137,9 +100,49 @@ def test_negative_binomial_observation(self, counts_factory, simple_delay_pmf): ) samples.append(jnp.sum(result.observed)) - # Should have some variability due to negative binomial sampling assert jnp.std(jnp.array(samples)) > 0 + def test_convolution_hand_computable(self): + """Test convolution with hand-computable spike input. + + 1000 infections on day 10, delay PMF [0.5, 0.5], ascertainment 1.0. + Expected predicted: 500 on day 10, 500 on day 11. + """ + process = Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([0.5, 0.5])), + noise=PoissonNoise(), + ) + + infections = jnp.zeros(30) + infections = infections.at[10].set(1000.0) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample(infections=infections, obs=None) + + # Day 10: 1000 * 0.5 * 1.0 = 500 + # Day 11: 1000 * 0.5 * 1.0 = 500 + # Day 0 is NaN (initialization from 2-element PMF) + assert jnp.isclose(result.predicted[10], 500.0, atol=1.0) + assert jnp.isclose(result.predicted[11], 500.0, atol=1.0) + # Other days (post-init, not 10 or 11) should be near zero + assert jnp.isclose(result.predicted[5], 0.0, atol=1.0) + + def test_observation_passthrough(self, counts_process): + """Test that providing obs returns those exact values.""" + infections = jnp.ones(30) * 100 + known_obs = jnp.arange(30, dtype=jnp.float32) + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + obs=known_obs, + ) + + # When obs is provided, observed values should equal obs + assert jnp.allclose(result.observed, known_obs) + class TestCountsWithPriors: """Test aggregated count observation with uncertain parameters.""" @@ -198,7 +201,7 @@ class TestCountsEdgeCases: """Test edge cases and error handling.""" def test_zero_infections(self, counts_process): - """Test with zero infections.""" + """Test with zero infections produces zero predicted counts.""" infections = jnp.zeros(20) with numpyro.handlers.seed(rng_seed=42): @@ -209,9 +212,11 @@ def test_zero_infections(self, counts_process): assert result.observed.shape[0] > 0 assert jnp.all(result.observed >= 0) + # Zero infections should produce zero predicted counts + assert jnp.allclose(result.predicted, 0.0) def test_small_infections(self, counts_process): - """Test with small infection values.""" + """Test with small infection values produces plausible counts.""" infections = jnp.ones(20) * 10 with numpyro.handlers.seed(rng_seed=42): @@ -222,6 +227,9 @@ def test_small_infections(self, counts_process): assert result.observed.shape[0] > 0 assert jnp.all(result.observed >= 0) + # With ascertainment rate 0.01 and 10 infections, + # predicted should be ~0.1 per day + assert jnp.all(result.predicted <= 10) def test_long_delay_distribution(self, counts_factory, long_delay_pmf): """Test with longer delay distribution.""" @@ -247,9 +255,8 @@ def test_dense_observations_with_nan_padding(self, counts_process): n_days = 30 infections = jnp.ones(n_days) * 100 - # Dense observations with NaN for "missing" days obs = jnp.ones(n_days) * 10.0 - obs = obs.at[:5].set(jnp.nan) # First 5 days "missing" + obs = obs.at[:5].set(jnp.nan) with numpyro.handlers.seed(rng_seed=42): result = counts_process.sample( @@ -257,76 +264,18 @@ def test_dense_observations_with_nan_padding(self, counts_process): obs=obs, ) - # With masking, observed has same shape as input (masked entries - # are sampled but don't contribute to likelihood) assert result.observed.shape[0] == n_days assert result.predicted.shape[0] == n_days - def test_prior_sampling_dense(self, counts_process): - """Test prior sampling produces dense output.""" - n_days = 30 - infections = jnp.ones(n_days) * 100 - - with numpyro.handlers.seed(rng_seed=42): - result = counts_process.sample( - infections=infections, - obs=None, - ) - - # Prior sampling: observed excludes NaN predictions (init period) - assert result.observed.shape[0] == n_days # simple_delay_pmf has no init - assert result.predicted.shape == (n_days,) - assert jnp.all(~jnp.isnan(result.observed)) - class TestCountsBySubpop: """Test CountsBySubpop for subpopulation-level observations.""" - def test_sample_returns_correct_shape(self): - """Test that CountsBySubpop sample returns correct shape.""" - delay_pmf = jnp.array([0.3, 0.4, 0.3]) - process = CountsBySubpop( - name="test", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), - delay_distribution_rv=DeterministicPMF("delay", delay_pmf), - noise=PoissonNoise(), - ) - - infections = jnp.ones((30, 3)) * 500 # 30 days, 3 subpops - # Times on shared axis (must be >= len(delay_pmf) - 1 to avoid NaN) - times = jnp.array([10, 15, 10, 15]) - subpop_indices = jnp.array([0, 0, 1, 1]) - - with numpyro.handlers.seed(rng_seed=42): - result = process.sample( - infections=infections, - times=times, - subpop_indices=subpop_indices, - obs=None, - ) - - assert result.observed.shape == times.shape - assert result.predicted.shape == infections.shape - - def test_infection_resolution(self): - """Test that CountsBySubpop returns 'subpop' resolution.""" - delay_pmf = jnp.array([1.0]) - process = CountsBySubpop( - name="test", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", delay_pmf), - noise=PoissonNoise(), - ) - - assert process.infection_resolution() == "subpop" - def test_non_contiguous_subpop_indices(self): """Test that non-contiguous subpop_indices work correctly. - This verifies that observation processes can observe any subset - of subpopulations, not just contiguous indices starting from 0. - For example, with K=5 subpopulations, observations might only - cover indices {0, 2, 4} while indices {1, 3} are unobserved. + Verifies observation processes can observe any subset + of subpopulations with correct value proportionality. """ delay_pmf = jnp.array([0.3, 0.4, 0.3]) process = CountsBySubpop( @@ -337,15 +286,13 @@ def test_non_contiguous_subpop_indices(self): ) # 5 subpopulations with distinct infection levels - # Subpop 0: 100, Subpop 1: 200, Subpop 2: 300, Subpop 3: 400, Subpop 4: 500 n_days = 20 infections = jnp.zeros((n_days, 5)) for k in range(5): infections = infections.at[:, k].set((k + 1) * 100.0) - # Observe only subpops 0, 2, 4 (non-contiguous, skipping 1 and 3) times = jnp.array([10, 10, 10]) - subpop_indices = jnp.array([0, 2, 4]) # Non-contiguous! + subpop_indices = jnp.array([0, 2, 4]) with numpyro.handlers.seed(rng_seed=42): result = process.sample( @@ -355,19 +302,11 @@ def test_non_contiguous_subpop_indices(self): obs=None, ) - # Verify correct shape assert result.observed.shape == (3,) - # Verify the predicted values correspond to the correct subpopulations - # predicted[t, k] should reflect infections from subpop k - # At time 10, predicted counts should be proportional to infection levels + # Predicted values should be proportional to infection levels predicted_at_obs = result.predicted[10, subpop_indices] - - # Subpop 0 has 100 infections, subpop 2 has 300, subpop 4 has 500 - # So predicted[10, 0] < predicted[10, 2] < predicted[10, 4] assert predicted_at_obs[0] < predicted_at_obs[1] < predicted_at_obs[2] - - # Verify the ratios match the infection ratios (100:300:500 = 1:3:5) assert jnp.isclose(predicted_at_obs[1] / predicted_at_obs[0], 3.0, atol=0.01) assert jnp.isclose(predicted_at_obs[2] / predicted_at_obs[0], 5.0, atol=0.01) @@ -375,8 +314,8 @@ def test_non_contiguous_subpop_indices(self): class TestPoissonNoise: """Test PoissonNoise model.""" - def test_poisson_counts(self, simple_delay_pmf): - """Test Counts with Poisson noise.""" + def test_poisson_mean_approximation(self, simple_delay_pmf): + """Test that Poisson samples have mean close to predicted rate.""" process = Counts( name="test", ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), @@ -384,57 +323,22 @@ def test_poisson_counts(self, simple_delay_pmf): noise=PoissonNoise(), ) - infections = jnp.ones(20) * 1000 + infections = jnp.ones(20) * 10000 # Large enough for stable mean - with numpyro.handlers.seed(rng_seed=42): - result = process.sample( - infections=infections, - obs=None, - ) - - assert result.observed.shape[0] == 20 - assert jnp.all(result.observed >= 0) - - -class TestCountBaseInternalMethods: - """Test internal _CountBase methods for coverage.""" - - def test_count_base_infection_resolution_raises(self, simple_delay_pmf): - """Test that _CountBase.infection_resolution() raises NotImplementedError.""" - - # Create a subclass that doesn't override infection_resolution - class IncompleteCountProcess(_CountBase): - """Incomplete count process for testing.""" - - def sample(self, **kwargs): - """Sample method stub.""" - pass - - process = IncompleteCountProcess( - name="test", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), - noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), - ) - with pytest.raises( - NotImplementedError, match="Subclasses must implement infection_resolution" - ): - process.infection_resolution() + samples = [] + for seed in range(50): + with numpyro.handlers.seed(rng_seed=seed): + result = process.sample(infections=infections, obs=None) + samples.append(result.observed[10]) + sample_mean = jnp.mean(jnp.array(samples)) + expected_rate = 10000 * 0.01 # 100 + # Mean should be within 20% of expected rate + assert jnp.abs(sample_mean - expected_rate) / expected_rate < 0.2 -class TestValidationMethods: - """Test validation methods for coverage.""" - def test_validate_calls_all_validations(self, simple_delay_pmf): - """Test that validate() calls all necessary validations.""" - process = Counts( - name="test", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), - noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), - ) - # Should not raise - process.validate() +class TestCountsValidation: + """Test validation methods.""" def test_validate_invalid_ascertainment_rate_negative(self, simple_delay_pmf): """Test that validate raises for negative ascertainment rate.""" @@ -460,101 +364,10 @@ def test_validate_invalid_ascertainment_rate_greater_than_one( with pytest.raises(ValueError, match="ascertainment_rate_rv must be in"): process.validate() - def test_lookback_days(self, simple_delay_pmf, long_delay_pmf): - """Test lookback_days returns PMF length minus 1 (0-indexed delays).""" - process_short = Counts( - name="test_short", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), - noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), - ) - # simple_delay_pmf has length 1, lookback = 1 - 1 = 0 - assert process_short.lookback_days() == 0 - - process_long = Counts( - name="test_long", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", long_delay_pmf), - noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), - ) - # long_delay_pmf has length 10, lookback = 10 - 1 = 9 - assert process_long.lookback_days() == 9 - - def test_infection_resolution_counts(self, simple_delay_pmf): - """Test that Counts returns 'aggregate' resolution.""" - process = Counts( - name="test", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), - noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), - ) - assert process.infection_resolution() == "aggregate" - - -class TestNoiseRepr: - """Test noise model __repr__ methods.""" - - def test_poisson_noise_repr(self): - """Test PoissonNoise __repr__ method.""" - noise = PoissonNoise() - assert repr(noise) == "PoissonNoise()" - - def test_negative_binomial_noise_repr(self): - """Test NegativeBinomialNoise __repr__ method.""" - conc_rv = DeterministicVariable("conc", 10.0) - noise = NegativeBinomialNoise(conc_rv) - repr_str = repr(noise) - assert "NegativeBinomialNoise" in repr_str - assert "concentration_rv" in repr_str - - -class TestCountsRepr: - """Test Counts and CountsBySubpop __repr__ methods.""" - - def test_counts_repr(self, simple_delay_pmf): - """Test Counts __repr__ method.""" - process = Counts( - name="test", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), - noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), - ) - repr_str = repr(process) - assert "Counts" in repr_str - assert "name='test'" in repr_str - assert "ascertainment_rate_rv" in repr_str - assert "delay_distribution_rv" in repr_str - assert "noise" in repr_str - - def test_counts_by_subpop_repr(self, simple_delay_pmf): - """Test CountsBySubpop __repr__ method.""" - process = CountsBySubpop( - name="test_subpop", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), - delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), - noise=PoissonNoise(), - ) - repr_str = repr(process) - assert "CountsBySubpop" in repr_str - assert "name='test_subpop'" in repr_str - assert "ascertainment_rate_rv" in repr_str - class TestNoiseValidation: """Test noise model validation methods.""" - def test_poisson_noise_validate(self): - """Test PoissonNoise validate method.""" - noise = PoissonNoise() - # Should not raise - Poisson has no parameters to validate - noise.validate() - - def test_negative_binomial_noise_validate_success(self): - """Test NegativeBinomialNoise validate with valid concentration.""" - noise = NegativeBinomialNoise(DeterministicVariable("conc", 10.0)) - # Should not raise - noise.validate() - def test_negative_binomial_noise_validate_zero_concentration(self): """Test NegativeBinomialNoise validate with zero concentration.""" noise = NegativeBinomialNoise(DeterministicVariable("conc", 0.0)) @@ -591,7 +404,7 @@ def test_validate_pmf_sum_not_one(self, simple_delay_pmf): delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), ) - bad_pmf = jnp.array([0.3, 0.3, 0.3]) # sums to 0.9 + bad_pmf = jnp.array([0.3, 0.3, 0.3]) with pytest.raises(ValueError, match="must sum to 1.0"): process._validate_pmf(bad_pmf, "test_pmf") @@ -603,7 +416,7 @@ def test_validate_pmf_negative_values(self, simple_delay_pmf): delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), ) - bad_pmf = jnp.array([1.5, -0.5]) # sums to 1.0 but has negative + bad_pmf = jnp.array([1.5, -0.5]) with pytest.raises(ValueError, match="must have non-negative values"): process._validate_pmf(bad_pmf, "test_pmf") diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py index 671526e4..0bffebcb 100644 --- a/test/test_observation_measurements.py +++ b/test/test_observation_measurements.py @@ -1,7 +1,8 @@ """ Unit tests for Measurements (continuous measurement observations). -These tests validate the measurement observation process base class implementation. +These tests validate the measurement observation process implementation +using ConcreteMeasurements from conftest.py. """ import jax.numpy as jnp @@ -12,95 +13,10 @@ from pyrenew.deterministic import DeterministicPMF from pyrenew.observation import ( HierarchicalNormalNoise, - Measurements, VectorizedRV, ) -from pyrenew.observation.base import BaseObservationProcess from pyrenew.randomvariable import DistributionalVariable - - -class ConcreteMeasurements(Measurements): - """Concrete implementation of Measurements for testing.""" - - def __init__(self, name, temporal_pmf_rv, noise, log10_scale=9.0): - """Initialize the concrete measurements for testing.""" - super().__init__(name=name, temporal_pmf_rv=temporal_pmf_rv, noise=noise) - self.log10_scale = log10_scale - - def validate(self) -> None: - """Validate parameters.""" - pmf = self.temporal_pmf_rv() - self._validate_pmf(pmf, "temporal_pmf_rv") - - def lookback_days(self) -> int: - """ - Return required lookback days for this observation. - - Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF - of length L covers lags 0 to L-1, requiring L-1 initialization points. - - Returns - ------- - int - Length of temporal PMF minus 1. - """ - return len(self.temporal_pmf_rv()) - 1 - - def _predicted_obs(self, infections): - """ - Simple predicted signal: log(convolution * scale). - - Returns - ------- - jnp.ndarray - Log-transformed predicted signal. - """ - pmf = self.temporal_pmf_rv() - - # Handle 2D infections (n_days, n_subpops) - if infections.ndim == 1: - infections = infections[:, jnp.newaxis] - - def convolve_col(col): # numpydoc ignore=GL08 - return self._convolve_with_alignment(col, pmf, 1.0)[0] - - import jax - - predicted = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) - - # Apply log10 scaling (simplified from wastewater model) - log_predicted = jnp.log(predicted + 1e-10) + self.log10_scale * jnp.log(10) - - return log_predicted - - -class TestMeasurementsBase: - """Test Measurements abstract base class.""" - - def test_is_base_observation_process(self): - """Test that Measurements inherits from BaseObservationProcess.""" - assert issubclass(Measurements, BaseObservationProcess) - - def test_infection_resolution_is_subpop(self): - """Test that Measurements returns 'subpop' resolution.""" - shedding_pmf = jnp.array([0.3, 0.4, 0.3]) - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.5)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) - - process = ConcreteMeasurements( - name="test", - temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), - noise=noise, - ) - - assert process.infection_resolution() == "subpop" +from test.test_helpers import ConcreteMeasurements class TestVectorizedRV: @@ -115,85 +31,42 @@ def test_init_and_sample(self): samples = vectorized.sample(n_groups=5) assert samples.shape == (5,) + # Verify samples are actually different (not degenerate) + assert jnp.std(samples) > 0 class TestHierarchicalNormalNoise: """Test HierarchicalNormalNoise model.""" - def test_repr(self): - """Test HierarchicalNormalNoise __repr__ method.""" - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.5)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) - repr_str = repr(noise) - assert "HierarchicalNormalNoise" in repr_str - assert "sensor_mode_rv" in repr_str - assert "sensor_sd_rv" in repr_str - - def test_validate(self): - """Test HierarchicalNormalNoise validate method.""" - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.5)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) - # Should not raise - validation is deferred to sample time - noise.validate() - - def test_sample_shape(self): - """Test that HierarchicalNormalNoise produces correct shape.""" - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.5)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) - - predicted = jnp.array([1.0, 2.0, 3.0, 4.0]) - sensor_indices = jnp.array([0, 0, 1, 1]) + def test_sample_shape_and_sensor_variation(self, hierarchical_normal_noise): + """Test shape and that different sensors produce different biases.""" + predicted = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + sensor_indices = jnp.array([0, 0, 1, 1, 2, 2]) with numpyro.handlers.seed(rng_seed=42): - samples = noise.sample( - name="test", - predicted=predicted, - obs=None, - sensor_indices=sensor_indices, - n_sensors=2, - ) + with numpyro.handlers.trace() as trace: + samples = hierarchical_normal_noise.sample( + name="test", + predicted=predicted, + obs=None, + sensor_indices=sensor_indices, + n_sensors=3, + ) assert samples.shape == predicted.shape - def test_sample_with_observations(self): - """Test that HierarchicalNormalNoise conditions on observations.""" - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.5)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + # Verify sensor modes are sampled (exist in trace) + sensor_modes = trace["ww_sensor_mode"]["value"] + assert sensor_modes.shape == (3,) + def test_sample_with_observations(self, hierarchical_normal_noise): + """Test that HierarchicalNormalNoise conditions on observations.""" predicted = jnp.array([1.0, 2.0, 3.0, 4.0]) obs = jnp.array([1.1, 2.1, 3.1, 4.1]) sensor_indices = jnp.array([0, 0, 1, 1]) with numpyro.handlers.seed(rng_seed=42): - samples = noise.sample( + samples = hierarchical_normal_noise.sample( name="test", predicted=predicted, obs=obs, @@ -208,69 +81,26 @@ def test_sample_with_observations(self): class TestConcreteMeasurements: """Test concrete Measurements implementation.""" - def test_lookback_days(self): + def test_lookback_days(self, hierarchical_normal_noise): """Test lookback_days returns len(pmf) - 1.""" - # PMF of length 3 should return 2 (covers lags 0, 1, 2) shedding_pmf = jnp.array([0.3, 0.4, 0.3]) - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.5)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) process = ConcreteMeasurements( name="test", temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), - noise=noise, + noise=hierarchical_normal_noise, ) - assert process.lookback_days() == 2 # len(3) - 1 = 2 + assert process.lookback_days() == 2 - def test_repr(self): - """Test Measurements __repr__ method.""" + def test_sample_shape_and_log_scale(self, hierarchical_normal_noise): + """Test that sample returns correct shape and log-scale output.""" shedding_pmf = jnp.array([0.3, 0.4, 0.3]) - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.5)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) process = ConcreteMeasurements( name="test", temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), - noise=noise, - ) - - repr_str = repr(process) - assert "ConcreteMeasurements" in repr_str - assert "temporal_pmf_rv" in repr_str - assert "noise" in repr_str - - def test_sample_shape(self): - """Test that sample returns correct shape.""" - shedding_pmf = jnp.array([0.3, 0.4, 0.3]) - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.5)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) - - process = ConcreteMeasurements( - name="test", - temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), - noise=noise, + noise=hierarchical_normal_noise, ) infections = jnp.ones((30, 2)) * 1000 @@ -290,24 +120,17 @@ def test_sample_shape(self): assert result.observed.shape == times.shape assert result.predicted.shape == infections.shape + # Output should be in log-scale (large positive values due to log10_scale=9) + assert jnp.all(result.predicted[2:, :] > 0) - def test_predicted_obs_stored(self): + def test_predicted_obs_stored(self, hierarchical_normal_noise_tight): """Test that predicted values are stored as deterministic.""" shedding_pmf = jnp.array([0.5, 0.5]) - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.01)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.01, 0.005, low=0.001)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) process = ConcreteMeasurements( name="test", temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), - noise=noise, + noise=hierarchical_normal_noise_tight, ) infections = jnp.ones((20, 2)) * 1000 @@ -330,42 +153,29 @@ def test_predicted_obs_stored(self): assert "test_obs" in trace assert "test_predicted" in trace - def test_non_contiguous_subpop_indices(self): + def test_non_contiguous_subpop_indices(self, hierarchical_normal_noise_tight): """Test that non-contiguous subpop_indices work correctly. This verifies that observation processes can observe any subset of subpopulations, not just contiguous indices starting from 0. - For example, with K=5 subpopulations, observations might only - cover indices {0, 2, 4} while indices {1, 3} are unobserved. """ shedding_pmf = jnp.array([0.5, 0.5]) - sensor_mode_rv = VectorizedRV( - DistributionalVariable("mode", dist.Normal(0, 0.01)), - plate_name="sensor_mode", - ) - sensor_sd_rv = VectorizedRV( - DistributionalVariable("sd", dist.TruncatedNormal(0.01, 0.005, low=0.001)), - plate_name="sensor_sd", - ) - noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) process = ConcreteMeasurements( name="test", temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), - noise=noise, + noise=hierarchical_normal_noise_tight, ) # 5 subpopulations with distinct infection levels - # Subpop 0: 100, Subpop 1: 200, Subpop 2: 300, Subpop 3: 400, Subpop 4: 500 n_days = 20 infections = jnp.zeros((n_days, 5)) for k in range(5): infections = infections.at[:, k].set((k + 1) * 100.0) - # Observe only subpops 0, 2, 4 (non-contiguous, skipping 1 and 3) - # Each observation from a different sensor + # Observe only subpops 0, 2, 4 (non-contiguous) times = jnp.array([10, 10, 10]) - subpop_indices = jnp.array([0, 2, 4]) # Non-contiguous! + subpop_indices = jnp.array([0, 2, 4]) sensor_indices = jnp.array([0, 1, 2]) with numpyro.handlers.seed(rng_seed=42): @@ -378,23 +188,12 @@ def test_non_contiguous_subpop_indices(self): obs=None, ) - # Verify correct shape assert result.observed.shape == (3,) - # Verify the predicted values correspond to the correct subpopulations - # predicted[t, k] should reflect infections from subpop k - # At time 10, predicted values should be proportional to infection levels - # Note: ConcreteMeasurements returns LOG-scale values, so linear ratios - # become differences in log space + # Predicted values should be proportional to infection levels + # In log space: differences should match log of ratios predicted_at_obs = result.predicted[10, subpop_indices] - - # Subpop 0 has 100 infections, subpop 2 has 300, subpop 4 has 500 - # In log space: log(300) - log(100) = log(3), log(500) - log(100) = log(5) assert predicted_at_obs[0] < predicted_at_obs[1] < predicted_at_obs[2] - - # Verify the differences match log of the infection ratios - # diff[1] - diff[0] should equal log(3) ≈ 1.099 - # diff[2] - diff[0] should equal log(5) ≈ 1.609 assert jnp.isclose( predicted_at_obs[1] - predicted_at_obs[0], jnp.log(3.0), atol=0.01 ) @@ -402,6 +201,79 @@ def test_non_contiguous_subpop_indices(self): predicted_at_obs[2] - predicted_at_obs[0], jnp.log(5.0), atol=0.01 ) + def test_log_scale_correctness(self, hierarchical_normal_noise_tight): + """Test that output is log-scale of convolved infections times scale.""" + # Use simple PMF [1.0] so convolution is identity + process = ConcreteMeasurements( + name="test", + temporal_pmf_rv=DeterministicPMF("shedding", jnp.array([1.0])), + noise=hierarchical_normal_noise_tight, + log10_scale=0.0, # No scaling, so output = log(infections) + ) + + infections = jnp.ones((20, 1)) * 500.0 + times = jnp.array([10]) + subpop_indices = jnp.array([0]) + sensor_indices = jnp.array([0]) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + times=times, + subpop_indices=subpop_indices, + sensor_indices=sensor_indices, + n_sensors=1, + obs=None, + ) + + # With PMF=[1.0] and log10_scale=0, predicted should be log(500) + expected = jnp.log(500.0) + assert jnp.isclose(result.predicted[10, 0], expected, atol=0.01) + + def test_sensor_bias_differences(self): + """Test that hierarchical noise produces sensor-specific biases.""" + shedding_pmf = jnp.array([1.0]) + + # Use wide priors to ensure sensors get distinguishable biases + sensor_mode_rv = VectorizedRV( + DistributionalVariable("mode", dist.Normal(0, 2.0)), + plate_name="sensor_mode", + ) + sensor_sd_rv = VectorizedRV( + DistributionalVariable("sd", dist.TruncatedNormal(0.1, 0.05, low=0.01)), + plate_name="sensor_sd", + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + process = ConcreteMeasurements( + name="test", + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + infections = jnp.ones((30, 1)) * 1000.0 + # Same subpop, same time, 3 different sensors + times = jnp.array([15, 15, 15]) + subpop_indices = jnp.array([0, 0, 0]) + sensor_indices = jnp.array([0, 1, 2]) + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + process.sample( + infections=infections, + times=times, + subpop_indices=subpop_indices, + sensor_indices=sensor_indices, + n_sensors=3, + obs=None, + ) + + # Sensor modes should be different (with wide prior) + sensor_modes = trace["mode"]["value"] + assert sensor_modes.shape == (3,) + # Not all modes should be identical + assert jnp.std(sensor_modes) > 0 + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index 5a947e30..916eeb17 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -3,7 +3,6 @@ """ import jax.numpy as jnp -import numpyro import pytest from pyrenew.deterministic import DeterministicPMF, DeterministicVariable @@ -65,7 +64,7 @@ def test_rejects_population_structure_at_configure_time(self): initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), - subpop_fractions=jnp.array([0.5, 0.5]), # Should fail + subpop_fractions=jnp.array([0.5, 0.5]), ) def test_rejects_n_initialization_points_at_configure_time(self): @@ -81,7 +80,7 @@ def test_rejects_n_initialization_points_at_configure_time(self): initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), baseline_rt_process=RandomWalk(), subpop_rt_deviation_process=RandomWalk(), - n_initialization_points=10, # Should fail + n_initialization_points=10, ) def test_rejects_reconfiguring_latent(self): @@ -112,7 +111,7 @@ def test_rejects_duplicate_observation_name(self, simple_builder): """Test that adding duplicate observation name raises ValueError.""" delay = DeterministicPMF("delay2", jnp.array([0.5, 0.5])) obs = Counts( - name="hospital", # Same name as existing observation + name="hospital", ascertainment_rate_rv=DeterministicVariable("ihr2", 0.02), delay_distribution_rv=delay, noise=NegativeBinomialNoise(DeterministicVariable("conc2", 20.0)), @@ -143,9 +142,8 @@ def test_compute_n_initialization_points_without_latent_raises(self): def test_compute_n_initialization_points_without_gen_int_raises(self): """Test that compute_n_initialization_points without gen_int_rv raises.""" builder = PyrenewBuilder() - # Configure latent without gen_int_rv builder.latent_class = HierarchicalInfections - builder.latent_params = {} # Missing gen_int_rv + builder.latent_params = {} with pytest.raises(ValueError, match="gen_int_rv is required"): builder.compute_n_initialization_points() @@ -155,8 +153,8 @@ def test_compute_n_initialization_points_returns_correct_value( ): """Test that compute_n_initialization_points returns max of lookbacks.""" n_init = simple_builder.compute_n_initialization_points() - # gen_int has 3 elements (1-indexed) -> 3 - # delay has 4 elements (0-indexed) -> lookback_days = 3 + # gen_int has 3 elements -> 3 + # delay has 4 elements -> lookback_days = 3 # max(3, 3) = 3 assert n_init == 3 @@ -164,36 +162,14 @@ def test_compute_n_initialization_points_returns_correct_value( class TestMultiSignalModelSampling: """Test MultiSignalModel sampling with population structure at sample time.""" - def test_sample_with_population_structure(self, simple_builder): - """Test that sample() works with population structure at sample time.""" - model = simple_builder.build() - n_days = 30 - n_total = model.latent.n_initialization_points + n_days - - with numpyro.handlers.seed(rng_seed=42): - with numpyro.handlers.trace() as tr: - model.sample( - n_days_post_init=n_days, - population_size=1_000_000, - subpop_fractions=SUBPOP_FRACTIONS, - hospital={"obs": None}, - ) - - inf_aggregate = tr["latent_infections"]["value"] - inf_all = tr["latent_infections_by_subpop"]["value"] - assert inf_aggregate.shape == (n_total,) - assert inf_all.shape == (n_total, 3) # K=3 - def test_run_with_population_structure(self, simple_builder): - """Test that run() works with population structure at sample time.""" + """Test that run() works and produces reasonable posterior samples.""" model = simple_builder.build() n_days = 10 n_total = model.latent.n_initialization_points + n_days - # Create dense observations with NaN padding for initialization period obs_values = jnp.array([10.0, 12.0, 15.0, 14.0, 11.0]) obs = model.pad_observations(obs_values) - # Pad with NaN for remaining days obs = jnp.concatenate([obs, jnp.full(n_days - len(obs_values), jnp.nan)]) model.run( @@ -208,25 +184,42 @@ def test_run_with_population_structure(self, simple_builder): samples = model.mcmc.get_samples() assert "latent_infections" in samples assert samples["latent_infections"].shape == (5, n_total) + # All infection samples should be positive + assert jnp.all(samples["latent_infections"] > 0) + def test_prior_predictive_multi_signal(self, simple_builder): + """Test prior predictive sampling from a builder-constructed model.""" + import jax.random + from numpyro.infer import Predictive -class TestMultiSignalModelValidation: - """Test data validation.""" - - def test_validate_data_requires_population_structure(self, simple_builder): - """Test that validate_data requires population structure.""" model = simple_builder.build() + n_days = 20 - # Should work with population structure - model.validate_data( - n_days_post_init=30, + predictive = Predictive( + model.sample, + num_samples=5, + ) + + rng_key = jax.random.PRNGKey(42) + prior_samples = predictive( + rng_key, + n_days_post_init=n_days, + population_size=1_000_000, subpop_fractions=SUBPOP_FRACTIONS, - hospital={ - "obs": jnp.array([10, 20]), - "times": jnp.array([5, 10]), - }, + hospital={"obs": None}, ) + n_total = model.latent.n_initialization_points + n_days + + assert "latent_infections" in prior_samples + assert prior_samples["latent_infections"].shape == (5, n_total) + # All prior predictive infections should be positive + assert jnp.all(prior_samples["latent_infections"] > 0) + + +class TestMultiSignalModelValidation: + """Test data validation.""" + def test_validate_data_rejects_out_of_bounds_times(self, simple_builder): """Test that times exceeding n_total_days raises error.""" model = simple_builder.build() @@ -279,17 +272,11 @@ def test_validate_data_rejects_mismatched_obs_times_length(self, simple_builder) n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, hospital={ - "obs": jnp.array([10, 20, 30]), # 3 elements - "times": jnp.array([5, 10]), # 2 elements + "obs": jnp.array([10, 20, 30]), + "times": jnp.array([5, 10]), }, ) - def test_validate_method_calls_internal_validate(self, simple_builder): - """Test that validate() method calls _validate().""" - model = simple_builder.build() - # Should not raise - model.validate() - def test_validate_data_rejects_negative_subpop_indices(self, simple_builder): """Test that negative subpop_indices raises error.""" model = simple_builder.build() @@ -308,124 +295,16 @@ def test_validate_data_rejects_out_of_bounds_subpop_indices(self, simple_builder """Test that subpop_indices >= K raises error.""" model = simple_builder.build() - # K is 3 (from SUBPOP_FRACTIONS = [0.3, 0.25, 0.45]) with pytest.raises(ValueError, match="subpop_indices contains"): model.validate_data( n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, hospital={ - "subpop_indices": jnp.array([0, 1, 5]), # 5 >= 3 + "subpop_indices": jnp.array([0, 1, 5]), "times": jnp.array([5, 6, 7]), }, ) -class TestPyrenewBuilderErrorHandling: - """Test PyrenewBuilder error handling.""" - - def test_build_raises_on_construction_error(self): - """Test that build() raises TypeError on latent construction failure.""" - builder = PyrenewBuilder() - gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) - - # Configure with an invalid parameter that will cause construction to fail - builder.configure_latent( - HierarchicalInfections, - gen_int_rv=gen_int, - I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - baseline_rt_process=RandomWalk(), - subpop_rt_deviation_process=RandomWalk(), - invalid_extra_param="this will cause TypeError", # Invalid param - ) - - delay = DeterministicPMF("delay", jnp.array([0.5, 0.5])) - obs = Counts( - name="hospital", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=delay, - noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), - ) - builder.add_observation(obs) - - with pytest.raises(TypeError, match="unexpected keyword argument"): - builder.build() - - -class TestMultiSignalModelObservationValidation: - """Test observation process validation in MultiSignalModel.""" - - def test_rejects_observation_without_infection_resolution(self): - """Test that observations must implement infection_resolution().""" - from pyrenew.observation.base import BaseObservationProcess - - class BadObservation(BaseObservationProcess): - """Observation that raises NotImplementedError for infection_resolution.""" - - def __init__(self): - """Initialize without temporal_pmf_rv.""" - self.name = "bad" - self.temporal_pmf_rv = None - - def sample(self, **kwargs): - """Sample stub.""" - pass - - def validate(self): - """Validate stub.""" - pass - - def lookback_days(self): - """ - Return lookback. - - Returns - ------- - int - The lookback value of 1. - """ - return 1 - - def infection_resolution(self): - """ - Return an invalid resolution to simulate bad implementation. - - Returns - ------- - str - An invalid resolution string. - """ - return "invalid_resolution" - - def _predicted_obs(self, infections): - """ - Predicted obs stub. - - Returns - ------- - ArrayLike - The infections array unchanged. - """ - return infections - - builder = PyrenewBuilder() - gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) - - builder.configure_latent( - HierarchicalInfections, - gen_int_rv=gen_int, - I0_rv=DeterministicVariable("I0", 0.001), - initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), - baseline_rt_process=RandomWalk(), - subpop_rt_deviation_process=RandomWalk(), - ) - - bad_obs = BadObservation() - builder.add_observation(bad_obs) - - with pytest.raises(ValueError, match="invalid infection_resolution"): - builder.build() - - if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_temporal_processes.py b/test/test_temporal_processes.py index 12d75e2e..7248518e 100644 --- a/test/test_temporal_processes.py +++ b/test/test_temporal_processes.py @@ -1,5 +1,5 @@ """ -Unit tests for temporal processes innovation_sd behavior. +Unit tests for temporal processes. """ import jax.numpy as jnp @@ -9,138 +9,27 @@ from pyrenew.latent import AR1, DifferencedAR1, RandomWalk -class TestTemporalProcessRepr: - """Test __repr__ methods for temporal processes.""" +class TestTemporalProcessVectorizedSampling: + """Test vectorized sampling across all temporal process types.""" - def test_ar1_repr(self): - """Test AR1 __repr__ method.""" - ar1 = AR1(autoreg=0.7, innovation_sd=0.5) - repr_str = repr(ar1) - assert "AR1" in repr_str - assert "autoreg=0.7" in repr_str - assert "innovation_sd=0.5" in repr_str - - def test_differenced_ar1_repr(self): - """Test DifferencedAR1 __repr__ method.""" - dar1 = DifferencedAR1(autoreg=0.6, innovation_sd=0.3) - repr_str = repr(dar1) - assert "DifferencedAR1" in repr_str - assert "autoreg=0.6" in repr_str - assert "innovation_sd=0.3" in repr_str - - def test_random_walk_repr(self): - """Test RandomWalk __repr__ method.""" - rw = RandomWalk(innovation_sd=0.2) - repr_str = repr(rw) - assert "RandomWalk" in repr_str - assert "innovation_sd=0.2" in repr_str - - -class TestAR1VectorizedSampling: - """Test AR1 vectorized sampling.""" - - def test_ar1_vectorized_sample_shape(self): - """Test that AR1 vectorized sample returns correct shape.""" - n_timepoints = 30 - n_processes = 5 - - ar1 = AR1(autoreg=0.7, innovation_sd=0.3) - - with numpyro.handlers.seed(rng_seed=42): - trajectories = ar1.sample( - n_timepoints=n_timepoints, - n_processes=n_processes, - ) - - assert trajectories.shape == (n_timepoints, n_processes) - - def test_ar1_vectorized_with_initial_values_array(self): - """Test AR1 vectorized with array of initial values.""" - n_timepoints = 30 - n_processes = 4 - initial_values = jnp.array([0.0, 1.0, -1.0, 2.0]) - - ar1 = AR1(autoreg=0.7, innovation_sd=0.3) - - with numpyro.handlers.seed(rng_seed=42): - trajectories = ar1.sample( - n_timepoints=n_timepoints, - n_processes=n_processes, - initial_value=initial_values, - ) - - assert trajectories.shape == (n_timepoints, n_processes) - - def test_ar1_vectorized_with_scalar_initial_value(self): - """Test AR1 vectorized with scalar initial value (broadcast).""" - n_timepoints = 30 - n_processes = 3 - - ar1 = AR1(autoreg=0.7, innovation_sd=0.3) - - with numpyro.handlers.seed(rng_seed=42): - trajectories = ar1.sample( - n_timepoints=n_timepoints, - n_processes=n_processes, - initial_value=1.0, - ) - - assert trajectories.shape == (n_timepoints, n_processes) - - -class TestDifferencedAR1Sampling: - """Test DifferencedAR1 sampling methods.""" - - def test_differenced_ar1_single_sample_shape(self): - """Test that DifferencedAR1 single sample returns correct shape.""" - n_timepoints = 30 - - dar1 = DifferencedAR1(autoreg=0.6, innovation_sd=0.3) - - with numpyro.handlers.seed(rng_seed=42): - trajectory = dar1.sample(n_timepoints=n_timepoints) - - assert trajectory.shape == (n_timepoints, 1) - - def test_differenced_ar1_single_with_initial_value(self): - """Test DifferencedAR1 single sample with initial value.""" - n_timepoints = 30 - - dar1 = DifferencedAR1(autoreg=0.6, innovation_sd=0.3) - - with numpyro.handlers.seed(rng_seed=42): - trajectory = dar1.sample( - n_timepoints=n_timepoints, - initial_value=1.5, - ) - - assert trajectory.shape == (n_timepoints, 1) - - def test_differenced_ar1_vectorized_sample_shape(self): - """Test that DifferencedAR1 vectorized sample returns correct shape.""" - n_timepoints = 30 - n_processes = 5 - - dar1 = DifferencedAR1(autoreg=0.6, innovation_sd=0.3) - - with numpyro.handlers.seed(rng_seed=42): - trajectories = dar1.sample( - n_timepoints=n_timepoints, - n_processes=n_processes, - ) - - assert trajectories.shape == (n_timepoints, n_processes) - - def test_differenced_ar1_vectorized_with_initial_values_array(self): - """Test DifferencedAR1 vectorized with array of initial values.""" + @pytest.mark.parametrize( + "process_cls,kwargs", + [ + (AR1, {"autoreg": 0.9, "innovation_sd": 0.05}), + (DifferencedAR1, {"autoreg": 0.9, "innovation_sd": 0.05}), + (RandomWalk, {"innovation_sd": 0.05}), + ], + ) + def test_vectorized_shape_and_initial_values_array(self, process_cls, kwargs): + """Test shape and initial value handling with array initial values.""" n_timepoints = 30 n_processes = 4 initial_values = jnp.array([0.0, 1.0, -1.0, 2.0]) - dar1 = DifferencedAR1(autoreg=0.6, innovation_sd=0.3) + process = process_cls(**kwargs) with numpyro.handlers.seed(rng_seed=42): - trajectories = dar1.sample( + trajectories = process.sample( n_timepoints=n_timepoints, n_processes=n_processes, initial_value=initial_values, @@ -148,15 +37,23 @@ def test_differenced_ar1_vectorized_with_initial_values_array(self): assert trajectories.shape == (n_timepoints, n_processes) - def test_differenced_ar1_vectorized_with_scalar_initial_value(self): - """Test DifferencedAR1 vectorized with scalar initial value.""" + @pytest.mark.parametrize( + "process_cls,kwargs", + [ + (AR1, {"autoreg": 0.9, "innovation_sd": 0.05}), + (DifferencedAR1, {"autoreg": 0.9, "innovation_sd": 0.05}), + (RandomWalk, {"innovation_sd": 0.05}), + ], + ) + def test_vectorized_shape_with_scalar_initial_value(self, process_cls, kwargs): + """Test shape with scalar initial value broadcast.""" n_timepoints = 30 n_processes = 3 - dar1 = DifferencedAR1(autoreg=0.6, innovation_sd=0.3) + process = process_cls(**kwargs) with numpyro.handlers.seed(rng_seed=42): - trajectories = dar1.sample( + trajectories = process.sample( n_timepoints=n_timepoints, n_processes=n_processes, initial_value=1.0, @@ -165,11 +62,11 @@ def test_differenced_ar1_vectorized_with_scalar_initial_value(self): assert trajectories.shape == (n_timepoints, n_processes) -class TestRandomWalkVectorizedSampling: - """Test RandomWalk vectorized sampling.""" +class TestRandomWalkInitialValues: + """Test that RandomWalk preserves initial values.""" def test_random_walk_vectorized_with_initial_values_array(self): - """Test RandomWalk vectorized with array of initial values.""" + """Test RandomWalk first row equals initial values.""" n_timepoints = 30 n_processes = 4 initial_values = jnp.array([0.0, 1.0, -1.0, 2.0]) @@ -184,11 +81,10 @@ def test_random_walk_vectorized_with_initial_values_array(self): ) assert trajectories.shape == (n_timepoints, n_processes) - # First row should be close to initial values (may vary by implementation) assert jnp.allclose(trajectories[0, :], initial_values) def test_random_walk_vectorized_with_scalar_initial_value(self): - """Test RandomWalk vectorized with scalar initial value.""" + """Test RandomWalk first row equals broadcast scalar.""" n_timepoints = 30 n_processes = 3 @@ -202,7 +98,6 @@ def test_random_walk_vectorized_with_scalar_initial_value(self): ) assert trajectories.shape == (n_timepoints, n_processes) - # First row should be all 1.0 assert jnp.allclose(trajectories[0, :], 1.0) @@ -223,7 +118,6 @@ def test_random_walk_smaller_innovation_sd_produces_smoother_trajectory( rw_large = RandomWalk(innovation_sd=1.0) trajectory_large = rw_large.sample(n_timepoints=n_timepoints) - # Smaller innovation_sd should produce smaller step sizes steps_small = jnp.abs(jnp.diff(trajectory_small[:, 0])) steps_large = jnp.abs(jnp.diff(trajectory_large[:, 0])) @@ -243,7 +137,6 @@ def test_ar1_smaller_innovation_sd_produces_lower_variance(self): ar_large = AR1(autoreg=autoreg, innovation_sd=1.0) trajectory_large = ar_large.sample(n_timepoints=n_timepoints) - # After burn-in, smaller innovation_sd should have lower variance burn_in = 20 var_small = jnp.var(trajectory_small[burn_in:, 0]) var_large = jnp.var(trajectory_large[burn_in:, 0]) @@ -265,7 +158,6 @@ def test_differenced_ar1_smaller_innovation_sd_produces_smoother_changes( dar_large = DifferencedAR1(autoreg=autoreg, innovation_sd=0.8) trajectory_large = dar_large.sample(n_timepoints=n_timepoints) - # Growth rates (differences) should have lower variance diffs_small = jnp.diff(trajectory_small[:, 0]) diffs_large = jnp.diff(trajectory_large[:, 0]) @@ -288,7 +180,6 @@ def test_vectorized_sampling_respects_innovation_sd(self): n_timepoints=n_timepoints, n_processes=n_processes ) - # Check that smaller innovation_sd produces smaller step sizes across all processes steps_small = jnp.abs(jnp.diff(trajs_small, axis=0)) steps_large = jnp.abs(jnp.diff(trajs_large, axis=0)) @@ -307,3 +198,43 @@ def test_validation_rejects_non_positive_innovation_sd(self): with pytest.raises(ValueError, match="innovation_sd must be positive"): DifferencedAR1(autoreg=0.5, innovation_sd=0.0) + + +class TestTemporalProcessBehavior: + """Test behavioral properties of temporal processes.""" + + def test_ar1_mean_reversion(self): + """Test that AR1 reverts toward zero from a displaced initial value.""" + ar1 = AR1(autoreg=0.95, innovation_sd=0.05) + + with numpyro.handlers.seed(rng_seed=42): + trajectory = ar1.sample( + n_timepoints=100, + initial_value=2.0, + ) + + # The trajectory mean should be closer to 0 than the initial value + # due to mean-reverting dynamics + trajectory_mean = jnp.mean(trajectory[:, 0]) + assert jnp.abs(trajectory_mean) < jnp.abs(2.0) + + def test_differenced_ar1_trend_persistence(self): + """Test that DifferencedAR1 produces persistent trends.""" + dar1 = DifferencedAR1(autoreg=0.95, innovation_sd=0.01) + + with numpyro.handlers.seed(rng_seed=42): + trajectory = dar1.sample( + n_timepoints=50, + initial_value=0.1, + ) + + # With positive initial rate and high autoreg, the differences + # (growth rates) should remain predominantly positive, + # producing a persistent upward trend + diffs = jnp.diff(trajectory[:, 0]) + fraction_positive = jnp.mean(diffs > 0) + assert fraction_positive > 0.5 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From ca97e0b6788a3161a29f788d362edeec8ebde989 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 10 Feb 2026 15:42:30 -0500 Subject: [PATCH 2/4] more unit tests --- test/test_interface_coverage.py | 276 ++++++++++++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 test/test_interface_coverage.py diff --git a/test/test_interface_coverage.py b/test/test_interface_coverage.py new file mode 100644 index 00000000..7f451288 --- /dev/null +++ b/test/test_interface_coverage.py @@ -0,0 +1,276 @@ +""" +Interface contract tests for coverage recovery. + +These parametrized tests exercise __repr__, validate(), infection_resolution(), +and get_required_lookback() across all classes that implement them, ensuring +the interface contracts are covered without per-class boilerplate tests. +""" + +import jax.numpy as jnp +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.latent import ( + AR1, + DifferencedAR1, + GammaGroupSdPrior, + HierarchicalInfections, + HierarchicalNormalPrior, + RandomWalk, + StudentTGroupModePrior, +) +from pyrenew.observation import ( + Counts, + CountsBySubpop, + HierarchicalNormalNoise, + NegativeBinomialNoise, + PoissonNoise, + VectorizedRV, +) +from pyrenew.randomvariable import DistributionalVariable +from test.test_helpers import ConcreteMeasurements + +# ============================================================================= +# Shared instance builders +# ============================================================================= + + +def _make_counts(): + """ + Build a Counts instance. + + Returns + ------- + instantiated object + """ + return Counts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + + +def _make_counts_by_subpop(): + """ + Build a CountsBySubpop instance. + + Returns + ------- + instantiated object + """ + return CountsBySubpop( + name="test_subpop", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), + noise=PoissonNoise(), + ) + + +def _make_measurements(): + """ + Build a ConcreteMeasurements instance. + + Returns + ------- + instantiated object + """ + sensor_mode_rv = VectorizedRV( + DistributionalVariable("mode", dist.Normal(0, 0.5)), + plate_name="sensor_mode", + ) + sensor_sd_rv = VectorizedRV( + DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.1)), + plate_name="sensor_sd", + ) + return ConcreteMeasurements( + name="test_ww", + temporal_pmf_rv=DeterministicPMF("shed", jnp.array([0.3, 0.4, 0.3])), + noise=HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv), + ) + + +def _make_hierarchical_normal_noise(): + """ + Build a HierarchicalNormalNoise instance. + + Returns + ------- + instantiated object + """ + sensor_mode_rv = VectorizedRV( + DistributionalVariable("mode", dist.Normal(0, 0.5)), + plate_name="sensor_mode", + ) + sensor_sd_rv = VectorizedRV( + DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.1)), + plate_name="sensor_sd", + ) + return HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + +# ============================================================================= +# __repr__ coverage +# ============================================================================= + + +@pytest.mark.parametrize( + "instance", + [ + pytest.param(AR1(autoreg=0.9, innovation_sd=0.1), id="AR1"), + pytest.param( + DifferencedAR1(autoreg=0.8, innovation_sd=0.2), id="DifferencedAR1" + ), + pytest.param(RandomWalk(innovation_sd=0.5), id="RandomWalk"), + pytest.param(_make_counts(), id="Counts"), + pytest.param(_make_counts_by_subpop(), id="CountsBySubpop"), + pytest.param(_make_measurements(), id="Measurements"), + pytest.param(PoissonNoise(), id="PoissonNoise"), + pytest.param( + NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + id="NegativeBinomialNoise", + ), + pytest.param(_make_hierarchical_normal_noise(), id="HierarchicalNormalNoise"), + ], +) +def test_repr_returns_nonempty_string(instance): + """All classes with __repr__ return a non-empty string containing the class name.""" + result = repr(instance) + assert isinstance(result, str) + assert len(result) > 0 + assert type(instance).__name__ in result + + +# ============================================================================= +# validate() coverage (no-op and real) +# ============================================================================= + + +@pytest.mark.parametrize( + "instance", + [ + pytest.param( + HierarchicalNormalPrior( + name="test", sd_rv=DeterministicVariable("sd", 1.0) + ), + id="HierarchicalNormalPrior", + ), + pytest.param( + GammaGroupSdPrior( + name="test", + sd_mean_rv=DeterministicVariable("mean", 0.5), + sd_concentration_rv=DeterministicVariable("conc", 10.0), + ), + id="GammaGroupSdPrior", + ), + pytest.param( + StudentTGroupModePrior( + name="test", + sd_rv=DeterministicVariable("sd", 1.0), + df_rv=DeterministicVariable("df", 5.0), + ), + id="StudentTGroupModePrior", + ), + pytest.param(PoissonNoise(), id="PoissonNoise"), + pytest.param(_make_hierarchical_normal_noise(), id="HierarchicalNormalNoise"), + pytest.param(_make_counts(), id="Counts"), + ], +) +def test_validate_does_not_raise(instance): + """validate() completes without error on well-formed instances.""" + instance.validate() + + +# ============================================================================= +# infection_resolution() coverage +# ============================================================================= + + +def test_counts_by_subpop_infection_resolution(): + """CountsBySubpop.infection_resolution() returns 'subpop'.""" + counts = _make_counts_by_subpop() + assert counts.infection_resolution() == "subpop" + + +def test_measurements_infection_resolution(): + """ConcreteMeasurements.infection_resolution() returns 'subpop'.""" + m = _make_measurements() + assert m.infection_resolution() == "subpop" + + +def test_base_count_observation_infection_resolution_raises(): + """Base _CountBase.infection_resolution() raises NotImplementedError.""" + from pyrenew.observation.count_observations import _CountBase + + class _MinimalCounts(_CountBase): + """Minimal subclass that inherits infection_resolution unchanged.""" + + def sample(self, *args, **kwargs): # numpydoc ignore=GL08 + pass + + obs = _MinimalCounts( + name="test_base", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), + noise=PoissonNoise(), + ) + with pytest.raises(NotImplementedError): + obs.infection_resolution() + + +# ============================================================================= +# get_required_lookback() coverage +# ============================================================================= + + +def test_get_required_lookback(gen_int_rv): + """get_required_lookback returns generation interval PMF length.""" + infections = HierarchicalInfections( + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), + n_initialization_points=7, + ) + expected_length = len(gen_int_rv()) + assert infections.get_required_lookback() == expected_length + + +# ============================================================================= +# HierarchicalInfections.validate() coverage +# ============================================================================= + + +def test_hierarchical_infections_validate(gen_int_rv): + """HierarchicalInfections.validate() runs without error on valid PMF.""" + infections = HierarchicalInfections( + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + subpop_rt_deviation_process=RandomWalk(innovation_sd=0.025), + n_initialization_points=7, + ) + infections.validate() + + +# ============================================================================= +# MultiSignalModel._validate_observation_resolutions error path +# ============================================================================= + + +def test_multisignal_model_rejects_invalid_resolution(): + """MultiSignalModel rejects observation with invalid infection_resolution.""" + from pyrenew.model.multisignal_model import MultiSignalModel + + class BadObservation: # numpydoc ignore=GL08 + def infection_resolution(self): # numpydoc ignore=GL08 + return "invalid" + + with pytest.raises(ValueError, match="invalid infection_resolution"): + MultiSignalModel( + latent_process=None, + observations={"bad": BadObservation()}, + ) From f02efb0fd9b0435f3d968100a3a9332555c5eed6 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 11 Feb 2026 12:34:06 -0500 Subject: [PATCH 3/4] improve test coverage --- test/test_helpers.py | 11 ----------- test/test_interface_coverage.py | 3 +++ 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index f8748d39..d3c026d3 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -24,17 +24,6 @@ def validate(self) -> None: pmf = self.temporal_pmf_rv() self._validate_pmf(pmf, "temporal_pmf_rv") - def lookback_days(self) -> int: - """ - Return required lookback days for this observation. - - Returns - ------- - int - Length of temporal PMF minus 1. - """ - return len(self.temporal_pmf_rv()) - 1 - def _predicted_obs(self, infections): """ Simple predicted signal: log(convolution * scale). diff --git a/test/test_interface_coverage.py b/test/test_interface_coverage.py index 7f451288..0086a974 100644 --- a/test/test_interface_coverage.py +++ b/test/test_interface_coverage.py @@ -209,6 +209,9 @@ class _MinimalCounts(_CountBase): def sample(self, *args, **kwargs): # numpydoc ignore=GL08 pass + def validate_data(self, n_total, n_subpops, **obs_data): # numpydoc ignore=GL08 + pass + obs = _MinimalCounts( name="test_base", ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), From 4f5534aa74eb4989a02edc5a84899f51f47b3e86 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 11 Feb 2026 14:01:21 -0500 Subject: [PATCH 4/4] changes by bot per code review by another bot - see PR conversation --- test/conftest.py | 39 ------------------- test/test_infection_initialization_process.py | 21 ++++++++-- test/test_latent_infections.py | 6 --- test/test_observation_counts.py | 2 + test/test_observation_negativebinom.py | 6 ++- test/test_pyrenew_builder.py | 33 +++++++++++++++- 6 files changed, 57 insertions(+), 50 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 22318c62..3c72055b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -50,19 +50,6 @@ def short_delay_pmf(): return jnp.array([0.5, 0.5]) -@pytest.fixture -def realistic_delay_pmf(): - """ - Realistic 10-day delay PMF (shifted gamma-like). - - Returns - ------- - jnp.ndarray - A 10-element PMF array with gamma-like shape. - """ - return jnp.array([0.01, 0.05, 0.10, 0.15, 0.20, 0.20, 0.15, 0.08, 0.04, 0.02]) - - @pytest.fixture def long_delay_pmf(): """ @@ -251,29 +238,3 @@ def counts_factory(): # ============================================================================= # Infection Fixtures # ============================================================================= - - -@pytest.fixture -def constant_infections(): - """ - Constant infections array (30 days, 100 infections/day). - - Returns - ------- - jnp.ndarray - A 1D array of shape (30,) with constant value 100. - """ - return jnp.ones(30) * 100 - - -@pytest.fixture -def constant_infections_2d(): - """ - Constant infections array for 2 subpopulations. - - Returns - ------- - jnp.ndarray - A 2D array of shape (30, 2) with constant value 100. - """ - return jnp.ones((30, 2)) * 100 diff --git a/test/test_infection_initialization_process.py b/test/test_infection_initialization_process.py index fcf94708..d017eae4 100644 --- a/test/test_infection_initialization_process.py +++ b/test/test_infection_initialization_process.py @@ -38,9 +38,24 @@ def test_infection_initialization_process(): InitializeInfectionsFromVec(n_timepoints), ) - for model in [zero_pad_model, exp_model, vec_model]: - with numpyro.handlers.seed(rng_seed=1): - model() + with numpyro.handlers.seed(rng_seed=1): + zero_pad_result = zero_pad_model() + exp_result = exp_model() + vec_result = vec_model() + + # All results should have shape (n_timepoints,) + assert zero_pad_result.shape == (n_timepoints,) + assert exp_result.shape == (n_timepoints,) + assert vec_result.shape == (n_timepoints,) + + # Zero-pad: all but last element should be zero + assert jnp.all(zero_pad_result[:-1] == 0) + + # Exponential growth: all values should be positive (LogNormal I0) + assert jnp.all(exp_result > 0) + + # Vec (identity passthrough): should equal jnp.arange(n_timepoints) + assert jnp.array_equal(vec_result, jnp.arange(n_timepoints)) # Check that the InfectionInitializationProcess class raises an error when the wrong type of I0 is passed with pytest.raises(TypeError): diff --git a/test/test_latent_infections.py b/test/test_latent_infections.py index 0c33f597..2f26d94c 100755 --- a/test/test_latent_infections.py +++ b/test/test_latent_infections.py @@ -30,12 +30,6 @@ def test_infections_as_deterministic(): gen_int=gen_int, ) with numpyro.handlers.seed(rng_seed=223): - Infections()( - Rt=sim_rt, - I0=jnp.zeros(gen_int.size), - gen_int=gen_int, - ) - inf_sampled1 = inf1(**obs) inf_sampled2 = inf1(**obs) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 4ae3e0b1..831a3f6f 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -266,6 +266,8 @@ def test_dense_observations_with_nan_padding(self, counts_process): assert result.observed.shape[0] == n_days assert result.predicted.shape[0] == n_days + # Predicted values should be non-NaN (predictions exist for all days) + assert jnp.all(~jnp.isnan(result.predicted)) class TestCountsBySubpop: diff --git a/test/test_observation_negativebinom.py b/test/test_observation_negativebinom.py index 7c9ef519..b2546d2e 100644 --- a/test/test_observation_negativebinom.py +++ b/test/test_observation_negativebinom.py @@ -19,7 +19,7 @@ def test_negativebinom_deterministic_obs(): concentration_rv=DeterministicVariable(name="concentration", value=10), ) - rates = np.random.randint(1, 5, size=10) + rates = np.array([3, 1, 4, 2, 3, 1, 4, 2, 3, 1]) with numpyro.handlers.seed(rng_seed=223): sim_nb1 = negb(mu=rates, obs=rates) sim_nb2 = negb(mu=rates, obs=rates) @@ -56,3 +56,7 @@ def test_negativebinom_random_obs(): np.mean(sim_nb2), decimal=1, ) + + # Sample mean should be close to the expected rate (5.0) + testing.assert_almost_equal(np.mean(sim_nb1), 5.0, decimal=0) + testing.assert_almost_equal(np.mean(sim_nb2), 5.0, decimal=0) diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index 08ed736e..885cff83 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -344,7 +344,7 @@ def test_validate_data_rejects_mismatched_obs_times_length( ) def test_validate_method_calls_internal_validate(self, simple_builder): - """Test that validate() method calls _validate().""" + """Test that validate() succeeds on a valid model.""" model = simple_builder.build() # Should not raise model.validate() @@ -394,5 +394,36 @@ def test_validate_data_rejects_wrong_length_dense_obs(self, validation_builder): ) +class TestMultiSignalModelHelpers: + """Test MultiSignalModel helper methods.""" + + def test_pad_observations_prepends_nans(self, simple_builder): + """Test that pad_observations prepends correct NaN padding.""" + model = simple_builder.build() + n_init = model.latent.n_initialization_points + + obs = jnp.array([10, 20, 30]) + padded = model.pad_observations(obs) + + # Shape should include initialization period + assert padded.shape == (n_init + 3,) + # First n_init values should be NaN + assert jnp.all(jnp.isnan(padded[:n_init])) + # Remaining values should match input + assert jnp.array_equal(padded[n_init:], jnp.array([10.0, 20.0, 30.0])) + # Integer input should be converted to float + assert jnp.issubdtype(padded.dtype, jnp.floating) + + def test_shift_times_adds_offset(self, simple_builder): + """Test that shift_times shifts by n_initialization_points.""" + model = simple_builder.build() + n_init = model.latent.n_initialization_points + + times = jnp.array([0, 5, 10]) + shifted = model.shift_times(times) + + assert jnp.array_equal(shifted, times + n_init) + + if __name__ == "__main__": pytest.main([__file__, "-v"])