Skip to content

Commit b67ff22

Browse files
aseyboldtjdehningricardoV94
authored
Implement wrap_jax and rename as_op to wrap_py (#1614)
--------- Co-authored-by: Jonas <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 7779b07 commit b67ff22

File tree

12 files changed

+1164
-25
lines changed

12 files changed

+1164
-25
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ jobs:
208208
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
209209
fi
210210
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
211-
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tfp-nightly; fi
211+
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
212212
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
213213
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
214214

doc/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ dependencies:
2525
- ablog
2626
- pip
2727
- pip:
28-
- -e ..
28+
- -e ..[jax]

doc/extending/creating_an_op.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -803,10 +803,10 @@ You can omit the :meth:`Rop` functions. Try to implement the testing apparatus d
803803
:download:`Solution<extending_pytensor_solution_1.py>`
804804

805805

806-
:func:`as_op`
806+
:func:`wrap_py`
807807
-------------
808808

809-
:func:`as_op` is a Python decorator that converts a Python function into a
809+
:func:`wrap_py` is a Python decorator that converts a Python function into a
810810
basic PyTensor :class:`Op` that will call the supplied function during execution.
811811

812812
This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation.
@@ -839,11 +839,11 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature
839839
inputs PyTensor variables that were declared.
840840

841841
.. note::
842-
The python function wrapped by the :func:`as_op` decorator needs to return a new
842+
The python function wrapped by the :func:`wrap_py` decorator needs to return a new
843843
data allocation, no views or in place modification of the input.
844844

845845

846-
:func:`as_op` Example
846+
:func:`wrap_py` Example
847847
^^^^^^^^^^^^^^^^^^^^^
848848

849849
.. testcode:: asop
@@ -852,14 +852,14 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature
852852
import pytensor.tensor as pt
853853
import numpy as np
854854
from pytensor import function
855-
from pytensor.compile.ops import as_op
855+
from pytensor.compile.ops import wrap_py
856856

857857
def infer_shape_numpy_dot(fgraph, node, input_shapes):
858858
ashp, bshp = input_shapes
859859
return [ashp[:-1] + bshp[-1:]]
860860

861861

862-
@as_op(
862+
@wrap_py(
863863
itypes=[pt.dmatrix, pt.dmatrix],
864864
otypes=[pt.dmatrix],
865865
infer_shape=infer_shape_numpy_dot,

doc/extending/extending_pytensor_solution_1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,17 +167,17 @@ def test_infer_shape(self):
167167

168168
import numpy as np
169169

170-
# as_op exercice
170+
# wrap_py exercice
171171
import pytensor
172-
from pytensor.compile.ops import as_op
172+
from pytensor.compile.ops import wrap_py
173173

174174

175175
def infer_shape_numpy_dot(fgraph, node, input_shapes):
176176
ashp, bshp = input_shapes
177177
return [ashp[:-1] + bshp[-1:]]
178178

179179

180-
@as_op(
180+
@wrap_py(
181181
itypes=[pt.fmatrix, pt.fmatrix],
182182
otypes=[pt.fmatrix],
183183
infer_shape=infer_shape_numpy_dot,
@@ -192,7 +192,7 @@ def infer_shape_numpy_add_sub(fgraph, node, input_shapes):
192192
return [ashp[0]]
193193

194194

195-
@as_op(
195+
@wrap_py(
196196
itypes=[pt.fmatrix, pt.fmatrix],
197197
otypes=[pt.fmatrix],
198198
infer_shape=infer_shape_numpy_add_sub,
@@ -201,7 +201,7 @@ def numpy_add(a, b):
201201
return np.add(a, b)
202202

203203

204-
@as_op(
204+
@wrap_py(
205205
itypes=[pt.fmatrix, pt.fmatrix],
206206
otypes=[pt.fmatrix],
207207
infer_shape=infer_shape_numpy_add_sub,

doc/library/index.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,16 @@ Convert to Variable
6161

6262
.. autofunction:: pytensor.as_symbolic(...)
6363

64+
Wrap JAX functions
65+
==================
66+
67+
.. autofunction:: wrap_jax(...)
68+
69+
Alias for :func:`pytensor.link.jax.ops.wrap_jax`
70+
6471
Debug
6572
=====
6673

6774
.. autofunction:: pytensor.dprint(...)
6875

6976
Alias for :func:`pytensor.printing.debugprint`
70-

pytensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def get_underlying_scalar_constant(v):
166166
from pytensor.scan.basic import scan
167167
from pytensor.scan.views import foldl, foldr, map, reduce
168168
from pytensor.compile.builders import OpFromGraph
169-
169+
from pytensor.link.jax.ops import wrap_jax
170170
# isort: on
171171

172172

pytensor/compile/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
register_deep_copy_op_c_code,
5757
register_view_op_c_code,
5858
view_op,
59+
wrap_py,
5960
)
6061
from pytensor.compile.profiling import ProfileStats
6162
from pytensor.compile.sharedvalue import SharedVariable, shared, shared_constructor

pytensor/compile/ops.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
This file contains auxiliary Ops, used during the compilation phase and Ops
3-
building class (:class:`FromFunctionOp`) and decorator (:func:`as_op`) that
3+
building class (:class:`FromFunctionOp`) and decorator (:func:`wrap_py`) that
44
help make new Ops more rapidly.
55
66
"""
@@ -268,12 +268,12 @@ def __reduce__(self):
268268
obj = load_back(mod, name)
269269
except (ImportError, KeyError, AttributeError):
270270
raise pickle.PicklingError(
271-
f"Can't pickle as_op(), not found as {mod}.{name}"
271+
f"Can't pickle wrap_py(), not found as {mod}.{name}"
272272
)
273273
else:
274274
if obj is not self:
275275
raise pickle.PicklingError(
276-
f"Can't pickle as_op(), not the object at {mod}.{name}"
276+
f"Can't pickle wrap_py(), not the object at {mod}.{name}"
277277
)
278278
return load_back, (mod, name)
279279

@@ -282,6 +282,18 @@ def _infer_shape(self, fgraph, node, input_shapes):
282282

283283

284284
def as_op(itypes, otypes, infer_shape=None):
285+
import warnings
286+
287+
warnings.warn(
288+
"pytensor.as_op is deprecated and will be removed in a future release. "
289+
"Please use pytensor.wrap_py instead.",
290+
FutureWarning,
291+
stacklevel=2,
292+
)
293+
return wrap_py(itypes, otypes, infer_shape)
294+
295+
296+
def wrap_py(itypes, otypes, infer_shape=None):
285297
"""
286298
Decorator that converts a function into a basic PyTensor op that will call
287299
the supplied function as its implementation.
@@ -301,8 +313,8 @@ def infer_shape(fgraph, node, input_shapes):
301313
302314
Examples
303315
--------
304-
@as_op(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix],
305-
otypes=[pytensor.tensor.fmatrix])
316+
@wrap_py(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix],
317+
otypes=[pytensor.tensor.fmatrix])
306318
def numpy_dot(a, b):
307319
return numpy.dot(a, b)
308320

pytensor/link/jax/dispatch/basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.graph import Constant
1414
from pytensor.graph.fg import FunctionGraph
1515
from pytensor.ifelse import IfElse
16+
from pytensor.link.jax.ops import JAXOp
1617
from pytensor.link.utils import fgraph_to_python
1718
from pytensor.raise_op import CheckAndRaise
1819

@@ -142,3 +143,8 @@ def opfromgraph(*inputs):
142143
return fgraph_fn(*inputs)
143144

144145
return opfromgraph
146+
147+
148+
@jax_funcify.register(JAXOp)
149+
def jax_op_funcify(op, **kwargs):
150+
return op.perform_jax

0 commit comments

Comments
 (0)