|
1 | 1 | Adding JAX, Numba and Pytorch support for `Op`\s |
2 | | -======================================= |
| 2 | +================================================ |
3 | 3 |
|
4 | 4 | PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do |
5 | 5 | this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function. |
6 | 6 |
|
7 | 7 | This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`. |
8 | 8 |
|
9 | 9 | Step 1: Identify the PyTensor :class:`Op` you'd like to implement |
10 | | ------------------------------------------------------------------------- |
| 10 | +----------------------------------------------------------------- |
11 | 11 |
|
12 | 12 | Find the source for the PyTensor :class:`Op` you'd like to be supported and |
13 | 13 | identify the function signature and return values. These can be determined by |
@@ -98,7 +98,7 @@ how the inputs and outputs are used to compute the outputs for an :class:`Op` |
98 | 98 | in Python. This method is effectively what needs to be implemented. |
99 | 99 |
|
100 | 100 | Step 2: Find the relevant method in JAX/Numba/Pytorch (or something close) |
101 | | ---------------------------------------------------------- |
| 101 | +-------------------------------------------------------------------------- |
102 | 102 |
|
103 | 103 | With a precise idea of what the PyTensor :class:`Op` does we need to figure out how |
104 | 104 | to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named |
@@ -269,7 +269,7 @@ and :func:`torch.cumprod` |
269 | 269 | z[0] = np.cumprod(x, axis=self.axis) |
270 | 270 |
|
271 | 271 | Step 3: Register the function with the respective dispatcher |
272 | | ---------------------------------------------------------------- |
| 272 | +------------------------------------------------------------ |
273 | 273 |
|
274 | 274 | With the PyTensor `Op` replicated, we'll need to register the |
275 | 275 | function with the backends `Linker`. This is done through the use of |
|
0 commit comments