Skip to content

Add JAX dispatch for CholeskySolve Op #4588

Add JAX dispatch for CholeskySolve Op

Add JAX dispatch for CholeskySolve Op #4588

Workflow file for this run

name: Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
# Cancels all previous workflow runs for pull requests that have not completed.
concurrency:
# The concurrency group contains the workflow name and the branch name for pull requests
# or the commit hash for any other events.
group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }}
cancel-in-progress: true
jobs:
changes:
name: "Check for changes"
runs-on: ubuntu-latest
outputs:
changes: ${{ steps.changes.outputs.src }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- uses: dorny/paths-filter@v3
id: changes
with:
filters: |
python: &python
- 'pytensor/**/*.py'
- 'tests/**/*.py'
- 'pytensor/**/*.pyx'
- 'tests/**/*.pyx'
- '*.py'
src:
- *python
- 'pytensor/**/*.c'
- 'tests/**/*.c'
- 'pytensor/**/*.h'
- 'tests/**/*.h'
- '.github/workflows/*.yml'
- 'setup.cfg'
- 'requirements.txt'
- '.pre-commit-config.yaml'
style:
name: Check code style
needs: changes
runs-on: ubuntu-latest
if: ${{ needs.changes.outputs.changes == 'true' }}
strategy:
matrix:
python-version: ["3.10", "3.13"]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- uses: pre-commit/[email protected]
test:
name: "${{ matrix.os }} test py${{ matrix.python-version }} numpy${{ matrix.numpy-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
needs:
- changes
- style
runs-on: ${{ matrix.os }}
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.10", "3.13"]
numpy-version: ["~=1.26.0", ">=2.0"]
fast-compile: [0, 1]
float32: [0, 1]
install-numba: [0]
install-jax: [0]
install-torch: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
- "tests/sparse"
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py"
- "tests/tensor/conv"
- "tests/tensor/rewriting"
- "tests/tensor/test_math.py"
- "tests/tensor/test_basic.py tests/tensor/test_inplace.py"
- "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
exclude:
- python-version: "3.10"
fast-compile: 1
- python-version: "3.10"
float32: 1
- fast-compile: 1
float32: 1
- numpy-version: "~=1.26.0"
fast-compile: 1
- numpy-version: "~=1.26.0"
float32: 1
- numpy-version: "~=1.26.0"
python-version: "3.13"
include:
- os: "ubuntu-latest"
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
python-version: "3.12"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
install-numba: 0
install-jax: 0
install-torch: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
numpy-version: "~=2.1.0"
fast-compile: 0
float32: 0
part: "tests/link/numba"
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: "~=2.1.0"
fast-compile: 0
float32: 0
part: "tests/link/numba"
- install-jax: 1
os: "ubuntu-latest"
python-version: "3.10"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-jax: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-torch: 1
os: "ubuntu-latest"
python-version: "3.10"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Set up Python ${{ matrix.python-version }}
uses: mamba-org/setup-micromamba@v2
with:
environment-name: pytensor-test
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
init-shell: bash
post-cleanup: "all"
create-args: python=${{ matrix.python-version }}
- name: Create matrix id
id: matrix-id
env:
MATRIX_CONTEXT: ${{ toJson(matrix) }}
run: |
echo $MATRIX_CONTEXT
export MATRIX_ID=`echo $MATRIX_CONTEXT | sha256sum | cut -c 1-32`
echo $MATRIX_ID
echo "id=$MATRIX_ID" >> $GITHUB_OUTPUT
- name: Install dependencies
shell: micromamba-shell {0}
run: |
if [[ $OS == "macos-15" ]]; then
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate;
else
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
pip install pytest-sphinx
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
if [[ $OS == "macos-15" ]]; then
python -c 'import pytensor; assert pytensor.config.blas__ldflags.startswith("-framework Accelerate"), "Blas flags are not set to MacOS Accelerate"';
else
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"';
fi
env:
PYTHON_VERSION: ${{ matrix.python-version }}
NUMPY_VERSION: ${{ matrix.numpy-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
OS: ${{ matrix.os}}
- name: Run tests
shell: micromamba-shell {0}
run: |
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest -r A --verbose --runslow --durations=50 --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
env:
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
MKL_THREADING_LAYER: GNU
MKL_NUM_THREADS: 1
OMP_NUM_THREADS: 1
PART: ${{ matrix.part }}
FAST_COMPILE: ${{ matrix.fast-compile }}
FLOAT32: ${{ matrix.float32 }}
- name: Upload coverage file
uses: actions/upload-artifact@v4
with:
name: coverage-${{ steps.matrix-id.outputs.id }}
path: coverage/coverage-${{ steps.matrix-id.outputs.id }}.xml
benchmarks:
name: "Benchmarks"
needs:
- changes
- style
runs-on: ubuntu-latest
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Set up Python 3.10
uses: mamba-org/setup-micromamba@v2
with:
environment-name: pytensor-test
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
init-shell: bash
post-cleanup: "all"
- name: Install dependencies
shell: micromamba-shell {0}
run: |
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
env:
PYTHON_VERSION: 3.10
- name: Download previous benchmark data
uses: actions/cache@v4
with:
path: ./cache
key: ${{ runner.os }}-benchmark
- name: Run benchmarks
shell: micromamba-shell {0}
run: |
export PYTENSOR_FLAGS=mode=FAST_COMPILE,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest --runslow --benchmark-only --benchmark-json output.json
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
name: Python Benchmark with pytest-benchmark
tool: "pytest"
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
alert-threshold: "200%"
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: false
fail-on-alert: false
auto-push: false
all-checks:
if: ${{ always() }}
runs-on: ubuntu-latest
name: "All tests"
needs: [changes, style, test]
steps:
- name: Check build matrix status
if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test.result != 'success') }}
run: exit 1
upload-coverage:
runs-on: ubuntu-latest
name: "Upload coverage"
needs: [changes, all-checks]
if: ${{ needs.changes.outputs.changes == 'true' && needs.all-checks.result == 'success' }}
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Install dependencies
run: |
python -m pip install -U coverage>=5.1 coveralls
- name: Download coverage file
uses: actions/download-artifact@v4
with:
pattern: coverage-*
path: coverage
merge-multiple: true
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
directory: ./coverage/
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}