Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
87eba5d
Add numerical stability test for censored distributions
maresb Dec 15, 2025
81da946
Add _logccdf dispatcher for numerically stable log survival function
maresb Dec 15, 2025
15e5f64
Move try/except fallback into _logccdf_helper
maresb Dec 15, 2025
15806c0
Add _logccdf support to Truncated distribution
maresb Dec 15, 2025
063af42
Fix logccdf IR rewriting to match logcdf pattern
maresb Dec 15, 2025
19b9979
Fix test import style in test_censoring.py
maresb Dec 15, 2025
20322a1
Remove redundant test_logccdf_helper_numerical_stability
maresb Dec 16, 2025
93734c2
Add test for _logccdf_helper fallback to log1mexp
maresb Dec 16, 2025
2df5274
Use ±100 sigma in numerical stability tests
maresb Dec 16, 2025
63c9327
Enhance test docstrings with What/Why/How documentation
maresb Dec 16, 2025
628e6d5
Add test for logccdf IR graph rewriting path
maresb Dec 16, 2025
36b8672
Add test for logccdf with SymbolicRandomVariable extended_signature
maresb Dec 16, 2025
5a2f7cf
Import log1mexp directly
maresb Jan 6, 2026
504b371
Add tests for _logccdf on discrete distributions
maresb Jan 6, 2026
532a22e
Use _logccdf_helper also for discrete distributions
maresb Jan 6, 2026
7dff7e0
Remove verbose inline comments about numerical stability
maresb Jan 6, 2026
b1647ff
Simplify verbose test docstrings
maresb Jan 6, 2026
74a3933
Simplify graph_contains_log1mexp using pytensor.graph.traversal.ances…
maresb Jan 6, 2026
c73681c
Remove test_logccdf_transformed_argument (redundant pm.Model usage)
maresb Jan 6, 2026
64e482a
Remove _helper tests, keep only user-facing API tests
maresb Jan 6, 2026
1cf84fd
Move censored numerical stability test to tests/distributions/test_ce…
maresb Jan 6, 2026
db8e3f1
Add logcdf tests for Erfc/Erfcx transforms
maresb Jan 6, 2026
7de533b
Explain the test assumption that Normal has a custom ccdf but Uniform…
maresb Jan 6, 2026
9444212
Move discrete transform logccdf tests to test_transforms.py
maresb Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,13 @@ def logcdf(value, mu, sigma):
msg="sigma > 0",
)

def logccdf(value, mu, sigma):
return check_parameters(
normal_lccdf(mu, sigma, value),
sigma > 0,
msg="sigma > 0",
)

def icdf(value, mu, sigma):
res = mu + sigma * -np.sqrt(2.0) * pt.erfcinv(2 * value)
res = check_icdf_value(res, value)
Expand Down
13 changes: 12 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
rv_size_is_none,
shape_from_dims,
)
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableOp, _icdf, _logccdf, _logcdf, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.printing import str_for_dist
Expand Down Expand Up @@ -150,6 +150,17 @@ def logcdf(op, value, *dist_params, **kwargs):
dist_params = [dist_params[i] for i in params_idxs]
return class_logcdf(value, *dist_params)

class_logccdf = clsdict.get("logccdf")
if class_logccdf:

@_logccdf.register(rv_type)
def logccdf(op, value, *dist_params, **kwargs):
if isinstance(op, RandomVariable):
rng, size, *dist_params = dist_params
elif params_idxs:
dist_params = [dist_params[i] for i in params_idxs]
return class_logccdf(value, *dist_params)

class_icdf = clsdict.get("icdf")
if class_icdf:

Expand Down
23 changes: 20 additions & 3 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.logprob.basic import icdf, logcdf, logp
from pymc.logprob.basic import icdf, logccdf, logcdf, logp
from pymc.math import logdiffexp
from pymc.pytensorf import collect_default_updates
from pymc.util import check_dist_not_registered
Expand Down Expand Up @@ -211,6 +211,23 @@ def _create_logcdf_exprs(
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
return lower_logcdf, upper_logcdf

@staticmethod
def _create_lower_logccdf_expr(
base_rv: TensorVariable,
value: TensorVariable,
lower: TensorVariable,
) -> TensorVariable:
"""Create logccdf expression at lower bound for base_rv.

Uses `value` as a template for broadcasting. This is numerically more
stable than computing log(1 - exp(logcdf)) for distributions that have
a registered logccdf method.
"""
# For left truncated discrete RVs, we need to include the whole lower bound.
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
return logccdf(base_rv, lower_value, warn_rvs=False)

def update(self, node: Apply):
"""Return the update mapping for the internal RNGs.

Expand Down Expand Up @@ -401,7 +418,7 @@ def truncated_logprob(op, values, *inputs, **kwargs):
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower)
elif is_upper_bounded:
lognorm = upper_logcdf

Expand Down Expand Up @@ -438,7 +455,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower)
elif is_upper_bounded:
lognorm = upper_logcdf

Expand Down
2 changes: 2 additions & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pymc.logprob.basic import (
conditional_logp,
icdf,
logccdf,
logcdf,
logp,
transformed_conditional_logp,
Expand All @@ -59,6 +60,7 @@

__all__ = (
"icdf",
"logccdf",
"logcdf",
"logp",
)
41 changes: 40 additions & 1 deletion pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from pytensor.graph import Apply, Op, Variable
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
from pytensor.tensor import TensorVariable, log1mexp
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -108,6 +108,45 @@ def _logcdf_helper(rv, value, **kwargs):
return logcdf


@singledispatch
def _logccdf(
op: Op,
value: TensorVariable,
*inputs: TensorVariable,
**kwargs,
):
"""Create a graph for the log complementary CDF (log survival function) of a ``RandomVariable``.

This function dispatches on the type of ``op``, which should be a subclass
of ``RandomVariable``. If you want to implement new logccdf graphs
for a ``RandomVariable``, register a new function on this dispatcher.

The log complementary CDF is defined as log(1 - CDF(x)), also known as the
log survival function. For distributions with a numerically stable implementation,
this should be used instead of computing log(1 - exp(logcdf)).
"""
raise NotImplementedError(f"LogCCDF method not implemented for {op}")


def _logccdf_helper(rv, value, **kwargs):
"""Helper that calls `_logccdf` dispatcher with fallback to log1mexp(logcdf).

If a numerically stable `_logccdf` implementation is registered for the
distribution, it will be used. Otherwise, falls back to computing
`log(1 - exp(logcdf))` which may be numerically unstable in the tails.
"""
try:
logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs)
except NotImplementedError:
logcdf = _logcdf_helper(rv, value, **kwargs)
logccdf = log1mexp(logcdf)

if rv.name:
logccdf.name = f"{rv.name}_logccdf"

return logccdf


@singledispatch
def _icdf(
op: Op,
Expand Down
65 changes: 65 additions & 0 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pymc.logprob.abstract import (
MeasurableOp,
_icdf_helper,
_logccdf_helper,
_logcdf_helper,
_logprob,
_logprob_helper,
Expand Down Expand Up @@ -302,6 +303,70 @@ def normal_logcdf(value, mu, sigma):
return expr


def logccdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable:
"""Create a graph for the log complementary CDF (log survival function) of a random variable.

The log complementary CDF is defined as log(1 - CDF(x)), also known as the
log survival function. For distributions with a numerically stable implementation,
this is more accurate than computing log(1 - exp(logcdf)).

Parameters
----------
rv : TensorVariable
value : tensor_like
Should be the same type (shape and dtype) as the rv.
warn_rvs : bool, default True
Warn if RVs were found in the logccdf graph.
This can happen when a variable has other random variables as inputs.
In that case, those random variables should be replaced by their respective values.

Returns
-------
logccdf : TensorVariable

Raises
------
RuntimeError
If the logccdf cannot be derived.

Examples
--------
Create a compiled function that evaluates the logccdf of a variable

.. code-block:: python

import pymc as pm
import pytensor.tensor as pt

mu = pt.scalar("mu")
rv = pm.Normal.dist(mu, 1.0)

value = pt.scalar("value")
rv_logccdf = pm.logccdf(rv, value)

# Use .eval() for debugging
print(rv_logccdf.eval({value: 0.9, mu: 0.0})) # -1.5272506

# Compile a function for repeated evaluations
rv_logccdf_fn = pm.compile_pymc([value, mu], rv_logccdf)
print(rv_logccdf_fn(value=0.9, mu=0.0)) # -1.5272506

"""
value = pt.as_tensor_variable(value, dtype=rv.dtype)
try:
return _logccdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph = construct_ir_fgraph({rv: value})
[ir_valued_rv] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
expr = _logccdf_helper(ir_rv, ir_value, **kwargs)
[expr] = cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph([expr])
return expr


def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable:
"""Create a graph for the inverse CDF of a random variable.

Expand Down
3 changes: 2 additions & 1 deletion pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pymc.logprob.abstract import (
MeasurableElemwise,
_logccdf_helper,
_logcdf_helper,
_logprob,
_logprob_helper,
Expand Down Expand Up @@ -95,7 +96,7 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
base_rv_op = base_rv.owner.op

logcdf = _logcdf_helper(base_rv, operand, **kwargs)
logccdf = pt.log1mexp(logcdf)
logccdf = _logccdf_helper(base_rv, operand, **kwargs)

condn_exp = pt.eq(value, np.array(True))

Expand Down
5 changes: 3 additions & 2 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableElemwise, _logccdf_helper, _logcdf, _logprob
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables

Expand Down Expand Up @@ -119,7 +119,8 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
if not (isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))):
is_upper_bounded = True

logccdf = pt.log1mexp(logcdf)
logccdf = _logccdf_helper(base_rv, value, **kwargs)

# For right clipped discrete RVs, we need to add an extra term
# corresponding to the pmf at the upper bound
if base_rv.dtype.startswith("int"):
Expand Down
6 changes: 4 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
MeasurableOp,
_icdf,
_icdf_helper,
_logccdf_helper,
_logcdf,
_logcdf_helper,
_logprob,
Expand Down Expand Up @@ -248,9 +249,10 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg

logcdf = _logcdf_helper(measurable_input, backward_value)
if is_discrete:
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
# For discrete distributions, P(X >= t) = P(X > t-1)
logccdf = _logccdf_helper(measurable_input, backward_value - 1)
else:
logccdf = pt.log1mexp(logcdf)
logccdf = _logccdf_helper(measurable_input, backward_value)

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
Expand Down
24 changes: 24 additions & 0 deletions tests/distributions/test_censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,27 @@ def test_censored_logcdf_discrete(self):
logcdf(censored_cat, eval_points).eval(),
expected_interval,
)

@pytest.mark.parametrize(
"censoring_side,bound_value",
[
("right", 100.0),
("left", -100.0),
],
)
def test_censored_logp_numerical_stability(self, censoring_side, bound_value):
"""Censored logp at 100 sigma should be finite, not -inf."""
ref_scipy = sp.stats.norm(0, 1)

normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
if censoring_side == "right":
censored = pm.Censored.dist(normal_dist, lower=None, upper=bound_value)
expected_logp = ref_scipy.logsf(bound_value)
else:
censored = pm.Censored.dist(normal_dist, lower=bound_value, upper=None)
expected_logp = ref_scipy.logcdf(bound_value)

logp_at_bound = logp(censored, bound_value).eval()

assert np.isfinite(logp_at_bound)
assert np.isclose(logp_at_bound, expected_logp, rtol=1e-6)
40 changes: 40 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,46 @@ def rv_op(cls, size=None, rng=None):
resized_rv = change_dist_size(rv, new_size=5, expand=True)
assert resized_rv.type.shape == (5,)

def test_logccdf_with_extended_signature(self):
"""Test logccdf registration for SymbolicRandomVariable with extended_signature."""
from pymc.distributions.dist_math import normal_lccdf
from pymc.distributions.distribution import Distribution

class TestDistWithLogccdf(Distribution):
# Create a SymbolicRandomVariable type with extended_signature
rv_type = type(
"TestRVWithLogccdf",
(SymbolicRandomVariable,),
{"extended_signature": "[rng],[size],(),()->[rng],()"},
)

@classmethod
def dist(cls, mu, sigma, **kwargs):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
return super().dist([mu, sigma], **kwargs)

@classmethod
def rv_op(cls, mu, sigma, size=None, rng=None):
rng = normalize_rng_param(rng)
size = normalize_size_param(size)
# Internally uses Normal, but wrapped in SymbolicRandomVariable
next_rng, draws = Normal.dist(mu, sigma, size=size, rng=rng).owner.outputs
return cls.rv_type(
inputs=[rng, size, mu, sigma],
outputs=[next_rng, draws],
ndim_supp=0,
)(rng, size, mu, sigma)

# This logccdf will be registered via params_idxs path
def logccdf(value, mu, sigma):
return normal_lccdf(mu, sigma, value)

rv = TestDistWithLogccdf.dist(0, 1)
result = pm.logccdf(rv, 0.5).eval()
expected = st.norm(0, 1).logsf(0.5) # ≈ -0.994
npt.assert_allclose(result, expected)


def test_distribution_op_registered():
"""Test that returned Ops are registered as virtual subclasses of the respective PyMC distributions."""
Expand Down
Loading