|
19 | 19 | from pytensor.graph.basic import Node
|
20 | 20 | from pytensor.graph.fg import FunctionGraph
|
21 | 21 | from pytensor.graph.rewriting.basic import node_rewriter
|
22 |
| -from pytensor.scalar.basic import GT, LT |
23 |
| -from pytensor.tensor.math import gt, lt |
| 22 | +from pytensor.scalar.basic import GE, GT, LE, LT |
| 23 | +from pytensor.tensor.math import ge, gt, le, lt |
24 | 24 |
|
25 | 25 | from pymc.logprob.abstract import (
|
26 | 26 | MeasurableElemwise,
|
|
36 | 36 | class MeasurableComparison(MeasurableElemwise):
|
37 | 37 | """A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph."""
|
38 | 38 |
|
39 |
| - valid_scalar_types = (GT, LT) |
| 39 | + valid_scalar_types = (GT, LT, GE, LE) |
40 | 40 |
|
41 | 41 |
|
42 |
| -@node_rewriter(tracks=[gt, lt]) |
| 42 | +@node_rewriter(tracks=[gt, lt, ge, le]) |
43 | 43 | def find_measurable_comparisons(
|
44 | 44 | fgraph: FunctionGraph, node: Node
|
45 | 45 | ) -> Optional[List[MeasurableComparison]]:
|
@@ -92,18 +92,21 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
|
92 | 92 |
|
93 | 93 | condn_exp = pt.eq(value, np.array(True))
|
94 | 94 |
|
95 |
| - if isinstance(op.scalar_op, GT): |
| 95 | + if isinstance(op.scalar_op, (GT, GE)): |
96 | 96 | logprob = pt.switch(condn_exp, logccdf, logcdf)
|
97 |
| - elif isinstance(op.scalar_op, LT): |
98 |
| - if base_rv.dtype.startswith("int"): |
99 |
| - logpmf = _logprob_helper(base_rv, operand, **kwargs) |
100 |
| - logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs) |
101 |
| - logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf)) |
102 |
| - else: |
103 |
| - logprob = pt.switch(condn_exp, logcdf, logccdf) |
| 97 | + elif isinstance(op.scalar_op, (LT, LE)): |
| 98 | + logprob = pt.switch(condn_exp, logcdf, logccdf) |
104 | 99 | else:
|
105 | 100 | raise TypeError(f"Unsupported scalar_op {op.scalar_op}")
|
106 | 101 |
|
| 102 | + if base_rv.dtype.startswith("int"): |
| 103 | + logpmf = _logprob_helper(base_rv, operand, **kwargs) |
| 104 | + logcdf_prev = _logcdf_helper(base_rv, operand - 1, **kwargs) |
| 105 | + if isinstance(op.scalar_op, LT): |
| 106 | + return pt.switch(condn_exp, logcdf_prev, pt.logaddexp(logccdf, logpmf)) |
| 107 | + elif isinstance(op.scalar_op, GE): |
| 108 | + return pt.switch(condn_exp, pt.logaddexp(logccdf, logpmf), logcdf_prev) |
| 109 | + |
107 | 110 | if base_rv_op.name:
|
108 | 111 | logprob.name = f"{base_rv_op}_logprob"
|
109 | 112 | logcdf.name = f"{base_rv_op}_logcdf"
|
|
0 commit comments