Skip to content

Commit 8eb150b

Browse files
committed
Reorg JAX dispatch structure
1 parent 03bb3a8 commit 8eb150b

File tree

9 files changed

+168
-148
lines changed

9 files changed

+168
-148
lines changed

pytensor/link/jax/dispatch/blas.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.blas import BatchedDot
5+
6+
7+
@jax_funcify.register(BatchedDot)
8+
def jax_funcify_BatchedDot(op, **kwargs):
9+
def batched_dot(a, b):
10+
if a.shape[0] != b.shape[0]:
11+
raise TypeError("Shapes must match in the 0-th dimension")
12+
return jnp.matmul(a, b)
13+
14+
return batched_dot

pytensor/link/jax/dispatch/math.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.math import Argmax, Dot, Max
5+
6+
7+
@jax_funcify.register(Dot)
8+
def jax_funcify_Dot(op, **kwargs):
9+
def dot(x, y):
10+
return jnp.dot(x, y)
11+
12+
return dot
13+
14+
15+
@jax_funcify.register(Max)
16+
def jax_funcify_Max(op, **kwargs):
17+
axis = op.axis
18+
19+
def max(x):
20+
max_res = jnp.max(x, axis)
21+
22+
return max_res
23+
24+
return max
25+
26+
27+
@jax_funcify.register(Argmax)
28+
def jax_funcify_Argmax(op, **kwargs):
29+
axis = op.axis
30+
31+
def argmax(x):
32+
if axis is None:
33+
axes = tuple(range(x.ndim))
34+
else:
35+
axes = tuple(int(ax) for ax in axis)
36+
37+
# NumPy does not support multiple axes for argmax; this is a
38+
# work-around
39+
keep_axes = jnp.array(
40+
[i for i in range(x.ndim) if i not in axes], dtype="int64"
41+
)
42+
# Not-reduced axes in front
43+
transposed_x = jnp.transpose(
44+
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
45+
)
46+
kept_shape = transposed_x.shape[: len(keep_axes)]
47+
reduced_shape = transposed_x.shape[len(keep_axes) :]
48+
49+
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
50+
# Otherwise reshape would complain citing float arg
51+
new_shape = (
52+
*kept_shape,
53+
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
54+
)
55+
reshaped_x = transposed_x.reshape(new_shape)
56+
57+
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
58+
59+
return max_idx_res
60+
61+
return argmax

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import jax.numpy as jnp
22

33
from pytensor.link.jax.dispatch import jax_funcify
4-
from pytensor.tensor.blas import BatchedDot
5-
from pytensor.tensor.math import Argmax, Dot, Max
64
from pytensor.tensor.nlinalg import (
75
SVD,
86
Det,
@@ -79,14 +77,6 @@ def qr_full(x, mode=mode):
7977
return qr_full
8078

8179

82-
@jax_funcify.register(Dot)
83-
def jax_funcify_Dot(op, **kwargs):
84-
def dot(x, y):
85-
return jnp.dot(x, y)
86-
87-
return dot
88-
89-
9080
@jax_funcify.register(MatrixPinv)
9181
def jax_funcify_Pinv(op, **kwargs):
9282
def pinv(x):
@@ -95,68 +85,9 @@ def pinv(x):
9585
return pinv
9686

9787

98-
@jax_funcify.register(BatchedDot)
99-
def jax_funcify_BatchedDot(op, **kwargs):
100-
def batched_dot(a, b):
101-
if a.shape[0] != b.shape[0]:
102-
raise TypeError("Shapes must match in the 0-th dimension")
103-
return jnp.matmul(a, b)
104-
105-
return batched_dot
106-
107-
10888
@jax_funcify.register(KroneckerProduct)
10989
def jax_funcify_KroneckerProduct(op, **kwargs):
11090
def _kron(x, y):
11191
return jnp.kron(x, y)
11292

11393
return _kron
114-
115-
116-
@jax_funcify.register(Max)
117-
def jax_funcify_Max(op, **kwargs):
118-
axis = op.axis
119-
120-
def max(x):
121-
max_res = jnp.max(x, axis)
122-
123-
return max_res
124-
125-
return max
126-
127-
128-
@jax_funcify.register(Argmax)
129-
def jax_funcify_Argmax(op, **kwargs):
130-
axis = op.axis
131-
132-
def argmax(x):
133-
if axis is None:
134-
axes = tuple(range(x.ndim))
135-
else:
136-
axes = tuple(int(ax) for ax in axis)
137-
138-
# NumPy does not support multiple axes for argmax; this is a
139-
# work-around
140-
keep_axes = jnp.array(
141-
[i for i in range(x.ndim) if i not in axes], dtype="int64"
142-
)
143-
# Not-reduced axes in front
144-
transposed_x = jnp.transpose(
145-
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
146-
)
147-
kept_shape = transposed_x.shape[: len(keep_axes)]
148-
reduced_shape = transposed_x.shape[len(keep_axes) :]
149-
150-
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
151-
# Otherwise reshape would complain citing float arg
152-
new_shape = (
153-
*kept_shape,
154-
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
155-
)
156-
reshaped_x = transposed_x.reshape(new_shape)
157-
158-
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
159-
160-
return max_idx_res
161-
162-
return argmax

pytensor/link/pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pytensor.link.pytorch.linker import PytorchLinker

tests/link/jax/test_blas.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.compile.function import function
5+
from pytensor.compile.mode import Mode
6+
from pytensor.configdefaults import config
7+
from pytensor.graph.fg import FunctionGraph
8+
from pytensor.graph.op import get_test_value
9+
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
10+
from pytensor.link.jax import JAXLinker
11+
from pytensor.tensor import blas as pt_blas
12+
from pytensor.tensor.type import tensor3
13+
from tests.link.jax.test_basic import compare_jax_and_py
14+
15+
16+
jax = pytest.importorskip("jax")
17+
18+
19+
def test_jax_BatchedDot():
20+
# tensor3 . tensor3
21+
a = tensor3("a")
22+
a.tag.test_value = (
23+
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
24+
)
25+
b = tensor3("b")
26+
b.tag.test_value = (
27+
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
28+
)
29+
out = pt_blas.BatchedDot()(a, b)
30+
fgraph = FunctionGraph([a, b], [out])
31+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
32+
33+
# A dimension mismatch should raise a TypeError for compatibility
34+
inputs = [get_test_value(a)[:-1], get_test_value(b)]
35+
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
36+
jax_mode = Mode(JAXLinker(), opts)
37+
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
38+
with pytest.raises(TypeError):
39+
pytensor_jax_fn(*inputs)

tests/link/jax/test_math.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.configdefaults import config
5+
from pytensor.graph.fg import FunctionGraph
6+
from pytensor.graph.op import get_test_value
7+
from pytensor.tensor.math import Argmax, Max, maximum
8+
from pytensor.tensor.math import max as pt_max
9+
from pytensor.tensor.type import dvector, matrix, scalar, vector
10+
from tests.link.jax.test_basic import compare_jax_and_py
11+
12+
13+
jax = pytest.importorskip("jax")
14+
15+
16+
def test_jax_max_and_argmax():
17+
# Test that a single output of a multi-output `Op` can be used as input to
18+
# another `Op`
19+
x = dvector()
20+
mx = Max([0])(x)
21+
amx = Argmax([0])(x)
22+
out = mx * amx
23+
out_fg = FunctionGraph([x], [out])
24+
compare_jax_and_py(out_fg, [np.r_[1, 2]])
25+
26+
27+
def test_dot():
28+
y = vector("y")
29+
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
30+
x = vector("x")
31+
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
32+
A = matrix("A")
33+
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
34+
alpha = scalar("alpha")
35+
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
36+
beta = scalar("beta")
37+
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
38+
39+
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
40+
# optimizations are turned on; however, when using JAX mode, it should
41+
# leave the expression alone.
42+
out = y.dot(alpha * A).dot(x) + beta * y
43+
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
44+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
45+
46+
out = maximum(y, x)
47+
fgraph = FunctionGraph([y, x], [out])
48+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
49+
50+
out = pt_max(y)
51+
fgraph = FunctionGraph([y], [out])
52+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

tests/link/jax/test_nlinalg.py

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,17 @@
11
import numpy as np
22
import pytest
3-
from packaging.version import parse as version_parse
43

54
from pytensor.compile.function import function
6-
from pytensor.compile.mode import Mode
75
from pytensor.configdefaults import config
86
from pytensor.graph.fg import FunctionGraph
9-
from pytensor.graph.op import get_test_value
10-
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
11-
from pytensor.link.jax import JAXLinker
12-
from pytensor.tensor import blas as pt_blas
137
from pytensor.tensor import nlinalg as pt_nlinalg
14-
from pytensor.tensor.math import Argmax, Max, maximum
15-
from pytensor.tensor.math import max as pt_max
16-
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
8+
from pytensor.tensor.type import matrix
179
from tests.link.jax.test_basic import compare_jax_and_py
1810

1911

2012
jax = pytest.importorskip("jax")
2113

2214

23-
def test_jax_BatchedDot():
24-
# tensor3 . tensor3
25-
a = tensor3("a")
26-
a.tag.test_value = (
27-
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
28-
)
29-
b = tensor3("b")
30-
b.tag.test_value = (
31-
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
32-
)
33-
out = pt_blas.BatchedDot()(a, b)
34-
fgraph = FunctionGraph([a, b], [out])
35-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
36-
37-
# A dimension mismatch should raise a TypeError for compatibility
38-
inputs = [get_test_value(a)[:-1], get_test_value(b)]
39-
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
40-
jax_mode = Mode(JAXLinker(), opts)
41-
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
42-
with pytest.raises(TypeError):
43-
pytensor_jax_fn(*inputs)
44-
45-
4615
def test_jax_basic_multiout():
4716
rng = np.random.default_rng(213234)
4817

@@ -80,53 +49,6 @@ def assert_fn(x, y):
8049
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
8150

8251

83-
@pytest.mark.xfail(
84-
version_parse(jax.__version__) >= version_parse("0.2.12"),
85-
reason="Omnistaging cannot be disabled",
86-
)
87-
def test_jax_basic_multiout_omni():
88-
# Test that a single output of a multi-output `Op` can be used as input to
89-
# another `Op`
90-
x = dvector()
91-
mx = Max([0])(x)
92-
amx = Argmax([0])(x)
93-
out = mx * amx
94-
out_fg = FunctionGraph([x], [out])
95-
compare_jax_and_py(out_fg, [np.r_[1, 2]])
96-
97-
98-
@pytest.mark.xfail(
99-
version_parse(jax.__version__) >= version_parse("0.2.12"),
100-
reason="Omnistaging cannot be disabled",
101-
)
102-
def test_tensor_basics():
103-
y = vector("y")
104-
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
105-
x = vector("x")
106-
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
107-
A = matrix("A")
108-
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
109-
alpha = scalar("alpha")
110-
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
111-
beta = scalar("beta")
112-
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
113-
114-
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
115-
# optimizations are turned on; however, when using JAX mode, it should
116-
# leave the expression alone.
117-
out = y.dot(alpha * A).dot(x) + beta * y
118-
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
119-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
120-
121-
out = maximum(y, x)
122-
fgraph = FunctionGraph([y, x], [out])
123-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
124-
125-
out = pt_max(y)
126-
fgraph = FunctionGraph([y], [out])
127-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
128-
129-
13052
def test_pinv():
13153
x = matrix("x")
13254
x_inv = pt_nlinalg.pinv(x)
File renamed without changes.

0 commit comments

Comments
 (0)