Optimize matmuls involving block diagonal matrices #4602
Workflow file for this run
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | name: Tests | |
| on: | |
| push: | |
| branches: | |
| - main | |
| pull_request: | |
| branches: | |
| - main | |
| # Cancels all previous workflow runs for pull requests that have not completed. | |
| concurrency: | |
| # The concurrency group contains the workflow name and the branch name for pull requests | |
| # or the commit hash for any other events. | |
| group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }} | |
| cancel-in-progress: true | |
| jobs: | |
| changes: | |
| name: "Check for changes" | |
| runs-on: ubuntu-latest | |
| outputs: | |
| changes: ${{ steps.changes.outputs.src }} | |
| steps: | |
| - uses: actions/checkout@v4 | |
| with: | |
| fetch-depth: 0 | |
| persist-credentials: false | |
| - uses: dorny/paths-filter@v3 | |
| id: changes | |
| with: | |
| filters: | | |
| python: &python | |
| - 'pytensor/**/*.py' | |
| - 'tests/**/*.py' | |
| - 'pytensor/**/*.pyx' | |
| - 'tests/**/*.pyx' | |
| - '*.py' | |
| src: | |
| - *python | |
| - 'pytensor/**/*.c' | |
| - 'tests/**/*.c' | |
| - 'pytensor/**/*.h' | |
| - 'tests/**/*.h' | |
| - '.github/workflows/*.yml' | |
| - 'setup.cfg' | |
| - 'requirements.txt' | |
| - '.pre-commit-config.yaml' | |
| style: | |
| name: Check code style | |
| needs: changes | |
| runs-on: ubuntu-latest | |
| if: ${{ needs.changes.outputs.changes == 'true' }} | |
| strategy: | |
| matrix: | |
| python-version: ["3.10", "3.13"] | |
| steps: | |
| - uses: actions/checkout@v4 | |
| with: | |
| persist-credentials: false | |
| - uses: actions/setup-python@v5 | |
| with: | |
| python-version: ${{ matrix.python-version }} | |
| - uses: pre-commit/[email protected] | |
| test: | |
| name: "${{ matrix.os }} test py${{ matrix.python-version }} numpy${{ matrix.numpy-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}" | |
| needs: | |
| - changes | |
| - style | |
| runs-on: ${{ matrix.os }} | |
| if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }} | |
| strategy: | |
| fail-fast: false | |
| matrix: | |
| os: ["ubuntu-latest"] | |
| python-version: ["3.10", "3.13"] | |
| numpy-version: ["~=1.26.0", ">=2.0"] | |
| fast-compile: [0, 1] | |
| float32: [0, 1] | |
| install-numba: [0] | |
| install-jax: [0] | |
| install-torch: [0] | |
| part: | |
| - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" | |
| - "tests/scan" | |
| - "tests/sparse" | |
| - "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py" | |
| - "tests/tensor/conv" | |
| - "tests/tensor/rewriting" | |
| - "tests/tensor/test_math.py" | |
| - "tests/tensor/test_basic.py tests/tensor/test_inplace.py" | |
| - "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" | |
| exclude: | |
| - python-version: "3.10" | |
| fast-compile: 1 | |
| - python-version: "3.10" | |
| float32: 1 | |
| - fast-compile: 1 | |
| float32: 1 | |
| - numpy-version: "~=1.26.0" | |
| fast-compile: 1 | |
| - numpy-version: "~=1.26.0" | |
| float32: 1 | |
| - numpy-version: "~=1.26.0" | |
| python-version: "3.13" | |
| include: | |
| - os: "ubuntu-latest" | |
| part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py" | |
| python-version: "3.12" | |
| numpy-version: ">=2.0" | |
| fast-compile: 0 | |
| float32: 0 | |
| install-numba: 0 | |
| install-jax: 0 | |
| install-torch: 0 | |
| - install-numba: 1 | |
| os: "ubuntu-latest" | |
| python-version: "3.10" | |
| numpy-version: "~=2.1.0" | |
| fast-compile: 0 | |
| float32: 0 | |
| part: "tests/link/numba" | |
| - install-numba: 1 | |
| os: "ubuntu-latest" | |
| python-version: "3.13" | |
| numpy-version: "~=2.1.0" | |
| fast-compile: 0 | |
| float32: 0 | |
| part: "tests/link/numba" | |
| - install-jax: 1 | |
| os: "ubuntu-latest" | |
| python-version: "3.10" | |
| numpy-version: ">=2.0" | |
| fast-compile: 0 | |
| float32: 0 | |
| part: "tests/link/jax" | |
| - install-jax: 1 | |
| os: "ubuntu-latest" | |
| python-version: "3.13" | |
| numpy-version: ">=2.0" | |
| fast-compile: 0 | |
| float32: 0 | |
| part: "tests/link/jax" | |
| - install-torch: 1 | |
| os: "ubuntu-latest" | |
| python-version: "3.10" | |
| numpy-version: ">=2.0" | |
| fast-compile: 0 | |
| float32: 0 | |
| part: "tests/link/pytorch" | |
| - os: macos-15 | |
| python-version: "3.13" | |
| numpy-version: ">=2.0" | |
| fast-compile: 0 | |
| float32: 0 | |
| install-numba: 0 | |
| install-jax: 0 | |
| install-torch: 0 | |
| part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" | |
| steps: | |
| - uses: actions/checkout@v4 | |
| with: | |
| fetch-depth: 0 | |
| persist-credentials: false | |
| - name: Set up Python ${{ matrix.python-version }} | |
| uses: mamba-org/setup-micromamba@v2 | |
| with: | |
| environment-name: pytensor-test | |
| micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved | |
| init-shell: bash | |
| post-cleanup: "all" | |
| create-args: python=${{ matrix.python-version }} | |
| - name: Create matrix id | |
| id: matrix-id | |
| env: | |
| MATRIX_CONTEXT: ${{ toJson(matrix) }} | |
| run: | | |
| echo $MATRIX_CONTEXT | |
| export MATRIX_ID=`echo $MATRIX_CONTEXT | sha256sum | cut -c 1-32` | |
| echo $MATRIX_ID | |
| echo "id=$MATRIX_ID" >> $GITHUB_OUTPUT | |
| - name: Install dependencies | |
| shell: micromamba-shell {0} | |
| run: | | |
| if [[ $OS == "macos-15" ]]; then | |
| micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate; | |
| else | |
| micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; | |
| fi | |
| if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi | |
| if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi | |
| if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi | |
| pip install pytest-sphinx | |
| pip install -e ./ | |
| micromamba list && pip freeze | |
| python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' | |
| if [[ $OS == "macos-15" ]]; then | |
| python -c 'import pytensor; assert pytensor.config.blas__ldflags.startswith("-framework Accelerate"), "Blas flags are not set to MacOS Accelerate"'; | |
| else | |
| python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'; | |
| fi | |
| env: | |
| PYTHON_VERSION: ${{ matrix.python-version }} | |
| NUMPY_VERSION: ${{ matrix.numpy-version }} | |
| INSTALL_NUMBA: ${{ matrix.install-numba }} | |
| INSTALL_JAX: ${{ matrix.install-jax }} | |
| INSTALL_TORCH: ${{ matrix.install-torch}} | |
| OS: ${{ matrix.os}} | |
| - name: Run tests | |
| shell: micromamba-shell {0} | |
| run: | | |
| if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi | |
| if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi | |
| export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe | |
| python -m pytest -r A --verbose --runslow --durations=50 --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip | |
| env: | |
| MATRIX_ID: ${{ steps.matrix-id.outputs.id }} | |
| MKL_THREADING_LAYER: GNU | |
| MKL_NUM_THREADS: 1 | |
| OMP_NUM_THREADS: 1 | |
| PART: ${{ matrix.part }} | |
| FAST_COMPILE: ${{ matrix.fast-compile }} | |
| FLOAT32: ${{ matrix.float32 }} | |
| - name: Upload coverage file | |
| uses: actions/upload-artifact@v4 | |
| with: | |
| name: coverage-${{ steps.matrix-id.outputs.id }} | |
| path: coverage/coverage-${{ steps.matrix-id.outputs.id }}.xml | |
| benchmarks: | |
| name: "Benchmarks" | |
| needs: | |
| - changes | |
| - style | |
| runs-on: ubuntu-latest | |
| if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }} | |
| strategy: | |
| fail-fast: false | |
| steps: | |
| - uses: actions/checkout@v4 | |
| with: | |
| fetch-depth: 0 | |
| persist-credentials: false | |
| - name: Set up Python 3.10 | |
| uses: mamba-org/setup-micromamba@v2 | |
| with: | |
| environment-name: pytensor-test | |
| micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved | |
| init-shell: bash | |
| post-cleanup: "all" | |
| - name: Install dependencies | |
| shell: micromamba-shell {0} | |
| run: | | |
| micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark | |
| pip install -e ./ | |
| micromamba list && pip freeze | |
| python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' | |
| python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"' | |
| env: | |
| PYTHON_VERSION: 3.10 | |
| - name: Download previous benchmark data | |
| uses: actions/cache@v4 | |
| with: | |
| path: ./cache | |
| key: ${{ runner.os }}-benchmark | |
| - name: Run benchmarks | |
| shell: micromamba-shell {0} | |
| run: | | |
| export PYTENSOR_FLAGS=mode=FAST_COMPILE,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe | |
| python -m pytest --runslow --benchmark-only --benchmark-json output.json | |
| - name: Store benchmark result | |
| uses: benchmark-action/github-action-benchmark@v1 | |
| with: | |
| name: Python Benchmark with pytest-benchmark | |
| tool: "pytest" | |
| output-file-path: output.json | |
| external-data-json-path: ./cache/benchmark-data.json | |
| alert-threshold: "200%" | |
| github-token: ${{ secrets.GITHUB_TOKEN }} | |
| comment-on-alert: false | |
| fail-on-alert: false | |
| auto-push: false | |
| all-checks: | |
| if: ${{ always() }} | |
| runs-on: ubuntu-latest | |
| name: "All tests" | |
| needs: [changes, style, test] | |
| steps: | |
| - name: Check build matrix status | |
| if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test.result != 'success') }} | |
| run: exit 1 | |
| upload-coverage: | |
| runs-on: ubuntu-latest | |
| name: "Upload coverage" | |
| needs: [changes, all-checks] | |
| if: ${{ needs.changes.outputs.changes == 'true' && needs.all-checks.result == 'success' }} | |
| steps: | |
| - uses: actions/checkout@v4 | |
| with: | |
| persist-credentials: false | |
| - name: Set up Python | |
| uses: actions/setup-python@v5 | |
| with: | |
| python-version: "3.13" | |
| - name: Install dependencies | |
| run: | | |
| python -m pip install -U coverage>=5.1 coveralls | |
| - name: Download coverage file | |
| uses: actions/download-artifact@v4 | |
| with: | |
| pattern: coverage-* | |
| path: coverage | |
| merge-multiple: true | |
| - name: Upload coverage to Codecov | |
| uses: codecov/codecov-action@v5 | |
| with: | |
| directory: ./coverage/ | |
| fail_ci_if_error: true | |
| token: ${{ secrets.CODECOV_TOKEN }} | |