Skip to content

Commit 61be336

Browse files
shreyas3156ricardoV94
authored andcommitted
Add logprob derivation for >= and <= operations
1 parent 9b712bf commit 61be336

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

pymc/logprob/binary.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from pytensor.graph.basic import Node
2020
from pytensor.graph.fg import FunctionGraph
2121
from pytensor.graph.rewriting.basic import node_rewriter
22-
from pytensor.scalar.basic import GT, LT
23-
from pytensor.tensor.math import gt, lt
22+
from pytensor.scalar.basic import GE, GT, LE, LT
23+
from pytensor.tensor.math import ge, gt, le, lt
2424

2525
from pymc.logprob.abstract import (
2626
MeasurableElemwise,
@@ -36,10 +36,10 @@
3636
class MeasurableComparison(MeasurableElemwise):
3737
"""A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph."""
3838

39-
valid_scalar_types = (GT, LT)
39+
valid_scalar_types = (GT, LT, GE, LE)
4040

4141

42-
@node_rewriter(tracks=[gt, lt])
42+
@node_rewriter(tracks=[gt, lt, ge, le])
4343
def find_measurable_comparisons(
4444
fgraph: FunctionGraph, node: Node
4545
) -> Optional[List[MeasurableComparison]]:
@@ -92,18 +92,21 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
9292

9393
condn_exp = pt.eq(value, np.array(True))
9494

95-
if isinstance(op.scalar_op, GT):
95+
if isinstance(op.scalar_op, (GT, GE)):
9696
logprob = pt.switch(condn_exp, logccdf, logcdf)
97-
elif isinstance(op.scalar_op, LT):
98-
if base_rv.dtype.startswith("int"):
99-
logpmf = _logprob_helper(base_rv, operand, **kwargs)
100-
logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs)
101-
logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf))
102-
else:
103-
logprob = pt.switch(condn_exp, logcdf, logccdf)
97+
elif isinstance(op.scalar_op, (LT, LE)):
98+
logprob = pt.switch(condn_exp, logcdf, logccdf)
10499
else:
105100
raise TypeError(f"Unsupported scalar_op {op.scalar_op}")
106101

102+
if base_rv.dtype.startswith("int"):
103+
logpmf = _logprob_helper(base_rv, operand, **kwargs)
104+
logcdf_prev = _logcdf_helper(base_rv, operand - 1, **kwargs)
105+
if isinstance(op.scalar_op, LT):
106+
return pt.switch(condn_exp, logcdf_prev, pt.logaddexp(logccdf, logpmf))
107+
elif isinstance(op.scalar_op, GE):
108+
return pt.switch(condn_exp, pt.logaddexp(logccdf, logpmf), logcdf_prev)
109+
107110
if base_rv_op.name:
108111
logprob.name = f"{base_rv_op}_logprob"
109112
logcdf.name = f"{base_rv_op}_logcdf"

tests/logprob/test_binary.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,25 @@
2727
@pytest.mark.parametrize(
2828
"comparison_op, exp_logp_true, exp_logp_false",
2929
[
30-
(pt.lt, st.norm(0, 1).logcdf, st.norm(0, 1).logsf),
31-
(pt.gt, st.norm(0, 1).logsf, st.norm(0, 1).logcdf),
30+
((pt.lt, pt.le), "logcdf", "logsf"),
31+
((pt.gt, pt.ge), "logsf", "logcdf"),
3232
],
3333
)
3434
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
3535
x_rv = pt.random.normal(0, 1)
36-
comp_x_rv = comparison_op(x_rv, 0.5)
36+
for op in comparison_op:
37+
comp_x_rv = op(x_rv, 0.5)
3738

38-
comp_x_vv = comp_x_rv.clone()
39+
comp_x_vv = comp_x_rv.clone()
3940

40-
logprob = logp(comp_x_rv, comp_x_vv)
41-
assert_no_rvs(logprob)
41+
logprob = logp(comp_x_rv, comp_x_vv)
42+
assert_no_rvs(logprob)
4243

43-
logp_fn = pytensor.function([comp_x_vv], logprob)
44+
logp_fn = pytensor.function([comp_x_vv], logprob)
45+
ref_scipy = st.norm(0, 1)
4446

45-
assert np.isclose(logp_fn(0), exp_logp_false(0.5))
46-
assert np.isclose(logp_fn(1), exp_logp_true(0.5))
47+
assert np.isclose(logp_fn(0), getattr(ref_scipy, exp_logp_false)(0.5))
48+
assert np.isclose(logp_fn(1), getattr(ref_scipy, exp_logp_true)(0.5))
4749

4850

4951
@pytest.mark.parametrize(
@@ -54,11 +56,21 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
5456
lambda x: st.poisson(2).logcdf(x - 1),
5557
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
5658
),
59+
(
60+
pt.ge,
61+
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
62+
lambda x: st.poisson(2).logcdf(x - 1),
63+
),
5764
(
5865
pt.gt,
5966
st.poisson(2).logsf,
6067
st.poisson(2).logcdf,
6168
),
69+
(
70+
pt.le,
71+
st.poisson(2).logcdf,
72+
st.poisson(2).logsf,
73+
),
6274
],
6375
)
6476
def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):

0 commit comments

Comments
 (0)