Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,8 +1625,7 @@ def rng_fn_scipy(cls, rng, n, p, size):
return stats.nbinom.rvs(n, p, size=size, random_state=rng)


nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
nbinom = negative_binomial = NegBinomialRV()


class BetaBinomialRV(ScipyRandomVariable):
Expand Down Expand Up @@ -1808,6 +1807,7 @@ def rng_fn(cls, rng, n, p, size):

multinomial = MultinomialRV()


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")


Expand Down
19 changes: 18 additions & 1 deletion pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import warnings
from collections.abc import Sequence
from copy import deepcopy
Expand Down Expand Up @@ -32,7 +33,20 @@
from pytensor.tensor.variable import TensorVariable


class RandomVariable(Op):
class RNGConsumerOp(Op):
"""Baseclass for Ops that consume RNGs."""

@abc.abstractmethod
def update(self, node: Apply) -> dict[Variable, Variable]:
"""Symbolic update expression for input RNG variables.

Returns a dictionary with the symbolic expressions required for correct updating
of RNG variables in repeated function evaluations.
"""
pass


class RandomVariable(RNGConsumerOp):
"""An `Op` that produces a sample from a random variable.

This is essentially `RandomFunction`, except that it removes the
Expand Down Expand Up @@ -123,6 +137,9 @@ def __init__(
if self.inplace:
self.destroy_map = {0: [0]}

def update(self, node: Apply) -> dict[Variable, Variable]:
return {node.inputs[0]: node.outputs[0]}

def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.

Expand Down
4 changes: 1 addition & 3 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
return [new_var]


@register_infer_shape
@node_rewriter([Assert])
def local_remove_all_assert(fgraph, node):
r"""A rewrite that removes all `Assert`\s from a graph.
Expand All @@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
See the :ref:`unsafe` section.

"""
if not isinstance(node.op, Assert):
return

return [node.inputs[0]]


Expand Down
29 changes: 29 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code


Expand Down Expand Up @@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
https://github.com/numpy/numpy/issues/28921
"""
return product(*(range(s) for s in shape))


def get_static_shape_from_size_variables(
size_vars: Sequence[Variable],
) -> tuple[int | None, ...]:
"""Get static shape from size variables.

Parameters
----------
size_vars : Sequence[Variable]
A sequence of variables representing the size of each dimension.
Returns
-------
tuple[int | None, ...]
A tuple containing the static lengths of each dimension, or None if
the length is not statically known.
"""
from pytensor.tensor.basic import get_scalar_constant_value

static_lengths: list[None | int] = [None] * len(size_vars)
for i, length in enumerate(size_vars):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_lengths[i] = int(static_length)
return tuple(static_lengths)
3 changes: 3 additions & 0 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ def dimshuffle(self, *pattern):
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
pattern = pattern[0]
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
if ds_op.new_order == tuple(range(self.type.ndim)):
# No-op
return self
return ds_op(self)

def flatten(self, ndim=1):
Expand Down
6 changes: 3 additions & 3 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def make_node(self, x):
return Apply(self, [x], [output])


def xtensor_from_tensor(x, dims):
return XTensorFromTensor(dims=dims)(x)
def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)


class Rename(XTypeCastOp):
Expand Down Expand Up @@ -96,7 +96,7 @@ def rename(x, name_dict: dict[str, str] | None = None, **names: str):
for old_name, new_name in names.items():
try:
new_names[old_names.index(old_name)] = new_name
except IndexError:
except ValueError:
raise ValueError(
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)
Expand Down
14 changes: 5 additions & 9 deletions pytensor/xtensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def cholesky(
((dims[0], dims[1]),),
((dims[0], dims[1]),),
)
x_op = XBlockwise(core_op, signature=core_op.gufunc_signature, core_dims=core_dims)
x_op = XBlockwise(core_op, core_dims=core_dims)
return x_op(x)


Expand All @@ -48,18 +48,15 @@ def solve(
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
output_core_dims = ((m2_dim,),)
# The shared dim disappears in the output
output_core_dims = ((m1_dim,),)
elif len(dims) == 3:
b_ndim = 2
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
output_core_dims = (
(
m2_dim,
n_dim,
),
)
# The shared dim disappears in the output
output_core_dims = ((m1_dim, n_dim),)
else:
raise ValueError("Solve dims must have length 2 or 3")

Expand All @@ -68,7 +65,6 @@ def solve(
)
x_op = XBlockwise(
core_op,
signature=core_op.gufunc_signature,
core_dims=(input_core_dims, output_core_dims),
)
return x_op(a, b)
167 changes: 167 additions & 0 deletions pytensor/xtensor/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from collections.abc import Sequence
from functools import wraps

import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.math import sqrt
from pytensor.xtensor.vectorization import XRV


def _as_xrv(
core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None,
):
"""Helper function to define an XRV constructor.

Parameters
----------
core_op : RandomVariable
The core random variable operation to wrap.
core_inps_dims_map : Sequence[Sequence[int]] | None, optional
A sequence of sequences mapping the core dimensions (specified by the user)
for each input parameter. This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for each input.
If None, it assumes the core dimensions are positional from left to right.
core_out_dims_map : Sequence[int] | None, optional
A sequence mapping the core dimensions (specified by the user) for the output variable.
This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for the output.
If None, it assumes the core dimensions are positional from left to right.

"""
if core_inps_dims_map is None:
# Assume core_dims map positionally from left to right
core_inps_dims_map = [tuple(range(ndim)) for ndim in core_op.ndims_params]
if core_out_dims_map is None:
# Assume core_dims map positionally from left to right
core_out_dims_map = tuple(range(core_op.ndim_supp))

core_dims_needed = max(
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
)

@wraps(core_op)
def xrv_constructor(
*params,
core_dims: Sequence[str] | str | None = None,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
if core_dims is None:
core_dims = ()
if core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims to be specified"
)
elif isinstance(core_dims, str):
core_dims = (core_dims,)

if len(core_dims) != core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}"
)

full_input_core_dims = tuple(
tuple(core_dims[i] for i in inp_dims_map)
for inp_dims_map in core_inps_dims_map
)
full_output_core_dims = tuple(core_dims[i] for i in core_out_dims_map)
full_core_dims = (full_input_core_dims, full_output_core_dims)

if extra_dims is None:
extra_dims = {}

return XRV(
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys())
)(rng, *extra_dims.values(), *params)

return xrv_constructor


bernoulli = _as_xrv(ptr.bernoulli)
beta = _as_xrv(ptr.beta)
betabinom = _as_xrv(ptr.betabinom)
binomial = _as_xrv(ptr.binomial)
categorical = _as_xrv(ptr.categorical)
cauchy = _as_xrv(ptr.cauchy)
dirichlet = _as_xrv(ptr.dirichlet)
exponential = _as_xrv(ptr.exponential)
gamma = _as_xrv(ptr._gamma)
gengamma = _as_xrv(ptr.gengamma)
geometric = _as_xrv(ptr.geometric)
gumbel = _as_xrv(ptr.gumbel)
halfcauchy = _as_xrv(ptr.halfcauchy)
halfnormal = _as_xrv(ptr.halfnormal)
hypergeometric = _as_xrv(ptr.hypergeometric)
integers = _as_xrv(ptr.integers)
invgamma = _as_xrv(ptr.invgamma)
laplace = _as_xrv(ptr.laplace)
logistic = _as_xrv(ptr.logistic)
lognormal = _as_xrv(ptr.lognormal)
multinomial = _as_xrv(ptr.multinomial)
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial)
normal = _as_xrv(ptr.normal)
pareto = _as_xrv(ptr.pareto)
poisson = _as_xrv(ptr.poisson)
t = _as_xrv(ptr.t)
triangular = _as_xrv(ptr.triangular)
truncexpon = _as_xrv(ptr.truncexpon)
uniform = _as_xrv(ptr.uniform)
vonmises = _as_xrv(ptr.vonmises)
wald = _as_xrv(ptr.wald)
weibull = _as_xrv(ptr.weibull)


def multivariate_normal(
mean,
cov,
*,
core_dims: Sequence[str],
extra_dims=None,
rng=None,
method="cholesky",
):
mean = as_xtensor(mean)
if len(core_dims) != 2:
raise ValueError(
f"multivariate_normal requires 2 core_dims, got {len(core_dims)}"
)

# Align core_dims, so that the dim that exists in mean comes before the one that only exists in cov
# This will be the core dimension of the output
if core_dims[0] not in mean.type.dims:
core_dims = core_dims[::-1]

xop = _as_xrv(ptr.MvNormalRV(method=method))
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)


def standard_normal(
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Standard normal random variable."""
return normal(0, 1, extra_dims=extra_dims, rng=rng)


def chisquare(
df,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Chi-square random variable."""
return gamma(df / 2.0, 2.0, extra_dims=extra_dims, rng=rng)


def rayleigh(
scale,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Rayleigh random variable."""

df = scale * 0 + 2 # Poor man's broadcasting, to pass dimensions of scale to the RV
return sqrt(chisquare(df, extra_dims=extra_dims, rng=rng)) * scale
Loading