diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 9d578db74..befe97998 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -33,11 +33,10 @@ moment, ) from pymc.distributions.shape_utils import _change_dist_size, change_dist_size -from pymc.distributions.transforms import _default_transform +from pymc.distributions.transforms import IntervalTransform, _default_transform from pymc.distributions.truncated import Truncated from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob from pymc.logprob.basic import logp -from pymc.logprob.transforms import IntervalTransform from pymc.pytensorf import floatX from pymc.util import check_dist_not_registered from pymc.vartypes import continuous_types, discrete_types @@ -417,15 +416,15 @@ def marginal_mixture_moment(op, rv, rng, weights, *components): # special handling or because we have custom logic to enable them. If new default # transforms are implemented, this list and function should be updated allowed_default_mixture_transforms = ( - transforms.CholeskyCovPacked, + transforms.CholeskyCovPackedTransform, transforms.CircularTransform, transforms.IntervalTransform, transforms.LogTransform, - transforms.LogExpM1, + transforms.LogExpM1Transform, transforms.LogOddsTransform, - transforms.Ordered, + transforms.OrderedTransform, transforms.SimplexTransform, - transforms.SumTo1, + transforms.SumTo1Transform, ) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 42a2a9f62..432f9e4c2 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1241,7 +1241,7 @@ def _LKJCholeksyCovRV_moment(op, rv, rng, n, eta, sd_dist): @_default_transform.register(_LKJCholeskyCovRV) def _LKJCholeksyCovRV_default_transform(op, rv): _, n, _, _ = rv.owner.inputs - return transforms.CholeskyCovPacked(n) + return transforms.CholeskyCovPackedTransform(n) @_logprob.register(_LKJCholeskyCovRV) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 2b828d5c6..e6221683b 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -11,32 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import warnings from functools import singledispatch +from typing import Callable, Optional, Tuple, Union import numpy as np -import pytensor.tensor as pt # ignore mypy error because it somehow considers that # "numpy.core.numeric has no attribute normalize_axis_tuple" from numpy.core.numeric import normalize_axis_tuple # type: ignore -from pytensor.graph import Op +from pytensor import scan +from pytensor import tensor as pt +from pytensor.gradient import jacobian +from pytensor.graph import Op, Variable from pytensor.tensor import TensorVariable -import pymc as pm - -from pymc.logprob.transforms import ( - CircularTransform, - IntervalTransform, - LogOddsTransform, - LogTransform, - RVTransform, - SimplexTransform, -) - __all__ = [ - "RVTransform", + "Transform", "simplex", "logodds", "Interval", @@ -50,6 +43,8 @@ "ZeroSumTransform", ] +from pymc.pytensorf import floatX + def __getattr__(name): if name in ("univariate_ordered", "multivariate_ordered"): @@ -60,6 +55,10 @@ def __getattr__(name): warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning) return sum_to_1 + if name == "RVTransform": + warnings.warn("RVTransform has been renamed to Transform", FutureWarning) + return Transform + raise AttributeError(f"module {__name__} has no attribute {name}") @@ -69,152 +68,128 @@ def _default_transform(op: Op, rv: TensorVariable): return None -class LogExpM1(RVTransform): - name = "log_exp_m1" +class Transform(abc.ABC): + ndim_supp: Optional[int] = None - def backward(self, value, *inputs): - return pt.softplus(value) + @abc.abstractmethod + def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: + """Apply the transformation.""" - def forward(self, value, *inputs): - """Inverse operation of softplus. + @abc.abstractmethod + def backward( + self, value: TensorVariable, *inputs: Variable + ) -> Union[TensorVariable, Tuple[TensorVariable, ...]]: + """Invert the transformation. Multiple values may be returned when the + transformation is not 1-to-1""" - y = Log(Exp(x) - 1) - = Log(1 - Exp(-x)) + x - """ - return pt.log(1.0 - pt.exp(-value)) + value + def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: + """Construct the log of the absolute value of the Jacobian determinant.""" + if self.ndim_supp not in (0, 1): + raise NotImplementedError( + f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}" + ) + if self.ndim_supp == 0: + jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape) + return pt.log(pt.abs(jac)) + else: + phi_inv = self.backward(value, *inputs) + return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) - def log_jac_det(self, value, *inputs): - return -pt.softplus(-value) + def __str__(self): + return f"{self.__class__.__name__}" -class Ordered(RVTransform): - name = "ordered" +class LocTransform(Transform): + name = "loc" - def __init__(self, ndim_supp=None): - if ndim_supp is not None: - warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) - - def backward(self, value, *inputs): - x = pt.zeros(value.shape) - x = pt.set_subtensor(x[..., 0], value[..., 0]) - x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:])) - return pt.cumsum(x, axis=-1) + def __init__(self, transform_args_fn): + self.transform_args_fn = transform_args_fn def forward(self, value, *inputs): - y = pt.zeros(value.shape) - y = pt.set_subtensor(y[..., 0], value[..., 0]) - y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1])) - return y + loc = self.transform_args_fn(*inputs) + return value + loc + + def backward(self, value, *inputs): + loc = self.transform_args_fn(*inputs) + return value - loc def log_jac_det(self, value, *inputs): - return pt.sum(value[..., 1:], axis=-1) + return pt.zeros_like(value) -class SumTo1(RVTransform): - """ - Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1] - This Transformation operates on the last dimension of the input tensor. - """ +class ScaleTransform(Transform): + name = "scale" - name = "sumto1" + def __init__(self, transform_args_fn): + self.transform_args_fn = transform_args_fn - def __init__(self, ndim_supp=None): - if ndim_supp is not None: - warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) + def forward(self, value, *inputs): + scale = self.transform_args_fn(*inputs) + return value * scale def backward(self, value, *inputs): - remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True) - return pt.concatenate([value[..., :], remaining], axis=-1) - - def forward(self, value, *inputs): - return value[..., :-1] + scale = self.transform_args_fn(*inputs) + return value / scale def log_jac_det(self, value, *inputs): - y = pt.zeros(value.shape) - return pt.sum(y, axis=-1) + scale = self.transform_args_fn(*inputs) + return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape))) -class CholeskyCovPacked(RVTransform): - """ - Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the - log scale - """ - - name = "cholesky-cov-packed" +class IntervalTransform(Transform): + name = "interval" - def __init__(self, n): + def __init__(self, args_fn: Callable[..., Tuple[Optional[Variable], Optional[Variable]]]): """ Parameters ---------- - n: int - Number of diagonal entries in the LKJCholeskyCov distribution + args_fn + Function that expects inputs of RandomVariable and returns the lower + and upper bounds for the interval transformation. If one of these is + None, the RV is considered to be unbounded on the respective edge. """ - self.diag_idxs = pt.arange(1, n + 1).cumsum() - 1 - - def backward(self, value, *inputs): - return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs])) + self.args_fn = args_fn def forward(self, value, *inputs): - return pt.set_subtensor(value[..., self.diag_idxs], pt.log(value[..., self.diag_idxs])) - - def log_jac_det(self, value, *inputs): - return pt.sum(value[..., self.diag_idxs], axis=-1) - + a, b = self.args_fn(*inputs) -class Chain(RVTransform): - __slots__ = ("param_extract_fn", "transform_list", "name") - - def __init__(self, transform_list): - self.transform_list = transform_list - self.name = "+".join([transf.name for transf in self.transform_list]) - - def forward(self, value, *inputs): - y = value - for transf in self.transform_list: - # TODO:Needs proper discussion as to what should be - # passed as inputs here - y = transf.forward(y, *inputs) - return y + if a is not None and b is not None: + return pt.log(value - a) - pt.log(b - value) + elif a is not None: + return pt.log(value - a) + elif b is not None: + return pt.log(b - value) + else: + raise ValueError("Both edges of IntervalTransform cannot be None") def backward(self, value, *inputs): - x = value - for transf in reversed(self.transform_list): - x = transf.backward(x, *inputs) - return x + a, b = self.args_fn(*inputs) + + if a is not None and b is not None: + sigmoid_x = pt.sigmoid(value) + return sigmoid_x * b + (1 - sigmoid_x) * a + elif a is not None: + return pt.exp(value) + a + elif b is not None: + return b - pt.exp(value) + else: + raise ValueError("Both edges of IntervalTransform cannot be None") def log_jac_det(self, value, *inputs): - y = pt.as_tensor_variable(value) - det_list = [] - ndim0 = y.ndim - for transf in reversed(self.transform_list): - det_ = transf.log_jac_det(y, *inputs) - det_list.append(det_) - y = transf.backward(y, *inputs) - ndim0 = min(ndim0, det_.ndim) - # match the shape of the smallest log_jac_det - det = 0.0 - for det_ in det_list: - if det_.ndim > ndim0: - det += det_.sum(axis=-1) - else: - det += det_ - return det - - -simplex = SimplexTransform() -simplex.__doc__ = """ -Instantiation of :class:`pymc.logprob.transforms.SimplexTransform` -for use in the ``transform`` argument of a random variable.""" + a, b = self.args_fn(*inputs) -logodds = LogOddsTransform() -logodds.__doc__ = """ -Instantiation of :class:`pymc.logprob.transforms.LogOddsTransform` -for use in the ``transform`` argument of a random variable.""" + if a is not None and b is not None: + s = pt.softplus(-value) + return pt.log(b - a) - 2 * s - value + elif a is None and b is None: + raise ValueError("Both edges of IntervalTransform cannot be None") + else: + return value class Interval(IntervalTransform): - """Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use in the + """Wrapper around :class:`IntervalTransform` for use in the ``transform`` argument of a random variable. Parameters @@ -297,7 +272,413 @@ def bounds_fn(*rv_inputs): super().__init__(args_fn=bounds_fn) -class ZeroSumTransform(RVTransform): +class LogOddsTransform(Transform): + name = "logodds" + + def backward(self, value, *inputs): + return pt.expit(value) + + def forward(self, value, *inputs): + return pt.log(value / (1 - value)) + + def log_jac_det(self, value, *inputs): + sigmoid_value = pt.sigmoid(value) + return pt.log(sigmoid_value) + pt.log1p(-sigmoid_value) + + +logodds = LogOddsTransform() +logodds.__doc__ = """ +Instantiation of :class:`pymc.logprob.transforms.LogOddsTransform` +for use in the ``transform`` argument of a random variable.""" + + +class SimplexTransform(Transform): + name = "simplex" + + def forward(self, value, *inputs): + value = pt.as_tensor(value) + log_value = pt.log(value) + N = value.shape[-1].astype(value.dtype) + shift = pt.sum(log_value, -1, keepdims=True) / N + return log_value[..., :-1] - shift + + def backward(self, value, *inputs): + value = pt.concatenate([value, -pt.sum(value, -1, keepdims=True)], axis=-1) + exp_value_max = pt.exp(value - pt.max(value, -1, keepdims=True)) + return exp_value_max / pt.sum(exp_value_max, -1, keepdims=True) + + def log_jac_det(self, value, *inputs): + value = pt.as_tensor(value) + N = value.shape[-1] + 1 + N = N.astype(value.dtype) + sum_value = pt.sum(value, -1, keepdims=True) + value_sum_expanded = value + sum_value + value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1) + logsumexp_value_expanded = pt.logsumexp(value_sum_expanded, -1, keepdims=True) + res = pt.log(N) + (N * sum_value) - (N * logsumexp_value_expanded) + return pt.sum(res, -1) + + +simplex = SimplexTransform() +simplex.__doc__ = """ +Instantiation of :class:`pymc.logprob.transforms.SimplexTransform` +for use in the ``transform`` argument of a random variable.""" + + +class LogTransform(Transform): + name = "log" + + def forward(self, value, *inputs): + return pt.log(value) + + def backward(self, value, *inputs): + return pt.exp(value) + + def log_jac_det(self, value, *inputs): + return value + + +log = LogTransform() +log.__doc__ = """ +Instantiation of :class:`pymc.logprob.transforms.LogTransform` +for use in the ``transform`` argument of a random variable.""" + + +class ExpTransform(Transform): + name = "exp" + + def forward(self, value, *inputs): + return pt.exp(value) + + def backward(self, value, *inputs): + return pt.log(value) + + def log_jac_det(self, value, *inputs): + return -pt.log(value) + + +class AbsTransform(Transform): + name = "abs" + + def forward(self, value, *inputs): + return pt.abs(value) + + def backward(self, value, *inputs): + value = pt.switch(value >= 0, value, np.nan) + return -value, value + + def log_jac_det(self, value, *inputs): + return pt.switch(value >= 0, 0, np.nan) + + +class PowerTransform(Transform): + name = "power" + + def __init__(self, power=None): + if not isinstance(power, (int, float)): + raise TypeError(f"Power must be integer or float, got {type(power)}") + if power == 0: + raise ValueError("Power cannot be 0") + self.power = power + super().__init__() + + def forward(self, value, *inputs): + return pt.power(value, self.power) + + def backward(self, value, *inputs): + inv_power = 1 / self.power + + # Powers that don't admit negative values + if (np.abs(self.power) < 1) or (self.power % 2 == 0): + backward_value = pt.switch(value >= 0, pt.power(value, inv_power), np.nan) + # Powers that admit negative values require special logic, because (-1)**(1/3) returns `nan` in PyTensor + else: + backward_value = pt.power(pt.abs(value), inv_power) * pt.switch(value >= 0, 1, -1) + + # In this case the transform is not 1-to-1 + if self.power % 2 == 0: + return -backward_value, backward_value + else: + return backward_value + + def log_jac_det(self, value, *inputs): + inv_power = 1 / self.power + + # Note: This fails for value==0 + res = np.log(np.abs(inv_power)) + (inv_power - 1) * pt.log(pt.abs(value)) + + # Powers that don't admit negative values + if (np.abs(self.power) < 1) or (self.power % 2 == 0): + res = pt.switch(value >= 0, res, np.nan) + + return res + + +class ArcsinhTransform(Transform): + name = "arcsinh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.arcsinh(value) + + def backward(self, value, *inputs): + return pt.sinh(value) + + +class ArccoshTransform(Transform): + name = "arccosh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.arccosh(value) + + def backward(self, value, *inputs): + return pt.cosh(value) + + +class ArctanhTransform(Transform): + name = "arctanh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.arctanh(value) + + def backward(self, value, *inputs): + return pt.tanh(value) + + +class SinhTransform(Transform): + name = "sinh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.sinh(value) + + def backward(self, value, *inputs): + return pt.arcsinh(value) + + +class CoshTransform(Transform): + name = "cosh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.cosh(value) + + def backward(self, value, *inputs): + back_value = pt.arccosh(value) + return (-back_value, back_value) + + def log_jac_det(self, value, *inputs): + return pt.switch( + value < 1, + np.nan, + -pt.log(pt.sqrt(value**2 - 1)), + ) + + +class TanhTransform(Transform): + name = "tanh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.tanh(value) + + def backward(self, value, *inputs): + return pt.arctanh(value) + + +class ErfTransform(Transform): + name = "erf" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.erf(value) + + def backward(self, value, *inputs): + return pt.erfinv(value) + + +class ErfcTransform(Transform): + name = "erfc" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.erfc(value) + + def backward(self, value, *inputs): + return pt.erfcinv(value) + + +class ErfcxTransform(Transform): + name = "erfcx" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.erfcx(value) + + def backward(self, value, *inputs): + # computes the inverse of erfcx, this was adapted from + # https://tinyurl.com/4mxfd3cz + x = pt.switch(value <= 1, 1.0 / (value * pt.sqrt(np.pi)), -pt.sqrt(pt.log(value))) + + def calc_delta_x(value, prior_result): + return prior_result - (pt.erfcx(prior_result) - value) / ( + 2 * prior_result * pt.erfcx(prior_result) - 2 / pt.sqrt(np.pi) + ) + + result, updates = scan( + fn=calc_delta_x, + outputs_info=pt.ones_like(x), + non_sequences=value, + n_steps=10, + ) + return result[-1] + + +class ChainTransform(Transform): + name = "chain" + + def __init__(self, transform_list): + self.transform_list = transform_list + + def forward(self, value, *inputs): + for transform in self.transform_list: + value = transform.forward(value, *inputs) + return value + + def backward(self, value, *inputs): + for transform in reversed(self.transform_list): + value = transform.backward(value, *inputs) + return value + + def log_jac_det(self, value, *inputs): + value = pt.as_tensor_variable(value) + det_list = [] + ndim0 = value.ndim + for transform in reversed(self.transform_list): + det_ = transform.log_jac_det(value, *inputs) + det_list.append(det_) + ndim0 = min(ndim0, det_.ndim) + value = transform.backward(value, *inputs) + # match the shape of the smallest jacobian_det + det = 0.0 + for det_ in det_list: + if det_.ndim > ndim0: + ndim_diff = det_.ndim - ndim0 + det += det_.sum(axis=tuple(range(-ndim_diff, 0))) + else: + det += det_ + return det + + +# For backwards compat +Chain = ChainTransform + + +class CircularTransform(Transform): + name = "circular" + + def backward(self, value, *inputs): + return pt.arctan2(pt.sin(value), pt.cos(value)) + + def forward(self, value, *inputs): + return pt.as_tensor_variable(value) + + def log_jac_det(self, value, *inputs): + return pt.zeros(value.shape) + + +circular = CircularTransform() +circular.__doc__ = """ +Instantiation of :class:`pymc.logprob.transforms.CircularTransform` +for use in the ``transform`` argument of a random variable.""" + + +class OrderedTransform(Transform): + name = "ordered" + + def __init__(self, ndim_supp=None): + if ndim_supp is not None: + warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) + + def backward(self, value, *inputs): + x = pt.zeros(value.shape) + x = pt.set_subtensor(x[..., 0], value[..., 0]) + x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:])) + return pt.cumsum(x, axis=-1) + + def forward(self, value, *inputs): + y = pt.zeros(value.shape) + y = pt.set_subtensor(y[..., 0], value[..., 0]) + y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1])) + return y + + def log_jac_det(self, value, *inputs): + return pt.sum(value[..., 1:], axis=-1) + + +ordered = OrderedTransform() +ordered.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.Ordered` +for use in the ``transform`` argument of a random variable.""" + + +class LogExpM1Transform(Transform): + name = "log_exp_m1" + + def backward(self, value, *inputs): + return pt.softplus(value) + + def forward(self, value, *inputs): + """Inverse operation of softplus. + + y = Log(Exp(x) - 1) + = Log(1 - Exp(-x)) + x + """ + return pt.log(1.0 - pt.exp(-value)) + value + + def log_jac_det(self, value, *inputs): + return -pt.softplus(-value) + + +log_exp_m1 = LogExpM1Transform() +log_exp_m1.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.LogExpM1` +for use in the ``transform`` argument of a random variable.""" + + +class SumTo1Transform(Transform): + """ + Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1] + This Transformation operates on the last dimension of the input tensor. + """ + + name = "sumto1" + + def __init__(self, ndim_supp=None): + if ndim_supp is not None: + warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) + + def backward(self, value, *inputs): + remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True) + return pt.concatenate([value[..., :], remaining], axis=-1) + + def forward(self, value, *inputs): + return value[..., :-1] + + def log_jac_det(self, value, *inputs): + y = pt.zeros(value.shape) + return pt.sum(y, axis=-1) + + +sum_to_1 = SumTo1Transform() +sum_to_1.__doc__ = """ +Instantiation of :class:`pymc.distributions.transforms.SumTo1` +for use in the ``transform`` argument of a random variable.""" + + +class ZeroSumTransform(Transform): """ Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``. @@ -314,65 +695,69 @@ class ZeroSumTransform(RVTransform): def __init__(self, zerosum_axes): self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes) + @staticmethod + def extend_axis_rev(array, axis): + normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] + + n = floatX(array.shape[normalized_axis]) + last = pt.take(array, [-1], axis=normalized_axis) + + sum_vals = -last * pt.sqrt(n) + norm = sum_vals / (pt.sqrt(n) + n) + slice_before = (slice(None, None),) * normalized_axis + + return array[slice_before + (slice(None, -1),)] + norm + + @staticmethod + def extend_axis(array, axis): + n = floatX(array.shape[axis] + 1) + sum_vals = array.sum(axis, keepdims=True) + norm = sum_vals / (pt.sqrt(n) + n) + fill_val = norm - sum_vals / pt.sqrt(n) + + out = pt.concatenate([array, fill_val], axis=axis) + return out - norm + def forward(self, value, *rv_inputs): for axis in self.zerosum_axes: - value = extend_axis_rev(value, axis=axis) + value = self.extend_axis_rev(value, axis=axis) return value def backward(self, value, *rv_inputs): for axis in self.zerosum_axes: - value = extend_axis(value, axis=axis) + value = self.extend_axis(value, axis=axis) return value def log_jac_det(self, value, *rv_inputs): return pt.constant(0.0) -def extend_axis(array, axis): - n = pm.floatX(array.shape[axis] + 1) - sum_vals = array.sum(axis, keepdims=True) - norm = sum_vals / (pt.sqrt(n) + n) - fill_val = norm - sum_vals / pt.sqrt(n) - - out = pt.concatenate([array, fill_val], axis=axis) - return out - norm - - -def extend_axis_rev(array, axis): - normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] - - n = pm.floatX(array.shape[normalized_axis]) - last = pt.take(array, [-1], axis=normalized_axis) +class CholeskyCovPackedTransform(Transform): + """ + Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the + log scale + """ - sum_vals = -last * pt.sqrt(n) - norm = sum_vals / (pt.sqrt(n) + n) - slice_before = (slice(None, None),) * normalized_axis + name = "cholesky-cov-packed" - return array[slice_before + (slice(None, -1),)] + norm + def __init__(self, n): + """ + Parameters + ---------- + n: int + Number of diagonal entries in the LKJCholeskyCov distribution + """ + self.diag_idxs = pt.arange(1, n + 1).cumsum() - 1 -log_exp_m1 = LogExpM1() -log_exp_m1.__doc__ = """ -Instantiation of :class:`pymc.distributions.transforms.LogExpM1` -for use in the ``transform`` argument of a random variable.""" + def backward(self, value, *inputs): + return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs])) -# Deprecated -ordered = Ordered() -ordered.__doc__ = """ -Instantiation of :class:`pymc.distributions.transforms.Ordered` -for use in the ``transform`` argument of a random variable.""" + def forward(self, value, *inputs): + return pt.set_subtensor(value[..., self.diag_idxs], pt.log(value[..., self.diag_idxs])) -log = LogTransform() -log.__doc__ = """ -Instantiation of :class:`pymc.logprob.transforms.LogTransform` -for use in the ``transform`` argument of a random variable.""" + def log_jac_det(self, value, *inputs): + return pt.sum(value[..., self.diag_idxs], axis=-1) -sum_to_1 = SumTo1() -sum_to_1.__doc__ = """ -Instantiation of :class:`pymc.distributions.transforms.SumTo1` -for use in the ``transform`` argument of a random variable.""" -circular = CircularTransform() -circular.__doc__ = """ -Instantiation of :class:`pymc.logprob.transforms.CircularTransform` -for use in the ``transform`` argument of a random variable.""" +CholeskyCovPacked = CholeskyCovPackedTransform diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index d711d25a4..268b67850 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -48,6 +48,8 @@ "Kron", ] +from pymc.pytensorf import constant_fold + TensorLike = Union[np.ndarray, TensorVariable] IntSequence = Union[np.ndarray, Sequence[int]] @@ -183,9 +185,6 @@ def n_dims(self) -> int: def _slice(self, X, Xs=None): xdims = X.shape[-1] if isinstance(xdims, Variable): - # Circular dependency - from pymc.pytensorf import constant_fold - [xdims] = constant_fold([xdims]) if self.input_dim != xdims: warnings.warn( diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 0e683a5d3..876f26b8c 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -18,13 +18,14 @@ import pytensor.tensor as pt from pytensor.compile import SharedVariable +from pytensor.graph import ancestors from pytensor.tensor.variable import TensorConstant from scipy.cluster.vq import kmeans # Avoid circular dependency when importing modelcontext from pymc.distributions.distribution import Distribution from pymc.model import modelcontext -from pymc.pytensorf import compile_pymc, walk_model +from pymc.pytensorf import compile_pymc _ = Distribution # keep both pylint and black happy @@ -48,7 +49,7 @@ def replace_with_values(vars_needed, replacements=None, model=None): model = modelcontext(model) inputs, input_names = [], [] - for rv in walk_model(vars_needed): + for rv in ancestors(vars_needed): if rv in model.named_vars.values() and not isinstance(rv, SharedVariable): inputs.append(rv) input_names.append(rv.name) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index b4248e7ed..09ce1cfbf 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -24,7 +24,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.tensor.variable import TensorVariable -from pymc.logprob.transforms import RVTransform +from pymc.distributions.transforms import Transform from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name @@ -177,7 +177,7 @@ def inner(seed, *args, **kwargs): def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], - rvs_to_transforms: Dict[TensorVariable, RVTransform], + rvs_to_transforms: Dict[TensorVariable, Transform], initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], jitter_rvs: Set[TensorVariable] = None, default_strategy: str = "moment", diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 01bf317ef..464d4eb76 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -39,7 +39,7 @@ from functools import singledispatch from typing import Sequence, Tuple -from pytensor.graph.op import Op +from pytensor.graph import Apply, Op, Variable from pytensor.graph.utils import MetaType from pytensor.tensor import TensorVariable from pytensor.tensor.elemwise import Elemwise @@ -153,3 +153,50 @@ def __init__(self, scalar_op, *args, **kwargs): MeasurableVariable.register(MeasurableElemwise) + + +class ValuedRV(Op): + r"""Represents the association of a measurable variable and its value. + + A `ValuedVariable` node represents the pair :math:`(Y, y)`, where + :math:`Y` is a random variable and :math:`y \sim Y`. + + Log-probability (densities) are functions over these pairs, which makes + these nodes in a graph an intermediate form that serves to construct a + log-probability from a model graph. + """ + + def make_node(self, rv, value): + assert isinstance(rv, Variable) + if value is not None: + assert isinstance(value, Variable) + assert rv.type.in_same_class(value.type) + return Apply(self, [rv, value], [rv.type(name=rv.name)]) + + def perform(self, node, inputs, out): + raise NotImplementedError("ValuedVar should not be present in the final graph!") + out[0][0] = inputs[0] + + def infer_shape(self, fgraph, node, input_shapes): + return [input_shapes[0]] + + +valued_rv = ValuedRV() + + +class PromisedValuedRV(Op): + r"""Marks a variable as being promised a valued variable in the logprob method.""" + + def make_node(self, rv): + assert isinstance(rv, Variable) + return Apply(self, [rv], [rv.type(name=rv.name)]) + + def perform(self, node, inputs, out): + raise NotImplementedError("PromisedValuedRV should not be present in the final graph!") + out[0][0] = inputs[0] + + def infer_shape(self, fgraph, node, input_shapes): + return [input_shapes[0]] + + +promised_valued_rv = PromisedValuedRV() diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 86222fd58..6993bbef7 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -36,7 +36,6 @@ import warnings -from collections import deque from typing import Dict, List, Optional, Sequence, Union import numpy as np @@ -55,6 +54,7 @@ from pytensor.tensor.variable import TensorVariable from typing_extensions import TypeAlias +from pymc.distributions.transforms import Transform from pymc.logprob.abstract import ( MeasurableVariable, _icdf_helper, @@ -63,8 +63,9 @@ _logprob_helper, ) from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph -from pymc.logprob.transforms import RVTransform, TransformValuesRewrite -from pymc.logprob.utils import find_rvs_in_graph, rvs_to_value_vars +from pymc.logprob.transforms import TransformValuesRewrite +from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph +from pymc.pytensorf import replace_vars_in_graphs TensorLike: TypeAlias = Union[Variable, float, np.ndarray] @@ -75,7 +76,7 @@ def _find_unallowed_rvs_in_graph(graph): return { rv - for rv in find_rvs_in_graph(graph) + for rv in rvs_in_graph(graph) if not isinstance(rv.owner.op, (SimulatorRV, MinibatchIndexRV)) } @@ -209,7 +210,8 @@ def normal_logp(value, mu, sigma): return _logprob_helper(rv, value, **kwargs) except NotImplementedError: fgraph, _, _ = construct_ir_fgraph({rv: value}) - [(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items() + [ir_valued_rv] = fgraph.outputs + [ir_rv, ir_value] = ir_valued_rv.owner.inputs expr = _logprob_helper(ir_rv, ir_value, **kwargs) cleanup_ir([expr]) if warn_rvs: @@ -307,8 +309,9 @@ def normal_logcdf(value, mu, sigma): except NotImplementedError: # Try to rewrite rv fgraph, rv_values, _ = construct_ir_fgraph({rv: value}) - [ir_rv] = fgraph.outputs - expr = _logcdf_helper(ir_rv, value, **kwargs) + [ir_valued_rv] = fgraph.outputs + [ir_rv, ir_value] = ir_valued_rv.owner.inputs + expr = _logcdf_helper(ir_rv, ir_value, **kwargs) cleanup_ir([expr]) if warn_rvs: _warn_rvs_in_inferred_graph(expr) @@ -389,8 +392,9 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens except NotImplementedError: # Try to rewrite rv fgraph, rv_values, _ = construct_ir_fgraph({rv: value}) - [ir_rv] = fgraph.outputs - expr = _icdf_helper(ir_rv, value, **kwargs) + [ir_valued_rv] = fgraph.outputs + [ir_rv, ir_value] = ir_valued_rv.owner.inputs + expr = _icdf_helper(ir_rv, ir_value, **kwargs) cleanup_ir([expr]) if warn_rvs: _warn_rvs_in_inferred_graph(expr) @@ -474,32 +478,14 @@ def conditional_logp( """ warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs) - fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter) + fgraph, rv_values, memo = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter) if extra_rewrites is not None: extra_rewrites.rewrite(fgraph) - rv_remapper = fgraph.preserve_rv_mappings - - # This is the updated random-to-value-vars map with the lifted/rewritten - # variables. The rewrites are supposed to produce new - # `MeasurableVariable`s that are amenable to `_logprob`. - updated_rv_values = rv_remapper.rv_values - - # Some rewrites also transform the original value variables. This is the - # updated map from the new value variables to the original ones, which - # we want to use as the keys in the final dictionary output - original_values = rv_remapper.original_values - - # When a `_logprob` has been produced for a `MeasurableVariable` node, all - # other references to it need to be replaced with its value-variable all - # throughout the `_logprob`-produced graphs. The following `dict` - # cumulatively maintains remappings for all the variables/nodes that needed - # to be recreated after replacing `MeasurableVariable`s with their - # value-variables. Since these replacements work in topological order, all - # the necessary value-variable replacements should be present for each - # node. - replacements = updated_rv_values.copy() + # Walk the graph from its inputs to its outputs and construct the + # log-probability + replacements = {} # To avoid cloning the value variables (or ancestors of value variables), # we map them to themselves in the `replacements` `dict` @@ -512,83 +498,91 @@ def conditional_logp( } ) - # Walk the graph from its inputs to its outputs and construct the - # log-probability - q = deque(fgraph.toposort()) - logprob_vars = {} - - while q: - node = q.popleft() + values_to_logprobs = {} + original_values = tuple(rv_values.values()) + # TODO: This seems too convoluted, can we just replace all RVs by their values, + # except for the fgraph outputs (for which we want to call _logprob on)? + for node in fgraph.toposort(): if not isinstance(node.op, MeasurableVariable): continue - q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values] + valued_nodes = get_related_valued_nodes(node, fgraph) - if not q_values: + if not valued_nodes: continue + node_values = [valued_var.inputs[1] for valued_var in valued_nodes] + node_rvs = [valued_var.inputs[0] for valued_var in valued_nodes] + node_output_idxs = [ + fgraph.outputs.index(valued_var.outputs[0]) for valued_var in valued_nodes + ] + # Replace `RandomVariable`s in the inputs with value variables. - # Also, store the results in the `replacements` map for the nodes - # that follow. - remapped_vars, _ = rvs_to_value_vars( - q_values + list(node.inputs), - initial_replacements=replacements, + # Also, store the results in the `replacements` map for the nodes that follow. + for node_rv, node_value in zip(node_rvs, node_values): + replacements[node_rv] = node_value + + remapped_vars = replace_vars_in_graphs( + graphs=node_values + list(node.inputs), + replacements=replacements, ) - q_values = remapped_vars[: len(q_values)] - q_rv_inputs = remapped_vars[len(q_values) :] + node_values = remapped_vars[: len(node_values)] + node_inputs = remapped_vars[len(node_values) :] - q_logprob_vars = _logprob( + node_logprob_expressions = _logprob( node.op, - q_values, - *q_rv_inputs, + node_values, + *node_inputs, **kwargs, ) - if not isinstance(q_logprob_vars, (list, tuple)): - q_logprob_vars = [q_logprob_vars] - - for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars): - q_value_var = original_values[q_value_var] + if not isinstance(node_logprob_expressions, (list, tuple)): + node_logprob_expressions = [node_logprob_expressions] - if q_value_var.name: - q_logprob_var.name = f"{q_value_var.name}_logprob" + for node_output_idx, node_value, node_logprob_expression in zip( + node_output_idxs, node_values, node_logprob_expressions + ): + original_value = original_values[node_output_idx] + if original_value.name: + node_logprob_expression.name = f"{original_value.name}_logprob" - if q_value_var in logprob_vars: + if original_value in values_to_logprobs: raise ValueError( - f"More than one logprob term was assigned to the value var {q_value_var}" + f"More than one logprob term was assigned to the value var {original_value}" ) - logprob_vars[q_value_var] = q_logprob_var + values_to_logprobs[original_value] = node_logprob_expression - # Recompute test values for the changes introduced by the - # replacements above. + # Recompute test values for the changes introduced by the replacements above. if config.compute_test_value != "off": - for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars): + for node in io_toposort( + graph_inputs(node_logprob_expressions), node_logprob_expressions + ): compute_test_value(node) - missing_value_terms = set(original_values.values()) - set(logprob_vars.keys()) + missing_value_terms = set(original_values) - set(values_to_logprobs.keys()) if missing_value_terms: raise RuntimeError( f"The logprob terms of the following value variables could not be derived: {missing_value_terms}" ) - logprob_expressions = list(logprob_vars.values()) - cleanup_ir(logprob_expressions) + logprobs = list(values_to_logprobs.values()) + cleanup_ir(logprobs) if warn_rvs: - rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprob_expressions) + rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs) if rvs_in_logp_expressions: warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning) - return logprob_vars + return values_to_logprobs def transformed_conditional_logp( rvs: Sequence[TensorVariable], *, rvs_to_values: Dict[TensorVariable, TensorVariable], - rvs_to_transforms: Dict[TensorVariable, RVTransform], + rvs_to_transforms: Dict[TensorVariable, Transform], jacobian: bool = True, **kwargs, ) -> List[TensorVariable]: diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index a344d8067..997e348bf 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -28,8 +28,11 @@ _logprob, _logprob_helper, ) -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import check_potential_measurability +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import ( + check_potential_measurability, + filter_measurable_variables, +) class MeasurableComparison(MeasurableElemwise): @@ -42,11 +45,7 @@ class MeasurableComparison(MeasurableElemwise): def find_measurable_comparisons( fgraph: FunctionGraph, node: Node ) -> Optional[List[MeasurableComparison]]: - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - - measurable_inputs = rv_map_feature.request_measurable(node.inputs) + measurable_inputs = filter_measurable_variables(node.inputs) if len(measurable_inputs) != 1: return None @@ -64,7 +63,7 @@ def find_measurable_comparisons( const = node.inputs[(measurable_var_idx + 1) % 2] # check for potential measurability of const - if check_potential_measurability([const], rv_map_feature.rv_values.keys()): + if check_potential_measurability([const]): return None node_scalar_op = node.op.scalar_op @@ -134,16 +133,12 @@ class MeasurableBitwise(MeasurableElemwise): @node_rewriter(tracks=[invert]) def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableBitwise]]: - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - base_var = node.inputs[0] if not base_var.dtype.startswith("bool"): raise None - if not rv_map_feature.request_measurable([base_var]): + if not filter_measurable_variables([base_var]): return None node_scalar_op = node.op.scalar_op diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 16f19c94a..bca01d08b 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -48,8 +48,8 @@ from pytensor.tensor.variable import TensorConstant from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import CheckParameterValue +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables class MeasurableClip(MeasurableElemwise): @@ -65,11 +65,7 @@ class MeasurableClip(MeasurableElemwise): def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableClip]]: # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None base_var, lower_bound, upper_bound = node.inputs @@ -158,11 +154,7 @@ class MeasurableRound(MeasurableElemwise): @node_rewriter(tracks=[ceil, floor, round_half_to_even]) def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableRound]]: - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None [base_var] = node.inputs diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 1049fd7bb..14de77b75 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -43,7 +43,8 @@ from pytensor.tensor.shape import SpecifyShape from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import filter_measurable_variables, replace_rvs_by_values class MeasurableSpecifyShape(SpecifyShape): @@ -53,6 +54,9 @@ class MeasurableSpecifyShape(SpecifyShape): MeasurableVariable.register(MeasurableSpecifyShape) +measurable_specify_shape = MeasurableSpecifyShape() + + @_logprob.register(MeasurableSpecifyShape) def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): (value,) = values @@ -62,30 +66,18 @@ def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): @node_rewriter([SpecifyShape]) -def find_measurable_specify_shapes(fgraph, node) -> Optional[List[MeasurableSpecifyShape]]: +def find_measurable_specify_shapes(fgraph, node): r"""Finds `SpecifyShapeOp`\s for which a `logprob` can be computed.""" if isinstance(node.op, MeasurableSpecifyShape): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - rv = node.outputs[0] - base_rv, *shape = node.inputs - if not ( - base_rv.owner - and isinstance(base_rv.owner.op, MeasurableVariable) - and base_rv not in rv_map_feature.rv_values - ): - return None # pragma: no cover + if not filter_measurable_variables([base_rv]): + return None - new_op = MeasurableSpecifyShape() - new_rv = new_op.make_node(base_rv, *shape).default_output() + new_rv = measurable_specify_shape.make_node(base_rv, *shape).default_output() return [new_rv] @@ -107,8 +99,6 @@ class MeasurableCheckAndRaise(CheckAndRaise): @_logprob.register(MeasurableCheckAndRaise) def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): - from pymc.pytensorf import replace_rvs_by_values - (value,) = values # transfer assertion from rv to value assertions = replace_rvs_by_values(assertions, rvs_to_values={inner_rv: value}) @@ -123,13 +113,9 @@ def find_measurable_check_and_raise(fgraph, node) -> Optional[List[MeasurableChe if isinstance(node.op, MeasurableCheckAndRaise): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - base_rv, *conds = node.inputs - if not rv_map_feature.request_measurable([base_rv]): + + if not filter_measurable_variables([base_rv]): return None op = node.op diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index c68c9f4ae..073df3271 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -42,7 +42,8 @@ from pytensor.tensor.extra_ops import CumOp from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import filter_measurable_variables class MeasurableCumsum(CumOp): @@ -86,18 +87,13 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]: if isinstance(node.op, MeasurableCumsum): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - base_rv = node.inputs[0] # Check that cumsum does not mix dimensions if base_rv.ndim > 1 and node.op.axis is None: return None - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None new_op = MeasurableCumsum(axis=node.op.axis or 0, mode="add") diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index b5c494793..b931bc355 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -70,15 +70,22 @@ MeasurableVariable, _logprob, _logprob_helper, + promised_valued_rv, ) from pymc.logprob.rewriting import ( - PreserveRVMappings, - assume_measured_ir_outputs, + assume_valued_outputs, + early_measurable_ir_rewrites_db, local_lift_DiracDelta, measurable_ir_rewrites_db, + remove_valued_rvs, subtensor_ops, ) -from pymc.logprob.utils import check_potential_measurability +from pymc.logprob.utils import ( + check_potential_measurability, + filter_measurable_variables, + replace_rvs_by_values, +) +from pymc.pytensorf import constant_fold def is_newaxis(x): @@ -255,9 +262,6 @@ def get_stack_mixture_vars( mixture_rvs = joined_rvs.owner.inputs elif isinstance(joined_rvs.owner.op, Join): - # TODO: Find better solution to avoid this circular dependency - from pymc.pytensorf import constant_fold - join_axis = joined_rvs.owner.inputs[0] # TODO: Support symbolic join axes. This will raise ValueError if it's not a constant (join_axis,) = constant_fold((join_axis,), raise_not_constant=False) @@ -265,7 +269,8 @@ def get_stack_mixture_vars( mixture_rvs = joined_rvs.owner.inputs[1:] - return mixture_rvs, join_axis + # TODO: Get rid of these 'PromisedValuedRVs' (the reason we have to get the input) + return [rv.owner.inputs[0] for rv in mixture_rvs], join_axis @node_rewriter(subtensor_ops) @@ -278,11 +283,6 @@ def find_measurable_index_mixture(fgraph, node): From these terms, new terms ``Z_rv[i] = mixture_comps[i][i == I_rv]`` are created for each ``i`` in ``enumerate(mixture_comps)``. """ - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - mixing_indices = node.inputs[1:] # TODO: Add check / test case for Advanced Boolean indexing @@ -303,7 +303,7 @@ def find_measurable_index_mixture(fgraph, node): if mixture_rvs is None or not isinstance(join_axis, (NoneTypeT, Constant)): return None - if rv_map_feature.request_measurable(mixture_rvs) != mixture_rvs: + if set(filter_measurable_variables(mixture_rvs)) != set(mixture_rvs): return None # Replace this sub-graph with a `MixtureRV` @@ -351,9 +351,6 @@ def logprob_MixtureRV( comp_rvs = [comp[None] for comp in comp_rvs] original_shape = (len(comp_rvs),) else: - # TODO: Find better solution to avoid this circular dependency - from pymc.pytensorf import constant_fold - join_axis_val = constant_fold((join_axis,))[0].item() original_shape = shape_tuple(comp_rvs[0]) @@ -413,10 +410,8 @@ class MeasurableSwitchMixture(MeasurableElemwise): @node_rewriter([switch]) def find_measurable_switch_mixture(fgraph, node): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover + if isinstance(node.op, MeasurableSwitchMixture): + return None switch_cond, *components = node.inputs @@ -427,12 +422,11 @@ def find_measurable_switch_mixture(fgraph, node): if any(comp.type.broadcastable != out_bcast for comp in components): return None - # Check that `switch_cond` is not potentially measurable - valued_rvs = rv_map_feature.rv_values.keys() - if check_potential_measurability([switch_cond], valued_rvs): + if set(filter_measurable_variables(components)) != set(components): return None - if rv_map_feature.request_measurable(components) != components: + # Check that `switch_cond` is not potentially measurable + if check_potential_measurability([switch_cond]): return None return [measurable_switch_mixture(switch_cond, *components)] @@ -504,28 +498,33 @@ def useless_ifelse_outputs(fgraph, node): @node_rewriter([IfElse]) def find_measurable_ifelse_mixture(fgraph, node): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover + from pymc.pytensorf import toposort_replace op = node.op + + if isinstance(op, MeasurableIfElse): + return None + if_var, *base_rvs = node.inputs - valued_rvs = rv_map_feature.rv_values.keys() - if not all(check_potential_measurability([base_var], valued_rvs) for base_var in base_rvs): + if not all(check_potential_measurability([base_var]) for base_var in base_rvs): return None - base_rvs = assume_measured_ir_outputs(valued_rvs, base_rvs) + base_rvs = assume_valued_outputs(base_rvs) if len(base_rvs) != op.n_outs * 2: return None if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs): return None - return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs + replacements = [(base_rv, promised_valued_rv(base_rv)) for base_rv in base_rvs] + temp_fgraph = FunctionGraph(outputs=base_rvs, clone=False) + toposort_replace(temp_fgraph, replacements) + new_base_rvs = temp_fgraph.outputs + return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *new_base_rvs).outputs -measurable_ir_rewrites_db.register( + +early_measurable_ir_rewrites_db.register( "useless_ifelse_outputs", useless_ifelse_outputs, "basic", @@ -533,7 +532,7 @@ def find_measurable_ifelse_mixture(fgraph, node): ) -measurable_ir_rewrites_db.register( +early_measurable_ir_rewrites_db.register( "find_measurable_ifelse_mixture", find_measurable_ifelse_mixture, "basic", @@ -544,7 +543,10 @@ def find_measurable_ifelse_mixture(fgraph, node): @_logprob.register(MeasurableIfElse) def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs): """Compute the log-likelihood graph for an `IfElse`.""" - from pymc.pytensorf import replace_rvs_by_values + + temp_fgraph = FunctionGraph(outputs=base_rvs, clone=False) + remove_valued_rvs.apply(temp_fgraph) + base_rvs = temp_fgraph.outputs assert len(values) * 2 == len(base_rvs) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 35b84542d..9d4e795cb 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -56,6 +56,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import filter_measurable_variables from pymc.math import logdiffexp from pymc.pytensorf import constant_fold @@ -76,11 +77,7 @@ class MeasurableMaxDiscrete(Max): @node_rewriter([Max]) def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: - rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - - if isinstance(node.op, MeasurableMax): + if isinstance(node.op, (MeasurableMax, MeasurableMaxDiscrete)): return None # pragma: no cover base_var = node.inputs[0] @@ -88,7 +85,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if base_var.owner is None: return None - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None # Non-univariate distributions and non-RVs must be rejected @@ -170,11 +167,6 @@ class MeasurableMaxNeg(Max): @node_rewriter(tracks=[Max]) def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: - rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - if isinstance(node.op, MeasurableMaxNeg): return None # pragma: no cover @@ -183,7 +175,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[ if base_var.owner is None: return None - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index db606469f..edd78437a 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -33,35 +33,18 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import warnings -from collections import deque -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple -import pytensor.tensor as pt - -from pytensor import config from pytensor.compile.mode import optdb -from pytensor.graph.basic import ( - Constant, - Variable, - ancestors, - io_toposort, - truncated_graph_inputs, -) +from pytensor.graph.basic import Variable, ancestors, truncated_graph_inputs from pytensor.graph.features import Feature from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.basic import ( - ChangeTracker, - EquilibriumGraphRewriter, - GraphRewriter, - node_rewriter, - out2in, -) +from pytensor.graph.rewriting.basic import GraphRewriter, node_rewriter, out2in from pytensor.graph.rewriting.db import ( + EquilibriumDB, LocalGroupDB, - RewriteDatabase, RewriteDatabaseQuery, SequenceDB, TopoDB, @@ -83,204 +66,33 @@ ) from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableVariable -from pymc.logprob.utils import DiracDelta, indices_from_subtensor +from pymc.logprob.abstract import ( + MeasurableVariable, + PromisedValuedRV, + ValuedRV, + valued_rv, +) +from pymc.logprob.utils import DiracDelta +from pymc.pytensorf import toposort_replace inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1) subtensor_ops = (AdvancedSubtensor, AdvancedSubtensor1, Subtensor) -class MeasurableEquilibriumGraphRewriter(EquilibriumGraphRewriter): - """EquilibriumGraphRewriter focused on IR measurable rewrites. - - This is a stripped down version of the EquilibriumGraphRewriter, - which specifically targets nodes in `PreserveRVMAppings.needs_measuring` - that are not yet measurable. - - """ - - def apply(self, fgraph): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - if not rv_map_feature: - return None - - change_tracker = ChangeTracker() - fgraph.attach_feature(change_tracker) - - changed = True - max_use_abort = False - rewriter_name = None - global_process_count = {} - - for rewriter in self.global_rewriters + list(self.get_node_rewriters()): - global_process_count.setdefault(rewriter, 0) - - while changed and not max_use_abort: - changed = False - max_nb_nodes = len(fgraph.apply_nodes) - max_use = max_nb_nodes * self.max_use_ratio - - # Apply global rewriters - for grewrite in self.global_rewriters: - change_tracker.reset() - grewrite.apply(fgraph) - if change_tracker.changed: - global_process_count[grewrite] += 1 - changed = True - if global_process_count[grewrite] > max_use: - max_use_abort = True - rewriter_name = getattr(grewrite, "name", None) or getattr( - grewrite, "__name__", "" - ) - - # Apply local node rewriters - q = deque(io_toposort(fgraph.inputs, fgraph.outputs)) - while q: - node = q.pop() - if node not in fgraph.apply_nodes: - continue - # This is where we filter only those nodes we care about: - # Nodes that have variables that we want to measure and are not yet measurable - if isinstance(node.op, MeasurableVariable): - continue - if not any(out in rv_map_feature.needs_measuring for out in node.outputs): - continue - for node_rewriter in self.node_tracker.get_trackers(node.op): - node_rewriter_change = self.process_node(fgraph, node, node_rewriter) - if not node_rewriter_change: - continue - global_process_count[node_rewriter] += 1 - changed = True - if global_process_count[node_rewriter] > max_use: - max_use_abort = True - rewriter_name = getattr(node_rewriter, "name", None) or getattr( - node_rewriter, "__name__", "" - ) - # If we converted to a MeasurableVariable we're done here! - if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableVariable): - # go to next node - break - - if max_use_abort: - msg = ( - f"{type(self).__name__} max'ed out by {rewriter_name}." - "You can safely raise the current threshold of " - f"{config.optdb__max_use_ratio} with the option `optdb__max_use_ratio`." - ) - if config.on_opt_error == "raise": - raise AssertionError(msg) - else: - warnings.warn(msg) - fgraph.remove_feature(change_tracker) - - -class MeasurableEquilibriumDB(RewriteDatabase): - """A database of rewrites that should be applied until equilibrium is reached. - - This will return a MeasurableEquilibriumGraphRewriter when queried. - - """ - - def query(self, *tags, **kwtags): - rewriters = super().query(*tags, **kwtags) - return MeasurableEquilibriumGraphRewriter( - rewriters, - max_use_ratio=config.optdb__max_use_ratio, - ) - - class PreserveRVMappings(Feature): - r"""Keeps track of random variables and their respective value variables during - graph rewrites in `rv_values` + pass - When a random variable is replaced in a rewrite, this `Feature` automatically - updates the `rv_values` mapping, so that the new variable is linked to the - original value variable. - In addition this `Feature` provides functionality to manually update a random - and/or value variable. A mapping from the transformed value variables to the - the original value variables is kept in `original_values`. +MeasurableVariable.register(ValuedRV) - Likewise, a `measurable_conversions` map is maintained, which holds - information about un-valued and un-measurable variables that were replaced - with measurable variables. This information can be used to revert these - rewrites. - """ +@node_rewriter([ValuedRV, PromisedValuedRV]) +def remove_ValuedRV(fgraph, node): + rv = node.inputs[0] + return [rv] + - def __init__(self, rv_values: Dict[TensorVariable, TensorVariable]): - """ - Parameters - ---------- - rv_values - Mappings between random variables and their value variables. - The keys of this map are what this `Feature` keeps updated. - The ``dict`` is updated in-place. - """ - self.rv_values = rv_values - self.original_values = {v: v for v in rv_values.values()} - self.needs_measuring = set(rv_values.keys()) - - def on_attach(self, fgraph): - if hasattr(fgraph, "preserve_rv_mappings"): - raise ValueError(f"{fgraph} already has the `PreserveRVMappings` feature attached.") - - fgraph.preserve_rv_mappings = self - - def update_rv_maps( - self, - old_rv: TensorVariable, - new_value: TensorVariable, - new_rv: Optional[TensorVariable] = None, - ): - """Update mappings for a random variable. - - It also creates/updates a map from new value variables to their - original value variables. - - Parameters - ---------- - old_rv - The random variable whose mappings will be updated. - new_value - The new value variable that will replace the current one assigned - to `old_rv`. - new_rv - When non-``None``, `old_rv` will also be replaced with `new_rv` in - the mappings, as well. - """ - old_value = self.rv_values.pop(old_rv) - original_value = self.original_values.pop(old_value) - - if new_rv is None: - new_rv = old_rv - - self.rv_values[new_rv] = new_value - self.original_values[new_value] = original_value - - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): - """ - Whenever a node is replaced during rewrite, we check if it had a value - variable associated with it and map it to the new node. - """ - r_value_var = self.rv_values.pop(r, None) - if r_value_var is not None: - self.rv_values[new_r] = r_value_var - self.needs_measuring.add(new_r) - if new_r.name is None: - new_r.name = r.name - - def request_measurable(self, vars: Sequence[Variable]) -> List[Variable]: - measurable = [] - for var in vars: - # Input vars or valued vars can't be measured for derived expressions - if not var.owner or var in self.rv_values: - continue - if isinstance(var.owner.op, MeasurableVariable): - measurable.append(var) - else: - self.needs_measuring.add(var) - return measurable +remove_valued_rvs = out2in(remove_ValuedRV) @register_canonicalize @@ -314,52 +126,23 @@ def remove_DiracDelta(fgraph, node): return [dd_val] -@node_rewriter(inc_subtensor_ops) -def incsubtensor_rv_replace(fgraph, node): - r"""Replace `*IncSubtensor*` `Op`\s and their value variables for log-probability calculations. - - This is used to derive the log-probability graph for ``Y[idx] = data``, where - ``Y`` is a `RandomVariable`, ``idx`` indices, and ``data`` some arbitrary data. - - To compute the log-probability of a statement like ``Y[idx] = data``, we must - first realize that our objective is equivalent to computing ``logprob(Y, z)``, - where ``z = pt.set_subtensor(y[idx], data)`` and ``y`` is the value variable - for ``Y``. - - In other words, the log-probability for an `*IncSubtensor*` is the log-probability - of the underlying `RandomVariable` evaluated at ``data`` for the indices - given by ``idx`` and at the value variable for ``~idx``. - - This provides a means of specifying "missing data", for instance. - """ - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - rv_var = node.outputs[0] - if rv_var not in rv_map_feature.rv_values: - return None # pragma: no cover - - base_rv_var = node.inputs[0] - - if not rv_map_feature.request_measurable([base_rv_var]): - return None - - data = node.inputs[1] - idx = indices_from_subtensor(getattr(node.op, "idx_list", None), node.inputs[2:]) - - # Create a new value variable with the indices `idx` set to `data` - value_var = rv_map_feature.rv_values[rv_var] - new_value_var = pt.set_subtensor(value_var[idx], data) - rv_map_feature.update_rv_maps(rv_var, new_value_var, base_rv_var) - - # Return the `RandomVariable` being indexed - return [base_rv_var] - - logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" + +# Rewrites that must run before any canonicalization or the main Equilibrium DB +early_measurable_ir_rewrites_db = LocalGroupDB() +early_measurable_ir_rewrites_db.name = "early_measurable_rewrites_db" +logprob_rewrites_db.register( + "early_ir_rewrites", + TopoDB( + early_measurable_ir_rewrites_db, + order="out_to_in", + ignore_newtrees=False, + failure_callback=None, + ), + "basic", +) + # Introduce sigmoid. We do it before canonicalization so that useless mul are removed next logprob_rewrites_db.register( "local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic" @@ -371,7 +154,7 @@ def incsubtensor_rv_replace(fgraph, node): # These rewrites convert un-measurable variables into their measurable forms, # but they need to be reapplied, because some of the measurable forms require # their inputs to be measurable. -measurable_ir_rewrites_db = MeasurableEquilibriumDB() +measurable_ir_rewrites_db = EquilibriumDB() measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db" logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic") @@ -380,7 +163,6 @@ def incsubtensor_rv_replace(fgraph, node): # (or eventually) the graph outputs. Often this is done by lifting other `Op`s # "up" through the random/measurable variables and into their inputs. measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic") -measurable_ir_rewrites_db.register("incsubtensor_lift", incsubtensor_rv_replace, "basic") logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic") @@ -394,6 +176,7 @@ def incsubtensor_rv_replace(fgraph, node): ) cleanup_ir_rewrites_db.register("remove_DiracDelta", remove_DiracDelta, "cleanup") +cleanup_ir_rewrites_db.register("remove_ValuedVar", remove_ValuedRV, "cleanup") def construct_ir_fgraph( @@ -438,16 +221,12 @@ def construct_ir_fgraph( """ # Since we're going to clone the entire graph, we need to keep a map from - # the old nodes to the new ones; otherwise, we won't be able to use - # `rv_values`. - # We start the `dict` with mappings from the value variables to themselves, - # to prevent them from being cloned. This also includes ancestors - memo = {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)} + # the old nodes to the new ones; otherwise, we won't be able to use `rv_values`. + memo = {} # We add `ShapeFeature` because it will get rid of references to the old # `RandomVariable`s that have been lifted; otherwise, it will be difficult - # to give good warnings when an unaccounted for `RandomVariable` is - # encountered + # to give good warnings when an unaccounted for `RandomVariable` is encountered fgraph = FunctionGraph( outputs=list(rv_values.keys()), clone=True, @@ -457,19 +236,22 @@ def construct_ir_fgraph( features=[ShapeFeature()], ) - # Update `rv_values` so that it uses the new cloned variables - rv_values = {memo[k]: v for k, v in rv_values.items()} + # Replace valued RVs by ValuedVar Ops so that rewrites are aware of conditioning points + # We use clones of the value variables so that they are not affected by rewrites + cloned_values = {v: v.clone() for v in rv_values.values()} + ir_rv_values = {memo[k]: cloned_values[v] for k, v in rv_values.items()} - # This `Feature` preserves the relationships between the original - # random variables (i.e. keys in `rv_values`) and the new ones - # produced when `Op`s are lifted through them. - rv_remapper = PreserveRVMappings(rv_values) - fgraph.attach_feature(rv_remapper) + replacements = tuple((rv, valued_rv(rv, value)) for rv, value in ir_rv_values.items()) + toposort_replace(fgraph, replacements, reverse=True) if ir_rewriter is None: ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])) ir_rewriter.rewrite(fgraph) + # Reintroduce original value variables + replacements = tuple((cloned_v, v) for v, cloned_v in cloned_values.items()) + toposort_replace(fgraph, replacements=replacements, reverse=True) + return fgraph, rv_values, memo @@ -479,8 +261,8 @@ def cleanup_ir(vars: Sequence[Variable]) -> None: ir_rewriter.rewrite(fgraph) -def assume_measured_ir_outputs( - inputs: Sequence[TensorVariable], outputs: Sequence[TensorVariable] +def assume_valued_outputs( + outputs: Sequence[TensorVariable], ) -> Sequence[TensorVariable]: """Run IR rewrite assuming each output is measured. @@ -490,7 +272,12 @@ def assume_measured_ir_outputs( This helper runs an inner ir rewrite after giving each output a dummy value variable. We replace inputs by dummies and then undo it so that any dependency on outer variables is preserved. """ - # Replace inputs by dummy variables + # Replace inputs by dummy variables (so they are not affected) + inputs = [ + valued_var + for valued_var in ancestors(outputs) + if (valued_var.owner and isinstance(valued_var.owner.op, ValuedRV)) + ] replaced_inputs = { var: var.type() for var in truncated_graph_inputs(outputs, ancestors_to_include=inputs) @@ -500,8 +287,9 @@ def assume_measured_ir_outputs( dummy_rv_values = {base_var: base_var.type() for base_var in cloned_outputs} fgraph, *_ = construct_ir_fgraph(dummy_rv_values) + remove_valued_rvs.apply(fgraph) - # Replace dummy variables by inputs + # Replace dummy variables by original inputs fgraph.replace_all( tuple((repl, orig) for orig, repl in replaced_inputs.items()), import_missing=True, diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 283dbd1c3..52d16533a 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -35,14 +35,12 @@ # SOFTWARE. from copy import copy -from typing import Callable, Dict, Iterable, List, Optional, Tuple, cast +from typing import Callable, Dict, Iterable, List, Tuple, cast import numpy as np -import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Variable -from pytensor.graph.op import compute_test_value +from pytensor.graph import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.scan.op import Scan @@ -53,16 +51,17 @@ from pytensor.tensor.variable import TensorVariable from pytensor.updates import OrderedUpdates -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableVariable, _logprob, valued_rv from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import ( - PreserveRVMappings, construct_ir_fgraph, inc_subtensor_ops, logprob_rewrites_db, measurable_ir_rewrites_db, + remove_valued_rvs, ) -from pymc.pytensorf import replace_rvs_by_values +from pymc.logprob.utils import get_related_valued_nodes +from pymc.logprob.utils import replace_rvs_by_values class MeasurableScan(Scan): @@ -360,78 +359,60 @@ def find_measurable_scans(fgraph, node): random variable and value variable mappings that have been specified for parts of a `Scan`\s outputs (e.g. everything except the initial values). """ + from pymc.pytensorf import toposort_replace if not hasattr(fgraph, "shape_feature"): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - if isinstance(node.op, Subtensor): node = node.inputs[0].owner if not (node and isinstance(node.op, Scan)): return None - if isinstance(node.op, MeasurableScan): - return None + + if isinstance(node.op, MeasurableScan): + return None curr_scanargs = ScanArgs.from_node(node) - # Find the un-output `MeasurableVariable`s created in the inner-graph - if not any(out in rv_map_feature.rv_values for out in node.outputs): - # TODO: T - # We need to remap user inputs that have been specified in terms of - # `Subtensor`s of this `Scan`'s node's outputs. - # + # TODO: Check what outputs are actually needed for ValuedRVs more than one node deep + # Find outputs of scan that are directly valued (mapping) + direct_valued_output_idxs = [ + node.outputs.index(valued_node.inputs[0]) + for valued_node in get_related_valued_nodes(node, fgraph) + ] + + # Find outputs of scan, whose slices are valued (recurring outputs) + indirect_valued_output_idxs = [] + if not get_related_valued_nodes(node, fgraph): # For example, the output that the user got was something like # `out[1:]` for `outputs_info = [{"initial": x0, "taps": [-1]}]`, so # they likely passed `{out[1:]: x_1T_vv}` to `joint_logprob`. # Since `out[1:]` isn't really the output of a `Scan`, but a - # `Subtensor` of the output `out` of a `Scan`, we need to account for - # that. + # `Subtensor` of the output `out` of a `Scan`, we need to account for that. # Get any `Subtensor` outputs that have been applied to outputs of this - # `Scan` (and get the corresponding indices of the outputs from this - # `Scan`) - output_clients: List[Tuple[Variable, int]] = sum( - [ - [ - # This is expected to work for `Subtensor` `Op`s, - # because they only ever have one output - (cl.default_output(), i) - for cl, _ in fgraph.get_clients(out) - if isinstance(cl.op, Subtensor) - ] - for i, out in enumerate(node.outputs) - ], - [], - ) - - # The second items in these tuples are the value variables mapped to - # the *user-specified* measurable variables (i.e. the first items) that - # are `Subtensor`s of the outputs of this `Scan`. The second items are - # the index of the corresponding output of this `Scan` node. - indirect_rv_vars = [ - (out, rv_map_feature.rv_values[out], out_idx) - for out, out_idx in output_clients - if out in rv_map_feature.rv_values + # `Scan` (and get the corresponding indices of the outputs from this `Scan`) + indirect_valued_outputs = [ + (idx, client) + for idx, out in enumerate(node.outputs) + for client, _ in fgraph.clients[out] + if (isinstance(client.op, Subtensor) and get_related_valued_nodes(client, fgraph)) + ] + indirect_valued_output_idxs = [idx for idx, _ in indirect_valued_outputs] + indirect_valued_output_nodes = [ + get_related_valued_nodes(client, fgraph)[0] for _, client in indirect_valued_outputs ] - - if not indirect_rv_vars: - return None - - # We need this for the `clone` in the loop that follows - if pytensor.config.compute_test_value != "off": - compute_test_value(node) # We're going to replace the user's random variable/value variable mappings # with ones that map directly to outputs of this `Scan`. - for rv_var, val_var, out_idx in indirect_rv_vars: + old_indirect_scan_valued_nodes = [] + new_indirect_scan_values = [] + for out_idx, valued_node in zip(indirect_valued_output_idxs, indirect_valued_output_nodes): + [scan_rv, scan_value] = valued_node.inputs # The full/un-`Subtensor`ed `Scan` output that we need to use full_out = node.outputs[out_idx] - assert rv_var.owner.inputs[0] == full_out + assert scan_rv.owner.inputs[0] == full_out # A new value variable that spans the full output. # We don't want the old graph to appear in the new log-probability @@ -440,17 +421,18 @@ def find_measurable_scans(fgraph, node): full_out_shape = tuple( fgraph.shape_feature.get_shape(full_out, i) for i in range(full_out.ndim) ) - new_val_var = pt.empty(full_out_shape, dtype=full_out.dtype) + new_scan_value = pt.empty(full_out_shape, dtype=full_out.dtype) # Set the parts of this new value variable that applied to the # user-specified value variable to the user's value variable subtensor_indices = indices_from_subtensor( - rv_var.owner.inputs[1:], rv_var.owner.op.idx_list + scan_rv.owner.inputs[1:], + scan_rv.owner.op.idx_list, ) # E.g. for a single `-1` TAPS, `s_0T[1:] = s_1T` where `s_0T` is # `new_val_var` and `s_1T` is the user-specified value variable # that only spans times `t=1` to `t=T`. - new_val_var = pt.set_subtensor(new_val_var[subtensor_indices], val_var) + new_scan_value = pt.set_subtensor(new_scan_value[subtensor_indices], scan_value) # This is the outer-input that sets `s_0T[i] = taps[i]` where `i` # is a TAP index (e.g. a TAP of `-1` maps to index `0` in a vector @@ -466,64 +448,63 @@ def find_measurable_scans(fgraph, node): # We're going to set those values on our `new_val_var` so that it can # serve as a complete replacement for the old input `outer_input_var`. - new_val_var = outer_input_var.owner.clone_with_new_inputs( - [new_val_var] + outer_input_var.owner.inputs[1:] + new_scan_value = outer_input_var.owner.clone_with_new_inputs( + [new_scan_value] + outer_input_var.owner.inputs[1:] ).default_output() # Replace the mapping - rv_map_feature.update_rv_maps(rv_var, new_val_var, full_out) - - op = MeasurableScan( - curr_scanargs.inner_inputs, - curr_scanargs.inner_outputs, - curr_scanargs.info, - mode=node.op.mode, - ) - new_node = op.make_node(*curr_scanargs.outer_inputs) - - return dict(zip(node.outputs, new_node.outputs)) - - -@node_rewriter([Scan, Subtensor]) -def add_opts_to_inner_graphs(fgraph, node): - """Update the `Mode`(s) used to compile the inner-graph of a `Scan` `Op`. - - This is how we add the measurable IR rewrites to the "body" - (i.e. inner-graph) of a `Scan` loop. - """ - - if isinstance(node.op, Subtensor): - node = node.inputs[0].owner - if not (node and isinstance(node.op, Scan)): - return None + old_indirect_scan_valued_nodes.append(valued_node) + new_indirect_scan_values.append(new_scan_value) - # TODO: This might not be needed now that we only target relevant nodes - # Avoid unnecessarily re-applying this rewrite - if getattr(node.op.mode, "had_logprob_rewrites", False): + if direct_valued_output_idxs and indirect_valued_output_idxs: + # TODO: Allow measuring mixed direct and indirect output types + return None + elif direct_valued_output_idxs: + valued_output_idxs = direct_valued_output_idxs + elif indirect_valued_output_idxs: + valued_output_idxs = indirect_valued_output_idxs + else: return None - inner_rv_values = {out: out.type() for out in node.op.inner_outputs} + # Make inner graph measurable + inner_rv_values = { + out: out.type() for i, out in enumerate(node.op.inner_outputs) if i in valued_output_idxs + } ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])) - inner_fgraph, rv_values, _ = construct_ir_fgraph(inner_rv_values, ir_rewriter=ir_rewriter) - - new_outputs = list(inner_fgraph.outputs) - - # TODO FIXME: This is pretty hackish. - new_mode = copy(node.op.mode) - new_mode.had_logprob_rewrites = True - - op = Scan(node.op.inner_inputs, new_outputs, node.op.info, mode=new_mode) - new_node = op.make_node(*node.inputs) - - return dict(zip(node.outputs, new_node.outputs)) + inner_fgraph, *_ = construct_ir_fgraph(inner_rv_values, ir_rewriter=ir_rewriter) + remove_valued_rvs(inner_fgraph) + new_rvs = list(inner_fgraph.outputs) + if not all(isinstance(new_out.owner.op, MeasurableVariable) for new_out in new_rvs): + return None + # Get new inner outs (replace measured vars in graph of non-valued outputs) + inner_outs = node.op.inner_outputs.copy() + old_to_new_inner_rvs = [] + for idx, rv in zip(valued_output_idxs, new_rvs): + old_rv = inner_outs[idx] + inner_outs[idx] = rv + old_to_new_inner_rvs.append((old_rv, rv)) + temp_fgraph = FunctionGraph( + outputs=inner_outs + [a for a, _ in old_to_new_inner_rvs], clone=False + ) + toposort_replace(temp_fgraph, old_to_new_inner_rvs) + inner_outs = temp_fgraph.outputs[: len(inner_outs)] + + op = MeasurableScan(node.op.inner_inputs, inner_outs, node.op.info, mode=copy(node.op.mode)) + new_rvs = op.make_node(*node.inputs).outputs + + replacements = {} + if indirect_valued_output_idxs: + # Create new ValuedRVs that sidestep the Subtensor + for new_rv, new_value, old_valued_node in zip( + new_rvs, new_indirect_scan_values, indirect_valued_output_nodes + ): + new_valued_rv = valued_rv(new_rv, new_value) + replacements[old_valued_node.outputs[0]] = new_valued_rv + else: + replacements.update(dict(zip(node.outputs, new_rvs))) + return replacements -measurable_ir_rewrites_db.register( - "add_opts_to_inner_graphs", - add_opts_to_inner_graphs, - "basic", - "scan", -) measurable_ir_rewrites_db.register( "find_measurable_scans", diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 1223b7091..8eabcdd42 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -36,88 +36,32 @@ from typing import List, Optional, Union -import pytensor - from pytensor import tensor as pt -from pytensor.graph.op import compute_test_value +from pytensor.graph import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.tensor.basic import Alloc, Join, MakeVector +from pytensor.tensor.basic import Join, MakeVector from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.random.rewriting import ( - local_dimshuffle_rv_lift, - local_rv_size_lift, -) +from pytensor.tensor.random.rewriting import local_dimshuffle_rv_lift -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import ( + MeasurableVariable, + _logprob, + _logprob_helper, + promised_valued_rv, +) from pymc.logprob.rewriting import ( - PreserveRVMappings, - assume_measured_ir_outputs, + assume_valued_outputs, + early_measurable_ir_rewrites_db, measurable_ir_rewrites_db, + remove_valued_rvs, ) -from pymc.logprob.utils import check_potential_measurability - - -@node_rewriter([Alloc]) -def naive_bcast_rv_lift(fgraph, node): - """Lift an ``Alloc`` through a ``RandomVariable`` ``Op``. - - XXX: This implementation simply broadcasts the ``RandomVariable``'s - parameters, which won't always work (e.g. multivariate distributions). - - TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to - determine which dimensions of each parameter need to be broadcasted. - Also, this doesn't need to remove ``size`` to perform the lifting, like it - currently does. - """ - - if not ( - isinstance(node.op, Alloc) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, RandomVariable) - ): - return None # pragma: no cover - - bcast_shape = node.inputs[1:] - - rv_var = node.inputs[0] - rv_node = rv_var.owner - - if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars: - return None # pragma: no cover - - # Do not replace RV if it is associated with a value variable - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: - return None - - if not bcast_shape: - # The `Alloc` is broadcasting a scalar to a scalar (i.e. doing nothing) - assert rv_var.ndim == 0 - return [rv_var] - - size_lift_res = local_rv_size_lift.transform(fgraph, rv_node) - if size_lift_res is None: - lifted_node = rv_node - else: - _, lifted_rv = size_lift_res - lifted_node = lifted_rv.owner - - rng, size, dtype, *dist_params = lifted_node.inputs - - new_dist_params = [ - pt.broadcast_to( - param, - pt.broadcast_shape(tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True), - ) - for param in dist_params - ] - bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params) - - if pytensor.config.compute_test_value != "off": - compute_test_value(bcasted_node) - - return [bcasted_node.outputs[1]] +from pymc.logprob.utils import ( + check_potential_measurability, + filter_measurable_variables, + replace_rvs_by_values, +) +from pymc.pytensorf import constant_fold, toposort_replace class MeasurableMakeVector(MakeVector): @@ -131,10 +75,13 @@ class MeasurableMakeVector(MakeVector): def logprob_make_vector(op, values, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `MeasurableMakeVector`.""" # TODO: Sort out this circular dependency issue - from pymc.pytensorf import replace_rvs_by_values (value,) = values + temp_fgraph = FunctionGraph(outputs=base_rvs, clone=False) + remove_valued_rvs.apply(temp_fgraph) + base_rvs = temp_fgraph.outputs + base_rvs_to_values = {base_rv: value[i] for i, base_rv in enumerate(base_rvs)} for i, (base_rv, value) in enumerate(base_rvs_to_values.items()): base_rv.name = f"base_rv[{i}]" @@ -158,11 +105,12 @@ class MeasurableJoin(Join): @_logprob.register(MeasurableJoin) def logprob_join(op, values, axis, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `Join`.""" - # TODO: Find better way to avoid circular dependency - from pymc.pytensorf import constant_fold, replace_rvs_by_values - (value,) = values + temp_fgraph = FunctionGraph(outputs=base_rvs, clone=False) + remove_valued_rvs.apply(temp_fgraph) + base_rvs = temp_fgraph.outputs + base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs] # We don't need the graph to be constant, just to have RandomVariables removed @@ -203,12 +151,18 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs): def find_measurable_stacks( fgraph, node ) -> Optional[List[Union[MeasurableMakeVector, MeasurableJoin]]]: - r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed.""" + r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed. - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + Because base variables in the Join and MakeVector may be interdependent, + the IR graph of these Ops is almost like an inner IR, except it doesn't have "valued RVs". - if rv_map_feature is None: - return None # pragma: no cover + To circumvent this issue, we run this rewrite early on, and tag variables as being + "promised_valued_rvs". A perhaps more elegant alternative is to wrap the sub-graph in an OpFromGraph? + This might avoid the repeated boxing and unboxing done to the "base_vars" throughout the life of the IR + """ + + if isinstance(node.op, (MeasurableMakeVector, MeasurableJoin)): + return None is_join = isinstance(node.op, Join) @@ -217,18 +171,25 @@ def find_measurable_stacks( else: base_vars = node.inputs - valued_rvs = rv_map_feature.rv_values.keys() - if not all(check_potential_measurability([base_var], valued_rvs) for base_var in base_vars): + if not all(check_potential_measurability([base_var]) for base_var in base_vars): return None - base_vars = assume_measured_ir_outputs(valued_rvs, base_vars) + base_vars = assume_valued_outputs(base_vars) if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_vars): return None + # Each base var will be "valued" by the logprob method, so other rewrites shouldn't mess with it + # and potentially break interdependencies. For this reason, this rewrite should be applied early in + # the IR construction + replacements = [(base_var, promised_valued_rv(base_var)) for base_var in base_vars] + temp_fgraph = FunctionGraph(outputs=base_vars, clone=False) + toposort_replace(temp_fgraph, replacements) + new_base_vars = temp_fgraph.outputs + if is_join: - measurable_stack = MeasurableJoin()(axis, *base_vars) + measurable_stack = MeasurableJoin()(axis, *new_base_vars) else: - measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars) + measurable_stack = MeasurableMakeVector(node.op.dtype)(*new_base_vars) return [measurable_stack] @@ -279,12 +240,10 @@ def logprob_dimshuffle(op, values, base_var, **kwargs): def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuffle]]: r"""Finds `Dimshuffle`\s for which a `logprob` can be computed.""" - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover + if isinstance(node.op, MeasurableDimShuffle): + return None - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None base_var = node.inputs[0] @@ -315,7 +274,7 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf "find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor" ) -measurable_ir_rewrites_db.register( +early_measurable_ir_rewrites_db.register( "find_measurable_stacks", find_measurable_stacks, "basic", diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index c2f038ad5..16ede14d6 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -34,17 +34,14 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import abc - from copy import copy -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import numpy as np import pytensor.tensor as pt -from pytensor import scan -from pytensor.gradient import DisconnectedType, jacobian -from pytensor.graph.basic import Apply, Node, Variable +from pytensor.gradient import DisconnectedType +from pytensor.graph.basic import Apply, Node from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op @@ -112,9 +109,28 @@ ) from pytensor.tensor.variable import TensorVariable +from pymc.distributions.transforms import ( + AbsTransform, + ArccoshTransform, + ArcsinhTransform, + ArctanhTransform, + CoshTransform, + ErfcTransform, + ErfcxTransform, + ErfTransform, + ExpTransform, + LocTransform, + LogTransform, + PowerTransform, + ScaleTransform, + SinhTransform, + TanhTransform, + Transform, +) from pymc.logprob.abstract import ( MeasurableElemwise, MeasurableVariable, + ValuedRV, _icdf, _icdf_helper, _logcdf, @@ -122,12 +138,13 @@ _logprob, _logprob_helper, ) -from pymc.logprob.rewriting import ( - PreserveRVMappings, - cleanup_ir_rewrites_db, - measurable_ir_rewrites_db, +from pymc.logprob.rewriting import cleanup_ir_rewrites_db, measurable_ir_rewrites_db +from pymc.logprob.utils import ( + CheckParameterValue, + check_potential_measurability, + filter_measurable_variables, + get_related_valued_nodes, ) -from pymc.logprob.utils import CheckParameterValue, check_potential_measurability class TransformedVariable(Op): @@ -168,39 +185,8 @@ def remove_TransformedVariables(fgraph, node): ) -class RVTransform(abc.ABC): - ndim_supp = None - - @abc.abstractmethod - def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: - """Apply the transformation.""" - - @abc.abstractmethod - def backward( - self, value: TensorVariable, *inputs: Variable - ) -> Union[TensorVariable, Tuple[TensorVariable, ...]]: - """Invert the transformation. Multiple values may be returned when the - transformation is not 1-to-1""" - - def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: - """Construct the log of the absolute value of the Jacobian determinant.""" - if self.ndim_supp not in (0, 1): - raise NotImplementedError( - f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}" - ) - if self.ndim_supp == 0: - jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape) - return pt.log(pt.abs(jac)) - else: - phi_inv = self.backward(value, *inputs) - return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) - - def __str__(self): - return f"{self.__class__.__name__}" - - -@node_rewriter(tracks=None) -def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: +@node_rewriter(tracks=[ValuedRV]) +def transform_values(fgraph: FunctionGraph, valued_node: Node) -> Optional[List[Node]]: """Apply transforms to value variables. It is assumed that the input value variables correspond to forward @@ -212,56 +198,45 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: ``Y`` on the natural scale. """ - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) values_to_transforms: Optional[TransformValuesMapping] = getattr( fgraph, "values_to_transforms", None ) - if rv_map_feature is None or values_to_transforms is None: + if values_to_transforms is None: return None # pragma: no cover - rv_vars = [] - value_vars = [] - - for out in node.outputs: - value = rv_map_feature.rv_values.get(out, None) - if value is None: - continue - rv_vars.append(out) - value_vars.append(value) - - if not value_vars: - return None - - transforms = [values_to_transforms.get(value_var, None) for value_var in value_vars] + rv_node = valued_node.inputs[0].owner + valued_nodes = get_related_valued_nodes(rv_node, fgraph) + rvs = [valued_var.inputs[0] for valued_var in valued_nodes] + values = [valued_var.inputs[1] for valued_var in valued_nodes] + transforms = [values_to_transforms.get(value, None) for value in values] if all(transform is None for transform in transforms): return None - new_op = _create_transformed_rv_op(node.op, transforms) + # Create a new RV Op whose logprob respects the transformed value variable + transformed_rv_op = _create_transformed_rv_op(rv_node.op, transforms) # Create a new `Apply` node and outputs - trans_node = node.clone() - trans_node.op = new_op + transformed_node = rv_node.clone() + transformed_node.op = transformed_rv_op + + replacements = dict(zip(rv_node.outputs, transformed_node.outputs)) # We now assume that the old value variable represents the *transformed space*. # This means that we need to replace all instance of the old value variable # with "inversely/un-" transformed versions of itself. - for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms): - rv_var_out_idx = node.outputs.index(rv_var) - + for rv, value, transform in zip(rvs, values, transforms): if transform is None: continue - new_value_var = transformed_variable( - transform.backward(value_var, *trans_node.inputs), value_var - ) + new_value = transformed_variable(transform.backward(value, *transformed_node.inputs), value) - if value_var.name and getattr(transform, "name", None): - new_value_var.name = f"{value_var.name}_{transform.name}" + if value.name and getattr(transform, "name", "transformed"): + new_value.name = f"{value.name}_{transform.name}" - rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx]) + replacements[value] = new_value - return trans_node.outputs + return replacements @node_rewriter(tracks=[Scan]) @@ -367,7 +342,7 @@ class TransformValuesRewrite(GraphRewriter): def __init__( self, - values_to_transforms: Dict[TensorVariable, Union[RVTransform, None]], + values_to_transforms: Dict[TensorVariable, Union[Transform, None]], ): """ Parameters @@ -414,10 +389,10 @@ class MeasurableTransform(MeasurableElemwise): # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` - transform_elemwise: RVTransform + transform_elemwise: Transform measurable_input_idx: int - def __init__(self, *args, transform: RVTransform, measurable_input_idx: int, **kwargs): + def __init__(self, *args, transform: Transform, measurable_input_idx: int, **kwargs): self.transform_elemwise = transform self.measurable_input_idx = measurable_input_idx super().__init__(*args, **kwargs) @@ -540,13 +515,17 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs) @node_rewriter([reciprocal]) def measurable_reciprocal_to_power(fgraph, node): """Convert reciprocal of `MeasurableVariable`s to power.""" - [inp] = node.inputs - return [pt.pow(inp, -1.0)] + if filter_measurable_variables(node.inputs): + [inp] = node.inputs + return [pt.pow(inp, -1.0)] @node_rewriter([sqr, sqrt]) def measurable_sqrt_sqr_to_power(fgraph, node): """Convert square root or square of `MeasurableVariable`s to power form.""" + if not filter_measurable_variables(node.inputs): + return None + [inp] = node.inputs if isinstance(node.op.scalar_op, Sqr): @@ -559,6 +538,9 @@ def measurable_sqrt_sqr_to_power(fgraph, node): @node_rewriter([true_div]) def measurable_div_to_product(fgraph, node): """Convert divisions involving `MeasurableVariable`s to products.""" + if not filter_measurable_variables(node.inputs): + return None + numerator, denominator = node.inputs # Check if numerator is 1 @@ -577,20 +559,25 @@ def measurable_div_to_product(fgraph, node): @node_rewriter([neg]) def measurable_neg_to_product(fgraph, node): """Convert negation of `MeasurableVariable`s to product with `-1`.""" - inp = node.inputs[0] - return [pt.mul(inp, -1.0)] + if filter_measurable_variables(node.inputs): + inp = node.inputs[0] + return [pt.mul(inp, -1)] @node_rewriter([sub]) def measurable_sub_to_neg(fgraph, node): """Convert subtraction involving `MeasurableVariable`s to addition with neg""" - minuend, subtrahend = node.inputs - return [pt.add(minuend, pt.neg(subtrahend))] + if filter_measurable_variables(node.inputs): + minuend, subtrahend = node.inputs + return [pt.add(minuend, pt.neg(subtrahend))] @node_rewriter([log1p, softplus, log1mexp, log2, log10]) def measurable_special_log_to_log(fgraph, node): """Convert log1p, log1mexp, softplus, log2, log10 of `MeasurableVariable`s to log form.""" + if not filter_measurable_variables(node.inputs): + return None + [inp] = node.inputs if isinstance(node.op.scalar_op, Log1p): @@ -608,7 +595,11 @@ def measurable_special_log_to_log(fgraph, node): @node_rewriter([expm1, sigmoid, exp2]) def measurable_special_exp_to_exp(fgraph, node): """Convert expm1, sigmoid, and exp2 of `MeasurableVariable`s to xp form.""" + if not filter_measurable_variables(node.inputs): + return None + [inp] = node.inputs + if isinstance(node.op.scalar_op, Exp2): return [pt.exp(pt.log(2) * inp)] if isinstance(node.op.scalar_op, Expm1): @@ -622,9 +613,12 @@ def measurable_power_exponent_to_exp(fgraph, node): """Convert power(base, rv) of `MeasurableVariable`s to exp(log(base) * rv) form.""" base, inp_exponent = node.inputs + if not filter_measurable_variables([inp_exponent]): + return None + # When the base is measurable we have `power(rv, exponent)`, which should be handled by `PowerTransform` and needs no further rewrite. # Here we change only the cases where exponent is measurable `power(base, rv)` which is not supported by the `PowerTransform` - if check_potential_measurability([base], fgraph.preserve_rv_mappings.rv_values.keys()): + if check_potential_measurability([base]): return None base = CheckParameterValue("base >= 0")(base, pt.all(pt.ge(base, 0.0))) @@ -658,12 +652,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li if isinstance(node.op, MeasurableVariable): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - # Check that we have a single source of measurement - measurable_inputs = rv_map_feature.request_measurable(node.inputs) + measurable_inputs = filter_measurable_variables(node.inputs) if len(measurable_inputs) != 1: return None @@ -678,29 +668,14 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li # would be invalid other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input) - if check_potential_measurability(other_inputs, rv_map_feature.rv_values.keys()): + if check_potential_measurability(other_inputs): return None scalar_op = node.op.scalar_op measurable_input_idx = 0 transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,) - transform: RVTransform - - transform_dict = { - Exp: ExpTransform(), - Log: LogTransform(), - Abs: AbsTransform(), - Sinh: SinhTransform(), - Cosh: CoshTransform(), - Tanh: TanhTransform(), - ArcSinh: ArcsinhTransform(), - ArcCosh: ArccoshTransform(), - ArcTanh: ArctanhTransform(), - Erf: ErfTransform(), - Erfc: ErfcTransform(), - Erfcx: ErfcxTransform(), - } - transform = transform_dict.get(type(scalar_op), None) + transform: Transform + if isinstance(scalar_op, Pow): # We only allow for the base to be measurable if measurable_input_idx != 0: @@ -718,11 +693,26 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform = LocTransform( transform_args_fn=lambda *inputs: inputs[-1], ) - elif transform is None: + elif isinstance(scalar_op, Mul): transform_inputs = (measurable_input, pt.mul(*other_inputs)) transform = ScaleTransform( transform_args_fn=lambda *inputs: inputs[-1], ) + else: + transform = { + Exp: ExpTransform, + Log: LogTransform, + Abs: AbsTransform, + Sinh: SinhTransform, + Cosh: CoshTransform, + Tanh: TanhTransform, + ArcSinh: ArcsinhTransform, + ArcCosh: ArccoshTransform, + ArcTanh: ArctanhTransform, + Erf: ErfTransform, + Erfc: ErfcTransform, + Erfcx: ErfcxTransform, + }[type(scalar_op)]() transform_op = MeasurableTransform( scalar_op=scalar_op, transform=transform, @@ -799,394 +789,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li ) -class SinhTransform(RVTransform): - name = "sinh" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.sinh(value) - - def backward(self, value, *inputs): - return pt.arcsinh(value) - - -class CoshTransform(RVTransform): - name = "cosh" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.cosh(value) - - def backward(self, value, *inputs): - back_value = pt.arccosh(value) - return (-back_value, back_value) - - def log_jac_det(self, value, *inputs): - return pt.switch( - value < 1, - np.nan, - -pt.log(pt.sqrt(value**2 - 1)), - ) - - -class TanhTransform(RVTransform): - name = "tanh" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.tanh(value) - - def backward(self, value, *inputs): - return pt.arctanh(value) - - -class ArcsinhTransform(RVTransform): - name = "arcsinh" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.arcsinh(value) - - def backward(self, value, *inputs): - return pt.sinh(value) - - -class ArccoshTransform(RVTransform): - name = "arccosh" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.arccosh(value) - - def backward(self, value, *inputs): - return pt.cosh(value) - - -class ArctanhTransform(RVTransform): - name = "arctanh" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.arctanh(value) - - def backward(self, value, *inputs): - return pt.tanh(value) - - -class ErfTransform(RVTransform): - name = "erf" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.erf(value) - - def backward(self, value, *inputs): - return pt.erfinv(value) - - -class ErfcTransform(RVTransform): - name = "erfc" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.erfc(value) - - def backward(self, value, *inputs): - return pt.erfcinv(value) - - -class ErfcxTransform(RVTransform): - name = "erfcx" - ndim_supp = 0 - - def forward(self, value, *inputs): - return pt.erfcx(value) - - def backward(self, value, *inputs): - # computes the inverse of erfcx, this was adapted from - # https://tinyurl.com/4mxfd3cz - x = pt.switch(value <= 1, 1.0 / (value * pt.sqrt(np.pi)), -pt.sqrt(pt.log(value))) - - def calc_delta_x(value, prior_result): - return prior_result - (pt.erfcx(prior_result) - value) / ( - 2 * prior_result * pt.erfcx(prior_result) - 2 / pt.sqrt(np.pi) - ) - - result, updates = scan( - fn=calc_delta_x, - outputs_info=pt.ones_like(x), - non_sequences=value, - n_steps=10, - ) - return result[-1] - - -class LocTransform(RVTransform): - name = "loc" - - def __init__(self, transform_args_fn): - self.transform_args_fn = transform_args_fn - - def forward(self, value, *inputs): - loc = self.transform_args_fn(*inputs) - return value + loc - - def backward(self, value, *inputs): - loc = self.transform_args_fn(*inputs) - return value - loc - - def log_jac_det(self, value, *inputs): - return pt.zeros_like(value) - - -class ScaleTransform(RVTransform): - name = "scale" - - def __init__(self, transform_args_fn): - self.transform_args_fn = transform_args_fn - - def forward(self, value, *inputs): - scale = self.transform_args_fn(*inputs) - return value * scale - - def backward(self, value, *inputs): - scale = self.transform_args_fn(*inputs) - return value / scale - - def log_jac_det(self, value, *inputs): - scale = self.transform_args_fn(*inputs) - return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape))) - - -class LogTransform(RVTransform): - name = "log" - - def forward(self, value, *inputs): - return pt.log(value) - - def backward(self, value, *inputs): - return pt.exp(value) - - def log_jac_det(self, value, *inputs): - return value - - -class ExpTransform(RVTransform): - name = "exp" - - def forward(self, value, *inputs): - return pt.exp(value) - - def backward(self, value, *inputs): - return pt.log(value) - - def log_jac_det(self, value, *inputs): - return -pt.log(value) - - -class AbsTransform(RVTransform): - name = "abs" - - def forward(self, value, *inputs): - return pt.abs(value) - - def backward(self, value, *inputs): - value = pt.switch(value >= 0, value, np.nan) - return -value, value - - def log_jac_det(self, value, *inputs): - return pt.switch(value >= 0, 0, np.nan) - - -class PowerTransform(RVTransform): - name = "power" - - def __init__(self, power=None): - if not isinstance(power, (int, float)): - raise TypeError(f"Power must be integer or float, got {type(power)}") - if power == 0: - raise ValueError("Power cannot be 0") - self.power = power - super().__init__() - - def forward(self, value, *inputs): - return pt.power(value, self.power) - - def backward(self, value, *inputs): - inv_power = 1 / self.power - - # Powers that don't admit negative values - if (np.abs(self.power) < 1) or (self.power % 2 == 0): - backward_value = pt.switch(value >= 0, pt.power(value, inv_power), np.nan) - # Powers that admit negative values require special logic, because (-1)**(1/3) returns `nan` in PyTensor - else: - backward_value = pt.power(pt.abs(value), inv_power) * pt.switch(value >= 0, 1, -1) - - # In this case the transform is not 1-to-1 - if self.power % 2 == 0: - return -backward_value, backward_value - else: - return backward_value - - def log_jac_det(self, value, *inputs): - inv_power = 1 / self.power - - # Note: This fails for value==0 - res = np.log(np.abs(inv_power)) + (inv_power - 1) * pt.log(pt.abs(value)) - - # Powers that don't admit negative values - if (np.abs(self.power) < 1) or (self.power % 2 == 0): - res = pt.switch(value >= 0, res, np.nan) - - return res - - -class IntervalTransform(RVTransform): - name = "interval" - - def __init__(self, args_fn: Callable[..., Tuple[Optional[Variable], Optional[Variable]]]): - """ - - Parameters - ---------- - args_fn - Function that expects inputs of RandomVariable and returns the lower - and upper bounds for the interval transformation. If one of these is - None, the RV is considered to be unbounded on the respective edge. - """ - self.args_fn = args_fn - - def forward(self, value, *inputs): - a, b = self.args_fn(*inputs) - - if a is not None and b is not None: - return pt.log(value - a) - pt.log(b - value) - elif a is not None: - return pt.log(value - a) - elif b is not None: - return pt.log(b - value) - else: - raise ValueError("Both edges of IntervalTransform cannot be None") - - def backward(self, value, *inputs): - a, b = self.args_fn(*inputs) - - if a is not None and b is not None: - sigmoid_x = pt.sigmoid(value) - return sigmoid_x * b + (1 - sigmoid_x) * a - elif a is not None: - return pt.exp(value) + a - elif b is not None: - return b - pt.exp(value) - else: - raise ValueError("Both edges of IntervalTransform cannot be None") - - def log_jac_det(self, value, *inputs): - a, b = self.args_fn(*inputs) - - if a is not None and b is not None: - s = pt.softplus(-value) - return pt.log(b - a) - 2 * s - value - elif a is None and b is None: - raise ValueError("Both edges of IntervalTransform cannot be None") - else: - return value - - -class LogOddsTransform(RVTransform): - name = "logodds" - - def backward(self, value, *inputs): - return pt.expit(value) - - def forward(self, value, *inputs): - return pt.log(value / (1 - value)) - - def log_jac_det(self, value, *inputs): - sigmoid_value = pt.sigmoid(value) - return pt.log(sigmoid_value) + pt.log1p(-sigmoid_value) - - -class SimplexTransform(RVTransform): - name = "simplex" - - def forward(self, value, *inputs): - value = pt.as_tensor(value) - log_value = pt.log(value) - N = value.shape[-1].astype(value.dtype) - shift = pt.sum(log_value, -1, keepdims=True) / N - return log_value[..., :-1] - shift - - def backward(self, value, *inputs): - value = pt.concatenate([value, -pt.sum(value, -1, keepdims=True)], axis=-1) - exp_value_max = pt.exp(value - pt.max(value, -1, keepdims=True)) - return exp_value_max / pt.sum(exp_value_max, -1, keepdims=True) - - def log_jac_det(self, value, *inputs): - value = pt.as_tensor(value) - N = value.shape[-1] + 1 - N = N.astype(value.dtype) - sum_value = pt.sum(value, -1, keepdims=True) - value_sum_expanded = value + sum_value - value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1) - logsumexp_value_expanded = pt.logsumexp(value_sum_expanded, -1, keepdims=True) - res = pt.log(N) + (N * sum_value) - (N * logsumexp_value_expanded) - return pt.sum(res, -1) - - -class CircularTransform(RVTransform): - name = "circular" - - def backward(self, value, *inputs): - return pt.arctan2(pt.sin(value), pt.cos(value)) - - def forward(self, value, *inputs): - return pt.as_tensor_variable(value) - - def log_jac_det(self, value, *inputs): - return pt.zeros(value.shape) - - -class ChainedTransform(RVTransform): - name = "chain" - - def __init__(self, transform_list, base_op): - self.transform_list = transform_list - self.base_op = base_op - - def forward(self, value, *inputs): - for transform in self.transform_list: - value = transform.forward(value, *inputs) - return value - - def backward(self, value, *inputs): - for transform in reversed(self.transform_list): - value = transform.backward(value, *inputs) - return value - - def log_jac_det(self, value, *inputs): - value = pt.as_tensor_variable(value) - det_list = [] - ndim0 = value.ndim - for transform in reversed(self.transform_list): - det_ = transform.log_jac_det(value, *inputs) - det_list.append(det_) - ndim0 = min(ndim0, det_.ndim) - value = transform.backward(value, *inputs) - # match the shape of the smallest jacobian_det - det = 0.0 - for det_ in det_list: - if det_.ndim > ndim0: - ndim_diff = det_.ndim - ndim0 - det += det_.sum(axis=tuple(range(-ndim_diff, 0))) - else: - det += det_ - return det - - def _create_transformed_rv_op( rv_op: Op, - transforms: Union[RVTransform, Sequence[Union[None, RVTransform]]], + transforms: Union[Transform, Sequence[Union[None, Transform]]], *, cls_dict_extra: Optional[Dict] = None, ) -> Op: diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 783b9ad95..48985c6d6 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -36,158 +36,101 @@ import warnings -from typing import ( - Callable, - Container, - Dict, - Generator, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union import numpy as np +import pytensor from pytensor import Variable from pytensor import tensor as pt -from pytensor.graph import Apply, Op -from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk -from pytensor.graph.fg import FunctionGraph +from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter +from pytensor.graph.basic import walk from pytensor.graph.op import HasInnerGraph from pytensor.link.c.type import CType from pytensor.raise_op import CheckAndRaise from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.distributions.transforms import Transform +from pymc.logprob.abstract import MeasurableVariable, ValuedRV, _logprob +from pymc.pytensorf import replace_vars_in_graphs from pymc.util import makeiter -def walk_model( - graphs: Iterable[TensorVariable], - walk_past_rvs: bool = False, - stop_at_vars: Optional[Set[TensorVariable]] = None, - expand_fn: Callable[[TensorVariable], List[TensorVariable]] = lambda var: [], -) -> Generator[TensorVariable, None, None]: - """Walk model graphs and yield their nodes. - - By default, these walks will not go past ``MeasurableVariable`` nodes. +def replace_rvs_by_values( + graphs: Sequence[TensorVariable], + *, + rvs_to_values: Dict[TensorVariable, TensorVariable], + rvs_to_transforms: Optional[Dict[TensorVariable, "Transform"]] = None, +) -> List[TensorVariable]: + """Clone and replace random variables in graphs with their value variables. Parameters ---------- graphs - The graphs to walk. - walk_past_rvs - If ``True``, the walk will not terminate at ``MeasurableVariable``s. - stop_at_vars - A list of variables at which the walk will terminate. - expand_fn - A function that returns the next variable(s) to be traversed. + The graphs in which to perform the replacements. + rvs_to_values + Mapping between the original graph RVs and respective value variables + rvs_to_transforms, optional + Mapping between the original graph RVs and respective value transforms """ - if stop_at_vars is None: - stop_at_vars = set() - - def expand(var: TensorVariable, stop_at_vars=stop_at_vars) -> List[TensorVariable]: - new_vars = expand_fn(var) - - if ( - var.owner - and (walk_past_rvs or not isinstance(var.owner.op, MeasurableVariable)) - and (var not in stop_at_vars) - ): - new_vars.extend(reversed(var.owner.inputs)) - return new_vars - - yield from walk(graphs, expand, False) - - -def replace_rvs_in_graphs( - graphs: Iterable[TensorVariable], - replacement_fn: Callable[ - [TensorVariable, Dict[TensorVariable, TensorVariable]], - Dict[TensorVariable, TensorVariable], - ], - initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, - **kwargs, -) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]: - """Replace random variables in graphs. - - This will *not* recompute test values. - - Parameters - ---------- - graphs - The graphs in which random variables are to be replaced. - - Returns - ------- - A ``tuple`` containing the transformed graphs and a ``dict`` of the - replacements that were made. - """ replacements = {} - if initial_replacements: - replacements.update(initial_replacements) - def expand_replace(var: TensorVariable) -> List[TensorVariable]: - new_nodes: List[TensorVariable] = [] - if var.owner and isinstance(var.owner.op, MeasurableVariable): - new_nodes.extend(replacement_fn(var, replacements)) - return new_nodes - - for var in walk_model(graphs, expand_fn=expand_replace, **kwargs): + def populate_replacements(var): + # Populate replacements dict with {rv: value} pairs indicating which graph + # RVs should be replaced by what value variables. + if not var.owner: + return [] + + next_vars = [] + value = rvs_to_values.get(var, None) + if value is not None: + rv = var + + if rvs_to_transforms is not None: + transform = rvs_to_transforms.get(rv, None) + if transform is not None: + # We want to replace uses of the RV by the back-transformation of its value + value = transform.backward(value, *rv.owner.inputs) + # The value may have a less precise type than the rv. In this case + # filter_variable will add a SpecifyShape to ensure they are consistent + value = rv.type.filter_variable(value, allow_convert=True) + value.name = rv.name + + replacements[rv] = value + # Also walk the graph of the value variable to make any additional + # replacements if that is not a simple input variable + next_vars.append(value) + + next_vars.extend(reversed(var.owner.inputs)) + return next_vars + + # Iterate over the generator to populate the replacements + for _ in walk(graphs, populate_replacements, bfs=False): pass - if replacements: - inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = {k: k for k in replacements.keys()} - equiv = clone_get_equiv(inputs, graphs, False, False, equiv) - - fg = FunctionGraph( - [equiv[i] for i in inputs], - [equiv[o] for o in graphs], - clone=False, - ) - - fg.replace_all(replacements.items(), import_missing=True) - - graphs = list(fg.outputs) + return replace_vars_in_graphs(graphs, replacements) - return graphs, replacements +def rvs_in_graph(vars: Union[Variable, Sequence[Variable]]) -> Set[Variable]: + """Assert that there are no `MeasurableVariable` nodes in a graph.""" -def rvs_to_value_vars( - graphs: Iterable[TensorVariable], - initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, - **kwargs, -) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]: - """Replace random variables in graphs with their value variables. - - This will *not* recompute test values in the resulting graphs. - - Parameters - ---------- - graphs - The graphs in which to perform the replacements. - initial_replacements - A ``dict`` containing the initial replacements to be made. + def expand(r): + owner = r.owner + if owner: + inputs = list(reversed(owner.inputs)) - """ + if isinstance(owner.op, HasInnerGraph): + inputs += owner.op.inner_outputs - def replace_fn(var, replacements): - rv_value_var = replacements.get(var, None) - if rv_value_var is not None: - replacements[var] = rv_value_var - # In case the value variable is itself a graph, we walk it for - # potential replacements - return [rv_value_var] - return [] + return inputs - return replace_rvs_in_graphs(graphs, replace_fn, initial_replacements, **kwargs) + return { + node + for node in walk(makeiter(vars), expand, False) + if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable)) + } def convert_indices(indices, entry): @@ -211,26 +154,59 @@ def indices_from_subtensor(idx_list, indices): ) -def check_potential_measurability( - inputs: Tuple[TensorVariable], valued_rvs: Container[TensorVariable] -) -> bool: +def filter_measurable_variables(inputs): + return [ + inp + for inp in inputs + if ( + inp.owner is not None + and not isinstance(inp.owner.op, ValuedRV) + and isinstance(inp.owner.op, MeasurableVariable) + ) + ] + + +def check_potential_measurability(inputs: Tuple[TensorVariable]) -> bool: + def expand_fn(var): + # expand_fn does not go beyond valued_rvs or any MeasurableVariable + if var.owner and not isinstance(var.owner.op, (ValuedRV, MeasurableVariable)): + return reversed(var.owner.inputs) + else: + return [] + if any( ancestor_var - for ancestor_var in walk_model( - inputs, - walk_past_rvs=False, - stop_at_vars=set(valued_rvs), - ) + for ancestor_var in walk(inputs, expand=expand_fn, bfs=False) if ( ancestor_var.owner + and not isinstance(ancestor_var.owner.op, ValuedRV) and isinstance(ancestor_var.owner.op, MeasurableVariable) - and ancestor_var not in valued_rvs ) ): return True return False +def get_related_valued_nodes(node: Apply, fgraph: FunctionGraph) -> list[Apply]: + """Get all ValuedVars related to the same RV node. + + Returns + ------- + rv_node + valued_nodes + """ + clients = fgraph.clients + valued_nodes = [] + for out in node.outputs: + for client, _ in clients[out]: + if client == "output": + continue + if isinstance(client.op, ValuedRV): + valued_nodes.append(client) + + return valued_nodes + + class ParameterValueError(ValueError): """Exception for invalid parameters values in logprob graphs""" @@ -251,6 +227,48 @@ def __str__(self): return f"Check{{{self.msg}}}" +@node_rewriter(tracks=[CheckParameterValue]) +def local_remove_check_parameter(fgraph, node): + """Rewrite that removes CheckParameterValue + + This is used when compile_rv_inplace + """ + if isinstance(node.op, CheckParameterValue): + return [node.inputs[0]] + + +@node_rewriter(tracks=[CheckParameterValue]) +def local_check_parameter_to_ninf_switch(fgraph, node): + if not node.op.can_be_replaced_by_ninf: + return None + + logp_expr, *logp_conds = node.inputs + if len(logp_conds) > 1: + logp_cond = pt.all(logp_conds) + else: + (logp_cond,) = logp_conds + out = pt.switch(logp_cond, logp_expr, -np.inf) + out.name = node.op.msg + + if out.dtype != node.outputs[0].dtype: + out = pt.cast(out, node.outputs[0].dtype) + + return [out] + + +pytensor.compile.optdb["canonicalize"].register( + "local_remove_check_parameter", + local_remove_check_parameter, + use_db_name_as_tag=False, +) + +pytensor.compile.optdb["canonicalize"].register( + "local_check_parameter_to_ninf_switch", + local_check_parameter_to_ninf_switch, + use_db_name_as_tag=False, +) + + class DiracDelta(Op): """An `Op` that represents a Dirac-delta distribution.""" @@ -291,23 +309,3 @@ def diracdelta_logprob(op, values, *inputs, **kwargs): (const_value,) = inputs values, const_value = pt.broadcast_arrays(values, const_value) return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf) - - -def find_rvs_in_graph(vars: Union[Variable, Sequence[Variable]]) -> Set[Variable]: - """Assert that there are no `MeasurableVariable` nodes in a graph.""" - - def expand(r): - owner = r.owner - if owner: - inputs = list(reversed(owner.inputs)) - - if isinstance(owner.op, HasInnerGraph): - inputs += owner.op.inner_outputs - - return inputs - - return { - node - for node in walk(makeiter(vars), expand, False) - if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable)) - } diff --git a/pymc/model/core.py b/pymc/model/core.py index 65ad468f7..4a0ff7573 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -65,7 +65,7 @@ ) from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import transformed_conditional_logp -from pymc.logprob.utils import ParameterValueError +from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values from pymc.model_graph import model_to_graphviz from pymc.pytensorf import ( PointFunc, @@ -75,7 +75,6 @@ gradient, hessian, inputvars, - replace_rvs_by_values, rewrite_pregrad, ) from pymc.util import ( diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index 9cfce57fc..b46d9688f 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -24,7 +24,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.sharedvar import ScalarSharedVariable -from pymc.logprob.transforms import RVTransform +from pymc.distributions.transforms import Transform from pymc.model.core import Model from pymc.pytensorf import StringType, find_rng_nodes, toposort_replace @@ -59,8 +59,8 @@ def perform(self, *args, **kwargs): class ModelValuedVar(ModelVar): __props__ = ("transform",) - def __init__(self, transform: Optional[RVTransform] = None): - if transform is not None and not isinstance(transform, RVTransform): + def __init__(self, transform: Optional[Transform] = None): + if transform is not None and not isinstance(transform, Transform): raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}") self.transform = transform super().__init__() diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index faa31339f..7a8681b33 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -13,17 +13,14 @@ # limitations under the License. import warnings -from typing import Any, List, Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Optional, Sequence, Union -from pytensor import Variable from pytensor.graph import ancestors -from pytensor.graph.basic import walk -from pytensor.graph.op import HasInnerGraph from pytensor.tensor import TensorVariable -from pytensor.tensor.random.op import RandomVariable from pymc import Model -from pymc.logprob.transforms import RVTransform +from pymc.distributions.transforms import Transform +from pymc.logprob.utils import rvs_in_graph from pymc.model.fgraph import ( ModelDeterministic, ModelFreeRV, @@ -40,7 +37,7 @@ parse_vars, prune_vars_detached_from_observed, ) -from pymc.pytensorf import _replace_vars_in_graphs, toposort_replace +from pymc.pytensorf import replace_vars_in_graphs, toposort_replace from pymc.util import get_transformed_name, get_untransformed_name @@ -122,44 +119,6 @@ def observe( return model_from_fgraph(fgraph) -def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]: - def replacement_fn(var, inner_replacements): - if var in replacements: - inner_replacements[var] = replacements[var] - - # Handle root inputs as those will never be passed to the replacement_fn - for inp in var.owner.inputs: - if inp.owner is None and inp in replacements: - inner_replacements[inp] = replacements[inp] - - return [var] - - replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn) - return replaced_graphs - - -def rvs_in_graph(vars: Sequence[Variable]) -> bool: - """Check if there are any rvs in the graph of vars""" - - from pymc.distributions.distribution import SymbolicRandomVariable - - def expand(r): - owner = r.owner - if owner: - inputs = list(reversed(owner.inputs)) - - if isinstance(owner.op, HasInnerGraph): - inputs += owner.op.inner_outputs - - return inputs - - return any( - node - for node in walk(vars, expand, False) - if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable)) - ) - - def do( model: Model, vars_to_interventions: Mapping[Union["str", TensorVariable], Any], @@ -263,7 +222,7 @@ def do( def change_value_transforms( model: Model, - vars_to_transforms: Mapping[ModelVariable, Union[RVTransform, None]], + vars_to_transforms: Mapping[ModelVariable, Union[Transform, None]], ) -> Model: """Change the value variables transforms in the model diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 8480d0dcf..7e6a41f93 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -35,12 +35,13 @@ from pytensor import scalar from pytensor.compile import Function, Mode, get_mode from pytensor.gradient import grad -from pytensor.graph import Type, node_rewriter, rewrite_graph +from pytensor.graph import Type, rewrite_graph from pytensor.graph.basic import ( Apply, Constant, Variable, clone_get_equiv, + equal_computations, graph_inputs, walk, ) @@ -63,8 +64,6 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable from pymc.exceptions import NotConstantValueError -from pymc.logprob.transforms import RVTransform -from pymc.logprob.utils import CheckParameterValue from pymc.util import makeiter from pymc.vartypes import continuous_types, isgenerator, typefilter @@ -192,6 +191,8 @@ def walk_model( expand_fn A function that returns the next variable(s) to be traversed. """ + warnings.warn("walk_model will be removed in a future relase of PyMC", FutureWarning) + if stop_at_vars is None: stop_at_vars = set() @@ -206,197 +207,34 @@ def expand(var): yield from walk(graphs, expand, bfs=False) -def _replace_vars_in_graphs( - graphs: Iterable[TensorVariable], - replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]], - **kwargs, -) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]: - """Replace variables in graphs. - - This will *not* recompute test values. - - Parameters - ---------- - graphs - The graphs in which random variables are to be replaced. - replacement_fn - A callable called on each graph output that populates a replacement dictionary and returns - nodes that should be investigated further. - - Returns - ------- - Tuple containing the transformed graphs and a ``dict`` of the replacements - that were made. - """ - replacements = {} - - def expand_replace(var): - new_nodes = [] - if var.owner: - # Call replacement_fn to update replacements dict inplace and, optionally, - # specify new nodes that should also be walked for replacements. This - # includes `value` variables that are not simple input variables, and may - # contain other `random` variables in their graphs (e.g., IntervalTransform) - new_nodes.extend(replacement_fn(var, replacements)) - return new_nodes - - # This iteration populates the replacements - for var in walk_model(graphs, expand_fn=expand_replace, **kwargs): - pass - - if replacements: - inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = {k: k for k in replacements.keys()} - equiv = clone_get_equiv(inputs, graphs, False, False, equiv) - - fg = FunctionGraph( - [equiv[i] for i in inputs], - [equiv[o] for o in graphs], - clone=False, - ) - - # replacements have to be done in reverse topological order so that nested - # expressions get recursively replaced correctly - toposort = fg.toposort() - sorted_replacements = sorted( - tuple(replacements.items()), - # Root inputs don't have owner, we give them negative priority -1 - key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner is not None else -1, - reverse=True, - ) - fg.replace_all(sorted_replacements, import_missing=True) - - graphs = list(fg.outputs) - - return graphs, replacements - - -def rvs_to_value_vars( +def replace_vars_in_graphs( graphs: Iterable[Variable], - apply_transforms: bool = True, - **kwargs, + replacements: Dict[Variable, Variable], ) -> List[Variable]: - """Clone and replace random variables in graphs with their value variables. - - This will *not* recompute test values in the resulting graphs. + """Replace variables in graphs. - Parameters - ---------- - graphs - The graphs in which to perform the replacements. - apply_transforms - If ``True``, apply each value variable's transform. + Graphs are cloned and not modified in place. """ - warnings.warn( - "rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead", - FutureWarning, - ) - - def populate_replacements( - random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable] - ) -> List[TensorVariable]: - # Populate replacements dict with {rv: value} pairs indicating which graph - # RVs should be replaced by what value variables. - - value_var = getattr( - random_var.tag, "observations", getattr(random_var.tag, "value_var", None) - ) - - # No value variable to replace RV with - if value_var is None: - return [] - - transform = getattr(value_var.tag, "transform", None) - if transform is not None and apply_transforms: - # We want to replace uses of the RV by the back-transformation of its value - value_var = transform.backward(value_var, *random_var.owner.inputs) - - replacements[random_var] = value_var - - # Also walk the graph of the value variable to make any additional replacements - # if that is not a simple input variable - return [value_var] - - # Clone original graphs + # Clone graph and get equivalences inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = clone_get_equiv(inputs, graphs, False, False, {}) - graphs = [equiv[n] for n in graphs] + equiv = {k: k for k in replacements.keys()} + equiv = clone_get_equiv(inputs, graphs, False, False, equiv) - graphs, _ = _replace_vars_in_graphs( - graphs, - replacement_fn=populate_replacements, - **kwargs, + fg = FunctionGraph( + [equiv[i] for i in inputs], + [equiv[o] for o in graphs], + clone=False, ) - return graphs - + # Filter replacement keys that are actually present in the graph + vars = fg.variables + final_replacements = tuple((k, v) for k, v in replacements.items() if k in vars) -def replace_rvs_by_values( - graphs: Sequence[TensorVariable], - *, - rvs_to_values: Dict[TensorVariable, TensorVariable], - rvs_to_transforms: Optional[Dict[TensorVariable, RVTransform]] = None, - **kwargs, -) -> List[TensorVariable]: - """Clone and replace random variables in graphs with their value variables. + # Replacements have to be done in reverse topological order so that nested + # expressions get recursively replaced correctly + toposort_replace(fg, final_replacements, reverse=True) - This will *not* recompute test values in the resulting graphs. - - Parameters - ---------- - graphs - The graphs in which to perform the replacements. - rvs_to_values - Mapping between the original graph RVs and respective value variables - rvs_to_transforms, optional - Mapping between the original graph RVs and respective value transforms - """ - - # Clone original graphs so that we don't modify variables in place - inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = clone_get_equiv(inputs, graphs, False, False, {}) - graphs = [equiv[n] for n in graphs] - - # Get needed mappings for equivalent cloned variables - equiv_rvs_to_values = {} - equiv_rvs_to_transforms = {} - for rv, value in rvs_to_values.items(): - equiv_rv = equiv.get(rv, rv) - equiv_rvs_to_values[equiv_rv] = equiv.get(value, value) - if rvs_to_transforms is not None: - equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv] - - def poulate_replacements(rv, replacements): - # Populate replacements dict with {rv: value} pairs indicating which graph - # RVs should be replaced by what value variables. - - # No value variable to replace RV with - value = equiv_rvs_to_values.get(rv, None) - if value is None: - return [] - - if rvs_to_transforms is not None: - transform = equiv_rvs_to_transforms.get(rv, None) - if transform is not None: - # We want to replace uses of the RV by the back-transformation of its value - value = transform.backward(value, *rv.owner.inputs) - # The value may have a less precise type than the rv. In this case - # filter_variable will add a SpecifyShape to ensure they are consistent - value = rv.type.filter_variable(value, allow_convert=True) - value.name = rv.name - - replacements[rv] = value - # Also walk the graph of the value variable to make any additional - # replacements if that is not a simple input variable - return [value] - - graphs, _ = _replace_vars_in_graphs( - graphs, - replacement_fn=poulate_replacements, - **kwargs, - ) - - return graphs + return list(fg.outputs) def inputvars(a): @@ -899,48 +737,6 @@ def largest_common_dtype(tensors): return np.stack([np.ones((), dtype=dtype) for dtype in dtypes]).dtype -@node_rewriter(tracks=[CheckParameterValue]) -def local_remove_check_parameter(fgraph, node): - """Rewrite that removes CheckParameterValue - - This is used when compile_rv_inplace - """ - if isinstance(node.op, CheckParameterValue): - return [node.inputs[0]] - - -@node_rewriter(tracks=[CheckParameterValue]) -def local_check_parameter_to_ninf_switch(fgraph, node): - if not node.op.can_be_replaced_by_ninf: - return None - - logp_expr, *logp_conds = node.inputs - if len(logp_conds) > 1: - logp_cond = pt.all(logp_conds) - else: - (logp_cond,) = logp_conds - out = pt.switch(logp_cond, logp_expr, -np.inf) - out.name = node.op.msg - - if out.dtype != node.outputs[0].dtype: - out = pt.cast(out, node.outputs[0].dtype) - - return [out] - - -pytensor.compile.optdb["canonicalize"].register( - "local_remove_check_parameter", - local_remove_check_parameter, - use_db_name_as_tag=False, -) - -pytensor.compile.optdb["canonicalize"].register( - "local_check_parameter_to_ninf_switch", - local_check_parameter_to_ninf_switch, - use_db_name_as_tag=False, -) - - def find_rng_nodes( variables: Iterable[Variable], ) -> List[Union[RandomStateSharedVariable, RandomGeneratorSharedVariable]]: @@ -1045,57 +841,80 @@ def scan_step(xtm1): from pymc.distributions.distribution import SymbolicRandomVariable def find_default_update(clients, rng: Variable) -> Union[None, Variable]: + """Recursively find default update expression for rng. + + Returns None if no unambiguous update can be found. + """ rng_clients = clients.get(rng, None) - # Root case, RNG is not used elsewhere + # Root case, RNG is not used elsewhere and can be used safely if not rng_clients: return rng - if len(rng_clients) > 1: - warnings.warn( - f"RNG Variable {rng} has multiple clients. This is likely an inconsistent random graph.", - UserWarning, - ) + updates = [] + for client, _ in rng_clients: + # RNG is an output of the function, this is not a problem + if client == "output": + updates.append("output") + + # RNG is used by another operator, which should output an update for the RNG + elif isinstance(client.op, RandomVariable): + # RandomVariable first output is always the update of the input RNG + next_rng = client.outputs[0] + + elif isinstance(client.op, SymbolicRandomVariable): + # SymbolicRandomVariable have an explicit method that returns an + # update mapping for their RNG(s) + next_rng = client.op.update(client).get(rng) + if next_rng is None: + raise ValueError( + f"No update found for at least one RNG used in SymbolicRandomVariable Op {client.op}" + ) + elif isinstance(client.op, Scan): + # Check if any shared output corresponds to the RNG + rng_idx = client.inputs.index(rng) + io_map = client.op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"] + out_idx = io_map.get(rng_idx, -1) + if out_idx != -1: + next_rng = client.outputs[out_idx] + else: # No break + raise ValueError( + f"No update found for at least one RNG used in Scan Op {client.op}.\n" + "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." + ) + else: + # We don't know how this RNG should be updated by this strange Op (e.g., OpFromGraph). + # The user should provide an update manually + continue + + # Recurse until we find final update for RNG + final_update = find_default_update(clients, next_rng) + if final_update is not None: + updates.append(final_update) + + if len(updates) == 0: return None - [client, _] = rng_clients[0] + updates = [update for update in updates if update != "output"] - # RNG is an output of the function, this is not a problem - if client == "output": + # RNG was only used as an output + if len(updates) == 0: return rng - # RNG is used by another operator, which should output an update for the RNG - if isinstance(client.op, RandomVariable): - # RandomVariable first output is always the update of the input RNG - next_rng = client.outputs[0] - - elif isinstance(client.op, SymbolicRandomVariable): - # SymbolicRandomVariable have an explicit method that returns an - # update mapping for their RNG(s) - next_rng = client.op.update(client).get(rng) - if next_rng is None: - raise ValueError( - f"No update found for at least one RNG used in SymbolicRandomVariable Op {client.op}" - ) - elif isinstance(client.op, Scan): - # Check if any shared output corresponds to the RNG - rng_idx = client.inputs.index(rng) - io_map = client.op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"] - out_idx = io_map.get(rng_idx, -1) - if out_idx != -1: - next_rng = client.outputs[out_idx] - else: # No break - raise ValueError( - f"No update found for at least one RNG used in Scan Op {client.op}.\n" - "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." + update = updates[0] + + if len(updates) > 1: + # RNG was used in multiple places. This is only okay if graphs are equivalent + if not all( + equal_computations([update], [other_update]) for other_update in updates[1:] + ): + warnings.warn( + f"RNG Variable {rng} has multiple clients. This is likely an inconsistent random graph.", + UserWarning, ) - else: - # We don't know how this RNG should be updated (e.g., OpFromGraph). - # The user should provide an update manually - return None + return None - # Recurse until we find final update for RNG - return find_default_update(clients, next_rng) + return update if inputs is None: inputs = [] diff --git a/pymc/testing.py b/pymc/testing.py index 3eb1b7ba8..431599978 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -36,14 +36,12 @@ from pymc.distributions.shape_utils import change_dist_size from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import icdf, logcdf, logp, transformed_conditional_logp -from pymc.logprob.utils import ParameterValueError, find_rvs_in_graph -from pymc.pytensorf import ( - compile_pymc, - floatX, - inputvars, - intX, +from pymc.logprob.utils import ( + ParameterValueError, local_check_parameter_to_ninf_switch, + rvs_in_graph, ) +from pymc.pytensorf import compile_pymc, floatX, inputvars, intX # This mode can be used for tests where model compilations takes the bulk of the runtime # AND where we don't care about posterior numerical or sampling stability (e.g., when @@ -952,6 +950,6 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable: def assert_no_rvs(vars: Sequence[Variable]) -> None: """Assert that there are no `MeasurableVariable` nodes in a graph.""" - rvs = find_rvs_in_graph(vars) + rvs = rvs_in_graph(vars) if rvs: raise AssertionError(f"RV found in graph: {rvs}") diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index 9632efd85..770a27073 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -59,9 +59,13 @@ ) from pymc.distributions.mixture import MixtureTransformWarning from pymc.distributions.shape_utils import change_dist_size, to_tuple -from pymc.distributions.transforms import _default_transform +from pymc.distributions.transforms import ( + IntervalTransform, + LogTransform, + SimplexTransform, + _default_transform, +) from pymc.logprob.basic import logp -from pymc.logprob.transforms import IntervalTransform, LogTransform, SimplexTransform from pymc.math import expand_packed_triangular from pymc.model import Model from pymc.pytensorf import floatX diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index e9027dcf3..00aee57bf 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -13,21 +13,36 @@ # limitations under the License. -from typing import Union - import numpy as np import pytensor -import pytensor.tensor as pt import pytest from numpy.testing import assert_allclose, assert_array_equal +from pytensor import tensor as pt from pytensor.tensor.variable import TensorConstant import pymc as pm import pymc.distributions.transforms as tr +from pymc.distributions.transforms import ( + ArccoshTransform, + ArcsinhTransform, + ArctanhTransform, + ChainTransform, + CoshTransform, + ErfcTransform, + ErfcxTransform, + ErfTransform, + ExpTransform, + IntervalTransform, + LocTransform, + LogTransform, + ScaleTransform, + SinhTransform, + TanhTransform, + Transform, +) from pymc.logprob.basic import transformed_conditional_logp -from pymc.logprob.transforms import RVTransform from pymc.pytensorf import floatX, jacobian from pymc.testing import ( Circ, @@ -120,33 +135,199 @@ def check_jacobian_det( assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol) -def test_simplex(): - check_vector_transform(tr.simplex, Simplex(2)) - check_vector_transform(tr.simplex, Simplex(4)) +class TestTransformBase: + @pytest.mark.parametrize("ndim", (0, 1)) + def test_fallback_log_jac_det(self, ndim): + """ + Test fallback log_jac_det in RVTransform produces correct the graph for a + simple transformation: x**2 -> -log(2*x) + """ - check_transform( - tr.simplex, MultiSimplex(3, 2), constructor=pt.matrix, test=floatX(np.zeros((2, 2))) - ) + class SquareTransform(Transform): + name = "square" + ndim_supp = ndim + def forward(self, value, *inputs): + return pt.power(value, 2) -def test_simplex_bounds(): - vals = get_values(tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0]))) + def backward(self, value, *inputs): + return pt.sqrt(value) - assert_allclose(vals.sum(axis=1), 1, tol) - assert_array_equal(vals > 0, True) - assert_array_equal(vals < 1, True) + square_tr = SquareTransform() - check_jacobian_det( - tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), lambda x: x[:-1] + value = pt.vector("value", dtype="float64") + value_tr = square_tr.forward(value) + log_jac_det = square_tr.log_jac_det(value_tr) + + test_value = np.r_[3, 4] + expected_log_jac_det = -np.log(2 * test_value) + if ndim == 1: + expected_log_jac_det = expected_log_jac_det.sum() + np.testing.assert_array_equal(log_jac_det.eval({value: test_value}), expected_log_jac_det) + + @pytest.mark.parametrize("ndim", (None, 2)) + def test_fallback_log_jac_det_undefined_ndim(self, ndim): + class SquareTransform(Transform): + name = "square" + ndim_supp = ndim + + def forward(self, value, *inputs): + return pt.power(value, 2) + + def backward(self, value, *inputs): + return pt.sqrt(value) + + with pytest.raises( + NotImplementedError, match=r"only implemented for ndim_supp in \(0, 1\)" + ): + SquareTransform().log_jac_det(0) + + +class TestInvalidTransform: + def test_discrete_trafo(self): + with pm.Model(): + with pytest.raises(ValueError) as err: + pm.Binomial("a", n=5, p=0.5, transform="log") + err.match("Transformations for discrete distributions") + + def test_univariate_transform_multivariate_dist_raises(self): + with pm.Model() as m: + pm.Dirichlet("x", [1, 1, 1], transform=tr.log) + + for jacobian in (True, False): + with pytest.raises( + NotImplementedError, + match="Univariate transform LogTransform cannot be applied to multivariate", + ): + m.logp(jacobian=jacobian) + + def test_invalid_jacobian_broadcast_raises(self): + class BuggyTransform(Transform): + name = "buggy" + + def forward(self, value, *inputs): + return value + + def backward(self, value, *inputs): + return value + + def log_jac_det(self, value, *inputs): + return pt.zeros_like(value.sum(-1, keepdims=True)) + + buggy_transform = BuggyTransform() + + with pm.Model() as m: + pm.Uniform("x", shape=(4, 3), transform=buggy_transform) + + for jacobian in (True, False): + with pytest.raises( + ValueError, + match="are not allowed to broadcast together. There is a bug in the implementation of either one", + ): + m.logp(jacobian=jacobian) + + +class TestInterval: + def test_lowerbound(self): + trans = tr.Interval(0.0, None) + check_transform(trans, Rplusbig) + + check_jacobian_det(trans, Rplusbig, elemwise=True) + check_jacobian_det(trans, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True) + + vals = get_values(trans) + assert_array_equal(vals > 0, True) + + def test_upperbound(self): + trans = tr.Interval(None, 0.0) + check_transform(trans, Rminusbig) + + check_jacobian_det(trans, Rminusbig, elemwise=True) + check_jacobian_det(trans, Vector(Rminusbig, 2), pt.vector, [-1, -1], elemwise=True) + + vals = get_values(trans) + assert_array_equal(vals < 0, True) + + def test_interval(self): + for a, b in [(-4, 5.5), (0.1, 0.7), (-10, 4.3)]: + domain = Unit * np.float64(b - a) + np.float64(a) + + trans = tr.Interval(a, b) + check_transform(trans, domain) + + check_jacobian_det(trans, domain, elemwise=True) + + vals = get_values(trans) + assert_array_equal(vals > a, True) + assert_array_equal(vals < b, True) + + @pytest.mark.skipif( + pytensor.config.floatX == "float32", reason="Test is designed for 64bit precision" ) + def test_interval_near_boundary(self): + lb = -1.0 + ub = 1e-7 + x0 = np.nextafter(ub, lb) + + with pm.Model() as model: + pm.Uniform("x", initval=x0, lower=lb, upper=ub) + + log_prob = model.point_logps() + assert_allclose(list(log_prob.values()), floatX(np.array([-52.68]))) + + def test_invalid_interval_helper(self): + with pytest.raises(ValueError, match="Lower and upper interval bounds cannot both be None"): + tr.Interval(None, None) + + with pytest.raises(ValueError, match="Interval bounds must be constant values"): + tr.Interval(pt.constant(5) + 1, None) + assert tr.Interval(pt.constant(5), None) -def test_simplex_accuracy(): - val = floatX(np.array([-30])) - x = pt.vector("x") - x.tag.test_value = val - identity_f = pytensor.function([x], tr.simplex.forward(tr.simplex.backward(x))) - assert_allclose(val, identity_f(val), tol) + def test_invalid_interval_transform(self): + x_rv = pt.random.normal(0, 1) + x_vv = x_rv.clone() + + msg = "Both edges of IntervalTransform cannot be None" + tr = IntervalTransform(lambda *inputs: (None, None)) + with pytest.raises(ValueError, match=msg): + tr.forward(x_vv, *x_rv.owner.inputs) + + tr = IntervalTransform(lambda *inputs: (None, None)) + with pytest.raises(ValueError, match=msg): + tr.backward(x_vv, *x_rv.owner.inputs) + + tr = IntervalTransform(lambda *inputs: (None, None)) + with pytest.raises(ValueError, match=msg): + tr.log_jac_det(x_vv, *x_rv.owner.inputs) + + +class TestSimplex: + def test_simplex(self): + check_vector_transform(tr.simplex, Simplex(2)) + check_vector_transform(tr.simplex, Simplex(4)) + + check_transform( + tr.simplex, MultiSimplex(3, 2), constructor=pt.matrix, test=floatX(np.zeros((2, 2))) + ) + + def test_simplex_bounds(self): + vals = get_values(tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0]))) + + assert_allclose(vals.sum(axis=1), 1, tol) + assert_array_equal(vals > 0, True) + assert_array_equal(vals < 1, True) + + check_jacobian_det( + tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), lambda x: x[:-1] + ) + + def test_simplex_accuracy(self): + val = floatX(np.array([-30])) + x = pt.vector("x") + x.tag.test_value = val + identity_f = pytensor.function([x], tr.simplex.forward(tr.simplex.backward(x))) + assert_allclose(val, identity_f(val), tol) def test_sum_to_1(): @@ -154,7 +335,7 @@ def test_sum_to_1(): check_vector_transform(tr.sum_to_1, Simplex(4)) with pytest.warns(FutureWarning, match="ndim_supp argument is deprecated"): - tr.SumTo1(2) + tr.SumTo1Transform(2) check_jacobian_det( tr.sum_to_1, @@ -206,57 +387,6 @@ def test_logodds(): assert_array_equal(vals < 1, True) -def test_lowerbound(): - trans = tr.Interval(0.0, None) - check_transform(trans, Rplusbig) - - check_jacobian_det(trans, Rplusbig, elemwise=True) - check_jacobian_det(trans, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True) - - vals = get_values(trans) - assert_array_equal(vals > 0, True) - - -def test_upperbound(): - trans = tr.Interval(None, 0.0) - check_transform(trans, Rminusbig) - - check_jacobian_det(trans, Rminusbig, elemwise=True) - check_jacobian_det(trans, Vector(Rminusbig, 2), pt.vector, [-1, -1], elemwise=True) - - vals = get_values(trans) - assert_array_equal(vals < 0, True) - - -def test_interval(): - for a, b in [(-4, 5.5), (0.1, 0.7), (-10, 4.3)]: - domain = Unit * np.float64(b - a) + np.float64(a) - - trans = tr.Interval(a, b) - check_transform(trans, domain) - - check_jacobian_det(trans, domain, elemwise=True) - - vals = get_values(trans) - assert_array_equal(vals > a, True) - assert_array_equal(vals < b, True) - - -@pytest.mark.skipif( - pytensor.config.floatX == "float32", reason="Test is designed for 64bit precision" -) -def test_interval_near_boundary(): - lb = -1.0 - ub = 1e-7 - x0 = np.nextafter(ub, lb) - - with pm.Model() as model: - pm.Uniform("x", initval=x0, lower=lb, upper=ub) - - log_prob = model.point_logps() - assert_allclose(list(log_prob.values()), floatX(np.array([-52.68]))) - - def test_circular(): trans = tr.circular check_transform(trans, Circ) @@ -270,11 +400,47 @@ def test_circular(): assert isinstance(trans.forward(1, None), TensorConstant) +def test_triangular_transform(): + with pm.Model() as m: + x = pm.Triangular("x", lower=0, c=1, upper=2) + + transform = m.rvs_to_transforms[x] + assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0) + assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2) + + +@pytest.mark.parametrize( + "transform", + [ + ErfTransform(), + ErfcTransform(), + ErfcxTransform(), + SinhTransform(), + CoshTransform(), + TanhTransform(), + ArcsinhTransform(), + ArccoshTransform(), + ArctanhTransform(), + LogTransform(), + ExpTransform(), + ], +) +def test_check_jac_det(transform): + check_jacobian_det( + transform, + Vector(Rplusbig, 2), + pt.dvector, + [0.1, 0.1], + elemwise=True, + rv_var=pt.random.normal(0.5, 1, name="base_rv"), + ) + + def test_ordered(): check_vector_transform(tr.ordered, SortedVector(6)) with pytest.warns(FutureWarning, match="ndim_supp argument is deprecated"): - tr.Ordered(1) + tr.OrderedTransform(1) check_jacobian_det( tr.ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False @@ -284,24 +450,64 @@ def test_ordered(): assert_array_equal(np.diff(vals) >= 0, True) -def test_chain_values(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) - vals = get_values(chain_tranf, Vector(R, 5), pt.vector, floatX(np.zeros(5))) - assert_array_equal(np.diff(vals) >= 0, True) +class TestChain: + def test_chain_values(self): + chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + vals = get_values(chain_tranf, Vector(R, 5), pt.vector, floatX(np.zeros(5))) + assert_array_equal(np.diff(vals) >= 0, True) + def test_chain_vector_transform(self): + chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + check_vector_transform(chain_tranf, UnitSortedVector(3)) -def test_chain_vector_transform(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) - check_vector_transform(chain_tranf, UnitSortedVector(3)) + @pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.") + def test_chain_jacob_det(self): + chain_tranf = tr.Chain([tr.logodds, tr.ordered]) + check_jacobian_det( + chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False + ) + def test_chained_transform(self): + loc = 5 + scale = 0.1 + + ch = ChainTransform( + transform_list=[ + ScaleTransform( + transform_args_fn=lambda *inputs: pt.constant(scale), + ), + ExpTransform(), + LocTransform( + transform_args_fn=lambda *inputs: pt.constant(loc), + ), + ], + ) -@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.") -def test_chain_jacob_det(): - chain_tranf = tr.Chain([tr.logodds, tr.ordered]) - check_jacobian_det(chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False) + x = pt.random.multivariate_normal(np.zeros(3), np.eye(3)) + x_val = x.eval() + + x_val_forward = ch.forward(x_val, *x.owner.inputs).eval() + np.testing.assert_allclose( + x_val_forward, + np.exp(x_val * scale) + loc, + rtol=1e-6, + ) + + x_val_backward = ch.backward(x_val_forward, *x.owner.inputs, scale, loc).eval() + np.testing.assert_allclose( + x_val_backward, + x_val, + rtol=1e-5, + ) + + log_jac_det = ch.log_jac_det(x_val_forward, *x.owner.inputs, scale, loc) + np.testing.assert_allclose( + pt.sum(log_jac_det).eval(), + np.sum(-np.log(scale) - np.log(x_val_forward - loc)), + ) -class TestElementWiseLogp: +class TestTransformedRVLogp: def build_model(self, distfam, params, size, transform, initval=None): if initval is not None: initval = pm.floatX(initval) @@ -578,88 +784,23 @@ def test_mvnormal_transform(self, mu, cov, size, shape, transform): ) self.check_vectortransform_elementwise_logp(model) + def test_transform_univariate_dist_logp_shape(self): + with pm.Model() as m: + pm.Uniform("x", shape=(4, 3), transform=tr.logodds) -def test_triangular_transform(): - with pm.Model() as m: - x = pm.Triangular("x", lower=0, c=1, upper=2) - - transform = m.rvs_to_transforms[x] - assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0) - assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2) - - -def test_interval_transform_raises(): - with pytest.raises(ValueError, match="Lower and upper interval bounds cannot both be None"): - tr.Interval(None, None) - - with pytest.raises(ValueError, match="Interval bounds must be constant values"): - tr.Interval(pt.constant(5) + 1, None) - - assert tr.Interval(pt.constant(5), None) - - -def test_discrete_trafo(): - with pm.Model(): - with pytest.raises(ValueError) as err: - pm.Binomial("a", n=5, p=0.5, transform="log") - err.match("Transformations for discrete distributions") - - -def test_transform_univariate_dist_logp_shape(): - with pm.Model() as m: - pm.Uniform("x", shape=(4, 3), transform=tr.logodds) - - assert m.logp(jacobian=False, sum=False)[0].type.shape == (4, 3) - assert m.logp(jacobian=True, sum=False)[0].type.shape == (4, 3) - - with pm.Model() as m: - pm.Uniform("x", shape=(4, 3), transform=tr.ordered) - - assert m.logp(jacobian=False, sum=False)[0].type.shape == (4,) - assert m.logp(jacobian=True, sum=False)[0].type.shape == (4,) - - -def test_univariate_transform_multivariate_dist_raises(): - with pm.Model() as m: - pm.Dirichlet("x", [1, 1, 1], transform=tr.log) - - for jacobian in (True, False): - with pytest.raises( - NotImplementedError, - match="Univariate transform LogTransform cannot be applied to multivariate", - ): - m.logp(jacobian=jacobian) - - -def test_invalid_jacobian_broadcast_raises(): - class BuggyTransform(RVTransform): - name = "buggy" - - def forward(self, value, *inputs): - return value - - def backward(self, value, *inputs): - return value - - def log_jac_det(self, value, *inputs): - return pt.zeros_like(value.sum(-1, keepdims=True)) - - buggy_transform = BuggyTransform() + assert m.logp(jacobian=False, sum=False)[0].type.shape == (4, 3) + assert m.logp(jacobian=True, sum=False)[0].type.shape == (4, 3) - with pm.Model() as m: - pm.Uniform("x", shape=(4, 3), transform=buggy_transform) + with pm.Model() as m: + pm.Uniform("x", shape=(4, 3), transform=tr.ordered) - for jacobian in (True, False): - with pytest.raises( - ValueError, - match="are not allowed to broadcast together. There is a bug in the implementation of either one", - ): - m.logp(jacobian=jacobian) + assert m.logp(jacobian=False, sum=False)[0].type.shape == (4,) + assert m.logp(jacobian=True, sum=False)[0].type.shape == (4,) def test_deprecated_ndim_supp_transforms(): with pytest.warns(FutureWarning, match="deprecated"): - tr.Ordered(ndim_supp=1) + tr.OrderedTransform(ndim_supp=1) with pytest.warns(FutureWarning, match="deprecated"): assert tr.univariate_ordered == tr.ordered @@ -668,7 +809,7 @@ def test_deprecated_ndim_supp_transforms(): assert tr.multivariate_ordered == tr.ordered with pytest.warns(FutureWarning, match="deprecated"): - tr.SumTo1(ndim_supp=1) + tr.SumTo1Transform(ndim_supp=1) with pytest.warns(FutureWarning, match="deprecated"): assert tr.univariate_sum_to_1 == tr.sum_to_1 diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index d9d007c51..8f25c4908 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -22,12 +22,11 @@ from pymc import Censored, Model, draw, find_MAP from pymc.distributions.continuous import Exponential, Gamma, TruncatedNormalRV from pymc.distributions.shape_utils import change_dist_size -from pymc.distributions.transforms import _default_transform +from pymc.distributions.transforms import IntervalTransform, _default_transform from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated from pymc.exceptions import TruncationError from pymc.logprob.abstract import _icdf from pymc.logprob.basic import logcdf, logp -from pymc.logprob.transforms import IntervalTransform from pymc.logprob.utils import ParameterValueError from pymc.testing import assert_moment_is_expected diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index 456a8f277..38a75993b 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -44,17 +44,10 @@ from pytensor.graph.basic import ancestors, equal_computations from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.subtensor import ( - AdvancedIncSubtensor, - AdvancedIncSubtensor1, - AdvancedSubtensor, - AdvancedSubtensor1, - IncSubtensor, - Subtensor, -) import pymc as pm +from pymc.distributions.transforms import LogTransform from pymc.logprob.basic import ( conditional_logp, icdf, @@ -62,14 +55,12 @@ logp, transformed_conditional_logp, ) -from pymc.logprob.transforms import LogTransform -from pymc.logprob.utils import rvs_to_value_vars, walk_model -from pymc.pytensorf import replace_rvs_by_values +from pymc.logprob.utils import replace_rvs_by_values from pymc.testing import assert_no_rvs -def test_factorized_joint_logprob_basic(): - # A simple check for when `factorized_joint_logprob` is the same as `logprob` +def test_conditional_logp_basic(): + # A simple check for when `conditional_logp` is the same as `logprob` a = pt.random.uniform(0.0, 1.0) a.name = "a" a_value_var = a.clone() @@ -93,9 +84,9 @@ def test_factorized_joint_logprob_basic(): # We need to replace the reference to `sigma` in `Y` with its value # variable ll_Y = logp(Y, y_value_var) - (ll_Y,), _ = rvs_to_value_vars( + (ll_Y,) = replace_rvs_by_values( [ll_Y], - initial_replacements={sigma: sigma_value_var}, + rvs_to_values={sigma: sigma_value_var}, ) total_ll_exp = logp(sigma, sigma_value_var) + ll_Y @@ -118,13 +109,13 @@ def test_factorized_joint_logprob_basic(): # There shouldn't be any `RandomVariable`s in the resulting graph assert_no_rvs(b_logp_combined) - res_ancestors = list(walk_model((b_logp_combined,), walk_past_rvs=True)) + res_ancestors = list(ancestors((b_logp_combined,))) assert b_value_var in res_ancestors assert c_value_var in res_ancestors assert a_value_var in res_ancestors -def test_factorized_joint_logprob_multi_obs(): +def test_conditional_logp_multi_obs(): a = pt.random.uniform(0.0, 1.0) b = pt.random.normal(0.0, 1.0) @@ -151,7 +142,7 @@ def test_factorized_joint_logprob_multi_obs(): assert equal_computations([logp_res_comb], [exp_logp_comb]) -def test_factorized_joint_logprob_diff_dims(): +def test_conditional_logp_diff_dims(): M = pt.matrix("M") x = pt.random.normal(0, 1, size=M.shape[1], name="X") y = pt.random.normal(M.dot(x), 1, name="Y") @@ -177,27 +168,13 @@ def test_factorized_joint_logprob_diff_dims(): assert exp_logp_val == pytest.approx(logp_val) -def test_incsubtensor_original_values_output_dict(): - """ - Test that the original un-incsubtensor value variable appears an the key of - the logprob factor - """ - - base_rv = pt.random.normal(0, 1, size=2) - rv = pt.set_subtensor(base_rv[0], 5) - vv = rv.clone() - - logp_dict = conditional_logp({rv: vv}) - assert vv in logp_dict - - def test_persist_inputs(): """Make sure we don't unnecessarily clone variables.""" x = pt.scalar("x") beta_rv = pt.random.normal(0, 1, name="beta") Y_rv = pt.random.normal(beta_rv * x, 1, name="y") - beta_vv = beta_rv.type() + beta_vv = beta_rv.clone() y_vv = Y_rv.clone() logp = conditional_logp({beta_rv: beta_vv, Y_rv: y_vv}) @@ -207,6 +184,7 @@ def test_persist_inputs(): # Make sure we don't clone value variables when they're graphs. y_vv_2 = y_vv * 2 + y_vv_2.name = "y_2" logp_2 = conditional_logp({beta_rv: beta_vv, Y_rv: y_vv_2}) logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()]) @@ -223,7 +201,7 @@ def test_persist_inputs(): assert y_vv_2 in ancestors([logp_2_combined]) -def test_warn_random_found_factorized_joint_logprob(): +def test_warn_random_found_conditional_logp(): x_rv = pt.random.normal(name="x") y_rv = pt.random.normal(x_rv, 1, name="y") @@ -248,7 +226,7 @@ def test_multiple_rvs_to_same_value_raises(): conditional_logp({x_rv1: x, x_rv2: x}) -def test_joint_logp_basic(): +def test_transformed_conditional_logp(): """Make sure we can compute a log-likelihood for a hierarchical model with transforms.""" with pm.Model() as m: @@ -274,60 +252,12 @@ def test_joint_logp_basic(): # There shouldn't be any `RandomVariable`s in the resulting graph assert_no_rvs(b_logp) - res_ancestors = list(walk_model((b_logp,))) + res_ancestors = list(ancestors((b_logp,))) assert b_value_var in res_ancestors assert c_value_var in res_ancestors assert a_value_var in res_ancestors -@pytest.mark.parametrize( - "indices, size", - [ - (slice(0, 2), 5), - (np.r_[True, True, False, False, True], 5), - (np.r_[0, 1, 4], 5), - ((np.array([0, 1, 4]), np.array([0, 1, 4])), (5, 5)), - ], -) -def test_joint_logp_incsubtensor(indices, size): - """Make sure we can compute a log-likelihood for ``Y[idx] = data`` where ``Y`` is univariate.""" - - mu = pm.floatX(np.power(10, np.arange(np.prod(size)))).reshape(size) - data = mu[indices] - sigma = 0.001 - rng = np.random.RandomState(232) - a_val = rng.normal(mu, sigma, size=size).astype(pytensor.config.floatX) - - rng = pytensor.shared(rng, borrow=False) - a = pm.Normal.dist(mu, sigma, size=size, rng=rng) - a_value_var = a.type() - a.name = "a" - - a_idx = pt.set_subtensor(a[indices], data) - - assert isinstance(a_idx.owner.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)) - - a_idx_value_var = a_idx.type() - a_idx_value_var.name = "a_idx_value" - - a_idx_logp = transformed_conditional_logp( - (a_idx,), - rvs_to_values={a_idx: a_value_var}, - rvs_to_transforms={}, - ) - - logp_vals = a_idx_logp[0].eval({a_value_var: a_val}) - - # The indices that were set should all have the same log-likelihood values, - # because the values they were set to correspond to the unique means along - # that dimension. This helps us confirm that the log-likelihood is - # associating the assigned values with their correct parameters. - a_val_idx = a_val.copy() - a_val_idx[indices] = data - exp_obs_logps = sp.norm.logpdf(a_val_idx, mu, sigma) - np.testing.assert_almost_equal(logp_vals, exp_obs_logps) - - def test_model_unchanged_logprob_access(): # Issue #5007 with pm.Model() as model: diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index 46c0a69d3..38f11f018 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -42,12 +42,12 @@ import scipy.stats as st from pymc import logp +from pymc.distributions.transforms import LogTransform from pymc.logprob import conditional_logp -from pymc.logprob.transforms import LogTransform, TransformValuesRewrite +from pymc.logprob.transforms import TransformValuesRewrite from pymc.testing import assert_no_rvs -@pytensor.config.change_flags(compute_test_value="raise") def test_continuous_rv_clip(): x_rv = pt.random.normal(0.5, 1) cens_x_rv = pt.clip(x_rv, -2, 2) @@ -180,7 +180,7 @@ def test_fail_base_and_clip_have_values(): x_vv = x_rv.clone() cens_x_vv = cens_x_rv.clone() - with pytest.raises(RuntimeError, match="could not be derived: {cens_x}"): + with pytest.raises(RuntimeError, match=r"could not be derived: {cens_x}"): conditional_logp({cens_x_rv: cens_x_vv, x_rv: x_vv}) @@ -194,7 +194,7 @@ def test_fail_multiple_clip_single_base(): cens_vv1 = cens_rv1.clone() cens_vv2 = cens_rv2.clone() - with pytest.raises(RuntimeError, match="could not be derived: {cens2}"): + with pytest.raises(ValueError, match="too many values to unpack"): conditional_logp({cens_rv1: cens_vv1, cens_rv2: cens_vv2}) diff --git a/tests/logprob/test_composite_logprob.py b/tests/logprob/test_composite_logprob.py index 607a48711..2d075e203 100644 --- a/tests/logprob/test_composite_logprob.py +++ b/tests/logprob/test_composite_logprob.py @@ -121,6 +121,7 @@ def test_nested_scalar_mixtures(): assert np.isclose(logp_fn(0, 0, 1, 50), st.norm.logpdf(150) + np.log(0.5) * 3) +@pytest.mark.xfail(reason="This is not currently enforced") @pytest.mark.parametrize("nested", (False, True)) def test_unvalued_ir_reversion(nested): """Make sure that un-valued IR rewrites are reverted.""" @@ -135,7 +136,7 @@ def test_unvalued_ir_reversion(nested): # measurable IR. rv_values = {z_rv: z_vv} - z_fgraph, _, memo = construct_ir_fgraph(rv_values) + z_fgraph, _, _ = construct_ir_fgraph(rv_values) # assert len(z_fgraph.preserve_rv_mappings.measurable_conversions) == 1 assert ( diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 1d2801592..7bbd24257 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -94,10 +94,10 @@ def create_mix_model(size, axis): x_vv = X_rv.clone() x_vv.name = "x" - with pytest.raises(RuntimeError, match="could not be derived: {m}"): + with pytest.raises(RuntimeError, match=r"could not be derived: {m}"): conditional_logp({M_rv: m_vv, I_rv: i_vv, X_rv: x_vv}) - with pytest.raises(RuntimeError, match="could not be derived: {m}"): + with pytest.raises(RuntimeError, match=r"could not be derived: {m}"): axis_at = pt.lscalar("axis") axis_at.tag.test_value = 0 env = create_mix_model((2,), axis_at) @@ -108,40 +108,6 @@ def create_mix_model(size, axis): conditional_logp({M_rv: m_vv, I_rv: i_vv}) -@pytensor.config.change_flags(compute_test_value="warn") -@pytest.mark.parametrize( - "op_constructor", - [ - lambda _I, _X, _Y: pt.stack([_X, _Y])[_I], - lambda _I, _X, _Y: pt.switch(_I, _X, _Y), - ], -) -def test_compute_test_value(op_constructor): - X_rv = pt.random.normal(0, 1, name="X") - Y_rv = pt.random.gamma(0.5, scale=2.0, name="Y") - - p_at = pt.scalar("p") - p_at.tag.test_value = 0.3 - - I_rv = pt.random.bernoulli(p_at, name="I") - - i_vv = I_rv.clone() - i_vv.name = "i" - - M_rv = op_constructor(I_rv, X_rv, Y_rv) - M_rv.name = "M" - - m_vv = M_rv.clone() - m_vv.name = "m" - - del M_rv.tag.test_value - - M_logp = conditional_logp({M_rv: m_vv, I_rv: i_vv}) - M_logp_combined = pt.add(*M_logp.values()) - - assert isinstance(M_logp_combined.tag.test_value, np.ndarray) - - @pytest.mark.parametrize( "p_val, size, supported", [ @@ -183,7 +149,7 @@ def test_hetero_mixture_binomial(p_val, size, supported): M_logp = conditional_logp({M_rv: m_vv, I_rv: i_vv}) M_logp_combined = pt.add(*M_logp.values()) else: - with pytest.raises(RuntimeError, match="could not be derived: {m}"): + with pytest.raises(RuntimeError, match=r"could not be derived: {m}"): conditional_logp({M_rv: m_vv, I_rv: i_vv}) return @@ -589,7 +555,7 @@ def test_hetero_mixture_categorical( if supported: logp_parts = conditional_logp({M_rv: m_vv, I_rv: i_vv}, sum=False) else: - with pytest.raises(RuntimeError, match="could not be derived: {m}"): + with pytest.raises(RuntimeError, match=r"could not be derived: {m}"): conditional_logp({M_rv: m_vv, I_rv: i_vv}, sum=False) return @@ -921,7 +887,9 @@ def test_scalar_switch_mixture(): z_vv.name = "z1" fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) + ir_valued = fgraph.outputs[0] + ir_rv = ir_valued.owner.inputs[0] + assert isinstance(ir_rv.owner.op, MeasurableSwitchMixture) # building the identical graph but with a stack to check that mixture logps are identical Z2_rv = pt.stack((Y_rv, X_rv))[I_rv] @@ -993,16 +961,21 @@ def test_switch_mixture_invalid_bcast(): valid_mix = pt.switch(valid_switch_cond, valid_true_branch, valid_false_branch) fgraph, _, _ = construct_ir_fgraph({valid_mix: valid_mix.type()}) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) + [ir_valued] = fgraph.outputs + ir_rv = ir_valued.owner.inputs[0] + assert isinstance(ir_rv.owner.op, MeasurableSwitchMixture) invalid_mix = pt.switch(invalid_switch_cond, valid_true_branch, valid_false_branch) fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) - assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + [ir_valued] = fgraph.outputs + ir_rv = ir_valued.owner.inputs[0] + assert not isinstance(ir_rv.owner.op, MeasurableVariable) invalid_mix = pt.switch(valid_switch_cond, valid_true_branch, invalid_false_branch) fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) - assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + [ir_valued] = fgraph.outputs + ir_rv = ir_valued.owner.inputs[0] + assert not isinstance(ir_rv.owner.op, MeasurableVariable) def test_ifelse_mixture_one_component(): @@ -1035,7 +1008,7 @@ def test_ifelse_mixture_multiple_components(): rng = np.random.default_rng(968) if_var = pt.scalar("if_var", dtype="bool") - comp_then1 = pt.random.normal(size=(2,), name="comp_true1") + comp_then1 = pt.random.normal(size=(2,), name="comp_then1") comp_then2 = comp_then1 + pt.random.normal(size=(2, 2), name="comp_then2") comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1") comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2") @@ -1073,11 +1046,15 @@ def test_ifelse_mixture_shared_component(): # comp_shared need not be an output of ifelse at all, # but since we allow arbitrary graphs we test it works as expected. comp_shared = pt.random.normal(size=(2,), name="comp_shared") - comp_then = outer_rv + pt.random.normal(comp_shared, 1, size=(4, 2), name="comp_then") - comp_else = outer_rv + pt.random.normal(comp_shared, 10, size=(8, 2), name="comp_else") + comp_then = outer_rv + pt.random.normal(comp_shared, 1, size=(4, 2)) + comp_then.name = "comp_then" + comp_else = outer_rv + pt.random.normal(comp_shared, 10, size=(8, 2)) + comp_else.name = "comp_else" shared_rv, mix_rv = ifelse( if_var, [comp_shared, comp_then], [comp_shared, comp_else], name="mix" ) + shared_rv.name = "shared" + mix_rv.name = "mix" outer_vv = outer_rv.clone() shared_vv = shared_rv.clone() diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 8eae026c0..b3152cddd 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -52,10 +52,10 @@ def test_argmax(): x = pt.random.normal(0, 1, size=(3,)) x.name = "x" x_max = pt.argmax(x, axis=-1) - x_max_value = pt.vector("x_max_value") + x_max_value = pt.scalar("x_max_value", dtype=x_max.type.dtype) with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented for Argmax")): - x_max_logprob = logp(x_max, x_max_value) + logp(x_max, x_max_value) @pytest.mark.parametrize( @@ -70,7 +70,7 @@ def test_non_iid_fails(pt_op): x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) x.name = "x" x_m = pt_op(x, axis=-1) - x_m_value = pt.vector("x_value") + x_m_value = pt.scalar("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value) @@ -87,7 +87,7 @@ def test_non_rv_fails(pt_op): x = pt.exp(pt.random.beta(0, 1, size=(3,))) x.name = "x" x_m = pt_op(x, axis=-1) - x_m_value = pt.vector("x_value") + x_m_value = pt.scalar("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value) @@ -105,7 +105,7 @@ def test_multivariate_rv_fails(pt_op): x = pm.StickBreakingWeights.dist(_alpha, _k) x.name = "x" x_m = pt_op(x, axis=-1) - x_m_value = pt.vector("x_value") + x_m_value = pt.scalar("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value) @@ -122,7 +122,7 @@ def test_categorical(pt_op): x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) x.name = "x" x_m = pt_op(x, axis=-1) - x_m_value = pt.vector("x_value") + x_m_value = pt.scalar("x_value", dtype=x.type.dtype) with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value) @@ -228,7 +228,7 @@ def test_min_non_mul_elemwise_fails(): x = pt.log(pt.random.beta(0, 1, size=(3,))) x.name = "x" x_min = pt.min(x, axis=-1) - x_min_value = pt.vector("x_min_value") + x_min_value = pt.scalar("x_min_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_min_logprob = logp(x_min, x_min_value) @@ -238,9 +238,9 @@ def test_min_non_mul_elemwise_fails(): [(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], ) def test_max_discrete(mu, size, value, axis): - x = pm.Poisson.dist(name="x", mu=mu, size=(size)) + x = pm.Poisson.dist(name="x", mu=mu, size=size) x_max = pt.max(x, axis=axis) - x_max_value = pt.scalar("x_max_value") + x_max_value = pt.scalar("x_max_value", dtype=x.type.dtype) x_max_logprob = logp(x_max, x_max_value) test_value = value diff --git a/tests/logprob/test_rewriting.py b/tests/logprob/test_rewriting.py index a44fe41f6..193fc514b 100644 --- a/tests/logprob/test_rewriting.py +++ b/tests/logprob/test_rewriting.py @@ -108,41 +108,3 @@ def test_local_remove_TransformedVariable(): assert not any( isinstance(v.owner.op, TransformedVariable) for v in ancestors([p_logp]) if v.owner ) - - -@pytest.mark.parametrize( - "indices, size", - [ - (slice(0, 2), 5), - (np.r_[True, True, False, False, True], 5), - (np.r_[0, 1, 4], 5), - ((np.array([0, 1, 4]), np.array([0, 1, 4])), (5, 5)), - ], -) -def test_joint_logprob_incsubtensor(indices, size): - """Make sure we can compute a joint log-probability for ``Y[idx] = data`` where ``Y`` is univariate.""" - - rng = np.random.RandomState(232) - mu = np.power(10, np.arange(np.prod(size))).reshape(size) - sigma = 0.001 - data = rng.normal(mu[indices], 1.0) - y_val = rng.normal(mu, sigma, size=size) - - Y_base_rv = pt.random.normal(mu, sigma, size=size) - Y_rv = pt.set_subtensor(Y_base_rv[indices], data) - Y_rv.name = "Y" - y_value_var = Y_rv.clone() - y_value_var.name = "y" - - assert isinstance(Y_rv.owner.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)) - - Y_rv_logp = conditional_logp({Y_rv: y_value_var}) - Y_rv_logp_combined = pt.add(*Y_rv_logp.values()) - - obs_logps = Y_rv_logp_combined.eval({y_value_var: y_val}) - - y_val_idx = y_val.copy() - y_val_idx[indices] = data - exp_obs_logps = sp.norm.logpdf(y_val_idx, mu, sigma) - - np.testing.assert_almost_equal(obs_logps, exp_obs_logps) diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index 187810770..beb703dd5 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -47,40 +47,9 @@ from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.rewriting import logprob_rewrites_db -from pymc.logprob.tensor import naive_bcast_rv_lift from pymc.testing import assert_no_rvs -def test_naive_bcast_rv_lift(): - r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `Alloc`\s.""" - X_rv = pt.random.normal() - Z_at = Alloc()(X_rv, *()) - - # Make sure we're testing what we intend to test - assert isinstance(Z_at.owner.op, Alloc) - - res = rewrite_graph(Z_at, custom_rewrite=in2out(naive_bcast_rv_lift), clone=False) - assert res is X_rv - - -def test_naive_bcast_rv_lift_valued_var(): - r"""Check that `naive_bcast_rv_lift` won't touch valued variables""" - - x_rv = pt.random.normal(name="x") - broadcasted_x_rv = pt.broadcast_to(x_rv, (2,)) - - y_rv = pt.random.normal(broadcasted_x_rv, name="y") - - x_vv = x_rv.clone() - y_vv = y_rv.clone() - logp_map = conditional_logp({x_rv: x_vv, y_rv: y_vv}) - assert x_vv in logp_map - assert y_vv in logp_map - assert len(logp_map) == 2 - assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0)) - assert np.allclose(logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0])) - - @pytest.mark.xfail(RuntimeError, reason="logprob for broadcasted RVs not implemented") def test_bcast_rv_logp(): """Test that derived logp for broadcasted RV is correct""" @@ -177,13 +146,17 @@ def test_measurable_make_vector_interdependent(reverse): @pytest.mark.parametrize("reverse", (False, True)) -def test_measurable_join_interdependent(reverse): +@pytest.mark.parametrize("nested_expr", (True, False)) +def test_measurable_join_interdependent(reverse, nested_expr): """Test that we can obtain a proper graph when stacked RVs depend on each other""" x = pt.random.normal(name="x") y_rvs = [] prev_rv = x for i in range(3): - next_rv = pt.random.normal(prev_rv + 1, name=f"y{i}", size=(1, 2)) + if nested_expr: + next_rv = pt.random.normal(prev_rv + 1, name=f"y{i}", size=(1, 2)) + else: + next_rv = pt.random.normal(0, name=f"y{i}", size=(1, 2)) + prev_rv + 1 y_rvs.append(next_rv) prev_rv = next_rv diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index efa1263a3..70878a448 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -48,33 +48,28 @@ from pytensor.scan import scan from pymc.distributions.continuous import Cauchy -from pymc.distributions.transforms import _default_transform, log, logodds -from pymc.logprob.abstract import MeasurableVariable, _logprob -from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp -from pymc.logprob.transforms import ( +from pymc.distributions.transforms import ( ArccoshTransform, ArcsinhTransform, ArctanhTransform, - ChainedTransform, CoshTransform, ErfcTransform, ErfcxTransform, ErfTransform, ExpTransform, - IntervalTransform, - LocTransform, LogOddsTransform, LogTransform, - RVTransform, - ScaleTransform, SinhTransform, TanhTransform, - TransformValuesMapping, - TransformValuesRewrite, + _default_transform, + log, + logodds, ) +from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp +from pymc.logprob.transforms import TransformValuesMapping, TransformValuesRewrite from pymc.logprob.utils import ParameterValueError -from pymc.testing import Rplusbig, Vector, assert_no_rvs -from tests.distributions.test_transform import check_jacobian_det +from pymc.testing import assert_no_rvs class DirichletScipyDist: @@ -633,135 +628,6 @@ def scan_step(prev_innov): np.testing.assert_allclose(logp_fn(**test_point), ref_logp_fn(**test_point)) -class TestRVTransform: - @pytest.mark.parametrize("ndim", (0, 1)) - def test_fallback_log_jac_det(self, ndim): - """ - Test fallback log_jac_det in RVTransform produces correct the graph for a - simple transformation: x**2 -> -log(2*x) - """ - - class SquareTransform(RVTransform): - name = "square" - ndim_supp = ndim - - def forward(self, value, *inputs): - return pt.power(value, 2) - - def backward(self, value, *inputs): - return pt.sqrt(value) - - square_tr = SquareTransform() - - value = pt.vector("value") - value_tr = square_tr.forward(value) - log_jac_det = square_tr.log_jac_det(value_tr) - - test_value = np.r_[3, 4] - expected_log_jac_det = -np.log(2 * test_value) - if ndim == 1: - expected_log_jac_det = expected_log_jac_det.sum() - np.testing.assert_array_equal(log_jac_det.eval({value: test_value}), expected_log_jac_det) - - @pytest.mark.parametrize("ndim", (None, 2)) - def test_fallback_log_jac_det_undefined_ndim(self, ndim): - class SquareTransform(RVTransform): - name = "square" - ndim_supp = ndim - - def forward(self, value, *inputs): - return pt.power(value, 2) - - def backward(self, value, *inputs): - return pt.sqrt(value) - - with pytest.raises( - NotImplementedError, match=r"only implemented for ndim_supp in \(0, 1\)" - ): - SquareTransform().log_jac_det(0) - - def test_invalid_interval_transform(self): - x_rv = pt.random.normal(0, 1) - x_vv = x_rv.clone() - - msg = "Both edges of IntervalTransform cannot be None" - tr = IntervalTransform(lambda *inputs: (None, None)) - with pytest.raises(ValueError, match=msg): - tr.forward(x_vv, *x_rv.owner.inputs) - - tr = IntervalTransform(lambda *inputs: (None, None)) - with pytest.raises(ValueError, match=msg): - tr.backward(x_vv, *x_rv.owner.inputs) - - tr = IntervalTransform(lambda *inputs: (None, None)) - with pytest.raises(ValueError, match=msg): - tr.log_jac_det(x_vv, *x_rv.owner.inputs) - - def test_chained_transform(self): - loc = 5 - scale = 0.1 - - ch = ChainedTransform( - transform_list=[ - ScaleTransform( - transform_args_fn=lambda *inputs: pt.constant(scale), - ), - ExpTransform(), - LocTransform( - transform_args_fn=lambda *inputs: pt.constant(loc), - ), - ], - base_op=pt.random.multivariate_normal, - ) - - x = pt.random.multivariate_normal(np.zeros(3), np.eye(3)) - x_val = x.eval() - - x_val_forward = ch.forward(x_val, *x.owner.inputs).eval() - np.testing.assert_allclose( - x_val_forward, - np.exp(x_val * scale) + loc, - ) - - x_val_backward = ch.backward(x_val_forward, *x.owner.inputs, scale, loc).eval() - np.testing.assert_allclose( - x_val_backward, - x_val, - ) - - log_jac_det = ch.log_jac_det(x_val_forward, *x.owner.inputs, scale, loc) - np.testing.assert_allclose( - pt.sum(log_jac_det).eval(), - np.sum(-np.log(scale) - np.log(x_val_forward - loc)), - ) - - @pytest.mark.parametrize( - "transform", - [ - ErfTransform(), - ErfcTransform(), - ErfcxTransform(), - SinhTransform(), - CoshTransform(), - TanhTransform(), - ArcsinhTransform(), - ArccoshTransform(), - ArctanhTransform(), - LogTransform(), - ExpTransform(), - ], - ) - def test_check_jac_det(self, transform): - check_jacobian_det( - transform, - Vector(Rplusbig, 2), - pt.dvector, - [0.1, 0.1], - elemwise=True, - rv_var=pt.random.normal(0.5, 1, name="base_rv"), - ) - - def test_exp_transform_rv(): base_rv = pt.random.normal(0, 1, size=3, name="base_rv") y_rv = pt.exp(base_rv) @@ -1010,8 +876,8 @@ def test_sqrt_transform(self): # ICDF is not implemented for chisquare, so we have to test with another identity # sqrt(exponential(lam)) = rayleigh(1 / sqrt(2 * lam)) lam = 2.5 - y_rv = pt.sqrt(pt.random.exponential(scale=1 / lam)) - y_vv = x_rv.clone() + y_rv = pt.sqrt(pt.random.exponential(scale=1 / lam, size=(4,))) + y_vv = y_rv.clone() y_icdf_fn = pytensor.function([y_vv], icdf(y_rv, y_vv)) q_test_val = np.r_[0.2, 0.5, 0.7, 0.9] np.testing.assert_allclose( @@ -1210,7 +1076,7 @@ def test_base_exponent_non_measurable(): with pytest.raises( RuntimeError, - match="The logprob terms of the following value variables could not be derived: {x}", + match=r"The logprob terms of the following value variables could not be derived: {x}", ): conditional_logp({x_rv: x_vv}) diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index 320de6a36..492461363 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -34,127 +34,238 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import warnings - import numpy as np import pytensor -import pytensor.tensor as pt import pytest from pytensor import function +from pytensor import tensor as pt from pytensor.compile import get_default_mode -from pytensor.tensor.random.basic import normal, uniform +from pytensor.graph.basic import ancestors, equal_computations +from pytensor.tensor.random.op import RandomVariable import pymc as pm -from pymc.logprob.abstract import MeasurableVariable -from pymc.logprob.basic import logp, transformed_conditional_logp +from pymc import SymbolicRandomVariable +from pymc.distributions.transforms import Interval +from pymc.logprob.abstract import MeasurableVariable, valued_rv +from pymc.logprob.basic import logp from pymc.logprob.utils import ( ParameterValueError, check_potential_measurability, dirac_delta, - rvs_to_value_vars, - walk_model, + replace_rvs_by_values, ) from pymc.testing import assert_no_rvs from tests.logprob.utils import create_pytensor_params, scipy_logprob_tester -def test_walk_model(): - d = pt.vector("d") - b = pt.vector("b") - c = uniform(0.0, d) - c.name = "c" - e = pt.log(c) - a = normal(e, b) - a.name = "a" - - test_graph = pt.exp(a + 1) - res = list(walk_model((test_graph,))) - assert a in res - assert c not in res - - res = list(walk_model((test_graph,), walk_past_rvs=True)) - assert a in res - assert c in res - - res = list(walk_model((test_graph,), walk_past_rvs=True, stop_at_vars={e})) - assert a in res - assert c not in res - - -def test_rvs_to_value_vars(): - a = pt.random.uniform(0.0, 1.0) - a.name = "a" - a.tag.value_var = a_value_var = a.clone() - - b = pt.random.uniform(0, a + 1.0) - b.name = "b" - b.tag.value_var = b_value_var = b.clone() - - c = pt.random.normal() - c.name = "c" - c.tag.value_var = c_value_var = c.clone() - - d = pt.log(c + b) + 2.0 - - initial_replacements = {b: b_value_var, c: c_value_var} - (res,), replaced = rvs_to_value_vars((d,), initial_replacements=initial_replacements) - - assert res.owner.op == pt.add - log_output = res.owner.inputs[0] - assert log_output.owner.op == pt.log - log_add_output = res.owner.inputs[0].owner.inputs[0] - assert log_add_output.owner.op == pt.add - c_output = log_add_output.owner.inputs[0] - - # We make sure that the random variables were replaced - # with their value variables - assert c_output == c_value_var - b_output = log_add_output.owner.inputs[1] - assert b_output == b_value_var - - # There shouldn't be any `RandomVariable`s in the resulting graph - assert_no_rvs(res) - - res_ancestors = list(walk_model((res,), walk_past_rvs=True)) - - assert b_value_var in res_ancestors - assert c_value_var in res_ancestors - assert a_value_var not in res_ancestors +class TestReplaceRVsByValues: + @pytest.mark.parametrize("symbolic_rv", (False, True)) + @pytest.mark.parametrize("apply_transforms", (True, False)) + def test_basic(self, symbolic_rv, apply_transforms): + # Interval transform between last two arguments + interval = ( + Interval(bounds_fn=lambda *args: (args[-2], args[-1])) if apply_transforms else None + ) + with pm.Model() as m: + a = pm.Uniform("a", 0.0, 1.0) + if symbolic_rv: + raw_b = pm.Uniform.dist(0, a + 1.0) + b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval) + # If not True, another distribution has to be used + assert isinstance(b.owner.op, SymbolicRandomVariable) + else: + b = pm.Uniform("b", 0, a + 1.0, transform=interval) + c = pm.Normal("c") + d = pt.log(c + b) + 2.0 + + a_value_var = m.rvs_to_values[a] + assert m.rvs_to_transforms[a] is not None + + b_value_var = m.rvs_to_values[b] + c_value_var = m.rvs_to_values[c] + + (res,) = replace_rvs_by_values( + (d,), + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) -def test_rvs_to_value_vars_intermediate_rv(): - """Test that function replaces values above an intermediate RV.""" - a = pt.random.uniform(0.0, 1.0) - a.name = "a" - a.tag.value_var = a_value_var = a.clone() + assert res.owner.op == pt.add + log_output = res.owner.inputs[0] + assert log_output.owner.op == pt.log + log_add_output = res.owner.inputs[0].owner.inputs[0] + assert log_add_output.owner.op == pt.add + c_output = log_add_output.owner.inputs[0] + + # We make sure that the random variables were replaced + # with their value variables + assert c_output == c_value_var + b_output = log_add_output.owner.inputs[1] + # When transforms are applied, the input is the back-transformation of the value_var, + # otherwise it is the value_var itself + if apply_transforms: + assert b_output != b_value_var + else: + assert b_output == b_value_var + + res_ancestors = list(ancestors((res,))) + res_rv_ancestors = [ + v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) + ] + + # There shouldn't be any `RandomVariable`s in the resulting graph + assert len(res_rv_ancestors) == 0 + assert b_value_var in res_ancestors + assert c_value_var in res_ancestors + # When transforms are used, `d` depends on `a` through the back-transformation of + # `b`, otherwise there is no direct connection between `d` and `a` + if apply_transforms: + assert a_value_var in res_ancestors + else: + assert a_value_var not in res_ancestors + + def test_intermediate_rv(self): + """Test that function replaces values above an intermediate RV.""" + a = pt.random.uniform(0.0, 1.0) + a.name = "a" + a.tag.value_var = a_value_var = a.clone() + + b = pt.random.uniform(0, a + 1.0) + b.name = "b" + b.tag.value_var = b.clone() + + c = pt.random.normal() + c.name = "c" + c.tag.value_var = c_value_var = c.clone() + + d = pt.log(c + b) + 2.0 + + initial_replacements = {a: a_value_var, c: c_value_var} + (res,) = replace_rvs_by_values((d,), rvs_to_values=initial_replacements) + + # Assert that the only RandomVariable that remains in the graph is `b` + res_ancestors = list(ancestors((res,))) + + assert ( + len( + list( + n + for n in res_ancestors + if n.owner and isinstance(n.owner.op, MeasurableVariable) + ) + ) + == 1 + ) - b = pt.random.uniform(0, a + 1.0) - b.name = "b" - b.tag.value_var = b.clone() + assert c_value_var in res_ancestors + assert a_value_var in res_ancestors - c = pt.random.normal() - c.name = "c" - c.tag.value_var = c_value_var = c.clone() + def test_unvalued_rv_model(self): + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal.dist(x) + z = pm.Normal("z", y) + out = z + y - d = pt.log(c + b) + 2.0 + x_value = m.rvs_to_values[x] + z_value = m.rvs_to_values[z] - initial_replacements = {a: a_value_var, c: c_value_var} - (res,), replaced = rvs_to_value_vars((d,), initial_replacements=initial_replacements) + (res,) = replace_rvs_by_values( + (out,), + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) - # Assert that the only RandomVariable that remains in the graph is `b` - res_ancestors = list(walk_model((res,), walk_past_rvs=True)) + assert res.owner.op == pt.add + assert res.owner.inputs[0] is z_value + res_y = res.owner.inputs[1] + # Graph should have be cloned, and therefore y and res_y should have different ids + assert res_y is not y + assert res_y.owner.op == pt.random.normal + assert res_y.owner.inputs[3] is x_value + + def test_no_change_inplace(self): + # Test that calling rvs_to_value_vars in models with nested transformations + # does not change the original rvs in place. See issue #5172 + with pm.Model() as m: + one = pm.LogNormal("one", mu=0) + two = pm.LogNormal("two", mu=pt.log(one)) + + # We add potentials or deterministics that are not in topological order + pm.Potential("two_pot", two) + pm.Potential("one_pot", one) + + before = pytensor.clone_replace(m.free_RVs) + + # This call would change the model free_RVs in place in #5172 + replace_rvs_by_values( + m.potentials, + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) - assert ( - len( - list(n for n in res_ancestors if n.owner and isinstance(n.owner.op, MeasurableVariable)) + after = pytensor.clone_replace(m.free_RVs) + assert equal_computations(before, after) + + @pytest.mark.parametrize("reversed", (False, True)) + def test_interdependent_transformed_rvs(self, reversed): + # Test that nested transformed variables, whose transformed values depend on other + # RVs are properly replaced + with pm.Model() as m: + transform = pm.distributions.transforms.Interval( + bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) + ) + x = pm.Uniform("x", lower=0, upper=1, transform=transform) + y = pm.Uniform("y", lower=0, upper=x, transform=transform) + z = pm.Uniform("z", lower=0, upper=y, transform=transform) + w = pm.Uniform("w", lower=0, upper=z, transform=transform) + + rvs = [x, y, z, w] + if reversed: + rvs = rvs[::-1] + + transform_values = replace_rvs_by_values( + rvs, + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, ) - == 1 - ) - assert c_value_var in res_ancestors - assert a_value_var in res_ancestors + for transform_value in transform_values: + assert_no_rvs(transform_value) + + if reversed: + transform_values = transform_values[::-1] + transform_values_fn = m.compile_fn(transform_values, point_fn=False) + + x_interval_test_value = np.random.rand() + y_interval_test_value = np.random.rand() + z_interval_test_value = np.random.rand() + w_interval_test_value = np.random.rand() + + # The 3 Nones correspond to unused rng, dtype and size arguments + expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval() + expected_y = transform.backward( + y_interval_test_value, None, None, None, 0, expected_x + ).eval() + expected_z = transform.backward( + z_interval_test_value, None, None, None, 0, expected_y + ).eval() + expected_w = transform.backward( + w_interval_test_value, None, None, None, 0, expected_z + ).eval() + + np.testing.assert_allclose( + transform_values_fn( + x_interval__=x_interval_test_value, + y_interval__=y_interval_test_value, + z_interval__=z_interval_test_value, + w_interval__=w_interval_test_value, + ), + [expected_x, expected_y, expected_z, expected_w], + ) def test_CheckParameter(): @@ -199,13 +310,18 @@ def scipy_logprob(obs, c): def test_check_potential_measurability(): x1 = pt.random.normal() + x1_valued = valued_rv(x1, x1.type()) x2 = pt.random.normal() + x2_valued = valued_rv(x2, x2.type()) x3 = pt.scalar("x3") - y = pt.exp(x1 + x2 + x3) # In the first three cases, y is potentially measurable, because it has at least on unvalued RV input - assert check_potential_measurability([y], {}) - assert check_potential_measurability([y], {x1}) - assert check_potential_measurability([y], {x2}) + y = pt.exp(x1 + x2 + x3) + assert check_potential_measurability([y]) + y = pt.exp(x1_valued + x2 + x3) + assert check_potential_measurability([y]) + y = pt.exp(x1 + x2_valued + x3) + assert check_potential_measurability([y]) # y is not potentially measurable because both RV inputs are valued - assert not check_potential_measurability([y], {x1, x2}) + y = pt.exp(x1_valued + x2_valued + x3) + assert not check_potential_measurability([y]) diff --git a/tests/model/test_core.py b/tests/model/test_core.py index ff5ed13f2..39eb23b9b 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -44,10 +44,9 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.distributions import Normal, transforms from pymc.distributions.distribution import PartialObservedRV -from pymc.distributions.transforms import log, simplex +from pymc.distributions.transforms import IntervalTransform, log, simplex from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning from pymc.logprob.basic import transformed_conditional_logp -from pymc.logprob.transforms import IntervalTransform from pymc.model import Point, ValueGradFunction, modelcontext from pymc.util import _FutureWarningValidatingScratchpad from pymc.variational.minibatch_rv import MinibatchRandomVariable diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index fac7b7046..19b7cd9ff 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -24,9 +24,8 @@ from pytensor import scan, shared from pytensor.compile.builders import OpFromGraph -from pytensor.graph.basic import Variable, equal_computations +from pytensor.graph.basic import Variable from pytensor.tensor.random.basic import normal, uniform -from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.var import RandomStateSharedVariable from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from pytensor.tensor.variable import TensorVariable @@ -35,23 +34,19 @@ from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicRandomVariable -from pymc.distributions.transforms import Interval from pymc.exceptions import NotConstantValueError from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( - _replace_vars_in_graphs, collect_default_updates, compile_pymc, constant_fold, convert_observed_data, extract_obs_data, replace_rng_nodes, - replace_rvs_by_values, + replace_vars_in_graphs, reseed_rngs, - rvs_to_value_vars, walk_model, ) -from pymc.testing import assert_no_rvs from pymc.vartypes import int_types @@ -286,21 +281,24 @@ def test_walk_model(): test_graph = pt.exp(e + 1) - res = list(walk_model((test_graph,))) + with pytest.warns(FutureWarning): + res = list(walk_model((test_graph,))) assert a in res assert b in res assert c in res assert d in res assert e in res - res = list(walk_model((test_graph,), stop_at_vars={c})) + with pytest.warns(FutureWarning): + res = list(walk_model((test_graph,), stop_at_vars={c})) assert a not in res assert b not in res assert c in res assert d in res assert e in res - res = list(walk_model((test_graph,), stop_at_vars={b})) + with pytest.warns(FutureWarning): + res = list(walk_model((test_graph,), stop_at_vars={b})) assert a not in res assert b in res assert c in res @@ -500,27 +498,53 @@ def test_multiple_updates_same_variable(self): warnings.simplefilter("error") rng = pytensor.shared(np.random.default_rng(), name="rng") - x = pt.random.normal(rng=rng) - y = pt.random.normal(rng=rng) + # x1 and x2 are identical + x1 = pt.random.normal(rng=rng) + x2 = pt.random.normal(rng=rng) + next_rng_x = x1.owner.outputs[0] + next_rng_x.name = "next_rng_x" + # y1 and y2 are not! + y1 = pt.random.normal(loc=-1, rng=next_rng_x) + y2 = pt.random.normal(loc=1, rng=next_rng_x) + next_rng_y = y1.owner.outputs[0] + next_rng_y.name = "next_rng_y" # No warnings if only one variable is used - assert compile_pymc([], [x]) - assert compile_pymc([], [y]) - + assert compile_pymc([], [x1]) + assert compile_pymc([], [x2]) + assert compile_pymc([], [y1]) + assert compile_pymc([], [y2]) + + # No warnings if two identical variables use the same RNG + f = compile_pymc([], [x1, x2], random_seed=456) + res1 = f() + res2 = f() + assert res1[0] == res1[1] + assert res2[0] == res2[1] + assert res1[0] != res2[0] + + # This could be allowed since the update graph for rng->x2->next_rng + # is a complete subset of the update graph for rng->x1->next_rng->y1->next_rng user_warn_msg = "RNG Variable rng has multiple clients" with pytest.warns(UserWarning, match=user_warn_msg): - f = compile_pymc([], [x, y], random_seed=456) + f = compile_pymc([], [x1, x2, y1], random_seed=456) + assert f() == f() + + # Warnings if two non-identical variables use the same RNG + user_warn_msg = "RNG Variable next_rng_x has multiple clients" + with pytest.warns(UserWarning, match=user_warn_msg): + f = compile_pymc([], [y1, y2], random_seed=456) assert f() == f() # The user can provide an explicit update, but we will still issue a warning with pytest.warns(UserWarning, match=user_warn_msg): - f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456) + f = compile_pymc([], [y1, y2], updates={rng: next_rng_y}, random_seed=456) assert f() != f() # Same with default update - rng.default_update = x.owner.outputs[0] + rng.default_update = next_rng_y with pytest.warns(UserWarning, match=user_warn_msg): - f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456) + f = compile_pymc([], [y1, y2], random_seed=456) assert f() != f() def test_nested_updates(self): @@ -668,211 +692,15 @@ def test_constant_fold_raises(): assert tuple(res[1].eval()) == (5,) -class TestReplaceRVsByValues: - @pytest.mark.parametrize("symbolic_rv", (False, True)) - @pytest.mark.parametrize("apply_transforms", (True, False)) - @pytest.mark.parametrize("test_deprecated_fn", (True, False)) - def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn): - # Interval transform between last two arguments - interval = ( - Interval(bounds_fn=lambda *args: (args[-2], args[-1])) if apply_transforms else None - ) - - with pm.Model() as m: - a = pm.Uniform("a", 0.0, 1.0) - if symbolic_rv: - raw_b = pm.Uniform.dist(0, a + 1.0) - b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval) - # If not True, another distribution has to be used - assert isinstance(b.owner.op, SymbolicRandomVariable) - else: - b = pm.Uniform("b", 0, a + 1.0, transform=interval) - c = pm.Normal("c") - d = pt.log(c + b) + 2.0 - - a_value_var = m.rvs_to_values[a] - assert m.rvs_to_transforms[a] is not None - - b_value_var = m.rvs_to_values[b] - c_value_var = m.rvs_to_values[c] - - if test_deprecated_fn: - with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): - (res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms) - else: - (res,) = replace_rvs_by_values( - (d,), - rvs_to_values=m.rvs_to_values, - rvs_to_transforms=m.rvs_to_transforms, - ) - - assert res.owner.op == pt.add - log_output = res.owner.inputs[0] - assert log_output.owner.op == pt.log - log_add_output = res.owner.inputs[0].owner.inputs[0] - assert log_add_output.owner.op == pt.add - c_output = log_add_output.owner.inputs[0] - - # We make sure that the random variables were replaced - # with their value variables - assert c_output == c_value_var - b_output = log_add_output.owner.inputs[1] - # When transforms are applied, the input is the back-transformation of the value_var, - # otherwise it is the value_var itself - if apply_transforms: - assert b_output != b_value_var - else: - assert b_output == b_value_var - - res_ancestors = list(walk_model((res,))) - res_rv_ancestors = [ - v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) - ] - - # There shouldn't be any `RandomVariable`s in the resulting graph - assert len(res_rv_ancestors) == 0 - assert b_value_var in res_ancestors - assert c_value_var in res_ancestors - # When transforms are used, `d` depends on `a` through the back-transformation of - # `b`, otherwise there is no direct connection between `d` and `a` - if apply_transforms: - assert a_value_var in res_ancestors - else: - assert a_value_var not in res_ancestors - - @pytest.mark.parametrize("test_deprecated_fn", (True, False)) - def test_unvalued_rv(self, test_deprecated_fn): - with pm.Model() as m: - x = pm.Normal("x") - y = pm.Normal.dist(x) - z = pm.Normal("z", y) - out = z + y - - x_value = m.rvs_to_values[x] - z_value = m.rvs_to_values[z] - - if test_deprecated_fn: - with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): - (res,) = rvs_to_value_vars((out,)) - else: - (res,) = replace_rvs_by_values( - (out,), - rvs_to_values=m.rvs_to_values, - rvs_to_transforms=m.rvs_to_transforms, - ) - - assert res.owner.op == pt.add - assert res.owner.inputs[0] is z_value - res_y = res.owner.inputs[1] - # Graph should have be cloned, and therefore y and res_y should have different ids - assert res_y is not y - assert res_y.owner.op == pt.random.normal - assert res_y.owner.inputs[3] is x_value - - @pytest.mark.parametrize("test_deprecated_fn", (True, False)) - def test_no_change_inplace(self, test_deprecated_fn): - # Test that calling rvs_to_value_vars in models with nested transformations - # does not change the original rvs in place. See issue #5172 - with pm.Model() as m: - one = pm.LogNormal("one", mu=0) - two = pm.LogNormal("two", mu=pt.log(one)) - - # We add potentials or deterministics that are not in topological order - pm.Potential("two_pot", two) - pm.Potential("one_pot", one) - - before = pytensor.clone_replace(m.free_RVs) - - # This call would change the model free_RVs in place in #5172 - if test_deprecated_fn: - with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): - rvs_to_value_vars(m.potentials) - else: - replace_rvs_by_values( - m.potentials, - rvs_to_values=m.rvs_to_values, - rvs_to_transforms=m.rvs_to_transforms, - ) - - after = pytensor.clone_replace(m.free_RVs) - assert equal_computations(before, after) - - @pytest.mark.parametrize("test_deprecated_fn", (True, False)) - @pytest.mark.parametrize("reversed", (False, True)) - def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn): - # Test that nested transformed variables, whose transformed values depend on other - # RVs are properly replaced - with pm.Model() as m: - transform = pm.distributions.transforms.Interval( - bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) - ) - x = pm.Uniform("x", lower=0, upper=1, transform=transform) - y = pm.Uniform("y", lower=0, upper=x, transform=transform) - z = pm.Uniform("z", lower=0, upper=y, transform=transform) - w = pm.Uniform("w", lower=0, upper=z, transform=transform) - - rvs = [x, y, z, w] - if reversed: - rvs = rvs[::-1] - - if test_deprecated_fn: - with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): - transform_values = rvs_to_value_vars(rvs) - else: - transform_values = replace_rvs_by_values( - rvs, - rvs_to_values=m.rvs_to_values, - rvs_to_transforms=m.rvs_to_transforms, - ) - - for transform_value in transform_values: - assert_no_rvs(transform_value) - - if reversed: - transform_values = transform_values[::-1] - transform_values_fn = m.compile_fn(transform_values, point_fn=False) - - x_interval_test_value = np.random.rand() - y_interval_test_value = np.random.rand() - z_interval_test_value = np.random.rand() - w_interval_test_value = np.random.rand() - - # The 3 Nones correspond to unused rng, dtype and size arguments - expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval() - expected_y = transform.backward( - y_interval_test_value, None, None, None, 0, expected_x - ).eval() - expected_z = transform.backward( - z_interval_test_value, None, None, None, 0, expected_y - ).eval() - expected_w = transform.backward( - w_interval_test_value, None, None, None, 0, expected_z - ).eval() - - np.testing.assert_allclose( - transform_values_fn( - x_interval__=x_interval_test_value, - y_interval__=y_interval_test_value, - z_interval__=z_interval_test_value, - w_interval__=w_interval_test_value, - ), - [expected_x, expected_y, expected_z, expected_w], - ) - - def test_replace_input(self): - inp = shared(0.0, name="inp") - x = pm.Normal.dist(inp) - - assert x.eval() < 50 - - new_inp = inp + 100 +def test_replace_vars_in_graphs(): + inp = shared(0.0, name="inp") + x = pm.Normal.dist(inp) - def replacement_fn(var, replacements): - if var is x: - replacements[x.owner.inputs[3]] = new_inp + assert x.eval() < 50 - return [] + new_inp = inp + 100 - [new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn) + replacements = {x.owner.inputs[3]: new_inp} + [new_x] = replace_vars_in_graphs([x], replacements=replacements) - assert new_x.eval() > 50 + assert new_x.eval() > 50 diff --git a/tests/test_util.py b/tests/test_util.py index 1b41d3a73..faf22aabf 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -22,7 +22,7 @@ import pymc as pm -from pymc.distributions.transforms import RVTransform +from pymc.distributions.transforms import Transform from pymc.util import ( UNSET, _get_seeds_per_chain, @@ -40,7 +40,7 @@ class TestTransformName: transform_name = "test" def test_get_transformed_name(self): - class NewTransform(RVTransform): + class NewTransform(Transform): name = self.transform_name def forward(self, value):