Skip to content

Commit 7ed5b71

Browse files
ricardoV94twiecki
authored andcommitted
Cleanup RV in graph checks
1 parent 13f9894 commit 7ed5b71

File tree

8 files changed

+88
-66
lines changed

8 files changed

+88
-66
lines changed

pymc/logprob/basic.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from typing import Dict, List, Optional, Sequence, Union
4141

4242
import numpy as np
43-
import pytensor
4443
import pytensor.tensor as pt
4544

4645
from pytensor import config
@@ -53,7 +52,6 @@
5352
)
5453
from pytensor.graph.op import compute_test_value
5554
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
56-
from pytensor.tensor.random.op import RandomVariable
5755
from pytensor.tensor.var import TensorVariable
5856
from typing_extensions import TypeAlias
5957

@@ -66,11 +64,22 @@
6664
)
6765
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
6866
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
69-
from pymc.logprob.utils import rvs_to_value_vars
67+
from pymc.logprob.utils import find_rvs_in_graph, rvs_to_value_vars
7068

7169
TensorLike: TypeAlias = Union[Variable, float, np.ndarray]
7270

7371

72+
def _find_unallowed_rvs_in_graph(graph):
73+
from pymc.data import MinibatchIndexRV
74+
from pymc.distributions.simulator import SimulatorRV
75+
76+
return {
77+
rv
78+
for rv in find_rvs_in_graph(graph)
79+
if not isinstance(rv.owner.op, (SimulatorRV, MinibatchIndexRV))
80+
}
81+
82+
7483
def _warn_rvs_in_inferred_graph(graph: Union[TensorVariable, Sequence[TensorVariable]]):
7584
"""Issue warning if any RVs are found in graph.
7685
@@ -81,13 +90,11 @@ def _warn_rvs_in_inferred_graph(graph: Union[TensorVariable, Sequence[TensorVari
8190
This makes it impossible (or difficult) to replace it by the respective values afterward,
8291
so we instruct users to do it beforehand.
8392
"""
84-
from pymc.testing import assert_no_rvs
8593

86-
try:
87-
assert_no_rvs(graph)
88-
except AssertionError:
94+
rvs_in_graph = _find_unallowed_rvs_in_graph(graph)
95+
if rvs_in_graph:
8996
warnings.warn(
90-
"RandomVariables were found in the derived graph. "
97+
f"RandomVariables {rvs_in_graph} were found in the derived graph. "
9198
"These variables are a clone and do not match the original ones on identity.\n"
9299
"If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. For example: "
93100
"`logp(model.replace_rvs_by_values([rv])[0], value)`",
@@ -149,6 +156,13 @@ def icdf(
149156
return expr
150157

151158

159+
RVS_IN_JOINT_LOGP_GRAPH_MSG = (
160+
"Random variables detected in the logp graph: %s.\n"
161+
"This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n"
162+
"or when not all rvs have a corresponding value variable."
163+
)
164+
165+
152166
def factorized_joint_logprob(
153167
rv_values: Dict[TensorVariable, TensorVariable],
154168
warn_missing_rvs: bool = True,
@@ -316,34 +330,13 @@ def factorized_joint_logprob(
316330
cleanup_ir(logprob_expressions)
317331

318332
if warn_missing_rvs:
319-
_warn_rvs_in_inferred_graph(logprob_expressions)
333+
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprob_expressions)
334+
if rvs_in_logp_expressions:
335+
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)
320336

321337
return logprob_vars
322338

323339

324-
def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
325-
# Raise if there are unexpected RandomVariables in the logp graph
326-
# Only SimulatorRVs MinibatchIndexRVs are allowed
327-
from pymc.data import MinibatchIndexRV
328-
from pymc.distributions.simulator import SimulatorRV
329-
330-
unexpected_rv_nodes = [
331-
node
332-
for node in pytensor.graph.ancestors(logp_terms)
333-
if (
334-
node.owner
335-
and isinstance(node.owner.op, RandomVariable)
336-
and not isinstance(node.owner.op, (SimulatorRV, MinibatchIndexRV))
337-
)
338-
]
339-
if unexpected_rv_nodes:
340-
raise ValueError(
341-
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
342-
"This can happen when DensityDist logp or Interval transform functions "
343-
"reference nonlocal variables."
344-
)
345-
346-
347340
def joint_logp(
348341
rvs: Sequence[TensorVariable],
349342
*,
@@ -381,5 +374,10 @@ def joint_logp(
381374
value_var = rvs_to_values[rv]
382375
logp_terms[value_var] = temp_logp_terms[value_var]
383376

384-
_check_no_rvs(list(logp_terms.values()))
385-
return list(logp_terms.values())
377+
logp_terms_list = list(logp_terms.values())
378+
379+
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list)
380+
if rvs_in_logp_expressions:
381+
raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions)
382+
383+
return logp_terms_list

pymc/logprob/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,34 @@
3636

3737
import warnings
3838

39-
from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple
39+
from typing import (
40+
Callable,
41+
Dict,
42+
Generator,
43+
Iterable,
44+
List,
45+
Optional,
46+
Sequence,
47+
Set,
48+
Tuple,
49+
Union,
50+
)
4051

4152
import numpy as np
4253

54+
from pytensor import Variable
4355
from pytensor import tensor as pt
4456
from pytensor.graph import Apply, Op
4557
from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk
4658
from pytensor.graph.fg import FunctionGraph
59+
from pytensor.graph.op import HasInnerGraph
4760
from pytensor.link.c.type import CType
4861
from pytensor.raise_op import CheckAndRaise
62+
from pytensor.tensor.random.op import RandomVariable
4963
from pytensor.tensor.var import TensorVariable
5064

5165
from pymc.logprob.abstract import MeasurableVariable, _logprob
66+
from pymc.util import makeiter
5267

5368

5469
def walk_model(
@@ -273,3 +288,23 @@ def diracdelta_logprob(op, values, *inputs, **kwargs):
273288
(const_value,) = inputs
274289
values, const_value = pt.broadcast_arrays(values, const_value)
275290
return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf)
291+
292+
293+
def find_rvs_in_graph(vars: Union[Variable, Sequence[Variable]]) -> Set[Variable]:
294+
"""Assert that there are no `MeasurableVariable` nodes in a graph."""
295+
296+
def expand(r):
297+
owner = r.owner
298+
if owner:
299+
inputs = list(reversed(owner.inputs))
300+
301+
if isinstance(owner.op, HasInnerGraph):
302+
inputs += owner.op.inner_outputs
303+
304+
return inputs
305+
306+
return {
307+
node
308+
for node in walk(makeiter(vars), expand, False)
309+
if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable))
310+
}

pymc/pytensorf.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from pymc.exceptions import NotConstantValueError
6666
from pymc.logprob.transforms import RVTransform
6767
from pymc.logprob.utils import CheckParameterValue
68+
from pymc.util import makeiter
6869
from pymc.vartypes import continuous_types, isgenerator, typefilter
6970

7071
PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
@@ -550,13 +551,6 @@ def hessian_diag(f, vars=None):
550551
return empty_gradient
551552

552553

553-
def makeiter(a):
554-
if isinstance(a, (tuple, list)):
555-
return a
556-
else:
557-
return [a]
558-
559-
560554
class IdentityOp(scalar.UnaryScalarOp):
561555
@staticmethod
562556
def st_impl(x):

pymc/testing.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@
2424
from numpy import random as nr
2525
from numpy import testing as npt
2626
from pytensor.compile.mode import Mode
27-
from pytensor.graph.basic import Variable, walk
28-
from pytensor.graph.op import HasInnerGraph
27+
from pytensor.graph.basic import Variable
2928
from pytensor.graph.rewriting.basic import in2out
3029
from pytensor.tensor import TensorVariable
31-
from pytensor.tensor.random.op import RandomVariable
3230
from scipy import special as sp
3331
from scipy import stats as st
3432

@@ -37,16 +35,14 @@
3735
from pymc.distributions.distribution import Distribution
3836
from pymc.distributions.shape_utils import change_dist_size
3937
from pymc.initial_point import make_initial_point_fn
40-
from pymc.logprob.abstract import MeasurableVariable
4138
from pymc.logprob.basic import icdf, joint_logp, logcdf, logp
42-
from pymc.logprob.utils import ParameterValueError
39+
from pymc.logprob.utils import ParameterValueError, find_rvs_in_graph
4340
from pymc.pytensorf import (
4441
compile_pymc,
4542
floatX,
4643
inputvars,
4744
intX,
4845
local_check_parameter_to_ninf_switch,
49-
makeiter,
5046
)
5147

5248
# This mode can be used for tests where model compilations takes the bulk of the runtime
@@ -964,19 +960,9 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
964960
)
965961

966962

967-
def assert_no_rvs(vars: Union[Variable, Sequence[Variable]]):
963+
def assert_no_rvs(vars: Sequence[Variable]) -> None:
968964
"""Assert that there are no `MeasurableVariable` nodes in a graph."""
969965

970-
def expand(r):
971-
owner = r.owner
972-
if owner:
973-
inputs = list(reversed(owner.inputs))
974-
975-
if isinstance(owner.op, HasInnerGraph):
976-
inputs += owner.op.inner_outputs
977-
978-
return inputs
979-
980-
for v in walk(makeiter(vars), expand, False):
981-
if v.owner and isinstance(v.owner.op, (RandomVariable, MeasurableVariable)):
982-
raise AssertionError(f"RV found in graph: {v}")
966+
rvs = find_rvs_in_graph(vars)
967+
if rvs:
968+
raise AssertionError(f"RV found in graph: {rvs}")

pymc/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,3 +510,10 @@ def _add_future_warning_tag(var) -> None:
510510
for k, v in old_tag.__dict__.items():
511511
new_tag.__dict__.setdefault(k, v)
512512
var.tag = new_tag
513+
514+
515+
def makeiter(a):
516+
if isinstance(a, (tuple, list)):
517+
return a
518+
else:
519+
return [a]

pymc/variational/approximations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from pymc.blocking import DictToArrayBijection
2727
from pymc.distributions.dist_math import rho2sigma
28-
from pymc.pytensorf import makeiter
28+
from pymc.util import makeiter
2929
from pymc.variational import opvi
3030
from pymc.variational.opvi import (
3131
Approximation,

pymc/variational/opvi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@
7272
compile_pymc,
7373
find_rng_nodes,
7474
identity,
75-
makeiter,
7675
reseed_rngs,
7776
)
7877
from pymc.util import (
7978
RandomState,
8079
WithMemoization,
8180
_get_seeds_per_chain,
8281
locally_cachedmethod,
82+
makeiter,
8383
)
8484
from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling
8585
from pymc.variational.updates import adagrad_window

tests/logprob/test_basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def test_warn_random_found_factorized_joint_logprob():
223223

224224
y_vv = y_rv.clone()
225225

226-
with pytest.warns(UserWarning, match="RandomVariables were found in the derived graph"):
226+
with pytest.warns(UserWarning, match="Random variables detected in the logp graph: {x}"):
227227
factorized_joint_logprob({y_rv: y_vv})
228228

229229
with warnings.catch_warnings():
@@ -443,7 +443,9 @@ def test_warn_random_found_probability_inference(func, scipy_func, test_value):
443443
# In which case the inference should either return that or fail explicitly
444444
# For now, the lopgrob submodule treats the input as a stochastic value.
445445
rv = pt.exp(pm.Normal.dist(input_rv))
446-
with pytest.warns(UserWarning, match="RandomVariables were found in the derived graph"):
446+
with pytest.warns(
447+
UserWarning, match="RandomVariables {input} were found in the derived graph"
448+
):
447449
assert func(rv, 0.0)
448450

449451
res = func(rv, 0.0, warn_missing_rvs=False)

0 commit comments

Comments
 (0)