Skip to content

Commit 8250e32

Browse files
jessegrabowskizaxtax
authored andcommitted
Implement Einsum as OpFromGraph
Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Rob Zinkov <[email protected]>
1 parent ee4d4f7 commit 8250e32

File tree

11 files changed

+598
-10
lines changed

11 files changed

+598
-10
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ jobs:
154154
shell: micromamba-shell {0}
155155
run: |
156156
157-
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
157+
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy opt_einsum pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
158158
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
159159
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
160160
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
@@ -215,7 +215,7 @@ jobs:
215215
- name: Install dependencies
216216
shell: micromamba-shell {0}
217217
run: |
218-
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
218+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy opt_einsum pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
219219
pip install -e ./
220220
micromamba list && pip freeze
221221
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- compilers
1212
- numpy>=1.17.0,<2
1313
- scipy>=0.14,<1.14.0
14+
- opt_einsum
1415
- filelock>=3.15
1516
- etuples
1617
- logical-unification

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dependencies = [
4949
"setuptools>=59.0.0",
5050
"scipy>=0.14,<1.14",
5151
"numpy>=1.17.0,<2",
52+
"opt_einsum",
5253
"filelock>=3.15",
5354
"etuples",
5455
"logical-unification",

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
1616
import pytensor.link.jax.dispatch.blockwise
17+
import pytensor.link.jax.dispatch.einsum
1718
import pytensor.link.jax.dispatch.sort
1819

1920
# isort: on

pytensor/link/jax/dispatch/einsum.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.einsum import Einsum
5+
6+
7+
@jax_funcify.register(Einsum)
8+
def jax_funcify_Einsum(op, **kwargs):
9+
subscripts = op.subscripts
10+
optimize = op.optimize
11+
12+
def einsum(*operands):
13+
return jnp.einsum(subscripts, *operands, optimize=optimize)
14+
15+
return einsum

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
150150

151151

152152
# isort: off
153+
from pytensor.tensor.einsum import einsum
153154
from pytensor.tensor.functional import vectorize
154155
# isort: on
155156

pytensor/tensor/basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,7 +2016,12 @@ def transpose(x, axes=None):
20162016
_x = as_tensor_variable(x)
20172017

20182018
if axes is None:
2019-
axes = list(range((_x.type.ndim - 1), -1, -1))
2019+
axes = tuple(range((_x.type.ndim - 1), -1, -1))
2020+
2021+
if tuple(axes) == tuple(range(len(axes))):
2022+
# No-op
2023+
return _x
2024+
20202025
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
20212026

20222027
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
@@ -4001,6 +4006,10 @@ def moveaxis(
40014006
source = normalize_axis_tuple(source, a.ndim, "source")
40024007
destination = normalize_axis_tuple(destination, a.ndim, "destination")
40034008

4009+
if source == destination:
4010+
# It's a no-op
4011+
return a
4012+
40044013
if len(source) != len(destination):
40054014
raise ValueError(
40064015
"`source` and `destination` arguments must have the same number of elements"
@@ -4315,9 +4324,7 @@ def atleast_Nd(
43154324
atleast_3d = partial(atleast_Nd, n=3)
43164325

43174326

4318-
def expand_dims(
4319-
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
4320-
) -> TensorVariable:
4327+
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
43214328
"""Expand the shape of an array.
43224329
43234330
Insert a new axis that will appear at the `axis` position in the expanded
@@ -4336,7 +4343,7 @@ def expand_dims(
43364343
"""
43374344
a = as_tensor(a)
43384345

4339-
if not isinstance(axis, tuple | list):
4346+
if not isinstance(axis, Sequence):
43404347
axis = (axis,)
43414348

43424349
out_ndim = len(axis) + a.ndim

0 commit comments

Comments
 (0)