Skip to content

Commit 13f9894

Browse files
ricardoV94twiecki
authored andcommitted
New logprob inference logic
This commit changes the logic used for logprob inference. Instead of eager bottom-up conversion to measurable variables in the IR rewrites, we only convert nodes whose outputs were marked as "needs_measuring". This is achieved with the new `PreserveRVMappings.request_measurable` method. This strategy obviates the need to undo unnecessary conversions. It also obviates a subtle need for graph cloning via the `ignore_logprob` helper, which prevented intermediate measurable rewrites from being reversed when they were needed to derive the logprob of valued variables, but were not directly valued. This indirect role of `ignore_logprob` is now done more explicitly and efficiently via the `request_measurable` method. All other uses of `ignore_logprob` (and `reconsider_logprob`) were removed from the codebase The `get_measurable_outputs` dispatching was also abandoned in favor of only considering outputs associated with value variables. A new MergeOptimizerRewrite was written to further target local rewrites to only those nodes whose variables have been marked as `needs_measuring`.
1 parent 0fa051d commit 13f9894

29 files changed

+294
-957
lines changed

pymc/data.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
import pymc as pm
3838

39-
from pymc.logprob.abstract import _get_measurable_outputs
4039
from pymc.pytensorf import convert_observed_data
4140

4241
__all__ = [
@@ -135,11 +134,6 @@ def make_node(self, rng, *args, **kwargs):
135134
return super().make_node(rng, *args, **kwargs)
136135

137136

138-
@_get_measurable_outputs.register(MinibatchIndexRV)
139-
def minibatch_index_rv_measuarable_outputs(op, node):
140-
return []
141-
142-
143137
minibatch_index = MinibatchIndexRV()
144138

145139

pymc/distributions/bound.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from pymc.distributions.shape_utils import to_tuple
2727
from pymc.distributions.transforms import _default_transform
2828
from pymc.logprob.basic import logp
29-
from pymc.logprob.utils import ignore_logprob
3029
from pymc.model import modelcontext
3130
from pymc.pytensorf import floatX, intX
3231
from pymc.util import check_dist_not_registered
@@ -202,7 +201,6 @@ def __new__(
202201
raise ValueError("Given dims do not exist in model coordinates.")
203202

204203
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
205-
dist = ignore_logprob(dist)
206204

207205
if isinstance(dist.owner.op, Continuous):
208206
res = _ContinuousBounded(
@@ -236,7 +234,6 @@ def dist(
236234
):
237235
cls._argument_checks(dist, **kwargs)
238236
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)
239-
dist = ignore_logprob(dist)
240237
if isinstance(dist.owner.op, Continuous):
241238
res = _ContinuousBounded.dist(
242239
[dist, lower, upper],

pymc/distributions/distribution.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from pytensor.graph.utils import MetaType
3333
from pytensor.tensor.basic import as_tensor_variable
3434
from pytensor.tensor.random.op import RandomVariable
35-
from pytensor.tensor.random.type import RandomType
3635
from pytensor.tensor.random.utils import normalize_size_param
3736
from pytensor.tensor.var import TensorVariable
3837
from typing_extensions import TypeAlias
@@ -49,13 +48,7 @@
4948
shape_from_dims,
5049
)
5150
from pymc.exceptions import BlockModelAccessError
52-
from pymc.logprob.abstract import (
53-
MeasurableVariable,
54-
_get_measurable_outputs,
55-
_icdf,
56-
_logcdf,
57-
_logprob,
58-
)
51+
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
5952
from pymc.logprob.rewriting import logprob_rewrites_db
6053
from pymc.model import BlockModelAccess
6154
from pymc.printing import str_for_dist
@@ -401,20 +394,6 @@ def dist(
401394
MeasurableVariable.register(SymbolicRandomVariable)
402395

403396

404-
@_get_measurable_outputs.register(SymbolicRandomVariable)
405-
def _get_measurable_outputs_symbolic_random_variable(op, node):
406-
# This tells PyMC that any non RandomType outputs are measurable
407-
408-
# Assume that if there is one default_output, that's the only one that is measurable
409-
# In the rare case this is not what one wants, a specialized _get_measuarable_outputs
410-
# can dispatch for a subclassed Op
411-
if op.default_output is not None:
412-
return [node.default_output()]
413-
414-
# Otherwise assume that any outputs that are not of RandomType are measurable
415-
return [out for out in node.outputs if not isinstance(out.type, RandomType)]
416-
417-
418397
@node_rewriter([SymbolicRandomVariable])
419398
def inline_symbolic_random_variable(fgraph, node):
420399
"""

pymc/distributions/mixture.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
3636
from pymc.distributions.transforms import _default_transform
3737
from pymc.distributions.truncated import Truncated
38-
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob, _logprob_helper
38+
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob
39+
from pymc.logprob.basic import logp
3940
from pymc.logprob.transforms import IntervalTransform
40-
from pymc.logprob.utils import ignore_logprob
4141
from pymc.pytensorf import floatX
4242
from pymc.util import check_dist_not_registered
4343
from pymc.vartypes import continuous_types, discrete_types
@@ -267,10 +267,6 @@ def rv_op(cls, weights, *components, size=None):
267267

268268
assert weights_ndim_batch == 0
269269

270-
# Component RVs terms are accounted by the Mixture logprob, so they can be
271-
# safely ignored in the logprob graph
272-
components = [ignore_logprob(component) for component in components]
273-
274270
# Create a OpFromGraph that encapsulates the random generating process
275271
# Create dummy input variables with the same type as the ones provided
276272
weights_ = weights.type()
@@ -350,10 +346,10 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
350346
if len(components) == 1:
351347
# Need to broadcast value across mixture axis
352348
mix_axis = -components[0].owner.op.ndim_supp - 1
353-
components_logp = _logprob_helper(components[0], pt.expand_dims(value, mix_axis))
349+
components_logp = logp(components[0], pt.expand_dims(value, mix_axis))
354350
else:
355351
components_logp = pt.stack(
356-
[_logprob_helper(component, value) for component in components],
352+
[logp(component, value) for component in components],
357353
axis=-1,
358354
)
359355

pymc/distributions/multivariate.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
)
6868
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
6969
from pymc.logprob.abstract import _logprob
70-
from pymc.logprob.utils import ignore_logprob
7170
from pymc.math import kron_diag, kron_dot
7271
from pymc.pytensorf import floatX, intX
7372
from pymc.util import check_dist_not_registered
@@ -1191,9 +1190,6 @@ def dist(cls, n, eta, sd_dist, **kwargs):
11911190
raise TypeError("sd_dist must be a scalar or vector distribution variable")
11921191

11931192
check_dist_not_registered(sd_dist)
1194-
# sd_dist is part of the generative graph, but should be completely ignored
1195-
# by the logp graph, since the LKJ logp explicitly includes these terms.
1196-
sd_dist = ignore_logprob(sd_dist)
11971193
return super().dist([n, eta, sd_dist], **kwargs)
11981194

11991195
@classmethod
@@ -2527,7 +2523,7 @@ def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int:
25272523
@classmethod
25282524
def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
25292525
shape = to_tuple(size) + tuple(support_shape)
2530-
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape))
2526+
normal_dist = pm.Normal.dist(sigma=sigma, shape=shape)
25312527

25322528
if n_zerosum_axes > normal_dist.ndim:
25332529
raise ValueError("Shape of distribution is too small for the number of zerosum axes")

pymc/distributions/timeseries.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from pymc.exceptions import NotConstantValueError
4444
from pymc.logprob.abstract import _logprob
4545
from pymc.logprob.basic import logp
46-
from pymc.logprob.utils import ignore_logprob, reconsider_logprob
4746
from pymc.pytensorf import constant_fold, floatX, intX
4847
from pymc.util import check_dist_not_registered
4948

@@ -111,11 +110,6 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> pt.TensorVari
111110
if init_dist in ancestors([innovation_dist]) or innovation_dist in ancestors([init_dist]):
112111
raise ValueError("init_dist and innovation_dist must be completely independent")
113112

114-
# PyMC should not be concerned that these variables don't have values, as they will be
115-
# accounted for in the logp of RandomWalk
116-
init_dist = ignore_logprob(init_dist)
117-
innovation_dist = ignore_logprob(innovation_dist)
118-
119113
steps = cls.get_steps(
120114
innovation_dist=innovation_dist,
121115
steps=steps,
@@ -235,14 +229,12 @@ def random_walk_moment(op, rv, init_dist, innovation_dist, steps):
235229

236230

237231
@_logprob.register(RandomWalkRV)
238-
def random_walk_logp(op, values, init_dist, innovation_dist, steps, **kwargs):
232+
def random_walk_logp(op, values, *inputs, **kwargs):
239233
# Although we can derive the logprob of random walks, it does not collapse
240234
# what we consider the core dimension of steps. We do it manually here.
241235
(value,) = values
242236
# Recreate RV and obtain inner graph
243-
rv_node = op.make_node(
244-
reconsider_logprob(init_dist), reconsider_logprob(innovation_dist), steps
245-
)
237+
rv_node = op.make_node(*inputs)
246238
rv = clone_replace(
247239
op.inner_outputs, replace={u: v for u, v in zip(op.inner_inputs, rv_node.inputs)}
248240
)[op.default_output]
@@ -571,9 +563,6 @@ def dist(
571563
)
572564
init_dist = Normal.dist(0, 100, shape=(*sigma.shape, ar_order))
573565

574-
# We can ignore init_dist, as it will be accounted for in the logp term
575-
init_dist = ignore_logprob(init_dist)
576-
577566
return super().dist([rhos, sigma, init_dist, steps, ar_order, constant], **kwargs)
578567

579568
@classmethod
@@ -789,8 +778,6 @@ def dist(cls, omega, alpha_1, beta_1, initial_vol, *, steps=None, **kwargs):
789778
initial_vol = pt.as_tensor_variable(initial_vol)
790779

791780
init_dist = Normal.dist(0, initial_vol)
792-
# We can ignore init_dist, as it will be accounted for in the logp term
793-
init_dist = ignore_logprob(init_dist)
794781

795782
return super().dist([omega, alpha_1, beta_1, initial_vol, init_dist, steps], **kwargs)
796783

@@ -973,8 +960,6 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
973960
UserWarning,
974961
)
975962
init_dist = Normal.dist(0, 100, shape=sde_pars[0].shape)
976-
# We can ignore init_dist, as it will be accounted for in the logp term
977-
init_dist = ignore_logprob(init_dist)
978963

979964
return super().dist([init_dist, steps, sde_pars, dt, sde_fn], **kwargs)
980965

pymc/logprob/abstract.py

Lines changed: 1 addition & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@
3636

3737
import abc
3838

39-
from copy import copy
4039
from functools import singledispatch
41-
from typing import Callable, List, Sequence, Tuple
40+
from typing import Sequence, Tuple
4241

43-
from pytensor.graph.basic import Apply, Variable
4442
from pytensor.graph.op import Op
4543
from pytensor.graph.utils import MetaType
4644
from pytensor.tensor import TensorVariable
@@ -135,107 +133,6 @@ class MeasurableVariable(abc.ABC):
135133
MeasurableVariable.register(RandomVariable)
136134

137135

138-
class UnmeasurableMeta(MetaType):
139-
def __new__(cls, name, bases, dict):
140-
if "id_obj" not in dict:
141-
dict["id_obj"] = None
142-
143-
return super().__new__(cls, name, bases, dict)
144-
145-
def __eq__(self, other):
146-
if isinstance(other, UnmeasurableMeta):
147-
return hash(self.id_obj) == hash(other.id_obj)
148-
return False
149-
150-
def __hash__(self):
151-
return hash(self.id_obj)
152-
153-
154-
class UnmeasurableVariable(metaclass=UnmeasurableMeta):
155-
"""
156-
id_obj is an attribute, i.e. tuple of length two, of the unmeasurable class object.
157-
e.g. id_obj = (NormalRV, noop_measurable_outputs_fn)
158-
"""
159-
160-
161-
def get_measurable_outputs(op: Op, node: Apply) -> List[Variable]:
162-
"""Return only the outputs that are measurable."""
163-
if isinstance(op, MeasurableVariable):
164-
return _get_measurable_outputs(op, node)
165-
else:
166-
return []
167-
168-
169-
@singledispatch
170-
def _get_measurable_outputs(op, node):
171-
return node.outputs
172-
173-
174-
@_get_measurable_outputs.register(RandomVariable)
175-
def _get_measurable_outputs_RandomVariable(op, node):
176-
return node.outputs[1:]
177-
178-
179-
def noop_measurable_outputs_fn(*args, **kwargs):
180-
return []
181-
182-
183-
def assign_custom_measurable_outputs(
184-
node: Apply,
185-
measurable_outputs_fn: Callable = noop_measurable_outputs_fn,
186-
type_prefix: str = "Unmeasurable",
187-
) -> Apply:
188-
"""Assign a custom ``_get_measurable_outputs`` dispatch function to a measurable variable instance.
189-
190-
The node is cloned and a custom `Op` that's a copy of the original node's
191-
`Op` is created. That custom `Op` replaces the old `Op` in the cloned
192-
node, and then a custom dispatch implementation is created for the clone
193-
`Op` in `_get_measurable_outputs`.
194-
195-
If `measurable_outputs_fn` isn't specified, a no-op is used; the result is
196-
a clone of `node` that will effectively be ignored by
197-
`factorized_joint_logprob`.
198-
199-
Parameters
200-
----------
201-
node
202-
The node to recreate with a new cloned `Op`.
203-
measurable_outputs_fn
204-
The function that will be assigned to the new cloned `Op` in the
205-
`_get_measurable_outputs` dispatcher.
206-
The default is a no-op function (i.e. no measurable outputs)
207-
type_prefix
208-
The prefix used for the new type's name.
209-
The default is ``"Unmeasurable"``, which matches the default
210-
``"measurable_outputs_fn"``.
211-
"""
212-
213-
new_node = node.clone()
214-
op_type = type(new_node.op)
215-
216-
if op_type in _get_measurable_outputs.registry.keys() and isinstance(op_type, UnmeasurableMeta):
217-
if _get_measurable_outputs.registry[op_type] != measurable_outputs_fn:
218-
raise ValueError(
219-
f"The type {op_type.__name__} with hash value {hash(op_type)} "
220-
"has already been dispatched a measurable outputs function."
221-
)
222-
return node
223-
224-
new_op_dict = op_type.__dict__.copy()
225-
new_op_dict["id_obj"] = (new_node.op, measurable_outputs_fn)
226-
new_op_dict.setdefault("original_op_type", op_type)
227-
228-
new_op_type = type(
229-
f"{type_prefix}{op_type.__name__}", (op_type, UnmeasurableVariable), new_op_dict
230-
)
231-
new_node.op = copy(new_node.op)
232-
new_node.op.__class__ = new_op_type
233-
234-
_get_measurable_outputs.register(new_op_type)(measurable_outputs_fn)
235-
236-
return new_node
237-
238-
239136
class MeasurableElemwise(Elemwise):
240137
"""Base class for Measurable Elemwise variables"""
241138

0 commit comments

Comments
 (0)