Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
Expand Down Expand Up @@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
Expand Down Expand Up @@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx

pip install -e ./
Expand All @@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}}

- name: Run tests
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# isort: off
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.linker import jax_funcify, jax_typify

# Load dispatch specializations
import pytensor.link.jax.dispatch.blas
Expand Down
16 changes: 4 additions & 12 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings
from collections.abc import Callable
from functools import singledispatch

import jax
import jax.numpy as jnp
Expand All @@ -10,8 +9,10 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph import Op
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.jax.linker import jax_funcify, jax_typify
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import Assert, CheckAndRaise

Expand All @@ -22,24 +23,15 @@
jax.config.update("jax_enable_x64", False)


@singledispatch
def jax_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to JAX types."""
if dtype is None:
return data
else:
return jnp.array(data, dtype=dtype)


@jax_typify.register(np.ndarray)
def jax_typify_ndarray(data, dtype=None, **kwargs):
if len(data.shape) == 0:
return data.item()
return jnp.array(data, dtype=dtype)


@singledispatch
def jax_funcify(op, node=None, storage_map=None, **kwargs):
@jax_funcify.register(Op)
def jax_funcify_op(op, node=None, storage_map=None, **kwargs):
"""Create a JAX compatible function from an PyTensor `Op`."""
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

Expand Down
21 changes: 18 additions & 3 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
import warnings
from functools import singledispatch

from numpy.random import Generator

from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.link.basic import JITLinker


@singledispatch
def jax_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to JAX types."""
import jax.numpy as jnp

if dtype is None:
return data
else:
return jnp.array(data, dtype=dtype)


@singledispatch
def jax_funcify(obj, *args, **kwargs):
"""Create a JAX compatible function from an PyTensor `Op`."""
raise NotImplementedError(f"No JAX conversion for the given type: {type(obj)}")


class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""

Expand All @@ -14,7 +32,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.random.type import RandomType

Expand Down Expand Up @@ -111,8 +128,6 @@ def convert_scalar_shape_inputs(
return convert_scalar_shape_inputs

def create_thunk_inputs(self, storage_map):
from pytensor.link.jax.dispatch import jax_typify

thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# isort: off
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
from pytensor.link.numba.linker import numba_funcify, numba_typify

# Load dispatch specializations
import pytensor.link.numba.dispatch.blockwise
Expand Down
18 changes: 4 additions & 14 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
import warnings
from copy import copy
from functools import singledispatch
from textwrap import dedent

import numba
Expand All @@ -21,11 +20,13 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph import Op
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.numba.linker import numba_funcify, numba_typify
from pytensor.link.utils import (
compile_function_src,
fgraph_to_python,
Expand Down Expand Up @@ -276,11 +277,6 @@ def create_arg_string(x):
return args


@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data


def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from a Pytensor `Op`."""

Expand Down Expand Up @@ -326,14 +322,8 @@ def perform(*inputs):
return perform


@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Generate a numba function for a given op and apply node.

The resulting function will usually use the `no_cpython_wrapper`
argument in numba, so it can not be called directly from python,
but only from other jit functions.
"""
@numba_funcify.register(Op)
def numba_funcify_op(op, node=None, storage_map=None, **kwargs):
return generate_fallback_impl(op, node, storage_map, **kwargs)


Expand Down
16 changes: 14 additions & 2 deletions pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from functools import singledispatch

from pytensor.link.basic import JITLinker


@singledispatch
def numba_typify(data, dtype=None, **kwargs):
raise NotImplementedError(
f"Numba funcify not implemented for data type {type(data)}"
)


@singledispatch
def numba_funcify(obj, *args, **kwargs):
raise NotImplementedError(f"Numba funcify not implemented for type {type(obj)}")


class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""

def fgraph_convert(self, fgraph, **kwargs):
from pytensor.link.numba.dispatch import numba_funcify

return numba_funcify(fgraph, **kwargs)

def jit_compile(self, fn):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# isort: off
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
from pytensor.link.pytorch.linker import pytorch_funcify, pytorch_typify

# # Load dispatch specializations
import pytensor.link.pytorch.dispatch.blas
Expand Down
12 changes: 4 additions & 8 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import singledispatch
from types import NoneType

import numpy as np
Expand All @@ -10,9 +9,11 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph import Op
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.pytorch.linker import pytorch_funcify, pytorch_typify
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import (
Expand All @@ -27,11 +28,6 @@
)


@singledispatch
def pytorch_typify(data, **kwargs):
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")


@pytorch_typify.register(np.ndarray)
@pytorch_typify.register(torch.Tensor)
def pytorch_typify_tensor(data, dtype=None, **kwargs):
Expand All @@ -45,8 +41,8 @@ def pytorch_typify_no_conversion_needed(data, **kwargs):
return data


@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
@pytorch_funcify.register(Op)
def pytorch_funcify_op(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation"
Expand Down
16 changes: 12 additions & 4 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from functools import singledispatch

from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator


@singledispatch
def pytorch_typify(data, **kwargs):
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")


@singledispatch
def pytorch_funcify(obj, *args, **kwargs):
raise NotImplementedError(f"pytorch_funcify is not implemented for {type(obj)}")


class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""

Expand All @@ -10,8 +22,6 @@ def __init__(self, *args, **kwargs):
self.gen_functors = []

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify

# We want to have globally unique names
# across the entire pytensor graph, not
# just the subgraph
Expand Down Expand Up @@ -40,8 +50,6 @@ def jit_compile(self, fn):
# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True

from pytensor.link.pytorch.dispatch import pytorch_typify

class wrapper:
"""
Pytorch would fail compiling our method when trying
Expand Down
18 changes: 0 additions & 18 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,6 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)


class CumsumOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "add"
return obj


class CumprodOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "mul"
return obj


def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.

Expand Down
16 changes: 16 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import (
linalg,
special,
)
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
xtensor_constant,
)


warnings.warn("xtensor module is experimental and full of bugs")
Loading