@@ -4,7 +4,7 @@ Adding JAX, Numba and Pytorch support for `Op`\s
44PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do
55this, each :class: `Op ` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function.
66
7- This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class: `Op `.
7+ This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class: `Op `.
88
99Step 1: Identify the PyTensor :class: `Op ` you'd like to implement
1010------------------------------------------------------------------------
@@ -60,7 +60,7 @@ could also have any data type (e.g. floats, ints), so our implementation
6060must be able to handle all the possible data types.
6161
6262It also tells us that there's only one return value, that it has a data type
63- determined by :meth: `x.type() ` i.e., the data type of the original tensor.
63+ determined by :meth: `x.type ` i.e., the data type of the original tensor.
6464This implies that the result is necessarily a matrix.
6565
6666Some class may have a more complex behavior. For example, the :class: `CumOp `\ :class: `Op `
@@ -116,7 +116,7 @@ Here's an example for :class:`DimShuffle`:
116116
117117.. tab-set ::
118118
119- .. tab-item :: JAX
119+ .. tab-item :: JAX
120120
121121 .. code :: python
122122
@@ -134,7 +134,7 @@ Here's an example for :class:`DimShuffle`:
134134 res = jnp.copy(res)
135135
136136 return res
137-
137+
138138 .. tab-item :: Numba
139139
140140 .. code :: python
@@ -465,7 +465,7 @@ Step 4: Write tests
465465 .. tab-item :: JAX
466466
467467 Test that your registered `Op ` is working correctly by adding tests to the
468- appropriate test suites in PyTensor (e.g. in ``tests.link.jax ``).
468+ appropriate test suites in PyTensor (e.g. in ``tests.link.jax ``).
469469 The tests should ensure that your implementation can
470470 handle the appropriate types of inputs and produce outputs equivalent to `Op.perform `.
471471 Check the existing tests for the general outline of these kinds of tests. In
@@ -478,7 +478,7 @@ Step 4: Write tests
478478 Here's a small example of a test for :class: `CumOp ` above:
479479
480480 .. code :: python
481-
481+
482482 import numpy as np
483483 import pytensor.tensor as pt
484484 from pytensor.configdefaults import config
@@ -514,22 +514,22 @@ Step 4: Write tests
514514 .. code :: python
515515
516516 import pytest
517-
517+
518518 def test_jax_CumOp ():
519519 """ Test JAX conversion of the `CumOp` `Op`."""
520520 a = pt.matrix(" a" )
521521 a.tag.test_value = np.arange(9 , dtype = config.floatX).reshape((3 , 3 ))
522-
522+
523523 with pytest.raises(NotImplementedError ):
524524 out = pt.cumprod(a, axis = 1 )
525525 fgraph = FunctionGraph([a], [out])
526526 compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
527-
528-
527+
528+
529529 .. tab-item :: Numba
530530
531531 Test that your registered `Op ` is working correctly by adding tests to the
532- appropriate test suites in PyTensor (e.g. in ``tests.link.numba ``).
532+ appropriate test suites in PyTensor (e.g. in ``tests.link.numba ``).
533533 The tests should ensure that your implementation can
534534 handle the appropriate types of inputs and produce outputs equivalent to `Op.perform `.
535535 Check the existing tests for the general outline of these kinds of tests. In
@@ -542,7 +542,7 @@ Step 4: Write tests
542542 Here's a small example of a test for :class: `CumOp ` above:
543543
544544 .. code :: python
545-
545+
546546 from tests.link.numba.test_basic import compare_numba_and_py
547547 from pytensor.graph import FunctionGraph
548548 from pytensor.compile.sharedvalue import SharedVariable
@@ -561,11 +561,11 @@ Step 4: Write tests
561561 if not isinstance (i, SharedVariable | Constant)
562562 ],
563563 )
564-
564+
565565
566566
567567 .. tab-item :: Pytorch
568-
568+
569569 Test that your registered `Op ` is working correctly by adding tests to the
570570 appropriate test suites in PyTensor (``tests.link.pytorch ``). The tests should ensure that your implementation can
571571 handle the appropriate types of inputs and produce outputs equivalent to `Op.perform `.
@@ -579,7 +579,7 @@ Step 4: Write tests
579579 Here's a small example of a test for :class: `CumOp ` above:
580580
581581 .. code :: python
582-
582+
583583 import numpy as np
584584 import pytest
585585 import pytensor.tensor as pt
@@ -592,7 +592,7 @@ Step 4: Write tests
592592 [" float64" , " int64" ],
593593 )
594594 @pytest.mark.parametrize (
595- " axis" ,
595+ " axis" ,
596596 [None , 1 , (0 ,)],
597597 )
598598 def test_pytorch_CumOp (axis , dtype ):
@@ -650,4 +650,4 @@ as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_.
650650All jitted functions now must have constant shape, which means a graph like the
651651one of :class: `Eye ` can never be translated to JAX, since it's fundamentally a
652652function with dynamic shapes. In other words, only PyTensor graphs with static shapes
653- can be translated to JAX at the moment.
653+ can be translated to JAX at the moment.
0 commit comments