Skip to content

Commit eee63d4

Browse files
author
Ian Schweer
committed
Merge branch 'main' into scalarloop
2 parents cb7e4db + a180b88 commit eee63d4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+3168
-782
lines changed

.github/workflows/pypi.yml

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,38 @@ jobs:
5050
fetch-depth: 0
5151

5252
- name: Build wheels
53-
uses: pypa/cibuildwheel@v2.19.2
53+
uses: pypa/cibuildwheel@v2.21.0
5454

5555
- uses: actions/upload-artifact@v4
5656
with:
5757
name: wheels-${{ matrix.platform }}
5858
path: ./wheelhouse/*.whl
5959

60+
build_universal_wheel:
61+
name: Build universal wheel for Pyodide
62+
runs-on: ubuntu-latest
63+
steps:
64+
- uses: actions/checkout@v4
65+
with:
66+
fetch-depth: 0
67+
68+
- name: Set up Python
69+
uses: actions/setup-python@v4
70+
with:
71+
python-version: '3.11'
72+
73+
- name: Install dependencies
74+
run: pip install numpy versioneer wheel
75+
76+
- name: Build universal wheel
77+
run: |
78+
PYODIDE=1 python setup.py bdist_wheel --universal
79+
80+
- uses: actions/upload-artifact@v4
81+
with:
82+
name: universal_wheel
83+
path: dist/*.whl
84+
6085
check_dist:
6186
name: Check dist
6287
needs: [make_sdist,build_wheels]
@@ -103,6 +128,11 @@ jobs:
103128
path: dist
104129
merge-multiple: true
105130

131+
- uses: actions/download-artifact@v4
132+
with:
133+
name: universal_wheel
134+
path: dist
135+
106136
- uses: pypa/[email protected]
107137
with:
108138
user: __token__

.pre-commit-config.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ repos:
2121
pytensor/tensor/variable\.py|
2222
)$
2323
- id: check-merge-conflict
24+
- repo: https://github.com/sphinx-contrib/sphinx-lint
25+
rev: v1.0.0
26+
hooks:
27+
- id: sphinx-lint
28+
args: ["."]
2429
- repo: https://github.com/astral-sh/ruff-pre-commit
25-
rev: v0.5.5
30+
rev: v0.6.5
2631
hooks:
2732
- id: ruff
2833
args: ["--fix", "--output-format=full"]

readthedocs.yml renamed to .readthedocs.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ sphinx:
44
conda:
55
environment: doc/environment.yml
66
build:
7-
os: "ubuntu-20.04"
7+
os: "ubuntu-lts-latest"
88
tools:
9-
python: "mambaforge-4.10"
9+
python: "mambaforge-latest"

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ For issues a minimal working example (MWE) is strongly recommended when relevant
2121
(fixing a typo in the documentation does not require a MWE). For discussions,
2222
MWEs are generally required. All MWEs must be implemented using PyTensor. Please
2323
do not submit MWEs if they are not implemented in PyTensor. In certain cases,
24-
pseudocode may be acceptable, but an PyTensor implementation is always preferable.
24+
pseudocode may be acceptable, but a PyTensor implementation is always preferable.
2525

2626
## Quick links
2727

doc/extending/creating_a_c_op.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ This distance between consecutive elements of an array over a given dimension,
152152
is called the stride of that dimension.
153153

154154

155-
Accessing NumPy :class`ndarray`\s' data and properties
155+
Accessing NumPy :class:`ndarray`'s data and properties
156156
------------------------------------------------------
157157

158158
The following macros serve to access various attributes of NumPy :class:`ndarray`\s.

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Adding JAX, Numba and Pytorch support for `Op`\s
44
PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do
55
this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function.
66

7-
This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`.
7+
This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`.
88

99
Step 1: Identify the PyTensor :class:`Op` you'd like to implement
1010
------------------------------------------------------------------------
@@ -60,7 +60,7 @@ could also have any data type (e.g. floats, ints), so our implementation
6060
must be able to handle all the possible data types.
6161

6262
It also tells us that there's only one return value, that it has a data type
63-
determined by :meth:`x.type()` i.e., the data type of the original tensor.
63+
determined by :meth:`x.type` i.e., the data type of the original tensor.
6464
This implies that the result is necessarily a matrix.
6565

6666
Some class may have a more complex behavior. For example, the :class:`CumOp`\ :class:`Op`
@@ -116,7 +116,7 @@ Here's an example for :class:`DimShuffle`:
116116

117117
.. tab-set::
118118

119-
.. tab-item:: JAX
119+
.. tab-item:: JAX
120120

121121
.. code:: python
122122
@@ -134,7 +134,7 @@ Here's an example for :class:`DimShuffle`:
134134
res = jnp.copy(res)
135135
136136
return res
137-
137+
138138
.. tab-item:: Numba
139139

140140
.. code:: python
@@ -465,7 +465,7 @@ Step 4: Write tests
465465
.. tab-item:: JAX
466466

467467
Test that your registered `Op` is working correctly by adding tests to the
468-
appropriate test suites in PyTensor (e.g. in ``tests.link.jax``).
468+
appropriate test suites in PyTensor (e.g. in ``tests.link.jax``).
469469
The tests should ensure that your implementation can
470470
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
471471
Check the existing tests for the general outline of these kinds of tests. In
@@ -478,7 +478,7 @@ Step 4: Write tests
478478
Here's a small example of a test for :class:`CumOp` above:
479479

480480
.. code:: python
481-
481+
482482
import numpy as np
483483
import pytensor.tensor as pt
484484
from pytensor.configdefaults import config
@@ -514,22 +514,22 @@ Step 4: Write tests
514514
.. code:: python
515515
516516
import pytest
517-
517+
518518
def test_jax_CumOp():
519519
"""Test JAX conversion of the `CumOp` `Op`."""
520520
a = pt.matrix("a")
521521
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
522-
522+
523523
with pytest.raises(NotImplementedError):
524524
out = pt.cumprod(a, axis=1)
525525
fgraph = FunctionGraph([a], [out])
526526
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
527-
528-
527+
528+
529529
.. tab-item:: Numba
530530

531531
Test that your registered `Op` is working correctly by adding tests to the
532-
appropriate test suites in PyTensor (e.g. in ``tests.link.numba``).
532+
appropriate test suites in PyTensor (e.g. in ``tests.link.numba``).
533533
The tests should ensure that your implementation can
534534
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
535535
Check the existing tests for the general outline of these kinds of tests. In
@@ -542,7 +542,7 @@ Step 4: Write tests
542542
Here's a small example of a test for :class:`CumOp` above:
543543

544544
.. code:: python
545-
545+
546546
from tests.link.numba.test_basic import compare_numba_and_py
547547
from pytensor.graph import FunctionGraph
548548
from pytensor.compile.sharedvalue import SharedVariable
@@ -561,11 +561,11 @@ Step 4: Write tests
561561
if not isinstance(i, SharedVariable | Constant)
562562
],
563563
)
564-
564+
565565
566566
567567
.. tab-item:: Pytorch
568-
568+
569569
Test that your registered `Op` is working correctly by adding tests to the
570570
appropriate test suites in PyTensor (``tests.link.pytorch``). The tests should ensure that your implementation can
571571
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
@@ -579,7 +579,7 @@ Step 4: Write tests
579579
Here's a small example of a test for :class:`CumOp` above:
580580

581581
.. code:: python
582-
582+
583583
import numpy as np
584584
import pytest
585585
import pytensor.tensor as pt
@@ -592,7 +592,7 @@ Step 4: Write tests
592592
["float64", "int64"],
593593
)
594594
@pytest.mark.parametrize(
595-
"axis",
595+
"axis",
596596
[None, 1, (0,)],
597597
)
598598
def test_pytorch_CumOp(axis, dtype):
@@ -650,4 +650,4 @@ as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_.
650650
All jitted functions now must have constant shape, which means a graph like the
651651
one of :class:`Eye` can never be translated to JAX, since it's fundamentally a
652652
function with dynamic shapes. In other words, only PyTensor graphs with static shapes
653-
can be translated to JAX at the moment.
653+
can be translated to JAX at the moment.

doc/extending/type.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ returns eitehr a new transferred variable (which can be the same as
333333
the input if no transfer is necessary) or returns None if the transfer
334334
can't be done.
335335

336-
Then register that function by calling :func:`register_transfer()`
336+
Then register that function by calling :func:`register_transfer`
337337
with it as argument.
338338

339339
An example

doc/library/compile/io.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ The ``inputs`` argument to ``pytensor.function`` is a list, containing the ``Var
3636
``self.<name>``. The default value is ``None``.
3737

3838
``value``: literal or ``Container``. The initial/default value for this
39-
input. If update is`` None``, this input acts just like
39+
input. If update is ``None``, this input acts just like
4040
an argument with a default value in Python. If update is not ``None``,
4141
changes to this
4242
value will "stick around", whether due to an update or a user's

doc/library/config.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ import ``pytensor`` and print the config variable, as in:
226226
in the future.
227227

228228
The ``'numpy+floatX'`` setting attempts to mimic NumPy casting rules,
229-
although it prefers to use ``float32` `numbers instead of ``float64`` when
229+
although it prefers to use ``float32`` numbers instead of ``float64`` when
230230
``config.floatX`` is set to ``'float32'`` and the associated data is not
231231
explicitly typed as ``float64`` (e.g. regular Python floats). Note that
232232
``'numpy+floatX'`` is not currently behaving exactly as planned (it is a

0 commit comments

Comments
 (0)