Skip to content

Commit 06ccf91

Browse files
authored
Merge branch 'main' into mlx-poc
2 parents 02ed254 + 73c0d4d commit 06ccf91

File tree

193 files changed

+19779
-4137
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

193 files changed

+19779
-4137
lines changed

.github/workflows/pypi.yml

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ on:
33
push:
44
branches:
55
- main
6-
- auto-release
76
pull_request:
87
branches: [main]
98
release:
@@ -16,10 +15,50 @@ concurrency:
1615
group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }}
1716
cancel-in-progress: true
1817

18+
permissions: {}
19+
1920
jobs:
21+
check_changes:
22+
runs-on: ubuntu-latest
23+
outputs:
24+
should_run: ${{ steps.set_should_run.outputs.should_run }}
25+
steps:
26+
- uses: actions/checkout@v4
27+
with:
28+
persist-credentials: false
29+
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
30+
id: filter
31+
with:
32+
filters: |
33+
any_changed:
34+
- '.github/workflows/pypi.yml'
35+
- 'pyproject.toml'
36+
- 'setup.py'
37+
- 'pytensor/_version.py'
38+
- 'pytensor/scan_perform.pyx'
39+
- 'pytensor/scan_perform_ext.py'
40+
- name: Set should_run output
41+
id: set_should_run
42+
run: |
43+
if [[ "${{ github.event_name == 'release' ||
44+
github.ref == 'refs/heads/main' ||
45+
(
46+
github.event_name == 'pull_request'
47+
&& steps.filter.outputs.any_changed == 'true'
48+
)
49+
}}" == "true" ]]; then
50+
echo "should_run=true" >> $GITHUB_OUTPUT
51+
else
52+
echo "should_run=false" >> $GITHUB_OUTPUT
53+
fi
54+
2055
# The job to build precompiled pypi wheels.
2156
make_sdist:
2257
name: Make SDist
58+
needs: check_changes
59+
# Run if it's a release or if relevant files changed on main
60+
if: |
61+
needs.check_changes.outputs.should_run == 'true'
2362
runs-on: ubuntu-latest
2463
permissions:
2564
# write id-token and attestations are required to attest build provenance
@@ -49,6 +88,10 @@ jobs:
4988

5089
run_checks:
5190
name: Build & inspect our package.
91+
needs: check_changes
92+
# Run if it's a release or if relevant files changed on main
93+
if: |
94+
needs.check_changes.outputs.should_run == 'true'
5295
# Note: the resulting builds are not actually published.
5396
# This is purely for additional testing and diagnostic purposes.
5497
runs-on: ubuntu-latest
@@ -62,6 +105,10 @@ jobs:
62105

63106
build_wheels:
64107
name: Build wheels for ${{ matrix.platform }}
108+
needs: check_changes
109+
# Run if it's a release or if relevant files changed on main
110+
if: |
111+
needs.check_changes.outputs.should_run == 'true'
65112
runs-on: ${{ matrix.platform }}
66113
permissions:
67114
# write id-token and attestations are required to attest build provenance
@@ -80,7 +127,7 @@ jobs:
80127
persist-credentials: false
81128

82129
- name: Build wheels
83-
uses: pypa/[email protected].0
130+
uses: pypa/cibuildwheel@faf86a6ed7efa889faf6996aa23820831055001a # v2.23.3
84131

85132
- name: Attest GitHub build provenance
86133
uses: actions/attest-build-provenance@v2
@@ -96,6 +143,10 @@ jobs:
96143

97144
build_universal_wheel:
98145
name: Build universal wheel for Pyodide
146+
needs: check_changes
147+
# Run if it's a release or if relevant files changed on main
148+
if: |
149+
needs.check_changes.outputs.should_run == 'true'
99150
runs-on: ubuntu-latest
100151
permissions:
101152
# write id-token and attestations are required to attest build provenance
@@ -113,7 +164,7 @@ jobs:
113164
python-version: '3.11'
114165

115166
- name: Install dependencies
116-
run: pip install numpy versioneer wheel
167+
run: pip install --upgrade setuptools numpy versioneer wheel
117168

118169
- name: Build universal wheel
119170
run: |
@@ -133,7 +184,7 @@ jobs:
133184

134185
check_dist:
135186
name: Check dist
136-
needs: [make_sdist,build_wheels]
187+
needs: [check_changes, make_sdist, build_wheels]
137188
runs-on: ubuntu-22.04
138189
steps:
139190
- uses: actions/download-artifact@v4

.github/workflows/test.yml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,12 @@ jobs:
8383
install-jax: [0]
8484
install-torch: [0]
8585
install-mlx: [0]
86+
install-xarray: [0]
8687
part:
87-
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
88+
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse --ignore=tests/xtensor"
8889
- "tests/scan"
8990
- "tests/sparse"
90-
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py"
91+
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
9192
- "tests/tensor/conv"
9293
- "tests/tensor/rewriting"
9394
- "tests/tensor/test_math.py"
@@ -108,7 +109,7 @@ jobs:
108109
python-version: "3.13"
109110
include:
110111
- os: "ubuntu-latest"
111-
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
112+
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
112113
python-version: "3.12"
113114
numpy-version: ">=2.0"
114115
fast-compile: 0
@@ -117,6 +118,7 @@ jobs:
117118
install-jax: 0
118119
install-torch: 0
119120
install-mlx: 0
121+
install-xarray: 0
120122
- install-numba: 1
121123
os: "ubuntu-latest"
122124
python-version: "3.10"
@@ -159,6 +161,13 @@ jobs:
159161
fast-compile: 0
160162
float32: 0
161163
part: "tests/link/mlx"
164+
- install-xarray: 1
165+
os: "ubuntu-latest"
166+
python-version: "3.13"
167+
numpy-version: ">=2.0"
168+
fast-compile: 0
169+
float32: 0
170+
part: "tests/xtensor"
162171
- os: macos-15
163172
python-version: "3.13"
164173
numpy-version: ">=2.0"
@@ -206,6 +215,7 @@ jobs:
206215
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
207216
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
208217
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi
218+
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
209219
pip install pytest-sphinx
210220
211221
pip install -e ./
@@ -223,6 +233,7 @@ jobs:
223233
INSTALL_JAX: ${{ matrix.install-jax }}
224234
INSTALL_TORCH: ${{ matrix.install-torch}}
225235
INSTALL_MLX: ${{ matrix.install-mlx }}
236+
INSTALL_XARRAY: ${{ matrix.install-xarray }}
226237
OS: ${{ matrix.os}}
227238

228239
- name: Run tests

doc/conf.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import inspect
33
import sys
4+
45
import pytensor
56
from pathlib import Path
67

@@ -12,6 +13,7 @@
1213
"sphinx.ext.autodoc",
1314
"sphinx.ext.todo",
1415
"sphinx.ext.doctest",
16+
"sphinx_copybutton",
1517
"sphinx.ext.napoleon",
1618
"sphinx.ext.linkcode",
1719
"sphinx.ext.mathjax",
@@ -86,8 +88,7 @@
8688

8789
# List of directories, relative to source directories, that shouldn't be
8890
# searched for source files.
89-
exclude_dirs = ["images", "scripts", "sandbox"]
90-
exclude_patterns = ['page_footer.md', '**/*.myst.md']
91+
exclude_patterns = ["README.md", "images/*", "page_footer.md", "**/*.myst.md"]
9192

9293
# The reST default role (used for this markup: `text`) to use for all
9394
# documents.
@@ -235,24 +236,41 @@
235236
# Resolve function
236237
# This function is used to populate the (source) links in the API
237238
def linkcode_resolve(domain, info):
238-
def find_source():
239+
def find_obj() -> object:
239240
# try to find the file and line number, based on code from numpy:
240241
# https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
241242
obj = sys.modules[info["module"]]
242243
for part in info["fullname"].split("."):
243244
obj = getattr(obj, part)
245+
return obj
244246

247+
def find_source(obj):
245248
fn = Path(inspect.getsourcefile(obj))
246-
fn = fn.relative_to(Path(__file__).parent)
249+
fn = fn.relative_to(Path(pytensor.__file__).parent)
247250
source, lineno = inspect.getsourcelines(obj)
248251
return fn, lineno, lineno + len(source) - 1
249252

253+
def fallback_source():
254+
return info["module"].replace(".", "/") + ".py"
255+
250256
if domain != "py" or not info["module"]:
251257
return None
258+
252259
try:
253-
filename = "pytensor/%s#L%d-L%d" % find_source()
260+
obj = find_obj()
254261
except Exception:
255-
filename = info["module"].replace(".", "/") + ".py"
262+
filename = fallback_source()
263+
else:
264+
try:
265+
filename = "pytensor/%s#L%d-L%d" % find_source(obj)
266+
except Exception:
267+
# warnings.warn(f"Could not find source code for {domain}:{info}")
268+
try:
269+
filename = obj.__module__.replace(".", "/") + ".py"
270+
except AttributeError:
271+
# Some objects do not have a __module__ attribute (?)
272+
filename = fallback_source()
273+
256274
import subprocess
257275

258276
tag = subprocess.Popen(

doc/environment.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ dependencies:
1313
- mock
1414
- pillow
1515
- pymc-sphinx-theme
16+
- sphinx-copybutton
1617
- sphinx-design
18+
- sphinx-sitemap
1719
- pygments
1820
- pydot
1921
- ipython
@@ -23,5 +25,4 @@ dependencies:
2325
- ablog
2426
- pip
2527
- pip:
26-
- sphinx_sitemap
2728
- -e ..

0 commit comments

Comments
 (0)