Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
680bb1e
Merge branch 'main' of https://github.com/CDCgov/PyRenew
cdc-mitzimorris Sep 15, 2025
2cb876b
update
cdc-mitzimorris Sep 18, 2025
60db8df
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Sep 22, 2025
32a5314
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 5, 2025
d6213f2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 8, 2025
96f27c9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 17, 2025
1cb6fa2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 24, 2025
f62e1e4
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 4, 2025
0c6785d
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 22, 2025
1ee62b9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Jan 29, 2026
0629461
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 4, 2026
efeadee
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
371ba98
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
0304bed
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 6, 2026
ffeea65
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
50e7261
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
dae6af8
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 10, 2026
181fa24
refactored (claude code)
cdc-mitzimorris Feb 10, 2026
655f719
cleanup observation validation logic
cdc-mitzimorris Feb 11, 2026
31c1bc8
more unit tests
cdc-mitzimorris Feb 11, 2026
00d6541
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
d126d33
lint fix
cdc-mitzimorris Feb 11, 2026
db7c53d
Merge branch 'mem_678_delegate_validation' of github-bf06:CDCgov/PyRe…
cdc-mitzimorris Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 5 additions & 46 deletions pyrenew/model/multisignal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,52 +202,11 @@ def validate_data(
f"Available: {list(self.observations.keys())}"
)

obs = obs_data.get("obs")
times = obs_data.get("times")

if times is not None:
# Sparse observations: times on shared axis [0, n_total)
times = jnp.asarray(times)
if jnp.any(times < 0):
raise ValueError(f"Observation '{name}': times cannot be negative")
max_time = jnp.max(times)
if max_time >= n_total:
raise ValueError(
f"Observation '{name}': times index {int(max_time)} "
f">= n_total ({n_total} = {n_init} init + "
f"{n_days_post_init} days). "
f"Times must be on shared axis [0, {n_total})."
)
if obs is not None and len(obs) != len(times):
raise ValueError(
f"Observation '{name}': obs length {len(obs)} "
f"must match times length {len(times)}"
)
elif obs is not None:
# Dense observations: length must equal n_total
obs = jnp.asarray(obs)
if obs.shape[0] != n_total:
raise ValueError(
f"Observation '{name}': obs length {obs.shape[0]} "
f"must equal n_total ({n_total} = {n_init} init + "
f"{n_days_post_init} days). "
f"Pad with NaN for initialization period."
)

# Validate subpop_indices if present
subpop_indices = obs_data.get("subpop_indices")
if subpop_indices is not None:
subpop_indices = jnp.asarray(subpop_indices)
if jnp.any(subpop_indices < 0):
raise ValueError(
f"Observation '{name}': subpop_indices cannot be negative"
)
max_idx = jnp.max(subpop_indices)
if max_idx >= pop.n_subpops:
raise ValueError(
f"Observation '{name}': subpop_indices contains "
f"{int(max_idx)} >= {pop.n_subpops} (n_subpops)"
)
self.observations[name].validate_data(
n_total=n_total,
n_subpops=pop.n_subpops,
**obs_data,
)

def sample(
self,
Expand Down
114 changes: 114 additions & 0 deletions pyrenew/observation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,120 @@ def _predicted_obs(
"""
pass # pragma: no cover

@abstractmethod
def validate_data(
self,
n_total: int,
n_subpops: int,
**obs_data,
) -> None:
"""
Validate observation data before running inference.

Each observation process validates its own data requirements.
Called by the model's ``validate_data()`` method with concrete
(non-traced) values before JAX tracing begins.

Parameters
----------
n_total : int
Total number of time steps (n_init + n_days_post_init).
n_subpops : int
Number of subpopulations.
**obs_data
Observation-specific data kwargs (same as passed to ``sample()``,
minus ``infections`` which comes from the latent process).

Raises
------
ValueError
If any data fails validation.
"""
pass # pragma: no cover

def _validate_times(self, times: ArrayLike, n_total: int) -> None:
"""
Validate a times index array.

Checks that all values are non-negative and within ``[0, n_total)``.

Parameters
----------
times : ArrayLike
Time indices on the shared time axis.
n_total : int
Total number of time steps.

Raises
------
ValueError
If times contains negative values or values >= n_total.
"""
times = jnp.asarray(times)
if jnp.any(times < 0):
raise ValueError(f"Observation '{self.name}': times cannot be negative")
max_time = jnp.max(times)
if max_time >= n_total:
raise ValueError(
f"Observation '{self.name}': times index {int(max_time)} "
f">= n_total ({n_total}). "
f"Times must be on shared axis [0, {n_total})."
)

def _validate_subpop_indices(
self, subpop_indices: ArrayLike, n_subpops: int
) -> None:
"""
Validate a subpopulation index array.

Checks that all values are non-negative and within ``[0, n_subpops)``.

Parameters
----------
subpop_indices : ArrayLike
Subpopulation indices (0-indexed).
n_subpops : int
Number of subpopulations.

Raises
------
ValueError
If subpop_indices contains negative values or values >= n_subpops.
"""
subpop_indices = jnp.asarray(subpop_indices)
if jnp.any(subpop_indices < 0):
raise ValueError(
f"Observation '{self.name}': subpop_indices cannot be negative"
)
max_idx = jnp.max(subpop_indices)
if max_idx >= n_subpops:
raise ValueError(
f"Observation '{self.name}': subpop_indices contains "
f"{int(max_idx)} >= {n_subpops} (n_subpops)"
)

def _validate_obs_times_length(self, obs: ArrayLike, times: ArrayLike) -> None:
"""
Validate that obs and times arrays have matching lengths.

Parameters
----------
obs : ArrayLike
Observed data array.
times : ArrayLike
Times index array.

Raises
------
ValueError
If obs and times have different lengths.
"""
if len(obs) != len(times):
raise ValueError(
f"Observation '{self.name}': obs length {len(obs)} "
f"must match times length {len(times)}"
)

@abstractmethod
def sample(
self,
Expand Down
75 changes: 75 additions & 0 deletions pyrenew/observation/count_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,41 @@ def __repr__(self) -> str:
f"noise={self.noise!r})"
)

def validate_data(
self,
n_total: int,
n_subpops: int,
obs: ArrayLike | None = None,
**kwargs,
) -> None:
"""
Validate aggregated count observation data.

Parameters
----------
n_total : int
Total number of time steps (n_init + n_days_post_init).
n_subpops : int
Number of subpopulations (unused for aggregate observations).
obs : ArrayLike | None
Observed counts on shared time axis. Shape: (n_total,).
**kwargs
Additional keyword arguments (ignored).

Raises
------
ValueError
If obs length doesn't match n_total.
"""
if obs is not None:
obs = jnp.asarray(obs)
if obs.shape[0] != n_total:
raise ValueError(
f"Observation '{self.name}': obs length {obs.shape[0]} "
f"must equal n_total ({n_total}). "
f"Pad with NaN for initialization period."
)

def sample(
self,
infections: ArrayLike,
Expand Down Expand Up @@ -275,6 +310,46 @@ def infection_resolution(self) -> str:
"""
return "subpop"

def validate_data(
self,
n_total: int,
n_subpops: int,
times: ArrayLike | None = None,
subpop_indices: ArrayLike | None = None,
obs: ArrayLike | None = None,
**kwargs,
) -> None:
"""
Validate subpopulation-level count observation data.

Parameters
----------
n_total : int
Total number of time steps (n_init + n_days_post_init).
n_subpops : int
Number of subpopulations.
times : ArrayLike | None
Day index for each observation on the shared time axis.
subpop_indices : ArrayLike | None
Subpopulation index for each observation (0-indexed).
obs : ArrayLike | None
Observed counts (n_obs,).
**kwargs
Additional keyword arguments (ignored).

Raises
------
ValueError
If times or subpop_indices are out of bounds, or if
obs and times have mismatched lengths.
"""
if times is not None:
self._validate_times(times, n_total)
if obs is not None:
self._validate_obs_times_length(obs, times)
if subpop_indices is not None:
self._validate_subpop_indices(subpop_indices, n_subpops)

def sample(
self,
infections: ArrayLike,
Expand Down
59 changes: 59 additions & 0 deletions pyrenew/observation/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
air quality, serology, etc.) with signal-specific processing.
"""

import jax.numpy as jnp
from jax.typing import ArrayLike

from pyrenew.metaclass import RandomVariable
Expand Down Expand Up @@ -104,6 +105,64 @@ def infection_resolution(self) -> str:
"""
return "subpop"

def validate_data(
self,
n_total: int,
n_subpops: int,
times: ArrayLike | None = None,
subpop_indices: ArrayLike | None = None,
sensor_indices: ArrayLike | None = None,
n_sensors: int | None = None,
obs: ArrayLike | None = None,
**kwargs,
) -> None:
"""
Validate measurement observation data.

Parameters
----------
n_total : int
Total number of time steps (n_init + n_days_post_init).
n_subpops : int
Number of subpopulations.
times : ArrayLike | None
Day index for each observation on the shared time axis.
subpop_indices : ArrayLike | None
Subpopulation index for each observation (0-indexed).
sensor_indices : ArrayLike | None
Sensor index for each observation (0-indexed).
n_sensors : int | None
Total number of measurement sensors.
obs : ArrayLike | None
Observed measurements (n_obs,).
**kwargs
Additional keyword arguments (ignored).

Raises
------
ValueError
If times, subpop_indices, or sensor_indices are out of bounds,
or if obs and times have mismatched lengths.
"""
if times is not None:
self._validate_times(times, n_total)
if obs is not None:
self._validate_obs_times_length(obs, times)
if subpop_indices is not None:
self._validate_subpop_indices(subpop_indices, n_subpops)
if sensor_indices is not None and n_sensors is not None:
sensor_indices = jnp.asarray(sensor_indices)
if jnp.any(sensor_indices < 0):
raise ValueError(
f"Observation '{self.name}': sensor_indices cannot be negative"
)
max_sensor = jnp.max(sensor_indices)
if max_sensor >= n_sensors:
raise ValueError(
f"Observation '{self.name}': sensor_indices contains "
f"{int(max_sensor)} >= {n_sensors} (n_sensors)"
)

def sample(
self,
infections: ArrayLike,
Expand Down
4 changes: 4 additions & 0 deletions test/test_observation_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,10 @@ def sample(self, **kwargs):
"""Sample method stub."""
pass

def validate_data(self, n_total, n_subpops, **obs_data):
"""Validate data stub."""
pass

process = IncompleteCountProcess(
name="test",
ascertainment_rate_rv=DeterministicVariable("ihr", 0.01),
Expand Down
Loading
Loading