Skip to content

Conversation

@cdc-mitzimorris
Copy link
Collaborator

This PR adds work that was done in https://github.com/cdcent/cfa-pyrenew-hierarchical/pull/4 to PyRenew.

It adds the base observation process class, concrete implementations for Count processes and the abstract base class for Measurement processes, together with unit tests and two new tutorials for count and measurement observation processes respectively.

Once this PR and the work done in https://github.com/cdcent/cfa-pyrenew-hierarchical/pull/5 have been added to PyRenew, subsequent PRs will deprecate unused features and harmonize the documentation and tutorials.

@codecov
Copy link

codecov bot commented Dec 23, 2025

Codecov Report

❌ Patch coverage is 98.38710% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 97.26%. Comparing base (02446c5) to head (7e48001).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
pyrenew/observation/count_observations.py 96.55% 2 Missing ⚠️
pyrenew/observation/noise.py 98.38% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #644      +/-   ##
==========================================
+ Coverage   96.98%   97.26%   +0.28%     
==========================================
  Files          42       47       +5     
  Lines        1094     1280     +186     
==========================================
+ Hits         1061     1245     +184     
- Misses         33       35       +2     
Flag Coverage Δ
unittests 97.26% <98.38%> (+0.28%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link

github-actions bot commented Dec 23, 2025

Thank you for your contribution @cdc-mitzimorris 🚀! Your github-pages is ready for download 👉 here 👈!
(The artifact expires on 2026-02-03T16:32:57Z. You can re-generate it by re-running the workflow here.)

Copy link
Collaborator

@dylanhmorris dylanhmorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cdc-mitzimorris! A few things to address and then I can re-review.

@cdc-mitzimorris
Copy link
Collaborator Author

@dylanhmorris - ready for re-review - made all suggested changes to the tutorials and added arg "name" to observation processes so that user specifies the signal name.

Copy link
Collaborator

@dylanhmorris dylanhmorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cdc-mitzimorris! Just a couple remaining questions.

@cdc-mitzimorris
Copy link
Collaborator Author

changes made.

@cdc-mitzimorris
Copy link
Collaborator Author

@dylanhmorris - conversation resolved.

Copy link
Collaborator

@dylanhmorris dylanhmorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need the separate noise module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cdc-mitzimorris this still needs to be implemented.

Comment on lines +155 to +156
Output preserves input timeline. First len(delay_pmf)-1 days return
-1 or ~0 (depending on noise model) due to NaN padding.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is meant by ~0? and why does the padding change depending on the noise model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should discuss this f2f in our upcoming meeting.

Comment on lines +198 to +199
times : ArrayLike | None
Day indices for sparse observations. None for dense observations.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index relative to what vector?

@cdc-mitzimorris
Copy link
Collaborator Author

@dylanhmorris ready for re-re-re-review!

Comment on lines +273 to +274
site_name = f"{self.name}_{suffix}"
numpyro.deterministic(site_name, value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the helper function

Suggested change
site_name = f"{self.name}_{suffix}"
numpyro.deterministic(site_name, value)
numpyro.deterministic(self._sample_site_name(suffix), value)

Comment on lines +144 to +184
def _validate_pmf(
self,
pmf: ArrayLike,
param_name: str,
atol: float = 1e-6,
) -> None:
"""
Validate that an array is a valid probability mass function.
Checks:
- Non-empty array
- Sums to 1.0 (within tolerance)
- All non-negative values
Parameters
----------
pmf : ArrayLike
The PMF array to validate
param_name : str
Name of the parameter (for error messages)
atol : float, default 1e-6
Absolute tolerance for sum-to-one check
Raises
------
ValueError
If PMF is empty, doesn't sum to 1.0 (within tolerance),
or contains negative values.
"""
if pmf.size == 0:
raise ValueError(f"{param_name} must return non-empty array")

pmf_sum = jnp.sum(pmf)
if not jnp.isclose(pmf_sum, 1.0, atol=atol):
raise ValueError(
f"{param_name} must sum to 1.0 (±{atol}), got {float(pmf_sum):.6f}"
)

if jnp.any(pmf < 0):
raise ValueError(f"{param_name} must have non-negative values")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should land this PR and then address, but this duplicates distutil.validate_discrete_dist_vector.

That said, I prefer the implementation here, so this should replace that.

We could also (in #645) consider having discrete PMFs as a strong abstraction. We could then require certain RVs to be of class PMF and deferring the validation to those RVs' constructors. Given how fundamental discrete PMFs are to discrete-time renewal processes I favor this.

Comment on lines +186 to +201
def get_minimum_observation_day(self) -> int:
"""
Get the first day with valid (non-NaN) convolution results.
Due to the convolution operation requiring a history window,
the first ``len(pmf) - 1`` days will have NaN values in the
output. This method returns the index of the first valid day.
Returns
-------
int
Day index (0-based) of first valid observation.
Equal to ``len(pmf) - 1``.
"""
pmf = self.temporal_pmf_rv()
return int(len(pmf) - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used anywhere in the class. Would be better to implement in terms of a PMF abstraction (could then also apply to computing the offset in compute_delay_ascertained_incidence. I think refrain from implementing for now.

Suggested change
def get_minimum_observation_day(self) -> int:
"""
Get the first day with valid (non-NaN) convolution results.
Due to the convolution operation requiring a history window,
the first ``len(pmf) - 1`` days will have NaN values in the
output. This method returns the index of the first valid day.
Returns
-------
int
Day index (0-based) of first valid observation.
Equal to ``len(pmf) - 1``.
"""
pmf = self.temporal_pmf_rv()
return int(len(pmf) - 1)

Comment on lines +66 to +71
ascertainment_rate = self.ascertainment_rate_rv()
if jnp.any(ascertainment_rate < 0) or jnp.any(ascertainment_rate > 1):
raise ValueError(
"ascertainment_rate_rv must be in [0, 1], "
"got value(s) outside this range"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A single sample doesn't suffice to validate this for stochastic ascertainment rates. I think remove until we have actual support handling for RVs (also some argument that the scaling factor for counts does not have to be <1, though in practice it usually is (since it's usually modeling the quantity P(event reported | infection) = P(event occurs | infection) * P(event reported | event occurs)).

Suggested change
ascertainment_rate = self.ascertainment_rate_rv()
if jnp.any(ascertainment_rate < 0) or jnp.any(ascertainment_rate > 1):
raise ValueError(
"ascertainment_rate_rv must be in [0, 1], "
"got value(s) outside this range"
)

Comment on lines +63 to +64
delay_pmf = self.temporal_pmf_rv()
self._validate_pmf(delay_pmf, "delay_distribution_rv")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, remove for now, implement when RVs have attributes that allow strict validation.0

Suggested change
delay_pmf = self.temporal_pmf_rv()
self._validate_pmf(delay_pmf, "delay_distribution_rv")

int
Length of delay distribution PMF.
"""
return len(self.temporal_pmf_rv())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But high priority to replace with something that doesn't require a sampling call.

Suggested change
return len(self.temporal_pmf_rv())
return jnp.shape(self.temporal_pmf_rv())[0]

Comment on lines +121 to +130
is_1d = infections.ndim == 1
if is_1d:
infections = infections[:, jnp.newaxis]

def convolve_col(col): # numpydoc ignore=GL08
return self._convolve_with_alignment(col, delay_pmf, ascertainment_rate)[0]

predicted_counts = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections)

return predicted_counts[:, 0] if is_1d else predicted_counts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's going to be a switch for 1d input, I think it's cleaner and clearer to vmap only in the >2d case, and not pad the input only to unpad the output

Suggested change
is_1d = infections.ndim == 1
if is_1d:
infections = infections[:, jnp.newaxis]
def convolve_col(col): # numpydoc ignore=GL08
return self._convolve_with_alignment(col, delay_pmf, ascertainment_rate)[0]
predicted_counts = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections)
return predicted_counts[:, 0] if is_1d else predicted_counts
is_1d = infections.ndim == 1
if is_1d:
predicted_counts = self._convolve_with_alignment(infections, delay_pmf, ascertainment_rate)[0]
else:
predicted_counts = jax.vmap(lambda col: self._convolve_with_alignment(col, delay_pmf, ascertainment_rate)[0], in_axes=1, out_axes=1)(infections)
return predicted_counts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants