Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,12 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse --ignore=tests/xtensor"
- "tests/scan"
- "tests/sparse"
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py"
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
- "tests/tensor/conv"
- "tests/tensor/rewriting"
- "tests/tensor/test_math.py"
Expand Down Expand Up @@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
Expand Down Expand Up @@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
Expand Down Expand Up @@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx

pip install -e ./
Expand All @@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}}

- name: Run tests
Expand Down
5 changes: 5 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def register_linker(name, linker):
if not config.cxx:
exclude = ["cxx_only"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
Expand All @@ -77,6 +79,7 @@ def register_linker(name, linker):
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE"
OPT_MINIMUM.name = "OPT_MINIMUM"
OPT_MERGE.name = "OPT_MERGE"
OPT_FAST_RUN.name = "OPT_FAST_RUN"
OPT_FAST_RUN_STABLE.name = "OPT_FAST_RUN_STABLE"
Expand All @@ -95,6 +98,7 @@ def register_linker(name, linker):
None: OPT_NONE,
"None": OPT_NONE,
"merge": OPT_MERGE,
"minimum_compile": OPT_MINIMUM,
"o4": OPT_FAST_RUN,
"o3": OPT_O3,
"o2": OPT_O2,
Expand Down Expand Up @@ -191,6 +195,7 @@ def apply(self, fgraph):
"merge1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0
)


# After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes.
# The opt should not do anything that need shape inference.
Expand Down
23 changes: 11 additions & 12 deletions pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ def register_view_op_c_code(type, code, version=()):
ViewOp.c_code_and_version[type] = (code, version)


class ViewOp(COp):
"""
Returns an inplace view of the input. Used internally by PyTensor.

"""
class TypeCastingOp(COp):
"""Op that performs a graph-level type cast operation, but has no effect computation-wise (identity function)."""

view_map = {0: [0]}
# Mapping from Type to C code (and version) to use.
Expand All @@ -47,13 +44,8 @@ class ViewOp(COp):
__props__: tuple = ()
_f16_ok: bool = True

def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inp, out):
(x,) = inp
(z,) = out
z[0] = x
def perform(self, node, inputs, outputs_storage):
outputs_storage[0][0] = inputs[0]

def __str__(self):
return f"{self.__class__.__name__}"
Expand Down Expand Up @@ -90,6 +82,13 @@ def c_code_cache_version(self):

return tuple(version)


class ViewOp(TypeCastingOp):
"""Returns an inplace view of the input. Used internally by PyTensor."""

def make_node(self, x):
return Apply(self, [x], [x.type()])

def infer_shape(self, fgraph, node, input_shapes):
return input_shapes

Expand Down
10 changes: 5 additions & 5 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
Expand Down Expand Up @@ -111,12 +111,12 @@ def deepcopyop(x):
return deepcopyop


@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op, **kwargs):
def viewop(x):
@jax_funcify.register(TypeCastingOp)
def jax_funcify_TypeCastingOp(op, **kwargs):
def type_cast(x):
return x

return viewop
return type_cast


@jax_funcify.register(OpFromGraph)
Expand Down
10 changes: 5 additions & 5 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from pytensor.compile.ops import ViewOp
from pytensor.compile.ops import TypeCastingOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
Expand Down Expand Up @@ -198,14 +198,14 @@ def cast(x):


@numba_basic.numba_njit
def viewop(x):
def identity(x):
return x


@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
return numba_basic.global_numba_func(viewop)
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
return numba_basic.global_numba_func(identity)


@numba_basic.numba_njit
Expand Down
10 changes: 9 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
Expand Down Expand Up @@ -71,6 +71,14 @@
)


@pytorch_funcify.register(TypeCastingOp)
def pytorch_funcify_CastingOp(op, node, **kwargs):
def type_cast(x):
return x

Check warning on line 77 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L76-L77

Added lines #L76 - L77 were not covered by tests

return type_cast

Check warning on line 79 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L79

Added line #L79 was not covered by tests


@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4551,7 +4551,7 @@ def ix_(*args):
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
new = new.dimshuffle(*(("x",) * k), 0, *(("x",) * (nd - k - 1)))
out.append(new)
return tuple(out)

Expand Down
18 changes: 0 additions & 18 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,6 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)


class CumsumOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "add"
return obj


class CumprodOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "mul"
return obj


def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.

Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,8 +1625,7 @@ def rng_fn_scipy(cls, rng, n, p, size):
return stats.nbinom.rvs(n, p, size=size, random_state=rng)


nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
nbinom = negative_binomial = NegBinomialRV()


class BetaBinomialRV(ScipyRandomVariable):
Expand Down Expand Up @@ -1808,6 +1807,7 @@ def rng_fn(cls, rng, n, p, size):

multinomial = MultinomialRV()


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")


Expand Down
19 changes: 18 additions & 1 deletion pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import warnings
from collections.abc import Sequence
from copy import deepcopy
Expand Down Expand Up @@ -32,7 +33,20 @@
from pytensor.tensor.variable import TensorVariable


class RandomVariable(Op):
class RNGConsumerOp(Op):
"""Baseclass for Ops that consume RNGs."""

@abc.abstractmethod
def update(self, node: Apply) -> dict[Variable, Variable]:
"""Symbolic update expression for input RNG variables.

Returns a dictionary with the symbolic expressions required for correct updating
of RNG variables in repeated function evaluations.
"""
pass

Check warning on line 46 in pytensor/tensor/random/op.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/random/op.py#L46

Added line #L46 was not covered by tests


class RandomVariable(RNGConsumerOp):
"""An `Op` that produces a sample from a random variable.

This is essentially `RandomFunction`, except that it removes the
Expand Down Expand Up @@ -123,6 +137,9 @@
if self.inplace:
self.destroy_map = {0: [0]}

def update(self, node: Apply) -> dict[Variable, Variable]:
return {node.inputs[0]: node.outputs[0]}

Check warning on line 141 in pytensor/tensor/random/op.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/random/op.py#L141

Added line #L141 was not covered by tests

def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.

Expand Down
4 changes: 1 addition & 3 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
return [new_var]


@register_infer_shape
@node_rewriter([Assert])
def local_remove_all_assert(fgraph, node):
r"""A rewrite that removes all `Assert`\s from a graph.
Expand All @@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
See the :ref:`unsafe` section.

"""
if not isinstance(node.op, Assert):
return

return [node.inputs[0]]


Expand Down
29 changes: 29 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code


Expand Down Expand Up @@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
https://github.com/numpy/numpy/issues/28921
"""
return product(*(range(s) for s in shape))


def get_static_shape_from_size_variables(
size_vars: Sequence[Variable],
) -> tuple[int | None, ...]:
"""Get static shape from size variables.

Parameters
----------
size_vars : Sequence[Variable]
A sequence of variables representing the size of each dimension.
Returns
-------
tuple[int | None, ...]
A tuple containing the static lengths of each dimension, or None if
the length is not statically known.
"""
from pytensor.tensor.basic import get_scalar_constant_value

static_lengths: list[None | int] = [None] * len(size_vars)
for i, length in enumerate(size_vars):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_lengths[i] = int(static_length)
return tuple(static_lengths)
3 changes: 3 additions & 0 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ def dimshuffle(self, *pattern):
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
pattern = pattern[0]
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
if ds_op.new_order == tuple(range(self.type.ndim)):
# No-op
return self
return ds_op(self)

def flatten(self, ndim=1):
Expand Down
14 changes: 14 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
xtensor_constant,
)


warnings.warn("xtensor module is experimental and full of bugs")
Loading