diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 47c9cd7d..d9aad03d 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -202,52 +202,11 @@ def validate_data( f"Available: {list(self.observations.keys())}" ) - obs = obs_data.get("obs") - times = obs_data.get("times") - - if times is not None: - # Sparse observations: times on shared axis [0, n_total) - times = jnp.asarray(times) - if jnp.any(times < 0): - raise ValueError(f"Observation '{name}': times cannot be negative") - max_time = jnp.max(times) - if max_time >= n_total: - raise ValueError( - f"Observation '{name}': times index {int(max_time)} " - f">= n_total ({n_total} = {n_init} init + " - f"{n_days_post_init} days). " - f"Times must be on shared axis [0, {n_total})." - ) - if obs is not None and len(obs) != len(times): - raise ValueError( - f"Observation '{name}': obs length {len(obs)} " - f"must match times length {len(times)}" - ) - elif obs is not None: - # Dense observations: length must equal n_total - obs = jnp.asarray(obs) - if obs.shape[0] != n_total: - raise ValueError( - f"Observation '{name}': obs length {obs.shape[0]} " - f"must equal n_total ({n_total} = {n_init} init + " - f"{n_days_post_init} days). " - f"Pad with NaN for initialization period." - ) - - # Validate subpop_indices if present - subpop_indices = obs_data.get("subpop_indices") - if subpop_indices is not None: - subpop_indices = jnp.asarray(subpop_indices) - if jnp.any(subpop_indices < 0): - raise ValueError( - f"Observation '{name}': subpop_indices cannot be negative" - ) - max_idx = jnp.max(subpop_indices) - if max_idx >= pop.n_subpops: - raise ValueError( - f"Observation '{name}': subpop_indices contains " - f"{int(max_idx)} >= {pop.n_subpops} (n_subpops)" - ) + self.observations[name].validate_data( + n_total=n_total, + n_subpops=pop.n_subpops, + **obs_data, + ) def sample( self, diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 764aeb66..3eeb2c63 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -312,6 +312,165 @@ def _predicted_obs( """ pass # pragma: no cover + @abstractmethod + def validate_data( + self, + n_total: int, + n_subpops: int, + **obs_data, + ) -> None: + """ + Validate observation data before running inference. + + Each observation process validates its own data requirements. + Called by the model's ``validate_data()`` method with concrete + (non-traced) values before JAX tracing begins. + + Parameters + ---------- + n_total : int + Total number of time steps (n_init + n_days_post_init). + n_subpops : int + Number of subpopulations. + **obs_data + Observation-specific data kwargs (same as passed to ``sample()``, + minus ``infections`` which comes from the latent process). + + Raises + ------ + ValueError + If any data fails validation. + """ + pass # pragma: no cover + + def _validate_index_array( + self, indices: ArrayLike, upper_bound: int, param_name: str + ) -> None: + """ + Validate an index array has non-negative values within bounds. + + Checks that all values are non-negative integers in ``[0, upper_bound)``. + + Parameters + ---------- + indices : ArrayLike + Index array to validate. + upper_bound : int + Exclusive upper bound for valid indices. + param_name : str + Name of the parameter (for error messages). + + Raises + ------ + ValueError + If indices contains negative values or values >= upper_bound. + """ + indices = jnp.asarray(indices) + if jnp.any(indices < 0): + raise ValueError( + f"Observation '{self.name}': {param_name} cannot be negative" + ) + max_val = jnp.max(indices) + if max_val >= upper_bound: + raise ValueError( + f"Observation '{self.name}': {param_name} contains " + f"{int(max_val)} >= {upper_bound} ({param_name} upper bound)" + ) + + def _validate_times(self, times: ArrayLike, n_total: int) -> None: + """ + Validate a times index array. + + Checks that all values are non-negative and within ``[0, n_total)``. + + Parameters + ---------- + times : ArrayLike + Time indices on the shared time axis. + n_total : int + Total number of time steps. + + Raises + ------ + ValueError + If times contains negative values or values >= n_total. + """ + self._validate_index_array(times, n_total, "times") + + def _validate_subpop_indices( + self, subpop_indices: ArrayLike, n_subpops: int + ) -> None: + """ + Validate a subpopulation index array. + + Checks that all values are non-negative and within ``[0, n_subpops)``. + + Parameters + ---------- + subpop_indices : ArrayLike + Subpopulation indices (0-indexed). + n_subpops : int + Number of subpopulations. + + Raises + ------ + ValueError + If subpop_indices contains negative values or values >= n_subpops. + """ + self._validate_index_array(subpop_indices, n_subpops, "subpop_indices") + + def _validate_obs_times_shape(self, obs: ArrayLike, times: ArrayLike) -> None: + """ + Validate that obs and times arrays have matching shapes. + + Parameters + ---------- + obs : ArrayLike + Observed data array. + times : ArrayLike + Times index array. + + Raises + ------ + ValueError + If obs and times have different shapes. + """ + obs = jnp.asarray(obs) + times = jnp.asarray(times) + if obs.shape != times.shape: + raise ValueError( + f"Observation '{self.name}': obs shape {obs.shape} " + f"must match times shape {times.shape}" + ) + + def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: + """ + Validate that obs covers the full shared time axis. + + For dense observations on the shared time axis ``[0, n_total)``, + obs must have length equal to ``n_total``. Use NaN to mark + unobserved timepoints (initialization period or missing data). + + Parameters + ---------- + obs : ArrayLike + Observed data array on the shared time axis. + n_total : int + Total number of time steps (n_init + n_days_post_init). + + Raises + ------ + ValueError + If obs length doesn't equal n_total. + """ + obs = jnp.asarray(obs) + if obs.shape[0] != n_total: + raise ValueError( + f"Observation '{self.name}': obs length {obs.shape[0]} " + f"must equal n_total ({n_total}). " + f"Pad with NaN for initialization period." + ) + @abstractmethod def sample( self, diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 80912143..bcbcaf99 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -175,6 +175,35 @@ def __repr__(self) -> str: f"noise={self.noise!r})" ) + def validate_data( + self, + n_total: int, + n_subpops: int, + obs: ArrayLike | None = None, + **kwargs, + ) -> None: + """ + Validate aggregated count observation data. + + Parameters + ---------- + n_total : int + Total number of time steps (n_init + n_days_post_init). + n_subpops : int + Number of subpopulations (unused for aggregate observations). + obs : ArrayLike | None + Observed counts on shared time axis. Shape: (n_total,). + **kwargs + Additional keyword arguments (ignored). + + Raises + ------ + ValueError + If obs length doesn't match n_total. + """ + if obs is not None: + self._validate_obs_dense(obs, n_total) + def sample( self, infections: ArrayLike, @@ -275,6 +304,46 @@ def infection_resolution(self) -> str: """ return "subpop" + def validate_data( + self, + n_total: int, + n_subpops: int, + times: ArrayLike | None = None, + subpop_indices: ArrayLike | None = None, + obs: ArrayLike | None = None, + **kwargs, + ) -> None: + """ + Validate subpopulation-level count observation data. + + Parameters + ---------- + n_total : int + Total number of time steps (n_init + n_days_post_init). + n_subpops : int + Number of subpopulations. + times : ArrayLike | None + Day index for each observation on the shared time axis. + subpop_indices : ArrayLike | None + Subpopulation index for each observation (0-indexed). + obs : ArrayLike | None + Observed counts (n_obs,). + **kwargs + Additional keyword arguments (ignored). + + Raises + ------ + ValueError + If times or subpop_indices are out of bounds, or if + obs and times have mismatched lengths. + """ + if times is not None: + self._validate_times(times, n_total) + if obs is not None: + self._validate_obs_times_shape(obs, times) + if subpop_indices is not None: + self._validate_subpop_indices(subpop_indices, n_subpops) + def sample( self, infections: ArrayLike, diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py index ab1b8ff1..82e50003 100644 --- a/pyrenew/observation/measurements.py +++ b/pyrenew/observation/measurements.py @@ -104,6 +104,54 @@ def infection_resolution(self) -> str: """ return "subpop" + def validate_data( + self, + n_total: int, + n_subpops: int, + times: ArrayLike | None = None, + subpop_indices: ArrayLike | None = None, + sensor_indices: ArrayLike | None = None, + n_sensors: int | None = None, + obs: ArrayLike | None = None, + **kwargs, + ) -> None: + """ + Validate measurement observation data. + + Parameters + ---------- + n_total : int + Total number of time steps (n_init + n_days_post_init). + n_subpops : int + Number of subpopulations. + times : ArrayLike | None + Day index for each observation on the shared time axis. + subpop_indices : ArrayLike | None + Subpopulation index for each observation (0-indexed). + sensor_indices : ArrayLike | None + Sensor index for each observation (0-indexed). + n_sensors : int | None + Total number of measurement sensors. + obs : ArrayLike | None + Observed measurements (n_obs,). + **kwargs + Additional keyword arguments (ignored). + + Raises + ------ + ValueError + If times, subpop_indices, or sensor_indices are out of bounds, + or if obs and times have mismatched lengths. + """ + if times is not None: + self._validate_times(times, n_total) + if obs is not None: + self._validate_obs_times_shape(obs, times) + if subpop_indices is not None: + self._validate_subpop_indices(subpop_indices, n_subpops) + if sensor_indices is not None and n_sensors is not None: + self._validate_index_array(sensor_indices, n_sensors, "sensor_indices") + def sample( self, infections: ArrayLike, diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index e43b5272..4351721e 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -410,6 +410,10 @@ def sample(self, **kwargs): """Sample method stub.""" pass + def validate_data(self, n_total, n_subpops, **obs_data): + """Validate data stub.""" + pass + process = IncompleteCountProcess( name="test", ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), diff --git a/test/test_observation_validation.py b/test/test_observation_validation.py new file mode 100644 index 00000000..4363290f --- /dev/null +++ b/test/test_observation_validation.py @@ -0,0 +1,559 @@ +""" +Unit tests for observation data validation functions. + +Tests the refactored validation helpers on BaseObservationProcess +and the validate_data() methods on Counts, CountsBySubpop, and +Measurements. +""" + +import jax.numpy as jnp +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import ( + Counts, + CountsBySubpop, + HierarchicalNormalNoise, + PoissonNoise, + VectorizedRV, +) +from pyrenew.observation.measurements import Measurements +from pyrenew.randomvariable import DistributionalVariable + +# --------------------------------------------------------------------------- +# Helpers – minimal concrete subclass of Measurements for testing +# --------------------------------------------------------------------------- + + +class StubMeasurements(Measurements): + """Minimal concrete Measurements for testing validate_data().""" + + def validate(self) -> None: + """ + Validate parameters. + + Raises + ------ + ValueError + If PMF is invalid. + """ + pmf = self.temporal_pmf_rv() + self._validate_pmf(pmf, "temporal_pmf_rv") + + def _predicted_obs(self, infections): + """ + Return infections unchanged (identity transform). + + Parameters + ---------- + infections : ArrayLike + Input infections. + + Returns + ------- + ArrayLike + Same as input. + """ + return infections + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def counts_proc(): + """ + Counts process used to access base validation helpers. + + Returns + ------- + Counts + A Counts observation process. + """ + return Counts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([0.3, 0.5, 0.2])), + noise=PoissonNoise(), + ) + + +@pytest.fixture() +def subpop_proc(): + """ + CountsBySubpop process for validate_data() tests. + + Returns + ------- + CountsBySubpop + A CountsBySubpop observation process. + """ + return CountsBySubpop( + name="subpop_hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([0.4, 0.4, 0.2])), + noise=PoissonNoise(), + ) + + +@pytest.fixture() +def measurements_proc(): + """ + Measurements process for validate_data() tests. + + Returns + ------- + StubMeasurements + A StubMeasurements observation process. + """ + sensor_mode_rv = VectorizedRV( + DistributionalVariable("mode", dist.Normal(0, 0.5)), + plate_name="sensor_mode", + ) + sensor_sd_rv = VectorizedRV( + DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), + plate_name="sensor_sd", + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + return StubMeasurements( + name="ww", + temporal_pmf_rv=DeterministicPMF("shedding", jnp.array([0.3, 0.4, 0.3])), + noise=noise, + ) + + +# =================================================================== +# _validate_index_array +# =================================================================== + + +class TestValidateIndexArray: + """Tests for BaseObservationProcess._validate_index_array.""" + + def test_valid_indices(self, counts_proc): + """Valid indices within bounds should not raise.""" + counts_proc._validate_index_array( + jnp.array([0, 1, 2, 3]), upper_bound=5, param_name="test_idx" + ) + + def test_single_valid_index(self, counts_proc): + """A single valid index should not raise.""" + counts_proc._validate_index_array( + jnp.array([0]), upper_bound=1, param_name="test_idx" + ) + + def test_negative_index_raises(self, counts_proc): + """Negative indices should raise ValueError.""" + with pytest.raises(ValueError, match="cannot be negative"): + counts_proc._validate_index_array( + jnp.array([0, -1, 2]), + upper_bound=5, + param_name="my_param", + ) + + def test_index_at_upper_bound_raises(self, counts_proc): + """Index exactly equal to upper_bound should raise.""" + with pytest.raises(ValueError, match="upper bound"): + counts_proc._validate_index_array( + jnp.array([0, 1, 5]), + upper_bound=5, + param_name="my_param", + ) + + def test_index_above_upper_bound_raises(self, counts_proc): + """Index above upper_bound should raise.""" + with pytest.raises(ValueError, match="upper bound"): + counts_proc._validate_index_array( + jnp.array([10]), + upper_bound=5, + param_name="my_param", + ) + + def test_error_message_includes_name_and_value(self, counts_proc): + """Error message should include the observation name and param name.""" + with pytest.raises(ValueError, match="hosp.*my_param"): + counts_proc._validate_index_array( + jnp.array([-1]), + upper_bound=5, + param_name="my_param", + ) + + def test_non_contiguous_valid_indices(self, counts_proc): + """Non-contiguous but valid indices should not raise.""" + counts_proc._validate_index_array( + jnp.array([0, 3, 7, 9]), + upper_bound=10, + param_name="test_idx", + ) + + +# =================================================================== +# _validate_times +# =================================================================== + + +class TestValidateTimes: + """Tests for BaseObservationProcess._validate_times.""" + + def test_valid_times(self, counts_proc): + """Valid time indices should not raise.""" + counts_proc._validate_times(jnp.array([0, 5, 19]), n_total=20) + + def test_negative_time_raises(self, counts_proc): + """Negative time index should raise ValueError.""" + with pytest.raises(ValueError, match="cannot be negative"): + counts_proc._validate_times(jnp.array([0, -1]), n_total=20) + + def test_time_at_n_total_raises(self, counts_proc): + """Time index equal to n_total should raise (0-indexed).""" + with pytest.raises(ValueError, match="upper bound"): + counts_proc._validate_times(jnp.array([0, 20]), n_total=20) + + def test_time_above_n_total_raises(self, counts_proc): + """Time index above n_total should raise.""" + with pytest.raises(ValueError, match="upper bound"): + counts_proc._validate_times(jnp.array([50]), n_total=20) + + +# =================================================================== +# _validate_subpop_indices +# =================================================================== + + +class TestValidateSubpopIndices: + """Tests for BaseObservationProcess._validate_subpop_indices.""" + + def test_valid_subpop_indices(self, counts_proc): + """Valid subpop indices should not raise.""" + counts_proc._validate_subpop_indices(jnp.array([0, 1, 2]), n_subpops=3) + + def test_non_contiguous_subpop_indices(self, counts_proc): + """Non-contiguous but valid subpop indices should not raise.""" + counts_proc._validate_subpop_indices(jnp.array([0, 2, 4]), n_subpops=5) + + def test_negative_subpop_index_raises(self, counts_proc): + """Negative subpop index should raise ValueError.""" + with pytest.raises(ValueError, match="cannot be negative"): + counts_proc._validate_subpop_indices(jnp.array([0, -1]), n_subpops=3) + + def test_subpop_index_at_n_subpops_raises(self, counts_proc): + """Subpop index equal to n_subpops should raise.""" + with pytest.raises(ValueError, match="upper bound"): + counts_proc._validate_subpop_indices(jnp.array([0, 3]), n_subpops=3) + + def test_subpop_index_above_n_subpops_raises(self, counts_proc): + """Subpop index well above n_subpops should raise.""" + with pytest.raises(ValueError, match="upper bound"): + counts_proc._validate_subpop_indices(jnp.array([10]), n_subpops=3) + + +# =================================================================== +# _validate_obs_times_shape +# =================================================================== + + +class TestValidateObsTimesShape: + """Tests for BaseObservationProcess._validate_obs_times_shape.""" + + def test_matching_1d_shapes(self, counts_proc): + """Obs and times with matching 1D shapes should not raise.""" + obs = jnp.array([1.0, 2.0, 3.0]) + times = jnp.array([0, 5, 10]) + counts_proc._validate_obs_times_shape(obs, times) + + def test_mismatched_lengths_raises(self, counts_proc): + """Obs and times with different lengths should raise ValueError.""" + obs = jnp.array([1.0, 2.0, 3.0]) + times = jnp.array([0, 5]) + with pytest.raises(ValueError, match="must match times shape"): + counts_proc._validate_obs_times_shape(obs, times) + + def test_empty_arrays_match(self, counts_proc): + """Two empty arrays should not raise.""" + obs = jnp.array([]) + times = jnp.array([]) + counts_proc._validate_obs_times_shape(obs, times) + + def test_scalar_arrays_match(self, counts_proc): + """Single-element arrays should not raise.""" + obs = jnp.array([5.0]) + times = jnp.array([0]) + counts_proc._validate_obs_times_shape(obs, times) + + def test_error_includes_both_shapes(self, counts_proc): + """Error message should report both shapes.""" + obs = jnp.array([1.0, 2.0]) + times = jnp.array([0, 1, 2]) + with pytest.raises(ValueError, match=r"\(2,\).*\(3,\)"): + counts_proc._validate_obs_times_shape(obs, times) + + +# =================================================================== +# _validate_obs_dense +# =================================================================== + + +class TestValidateObsDense: + """Tests for BaseObservationProcess._validate_obs_dense.""" + + def test_correct_length(self, counts_proc): + """Obs with length equal to n_total should not raise.""" + obs = jnp.ones(30) + counts_proc._validate_obs_dense(obs, n_total=30) + + def test_obs_with_nan_correct_length(self, counts_proc): + """Obs with NaN padding but correct length should not raise.""" + obs = jnp.ones(30).at[:5].set(jnp.nan) + counts_proc._validate_obs_dense(obs, n_total=30) + + def test_obs_too_short_raises(self, counts_proc): + """Obs shorter than n_total should raise ValueError.""" + obs = jnp.ones(20) + with pytest.raises(ValueError, match="must equal n_total"): + counts_proc._validate_obs_dense(obs, n_total=30) + + def test_obs_too_long_raises(self, counts_proc): + """Obs longer than n_total should raise ValueError.""" + obs = jnp.ones(40) + with pytest.raises(ValueError, match="must equal n_total"): + counts_proc._validate_obs_dense(obs, n_total=30) + + def test_error_includes_lengths(self, counts_proc): + """Error message should include actual and expected lengths.""" + obs = jnp.ones(15) + with pytest.raises(ValueError, match="15.*30"): + counts_proc._validate_obs_dense(obs, n_total=30) + + +# =================================================================== +# Counts.validate_data() +# =================================================================== + + +class TestCountsValidateData: + """Tests for Counts.validate_data().""" + + def test_none_obs_passes(self, counts_proc): + """validate_data with obs=None should not raise.""" + counts_proc.validate_data(n_total=30, n_subpops=1, obs=None) + + def test_correct_obs_passes(self, counts_proc): + """validate_data with correctly shaped obs should not raise.""" + obs = jnp.ones(30) * 5.0 + counts_proc.validate_data(n_total=30, n_subpops=1, obs=obs) + + def test_nan_padded_obs_passes(self, counts_proc): + """validate_data with NaN-padded obs of correct length should not raise.""" + obs = jnp.ones(30).at[:3].set(jnp.nan) + counts_proc.validate_data(n_total=30, n_subpops=1, obs=obs) + + def test_wrong_length_obs_raises(self, counts_proc): + """validate_data with obs of wrong length should raise ValueError.""" + obs = jnp.ones(20) + with pytest.raises(ValueError, match="must equal n_total"): + counts_proc.validate_data(n_total=30, n_subpops=1, obs=obs) + + def test_extra_kwargs_ignored(self, counts_proc): + """validate_data should ignore extra keyword arguments.""" + counts_proc.validate_data( + n_total=30, n_subpops=1, obs=None, extra_param="ignored" + ) + + +# =================================================================== +# CountsBySubpop.validate_data() +# =================================================================== + + +class TestCountsBySubpopValidateData: + """Tests for CountsBySubpop.validate_data().""" + + def test_all_none_passes(self, subpop_proc): + """validate_data with all optional args None should not raise.""" + subpop_proc.validate_data(n_total=30, n_subpops=3) + + def test_valid_data_passes(self, subpop_proc): + """validate_data with valid times, subpop_indices, obs should not raise.""" + times = jnp.array([5, 10, 15, 20]) + subpop_indices = jnp.array([0, 1, 2, 0]) + obs = jnp.array([10.0, 20.0, 30.0, 15.0]) + subpop_proc.validate_data( + n_total=30, + n_subpops=3, + times=times, + subpop_indices=subpop_indices, + obs=obs, + ) + + def test_invalid_times_raises(self, subpop_proc): + """validate_data with out-of-bounds times should raise.""" + times = jnp.array([5, 30]) # 30 == n_total, out of bounds + with pytest.raises(ValueError, match="upper bound"): + subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + + def test_negative_times_raises(self, subpop_proc): + """validate_data with negative times should raise.""" + times = jnp.array([-1, 5]) + with pytest.raises(ValueError, match="cannot be negative"): + subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + + def test_invalid_subpop_indices_raises(self, subpop_proc): + """validate_data with out-of-bounds subpop_indices should raise.""" + subpop_indices = jnp.array([0, 3]) # 3 == n_subpops, out of bounds + with pytest.raises(ValueError, match="upper bound"): + subpop_proc.validate_data( + n_total=30, n_subpops=3, subpop_indices=subpop_indices + ) + + def test_negative_subpop_indices_raises(self, subpop_proc): + """validate_data with negative subpop_indices should raise.""" + subpop_indices = jnp.array([0, -1]) + with pytest.raises(ValueError, match="cannot be negative"): + subpop_proc.validate_data( + n_total=30, n_subpops=3, subpop_indices=subpop_indices + ) + + def test_mismatched_obs_times_raises(self, subpop_proc): + """validate_data with obs/times shape mismatch should raise.""" + times = jnp.array([5, 10, 15]) + obs = jnp.array([1.0, 2.0]) # length 2 != length 3 + with pytest.raises(ValueError, match="must match times shape"): + subpop_proc.validate_data(n_total=30, n_subpops=3, times=times, obs=obs) + + def test_obs_without_times_skips_shape_check(self, subpop_proc): + """validate_data with obs but no times should not check shapes.""" + obs = jnp.array([1.0, 2.0]) + # times is None, so shape check is skipped + subpop_proc.validate_data(n_total=30, n_subpops=3, obs=obs) + + def test_times_without_obs_skips_shape_check(self, subpop_proc): + """validate_data with times but no obs should validate times only.""" + times = jnp.array([5, 10, 15]) + subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + + def test_non_contiguous_subpop_indices_valid(self, subpop_proc): + """validate_data with non-contiguous but valid subpop_indices passes.""" + subpop_indices = jnp.array([0, 2]) # skip index 1 + subpop_proc.validate_data( + n_total=30, n_subpops=3, subpop_indices=subpop_indices + ) + + +# =================================================================== +# Measurements.validate_data() +# =================================================================== + + +class TestMeasurementsValidateData: + """Tests for Measurements.validate_data().""" + + def test_all_none_passes(self, measurements_proc): + """validate_data with all optional args None should not raise.""" + measurements_proc.validate_data(n_total=30, n_subpops=3) + + def test_valid_data_passes(self, measurements_proc): + """validate_data with fully valid data should not raise.""" + times = jnp.array([5, 10, 15, 20]) + subpop_indices = jnp.array([0, 1, 2, 0]) + sensor_indices = jnp.array([0, 1, 0, 1]) + obs = jnp.array([1.1, 2.2, 3.3, 1.5]) + measurements_proc.validate_data( + n_total=30, + n_subpops=3, + times=times, + subpop_indices=subpop_indices, + sensor_indices=sensor_indices, + n_sensors=2, + obs=obs, + ) + + def test_invalid_times_raises(self, measurements_proc): + """validate_data with out-of-bounds times should raise.""" + times = jnp.array([5, 30]) + with pytest.raises(ValueError, match="upper bound"): + measurements_proc.validate_data(n_total=30, n_subpops=3, times=times) + + def test_negative_times_raises(self, measurements_proc): + """validate_data with negative times should raise.""" + times = jnp.array([-1, 5]) + with pytest.raises(ValueError, match="cannot be negative"): + measurements_proc.validate_data(n_total=30, n_subpops=3, times=times) + + def test_invalid_subpop_indices_raises(self, measurements_proc): + """validate_data with out-of-bounds subpop_indices should raise.""" + subpop_indices = jnp.array([0, 5]) + with pytest.raises(ValueError, match="upper bound"): + measurements_proc.validate_data( + n_total=30, n_subpops=3, subpop_indices=subpop_indices + ) + + def test_invalid_sensor_indices_raises(self, measurements_proc): + """validate_data with out-of-bounds sensor_indices should raise.""" + sensor_indices = jnp.array([0, 4]) + with pytest.raises(ValueError, match="upper bound"): + measurements_proc.validate_data( + n_total=30, + n_subpops=3, + sensor_indices=sensor_indices, + n_sensors=3, + ) + + def test_negative_sensor_indices_raises(self, measurements_proc): + """validate_data with negative sensor_indices should raise.""" + sensor_indices = jnp.array([-1, 0]) + with pytest.raises(ValueError, match="cannot be negative"): + measurements_proc.validate_data( + n_total=30, + n_subpops=3, + sensor_indices=sensor_indices, + n_sensors=3, + ) + + def test_sensor_indices_without_n_sensors_skips(self, measurements_proc): + """validate_data with sensor_indices but no n_sensors skips check.""" + sensor_indices = jnp.array([0, 99]) # would be invalid with n_sensors + # n_sensors is None so sensor_indices validation is skipped + measurements_proc.validate_data( + n_total=30, + n_subpops=3, + sensor_indices=sensor_indices, + n_sensors=None, + ) + + def test_n_sensors_without_sensor_indices_skips(self, measurements_proc): + """validate_data with n_sensors but no sensor_indices skips check.""" + measurements_proc.validate_data( + n_total=30, + n_subpops=3, + sensor_indices=None, + n_sensors=5, + ) + + def test_mismatched_obs_times_raises(self, measurements_proc): + """validate_data with obs/times shape mismatch should raise.""" + times = jnp.array([5, 10, 15]) + obs = jnp.array([1.0, 2.0]) + with pytest.raises(ValueError, match="must match times shape"): + measurements_proc.validate_data( + n_total=30, n_subpops=3, times=times, obs=obs + ) + + def test_matching_obs_times_passes(self, measurements_proc): + """validate_data with matching obs/times shapes should not raise.""" + times = jnp.array([5, 10]) + obs = jnp.array([1.0, 2.0]) + measurements_proc.validate_data(n_total=30, n_subpops=3, times=times, obs=obs) + + def test_non_contiguous_subpop_indices_valid(self, measurements_proc): + """validate_data with non-contiguous but valid subpop_indices passes.""" + subpop_indices = jnp.array([0, 2]) + measurements_proc.validate_data( + n_total=30, n_subpops=3, subpop_indices=subpop_indices + ) + + def test_extra_kwargs_ignored(self, measurements_proc): + """validate_data should ignore extra keyword arguments.""" + measurements_proc.validate_data(n_total=30, n_subpops=3, foo="bar") diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index 5a947e30..59e9a6f4 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -9,7 +9,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HierarchicalInfections, RandomWalk from pyrenew.model import MultiSignalModel, PyrenewBuilder -from pyrenew.observation import Counts, NegativeBinomialNoise +from pyrenew.observation import Counts, CountsBySubpop, NegativeBinomialNoise # Standard population structure for tests (3 subpopulations) SUBPOP_FRACTIONS = jnp.array([0.3, 0.25, 0.45]) @@ -49,6 +49,53 @@ def simple_builder(): return builder +@pytest.fixture +def validation_builder(): + """ + Create a builder with both aggregate and subpop observations. + + Used for testing validate_data() delegation to different + observation types. + + Returns + ------- + PyrenewBuilder + Builder with Counts ("hospital") and CountsBySubpop + ("hospital_subpop") observations. + """ + builder = PyrenewBuilder() + gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) + + builder.configure_latent( + HierarchicalInfections, + gen_int_rv=gen_int, + I0_rv=DeterministicVariable("I0", 0.001), + initial_log_rt_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=RandomWalk(), + subpop_rt_deviation_process=RandomWalk(), + ) + + delay = DeterministicPMF("delay", jnp.array([0.1, 0.3, 0.4, 0.2])) + builder.add_observation( + Counts( + name="hospital", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + ) + builder.add_observation( + CountsBySubpop( + name="hospital_subpop", + ascertainment_rate_rv=DeterministicVariable("ihr_subpop", 0.01), + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(DeterministicVariable("conc_subpop", 10.0)), + ) + ) + + return builder + + class TestPyrenewBuilderConfiguration: """Test PyrenewBuilder configuration.""" @@ -213,54 +260,57 @@ def test_run_with_population_structure(self, simple_builder): class TestMultiSignalModelValidation: """Test data validation.""" - def test_validate_data_requires_population_structure(self, simple_builder): - """Test that validate_data requires population structure.""" - model = simple_builder.build() + def test_validate_data_accepts_valid_data(self, validation_builder): + """Test that validate_data accepts valid dense and sparse data.""" + model = validation_builder.build() + n_total = model.latent.n_initialization_points + 30 - # Should work with population structure model.validate_data( n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, hospital={ + "obs": jnp.full(n_total, jnp.nan), + }, + hospital_subpop={ "obs": jnp.array([10, 20]), "times": jnp.array([5, 10]), }, ) - def test_validate_data_rejects_out_of_bounds_times(self, simple_builder): + def test_validate_data_rejects_out_of_bounds_times(self, validation_builder): """Test that times exceeding n_total_days raises error.""" - model = simple_builder.build() + model = validation_builder.build() n_total = model.latent.n_initialization_points + 30 - with pytest.raises(ValueError, match="times index"): + with pytest.raises(ValueError, match="times"): model.validate_data( n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, - hospital={ + hospital_subpop={ "obs": jnp.array([10]), "times": jnp.array([n_total + 10]), }, ) - def test_validate_data_rejects_negative_times(self, simple_builder): + def test_validate_data_rejects_negative_times(self, validation_builder): """Test that negative times raises error.""" - model = simple_builder.build() + model = validation_builder.build() - with pytest.raises(ValueError, match="times cannot be negative"): + with pytest.raises(ValueError, match="times.*negative"): model.validate_data( n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, - hospital={ + hospital_subpop={ "obs": jnp.array([10]), "times": jnp.array([-1]), }, ) - def test_validate_data_rejects_unknown_observation(self, simple_builder): + def test_validate_data_rejects_unknown_observation(self, validation_builder): """Test that unknown observation name raises error.""" - model = simple_builder.build() + model = validation_builder.build() - with pytest.raises(ValueError, match="Unknown observation"): + with pytest.raises(ValueError, match="Unknown"): model.validate_data( n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, @@ -270,15 +320,17 @@ def test_validate_data_rejects_unknown_observation(self, simple_builder): }, ) - def test_validate_data_rejects_mismatched_obs_times_length(self, simple_builder): + def test_validate_data_rejects_mismatched_obs_times_length( + self, validation_builder + ): """Test that mismatched obs and times lengths raises error.""" - model = simple_builder.build() + model = validation_builder.build() - with pytest.raises(ValueError, match=r"obs length.*must match times length"): + with pytest.raises(ValueError, match="obs.*times"): model.validate_data( n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, - hospital={ + hospital_subpop={ "obs": jnp.array([10, 20, 30]), # 3 elements "times": jnp.array([5, 10]), # 2 elements }, @@ -290,35 +342,50 @@ def test_validate_method_calls_internal_validate(self, simple_builder): # Should not raise model.validate() - def test_validate_data_rejects_negative_subpop_indices(self, simple_builder): + def test_validate_data_rejects_negative_subpop_indices(self, validation_builder): """Test that negative subpop_indices raises error.""" - model = simple_builder.build() + model = validation_builder.build() - with pytest.raises(ValueError, match="subpop_indices cannot be negative"): + with pytest.raises(ValueError, match="subpop_indices.*negative"): model.validate_data( n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, - hospital={ + hospital_subpop={ "subpop_indices": jnp.array([-1, 0, 1]), "times": jnp.array([5, 6, 7]), }, ) - def test_validate_data_rejects_out_of_bounds_subpop_indices(self, simple_builder): + def test_validate_data_rejects_out_of_bounds_subpop_indices( + self, validation_builder + ): """Test that subpop_indices >= K raises error.""" - model = simple_builder.build() + model = validation_builder.build() # K is 3 (from SUBPOP_FRACTIONS = [0.3, 0.25, 0.45]) - with pytest.raises(ValueError, match="subpop_indices contains"): + with pytest.raises(ValueError, match="subpop_indices"): model.validate_data( n_days_post_init=30, subpop_fractions=SUBPOP_FRACTIONS, - hospital={ + hospital_subpop={ "subpop_indices": jnp.array([0, 1, 5]), # 5 >= 3 "times": jnp.array([5, 6, 7]), }, ) + def test_validate_data_rejects_wrong_length_dense_obs(self, validation_builder): + """Test that dense obs with wrong length raises error.""" + model = validation_builder.build() + + with pytest.raises(ValueError, match="obs.*n_total"): + model.validate_data( + n_days_post_init=30, + subpop_fractions=SUBPOP_FRACTIONS, + hospital={ + "obs": jnp.array([10, 20, 30]), # wrong length + }, + ) + class TestPyrenewBuilderErrorHandling: """Test PyrenewBuilder error handling.""" @@ -408,6 +475,10 @@ def _predicted_obs(self, infections): """ return infections + def validate_data(self, n_total, n_subpops, **obs_data): + """Validate data stub.""" + pass + builder = PyrenewBuilder() gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3]))