diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 08da614ce0..ff85db3abe 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -13,7 +13,9 @@ jobs: shell: bash -leo pipefail {0} steps: - uses: actions/checkout@v4 - - uses: mamba-org/setup-micromamba@v1 + with: + persist-credentials: false + - uses: mamba-org/setup-micromamba@v2 with: micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved environment-file: environment.yml diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index cfb16750e9..3462dd00ff 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -21,15 +21,27 @@ jobs: make_sdist: name: Make SDist runs-on: ubuntu-latest + permissions: + # write id-token and attestations are required to attest build provenance + id-token: write + attestations: write steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: true + persist-credentials: false - name: Build SDist run: pipx run build --sdist + - name: Attest GitHub build provenance + uses: actions/attest-build-provenance@v2 + # Don't attest from forks + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository + with: + subject-path: dist/*.tar.gz + - uses: actions/upload-artifact@v4 with: name: sdist @@ -45,24 +57,37 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - uses: hynek/build-and-inspect-python-package@v2 build_wheels: name: Build wheels for ${{ matrix.platform }} runs-on: ${{ matrix.platform }} + permissions: + # write id-token and attestations are required to attest build provenance + id-token: write + attestations: write strategy: matrix: platform: - - macos-12 - - windows-2022 - - ubuntu-20.04 + - macos-latest + - windows-latest + - ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - name: Build wheels - uses: pypa/cibuildwheel@v2.21.2 + uses: pypa/cibuildwheel@v2.22.0 + + - name: Attest GitHub build provenance + uses: actions/attest-build-provenance@v2 + # Don't attest from forks + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository + with: + subject-path: ./wheelhouse/*.whl - uses: actions/upload-artifact@v4 with: @@ -72,13 +97,18 @@ jobs: build_universal_wheel: name: Build universal wheel for Pyodide runs-on: ubuntu-latest + permissions: + # write id-token and attestations are required to attest build provenance + id-token: write + attestations: write steps: - uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -89,6 +119,13 @@ jobs: run: | PYODIDE=1 python setup.py bdist_wheel --universal + - name: Attest GitHub build provenance + uses: actions/attest-build-provenance@v2 + # Don't attest from forks + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository + with: + subject-path: dist/*.whl + - uses: actions/upload-artifact@v4 with: name: universal_wheel @@ -125,9 +162,16 @@ jobs: upload_pypi: name: Upload to PyPI on release + # Use the `release` GitHub environment to protect the Trusted Publishing (OIDC) + # workflow by requiring signoff from a maintainer. + environment: release + permissions: + # write id-token is required for trusted publishing (OIDC) + id-token: write needs: [check_dist] runs-on: ubuntu-latest - if: github.event_name == 'release' && github.event.action == 'published' + # Don't publish from forks + if: github.repository_owner == 'pymc-devs' && github.event_name == 'release' && github.event.action == 'published' steps: - uses: actions/download-artifact@v4 with: @@ -145,7 +189,5 @@ jobs: name: universal_wheel path: dist - - uses: pypa/gh-action-pypi-publish@v1.10.3 - with: - user: __token__ - password: ${{ secrets.pypi_password }} + - uses: pypa/gh-action-pypi-publish@v1.12.2 + # Implicitly attests that the packages were uploaded in the context of this workflow. diff --git a/.github/workflows/rtd-link-preview.yml b/.github/workflows/rtd-link-preview.yml index 23a967e123..0eb2acd377 100644 --- a/.github/workflows/rtd-link-preview.yml +++ b/.github/workflows/rtd-link-preview.yml @@ -1,15 +1,15 @@ name: Read the Docs Pull Request Preview on: - pull_request_target: + # See + pull_request_target: # zizmor: ignore[dangerous-triggers] types: - opened -permissions: - pull-requests: write - jobs: documentation-links: runs-on: ubuntu-latest + permissions: + pull-requests: write steps: - uses: readthedocs/actions/preview@v1 with: diff --git a/.github/workflows/slow-tests-issue.yml b/.github/workflows/slow-tests-issue.yml new file mode 100644 index 0000000000..643853f617 --- /dev/null +++ b/.github/workflows/slow-tests-issue.yml @@ -0,0 +1,31 @@ +# Taken from https://github.com/pymc-labs/pymc-marketing/tree/main/.github/workflows/slow-tests-issue.yml +# See the scripts in the `scripts/slowest_tests` directory for more information +--- +name: Slow Tests Issue Body + +on: + workflow_dispatch: + schedule: + - cron: '0 */6 * * *' + +permissions: + issues: write + +jobs: + update-comment: + runs-on: ubuntu-latest + steps: + - name: Install ZSH + run: sudo apt-get update && sudo apt-get install -y zsh + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Trigger the script + working-directory: scripts/slowest_tests + shell: zsh {0} + run: source update-slowest-times-issue.sh + env: + GITHUB_TOKEN: ${{ github.token }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a8456c8292..53f1e16606 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,6 +25,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - uses: dorny/paths-filter@v3 id: changes with: @@ -56,6 +57,8 @@ jobs: python-version: ["3.10", "3.12"] steps: - uses: actions/checkout@v4 + with: + persist-credentials: false - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -133,7 +136,7 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/pytorch" - - os: macos-latest + - os: macos-15 python-version: "3.12" fast-compile: 0 float32: 0 @@ -146,8 +149,9 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} - uses: mamba-org/setup-micromamba@v1 + uses: mamba-org/setup-micromamba@v2 with: environment-name: pytensor-test micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved @@ -169,7 +173,7 @@ jobs: shell: micromamba-shell {0} run: | - if [[ $OS == "macos-latest" ]]; then + if [[ $OS == "macos-15" ]]; then micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" numpy scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate; else micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; @@ -182,7 +186,7 @@ jobs: pip install -e ./ micromamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' - if [[ $OS == "macos-latest" ]]; then + if [[ $OS == "macos-15" ]]; then python -c 'import pytensor; assert pytensor.config.blas__ldflags.startswith("-framework Accelerate"), "Blas flags are not set to MacOS Accelerate"'; else python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'; @@ -229,8 +233,9 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - name: Set up Python 3.10 - uses: mamba-org/setup-micromamba@v1 + uses: mamba-org/setup-micromamba@v2 with: environment-name: pytensor-test micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved @@ -286,6 +291,8 @@ jobs: if: ${{ needs.changes.outputs.changes == 'true' && needs.all-checks.result == 'success' }} steps: - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python uses: actions/setup-python@v5 @@ -304,7 +311,7 @@ jobs: merge-multiple: true - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: directory: ./coverage/ fail_ci_if_error: true diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml new file mode 100644 index 0000000000..b747897eb8 --- /dev/null +++ b/.github/workflows/zizmor.yml @@ -0,0 +1,36 @@ +# https://github.com/woodruffw/zizmor +name: zizmor GHA analysis + +on: + push: + branches: ["main"] + pull_request: + branches: ["**"] + +jobs: + zizmor: + name: zizmor latest via PyPI + runs-on: ubuntu-latest + permissions: + security-events: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + persist-credentials: false + + - uses: hynek/setup-cached-uv@v2 + + - name: Run zizmor ๐ŸŒˆ + run: uvx zizmor --format sarif . > results.sarif + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload SARIF file + uses: github/codeql-action/upload-sarif@v3 + with: + # Path to SARIF file relative to the root of the repository + sarif_file: results.sarif + # Optional category for the results + # Used to differentiate multiple results for one commit + category: zizmor diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 29626ea4c3..73139a4d58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: sphinx-lint args: ["."] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.1 + rev: v0.7.3 hooks: - id: ruff args: ["--fix", "--output-format=full"] diff --git a/doc/.templates/nb-badges.html b/doc/.templates/nb-badges.html new file mode 100644 index 0000000000..a955510bb0 --- /dev/null +++ b/doc/.templates/nb-badges.html @@ -0,0 +1,24 @@ +{% if pagename in ablog %} + + +{% set gh_basepath = github_user + '/' + github_repo + '/blob/' + github_version + '/' %} +{% set encoded_base = github_user + '%252F' + github_repo %} +{% set gh_binder = github_user + '/' + github_repo + '/' + github_version %} +{% set doc_path_aux = doc_path | trim('/') %} +{% set file_path = doc_path_aux + '/' + pagename + ".ipynb" %} +{% set encoded_path = file_path | replace("/", "%252F") %} + + +
+

+ + View On GitHub + + + Open In Binder + + + Open In Colab +

+
+{% endif %} \ No newline at end of file diff --git a/doc/.templates/rendered_citation.html b/doc/.templates/rendered_citation.html new file mode 100644 index 0000000000..ccb53efa6f --- /dev/null +++ b/doc/.templates/rendered_citation.html @@ -0,0 +1,13 @@ + +{% if pagename in ablog %} + {% set post = ablog[pagename] %} + {% for coll in post.author %} + {% if coll|length %} + {{ coll }} + {% if loop.index < post.author | length %},{% endif %} + {% else %} + {{ coll }} + {% if loop.index < post.author | length %},{% endif %} + {% endif %} + {% endfor %}. "{{ title.split(' โ€” ')[0] }}". In: Pytensor Examples. Ed. by Pytensor Team. +{% endif %} \ No newline at end of file diff --git a/doc/blog.md b/doc/blog.md new file mode 100644 index 0000000000..88ebe9dc5b --- /dev/null +++ b/doc/blog.md @@ -0,0 +1,7 @@ +--- +orphan: true +--- + +# Recent updates + + diff --git a/doc/conf.py b/doc/conf.py index 5b2d0c71a4..1729efc4b1 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,31 +1,13 @@ -# pytensor documentation build configuration file, created by -# sphinx-quickstart on Tue Oct 7 16:34:06 2008. -# -# This file is execfile()d with the current directory set to its containing -# directory. -# -# The contents of this file are pickled, so don't put values in the namespace -# that aren't pickleable (module imports are okay, they're removed -# automatically). -# -# All configuration values have a default value; values that are commented out -# serve to show the default value. - -# If your extensions are in another directory, add it here. If the directory -# is relative to the documentation root, use Path.absolute to make it -# absolute, like shown here. -# sys.path.append(str(Path("some/directory").absolute())) - import os import inspect import sys import pytensor +from pathlib import Path + +sys.path.insert(0, str(Path("..").resolve() / "scripts")) # General configuration # --------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.todo", @@ -34,9 +16,22 @@ "sphinx.ext.linkcode", "sphinx.ext.mathjax", "sphinx_design", - "sphinx.ext.intersphinx" + "sphinx.ext.intersphinx", + "sphinx.ext.autosummary", + "sphinx.ext.autosectionlabel", + "ablog", + "myst_nb", + "generate_gallery", + "sphinx_sitemap", ] +# Don't auto-generate summary for class members. +numpydoc_show_class_members = False +autosummary_generate = True +autodoc_typehints = "none" +remove_from_toctrees = ["**/classmethods/*"] + + intersphinx_mapping = { "jax": ("https://jax.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), @@ -92,6 +87,7 @@ # List of directories, relative to source directories, that shouldn't be # searched for source files. exclude_dirs = ["images", "scripts", "sandbox"] +exclude_patterns = ['page_footer.md', '**/*.myst.md'] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -115,19 +111,15 @@ # Options for HTML output # ----------------------- -# The style sheet to use for HTML and HTML Help pages. A file of that name -# must exist either in Sphinx' static/ path, or in one of the custom paths -# given in html_static_path. -# html_style = 'default.css' -# html_theme = 'sphinxdoc' +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "pymc_sphinx_theme" +html_logo = "images/PyTensor_RGB.svg" + +html_baseurl = "https://pytensor.readthedocs.io" +sitemap_url_scheme = f"{{lang}}{rtd_version}/{{link}}" -# html4_writer added to Fix colon & whitespace misalignment -# https://github.com/readthedocs/sphinx_rtd_theme/issues/766#issuecomment-513852197 -# https://github.com/readthedocs/sphinx_rtd_theme/issues/766#issuecomment-629666319 -# html4_writer = False -html_logo = "images/PyTensor_RGB.svg" -html_theme = "pymc_sphinx_theme" html_theme_options = { "use_search_override": False, "icon_links": [ @@ -156,15 +148,27 @@ "type": "fontawesome", }, ], + "secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink", "donate"], + "navbar_start": ["navbar-logo"], + "article_header_end": ["nb-badges"], + "article_footer_items": ["rendered_citation.html"], } html_context = { + "github_url": "https://github.com", "github_user": "pymc-devs", "github_repo": "pytensor", - "github_version": "main", + "github_version": version if "." in rtd_version else "main", + "sandbox_repo": f"pymc-devs/pymc-sandbox/{version}", "doc_path": "doc", "default_mode": "light", } +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ["../_static"] +html_extra_path = ["_thumbnails", 'images', "robots.txt"] +templates_path = [".templates"] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". @@ -295,3 +299,62 @@ def find_source(): # If false, no module index is generated. # latex_use_modindex = True + + +# -- MyST config ------------------------------------------------- +myst_enable_extensions = [ + "colon_fence", + "deflist", + "dollarmath", + "amsmath", + "substitution", +] +myst_dmath_double_inline = True + +citation_code = f""" +```bibtex +@incollection{{citekey, + author = "", + title = "", + editor = "Pytensor Team", + booktitle = "Pytensor Examples", +}} +``` +""" + +myst_substitutions = { + "pip_dependencies": "{{ extra_dependencies }}", + "conda_dependencies": "{{ extra_dependencies }}", + "extra_install_notes": "", + "citation_code": citation_code, +} + +nb_execution_mode = "off" +nbsphinx_execute = "never" +nbsphinx_allow_errors = True + +rediraffe_redirects = { + "index.md": "gallery.md", +} + +# -- Bibtex config ------------------------------------------------- +bibtex_bibfiles = ["references.bib"] +bibtex_default_style = "unsrt" +bibtex_reference_style = "author_year" + + +# -- ablog config ------------------------------------------------- +blog_baseurl = "https://pytensor.readthedocs.io/en/latest/index.html" +blog_title = "Pytensor Examples" +blog_path = "blog" +blog_authors = { + "contributors": ("Pytensor Contributors", "https://pytensor.readthedocs.io"), +} +blog_default_author = "contributors" +post_show_prev_next = False +fontawesome_included = True +# post_redirect_refresh = 1 +# post_auto_image = 1 +# post_auto_excerpt = 2 + +# notfound_urls_prefix = "" diff --git a/doc/core_development_guide.rst b/doc/core_development_guide.rst index 082fbaa514..82c15ddc8f 100644 --- a/doc/core_development_guide.rst +++ b/doc/core_development_guide.rst @@ -26,12 +26,4 @@ some of them might be outdated though: * :ref:`unittest` -- Tutorial on how to use unittest in testing PyTensor. -* :ref:`sandbox_debugging_step_mode` -- How to step through the execution of - an PyTensor function and print the inputs and outputs of each op. - -* :ref:`sandbox_elemwise` -- Description of element wise operations. - -* :ref:`sandbox_randnb` -- Description of how PyTensor deals with random - numbers. - * :ref:`sparse` -- Description of the ``sparse`` type in PyTensor. diff --git a/doc/environment.yml b/doc/environment.yml index ae17b6379d..d58af79cc6 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -14,6 +14,14 @@ dependencies: - pillow - pymc-sphinx-theme - sphinx-design + - pygments + - pydot + - ipython + - myst-nb + - matplotlib + - watermark + - ablog - pip - pip: + - sphinx_sitemap - -e .. diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 23faea9465..1fb25f83b6 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -1,5 +1,5 @@ Adding JAX, Numba and Pytorch support for `Op`\s -======================================= +================================================ PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function. @@ -7,7 +7,7 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Py This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`. Step 1: Identify the PyTensor :class:`Op` you'd like to implement ------------------------------------------------------------------------- +----------------------------------------------------------------- Find the source for the PyTensor :class:`Op` you'd like to be supported and identify the function signature and return values. These can be determined by @@ -98,7 +98,7 @@ how the inputs and outputs are used to compute the outputs for an :class:`Op` in Python. This method is effectively what needs to be implemented. Step 2: Find the relevant method in JAX/Numba/Pytorch (or something close) ---------------------------------------------------------- +-------------------------------------------------------------------------- With a precise idea of what the PyTensor :class:`Op` does we need to figure out how to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named @@ -269,7 +269,7 @@ and :func:`torch.cumprod` z[0] = np.cumprod(x, axis=self.axis) Step 3: Register the function with the respective dispatcher ---------------------------------------------------------------- +------------------------------------------------------------ With the PyTensor `Op` replicated, we'll need to register the function with the backends `Linker`. This is done through the use of @@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`: if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit() def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`: else: if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit() def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: diff --git a/doc/gallery/page_footer.md b/doc/gallery/page_footer.md new file mode 100644 index 0000000000..6f9c88f801 --- /dev/null +++ b/doc/gallery/page_footer.md @@ -0,0 +1,27 @@ +## License notice +All the notebooks in this example gallery are provided under a +[3-Clause BSD License](https://github.com/pymc-devs/pytensor/blob/main/doc/LICENSE.txt) +which allows modification, and redistribution for any +use provided the copyright and license notices are preserved. + +## Citing Pytensor Examples + +To cite this notebook, please use the suggested citation below. + +:::{important} +Many notebooks are adapted from other sources: blogs, books... In such cases you should +cite the original source as well. + +Also remember to cite the relevant libraries used by your code. +::: + +Here is an example citation template in bibtex: + +{{ citation_code }} + +which once rendered could look like: + + + \ No newline at end of file diff --git a/doc/gallery/rewrites/graph_rewrites.ipynb b/doc/gallery/rewrites/graph_rewrites.ipynb new file mode 100644 index 0000000000..298e13b95e --- /dev/null +++ b/doc/gallery/rewrites/graph_rewrites.ipynb @@ -0,0 +1,1104 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(Graph_rewrites)=\n", + "\n", + "# PyTensor graph rewrites from scratch\n", + "\n", + ":::{post} Jan 11, 2025 \n", + ":tags: Graph rewrites \n", + ":category: avanced, explanation \n", + ":author: Ricardo Vieira \n", + ":::" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Manipulating nodes directly" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This section walks through the low level details of PyTensor graph manipulation. \n", + "Users are not supposed to work or even be aware of these details, but it may be helpful for developers.\n", + "We start with very **bad practices** and move on towards the **right** way of doing rewrites.\n", + "\n", + "* {doc}`Graph structures `\n", + "is a required precursor to this guide\n", + "* {doc}`Graph rewriting ` provides the user-level summary of what is covered in here. Feel free to revisit once you're done here.\n", + "\n", + "As described in {doc}`Graph structures`, PyTensor graphs are composed of sequences {class}`Apply` nodes, which link {class}`Variable`s\n", + "that form the inputs and outputs of a computational {class}`Op`eration.\n", + "\n", + "The list of inputs of an {class}`Apply` node can be changed inplace to modify the computational path that leads to it.\n", + "Consider the following simple example:" + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:46.104335Z", + "start_time": "2025-01-11T07:37:46.100021Z" + } + }, + "source": [ + "%env PYTENSOR_FLAGS=cxx=\"\"" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: PYTENSOR_FLAGS=cxx=\"\"\n" + ] + } + ], + "execution_count": 1 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:49.384149Z", + "start_time": "2025-01-11T07:37:46.201672Z" + } + }, + "source": [ + "import pytensor\n", + "import pytensor.tensor as pt\n", + "\n", + "x = pt.scalar(\"x\")\n", + "y = pt.log(1 + x)\n", + "out = y * 2\n", + "pytensor.dprint(out, id_type=\"\");" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mul\n", + " โ”œโ”€ Log\n", + " โ”‚ โ””โ”€ Add\n", + " โ”‚ โ”œโ”€ 1\n", + " โ”‚ โ””โ”€ x\n", + " โ””โ”€ 2\n" + ] + } + ], + "execution_count": 2 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A standard rewrite replaces `pt.log(1 + x)` by the more stable form `pt.log1p(x)`.\n", + "We can do this by changing the inputs of the `out` node inplace." + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:49.924153Z", + "start_time": "2025-01-11T07:37:49.920272Z" + } + }, + "source": [ + "out.owner.inputs[0] = pt.log1p(x)\n", + "pytensor.dprint(out, id_type=\"\");" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mul\n", + " โ”œโ”€ Log1p\n", + " โ”‚ โ””โ”€ x\n", + " โ””โ”€ 2\n" + ] + } + ], + "execution_count": 3 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are two problems with this direct approach:\n", + "1. We are modifying variables in place\n", + "2. We have to know which nodes have as input the variable we want to replace\n", + "\n", + "Point 1. is important because some rewrites are \"destructive\" and the user may want to reuse the same graph in multiple functions.\n", + "\n", + "Point 2. is important because it forces us to shift the focus of attention from the operation we want to rewrite to the variables where the operation is used. It also risks unneccessary duplication of variables, if we perform the same replacement independently for each use. This could make graph rewriting consideraby slower!\n", + "\n", + "PyTensor makes use of {class}`FunctionGraph`s to solve these two issues.\n", + "By default, a FunctionGraph will clone all the variables between the inputs and outputs,\n", + "so that the corresponding graph can be rewritten.\n", + "In addition, it will create a {term}`client`s dictionary that maps all the variables to the nodes where they are used.\n", + "\n", + "\n", + "Let's see how we can use a FunctionGraph to achieve the same rewrite:" + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.005393Z", + "start_time": "2025-01-11T07:37:49.997328Z" + } + }, + "source": [ + "from pytensor.graph import FunctionGraph\n", + "\n", + "x = pt.scalar(\"x\")\n", + "y = pt.log(1 + x)\n", + "out1 = y * 2\n", + "out2 = 2 / y\n", + "\n", + "# Create an empty dictionary which FunctionGraph will populate\n", + "# with the mappings from old variables to cloned ones\n", + "memo = {}\n", + "fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)\n", + "fg_x = memo[x]\n", + "fg_y = memo[y]\n", + "print(\"Before:\\n\")\n", + "pytensor.dprint(fg.outputs)\n", + "\n", + "# Create expression of interest with cloned variables\n", + "fg_y_repl = pt.log1p(fg_x)\n", + "\n", + "# Update all uses of old variable to new one\n", + "# Each entry in the clients dictionary, \n", + "# contains a node and the input index where the variable is used\n", + "# Note: Some variables could be used multiple times in a single node\n", + "for client, idx in fg.clients[fg_y]:\n", + " client.inputs[idx] = fg_y_repl\n", + " \n", + "print(\"\\nAfter:\\n\")\n", + "pytensor.dprint(fg.outputs);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before:\n", + "\n", + "Mul [id A]\n", + " โ”œโ”€ Log [id B]\n", + " โ”‚ โ””โ”€ Add [id C]\n", + " โ”‚ โ”œโ”€ 1 [id D]\n", + " โ”‚ โ””โ”€ x [id E]\n", + " โ””โ”€ 2 [id F]\n", + "True_div [id G]\n", + " โ”œโ”€ 2 [id H]\n", + " โ””โ”€ Log [id B]\n", + " โ””โ”€ ยทยทยท\n", + "\n", + "After:\n", + "\n", + "Mul [id A]\n", + " โ”œโ”€ Log1p [id B]\n", + " โ”‚ โ””โ”€ x [id C]\n", + " โ””โ”€ 2 [id D]\n", + "True_div [id E]\n", + " โ”œโ”€ 2 [id F]\n", + " โ””โ”€ Log1p [id B]\n", + " โ””โ”€ ยทยทยท\n" + ] + } + ], + "execution_count": 4 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that both uses of `log(1 + x)` were replaced by the new `log1p(x)`.\n", + "\n", + "It would probably be a good idea to update the clients dictionary\n", + "if we wanted to perform another rewrite.\n", + "\n", + "There are a couple of other variables in the FunctionGraph that we would also want to update,\n", + "but there is no point to doing all this bookeeping manually. \n", + "FunctionGraph offers a {meth}`replace ` method that takes care of all this for the user." + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.078947Z", + "start_time": "2025-01-11T07:37:50.072465Z" + } + }, + "source": [ + "# We didn't modify the variables in place so we can just reuse them!\n", + "memo = {}\n", + "fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)\n", + "fg_x = memo[x]\n", + "fg_y = memo[y]\n", + "print(\"Before:\\n\")\n", + "pytensor.dprint(fg.outputs)\n", + "\n", + "# Create expression of interest with cloned variables\n", + "fg_y_repl = pt.log1p(fg_x)\n", + "fg.replace(fg_y, fg_y_repl)\n", + " \n", + "print(\"\\nAfter:\\n\")\n", + "pytensor.dprint(fg.outputs);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before:\n", + "\n", + "Mul [id A]\n", + " โ”œโ”€ Log [id B]\n", + " โ”‚ โ””โ”€ Add [id C]\n", + " โ”‚ โ”œโ”€ 1 [id D]\n", + " โ”‚ โ””โ”€ x [id E]\n", + " โ””โ”€ 2 [id F]\n", + "True_div [id G]\n", + " โ”œโ”€ 2 [id H]\n", + " โ””โ”€ Log [id B]\n", + " โ””โ”€ ยทยทยท\n", + "\n", + "After:\n", + "\n", + "Mul [id A]\n", + " โ”œโ”€ Log1p [id B]\n", + " โ”‚ โ””โ”€ x [id C]\n", + " โ””โ”€ 2 [id D]\n", + "True_div [id E]\n", + " โ”œโ”€ 2 [id F]\n", + " โ””โ”€ Log1p [id B]\n", + " โ””โ”€ ยทยทยท\n" + ] + } + ], + "execution_count": 5 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There is still one big limitation with this approach.\n", + "We have to know in advance \"where\" the variable we want to replace is present.\n", + "It also doesn't scale to multiple instances of the same pattern.\n", + "\n", + "A more sensible approach would be to iterate over the nodes in the FunctionGraph\n", + "and apply the rewrite wherever `log(1 + x)` may be present.\n", + "\n", + "To keep code organized we will create a function \n", + "that takes as input a node and returns a valid replacement." + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.161507Z", + "start_time": "2025-01-11T07:37:50.156975Z" + } + }, + "source": [ + "from pytensor.graph import Constant\n", + "\n", + "def local_log1p(node):\n", + " # Check that this node is a Log op\n", + " if node.op != pt.log:\n", + " return None\n", + " \n", + " # Check that the input is another node (it could be an input variable)\n", + " add_node = node.inputs[0].owner\n", + " if add_node is None:\n", + " return None\n", + " \n", + " # Check that the input to this node is an Add op\n", + " # with 2 inputs (Add can have more inputs)\n", + " if add_node.op != pt.add or len(add_node.inputs) != 2:\n", + " return None\n", + " \n", + " # Check wether we have add(1, y) or add(x, 1)\n", + " [x, y] = add_node.inputs\n", + " if isinstance(x, Constant) and x.data == 1:\n", + " return [pt.log1p(y)]\n", + " if isinstance(y, Constant) and y.data == 1:\n", + " return [pt.log1p(x)]\n", + "\n", + " return None" + ], + "outputs": [], + "execution_count": 6 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.248106Z", + "start_time": "2025-01-11T07:37:50.242014Z" + } + }, + "source": [ + "# We no longer need the memo, because our rewrite works with the node information\n", + "fg = FunctionGraph([x], [out1, out2], clone=True)\n", + "\n", + "# Toposort gives a list of all nodes in a graph in topological order\n", + "# The strategy of iteration can be important when we are dealing with multiple rewrites\n", + "for node in fg.toposort():\n", + " repl = local_log1p(node)\n", + " if repl is None:\n", + " continue\n", + " # We should get one replacement of each output of the node\n", + " assert len(repl) == len(node.outputs)\n", + " # We could use `fg.replace_all` to avoid this loop\n", + " for old, new in zip(node.outputs, repl):\n", + " fg.replace(old, new)\n", + "\n", + "pytensor.dprint(fg);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mul [id A] 1\n", + " โ”œโ”€ Log1p [id B] 0\n", + " โ”‚ โ””โ”€ x [id C]\n", + " โ””โ”€ 2 [id D]\n", + "True_div [id E] 2\n", + " โ”œโ”€ 2 [id F]\n", + " โ””โ”€ Log1p [id B] 0\n", + " โ””โ”€ ยทยทยท\n" + ] + } + ], + "execution_count": 7 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is starting to look much more scalable!\n", + "\n", + "We are still reinventing may wheels that already exist in PyTensor, but we're getting there.\n", + "Before we move up the ladder of abstraction, let's discuss two gotchas:\n", + "\n", + "1. The replacement variables should have types that are compatible with the original ones.\n", + "2. We have to be careful about introducing circular dependencies\n", + "\n", + "For 1. let's look at a simple graph simplification, where we replace a costly operation that is ultimately multiplied by zero." + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.344446Z", + "start_time": "2025-01-11T07:37:50.328071Z" + } + }, + "source": [ + "x = pt.vector(\"x\", dtype=\"float32\")\n", + "zero = pt.zeros(())\n", + "zero.name = \"zero\"\n", + "y = pt.exp(x) * zero\n", + "\n", + "fg = FunctionGraph([x], [y], clone=False)\n", + "try:\n", + " fg.replace(y, pt.zeros(()))\n", + "except TypeError as exc:\n", + " print(f\"TypeError: {exc}\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TypeError: Cannot convert Type Scalar(float64, shape=()) (of Variable Alloc.0) into Type Vector(float64, shape=(?,)). You can try to manually convert Alloc.0 into a Vector(float64, shape=(?,)).\n" + ] + } + ], + "execution_count": 8 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first achievement of a new PyTensor developer is unlocked by stumbling upon an error like that!\n", + "\n", + "It's important to keep in mind the Tensor part of PyTensor.\n", + "\n", + "The problem here is that we are trying to replace the `y` variable which is a float32 vector by the `zero` variable which is a float64 scalar!" + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.408682Z", + "start_time": "2025-01-11T07:37:50.404355Z" + } + }, + "source": [ + "pytensor.dprint(fg.outputs, id_type=\"\", print_type=True);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mul \n", + " โ”œโ”€ Exp \n", + " โ”‚ โ””โ”€ x \n", + " โ””โ”€ ExpandDims{axis=0} \n", + " โ””โ”€ Alloc 'zero'\n", + " โ””โ”€ 0.0 \n" + ] + } + ], + "execution_count": 9 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.512585Z", + "start_time": "2025-01-11T07:37:50.488176Z" + } + }, + "source": [ + "vector_zero = pt.zeros(x.shape)\n", + "vector_zero.name = \"vector_zero\"\n", + "fg.replace(y, vector_zero)\n", + "pytensor.dprint(fg.outputs, id_type=\"\", print_type=True);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Alloc 'vector_zero'\n", + " โ”œโ”€ 0.0 \n", + " โ””โ”€ Subtensor{i} \n", + " โ”œโ”€ Shape \n", + " โ”‚ โ””โ”€ x \n", + " โ””โ”€ 0 \n" + ] + } + ], + "execution_count": 10 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now to the second (less common) gotcha. Introducing circular dependencies:" + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.572844Z", + "start_time": "2025-01-11T07:37:50.567175Z" + } + }, + "source": [ + "x = pt.scalar(\"x\")\n", + "y = x + 1\n", + "y.name = \"y\"\n", + "z = y + 1\n", + "z.name = \"z\"\n", + "\n", + "fg = FunctionGraph([x], [z], clone=False)\n", + "fg.replace(x, z)\n", + "pytensor.dprint(fg.outputs);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add [id A] 'z'\n", + " โ”œโ”€ Add [id B] 'y'\n", + " โ”‚ โ”œโ”€ Add [id A] 'z'\n", + " โ”‚ โ”‚ โ””โ”€ ยทยทยท\n", + " โ”‚ โ””โ”€ 1 [id C]\n", + " โ””โ”€ 1 [id D]\n" + ] + } + ], + "execution_count": 11 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Oops! There is not much to say about this one, other than don't do it!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using graph rewriters" + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.634996Z", + "start_time": "2025-01-11T07:37:50.631699Z" + } + }, + "source": [ + "from pytensor.graph.rewriting.basic import NodeRewriter\n", + "\n", + "class LocalLog1pNodeRewriter(NodeRewriter):\n", + " \n", + " def tracks(self):\n", + " return [pt.log]\n", + " \n", + " def transform(self, fgraph, node):\n", + " return local_log1p(node) \n", + " \n", + " def __str__(self):\n", + " return \"local_log1p\"\n", + " \n", + " \n", + "local_log1p_node_rewriter = LocalLog1pNodeRewriter()" + ], + "outputs": [], + "execution_count": 12 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A {class}`NodeRewriter` is required to implement only the {meth}`transform ` method.\n", + "As before, this method expects a node and should return a valid replacement for each output or `None`.\n", + "\n", + "We also receive the {class}`FunctionGraph` object, as some node rewriters may want to use global information to decide whether to return a replacement or not.\n", + "\n", + "For example some rewrites that skip intermediate computations may not be useful if those intermediate computations are used by other variables.\n", + "\n", + "The {meth}`tracks ` optional method is very useful for filtering out \"useless\" rewrites. When {class}`NodeRewriter`s only applies to a specific rare {class}`Op` it can be ignored completely when that {class}`Op` is not present in the graph.\n", + "\n", + "On its own, a {class}`NodeRewriter` isn't any better than what we had before. Where it becomes useful is when included inside a {class}`GraphRewriter`, which will apply it to a whole {class}`FunctionGraph `." + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.702188Z", + "start_time": "2025-01-11T07:37:50.696179Z" + } + }, + "source": [ + "from pytensor.graph.rewriting.basic import in2out\n", + "\n", + "x = pt.scalar(\"x\")\n", + "y = pt.log(1 + x)\n", + "out = pt.exp(y)\n", + "\n", + "fg = FunctionGraph([x], [out])\n", + "in2out(local_log1p_node_rewriter, name=\"local_log1p\").rewrite(fg)\n", + "\n", + "pytensor.dprint(fg.outputs);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exp [id A]\n", + " โ””โ”€ Log1p [id B]\n", + " โ””โ”€ x [id C]\n" + ] + } + ], + "execution_count": 13 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we used {func}`in2out` which creates a {class}`GraphRewriter` (specifically a {class}`WalkingGraphRewriter`) which walks from the inputs to the outputs of a FunctionGraph trying to apply whatever nodes are \"registered\" in it.\n", + "\n", + "Wrapping simple functions in {class}`NodeRewriter`s is so common that PyTensor offers a decorator for it.\n", + "\n", + "Let's create a new rewrite that removes useless `abs(exp(x)) -> exp(x)`." + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.761196Z", + "start_time": "2025-01-11T07:37:50.757401Z" + } + }, + "source": [ + "from pytensor.graph.rewriting.basic import node_rewriter\n", + "\n", + "@node_rewriter(tracks=[pt.abs])\n", + "def local_useless_abs_exp(fgraph, node):\n", + " # Because of the tracks we don't need to check \n", + " # that `node` has a `Sign` Op.\n", + " # We still need to check whether it's input is an `Abs` Op\n", + " exp_node = node.inputs[0].owner\n", + " if exp_node is None or exp_node.op != pt.exp:\n", + " return None\n", + " return exp_node.outputs" + ], + "outputs": [], + "execution_count": 14 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "Another very useful helper is the {class}`PatternNodeRewriter`, which allows you to specify a rewrite via \"template matching\"." + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.848713Z", + "start_time": "2025-01-11T07:37:50.845435Z" + } + }, + "source": [ + "from pytensor.graph.rewriting.basic import PatternNodeRewriter\n", + "\n", + "local_useless_abs_square = PatternNodeRewriter(\n", + " (pt.abs, (pt.pow, \"x\", 2)),\n", + " (pt.pow, \"x\", 2),\n", + " name=\"local_useless_abs_square\",\n", + ")" + ], + "outputs": [], + "execution_count": 15 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is very useful for simple Elemwise rewrites, but becomes a bit cumbersome with Ops that must be parametrized\n", + "everytime they are used." + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.925407Z", + "start_time": "2025-01-11T07:37:50.897320Z" + } + }, + "source": [ + "x = pt.scalar(\"x\")\n", + "y = pt.exp(x)\n", + "z = pt.abs(y)\n", + "w = pt.log(1.0 + z)\n", + "out = pt.abs(w ** 2)\n", + "\n", + "fg = FunctionGraph([x], [out])\n", + "in2out_rewrite = in2out(\n", + " local_log1p_node_rewriter, \n", + " local_useless_abs_exp, \n", + " local_useless_abs_square,\n", + " name=\"custom_rewrites\"\n", + ")\n", + "in2out_rewrite.rewrite(fg)\n", + "\n", + "pytensor.dprint(fg.outputs);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pow [id A]\n", + " โ”œโ”€ Log1p [id B]\n", + " โ”‚ โ””โ”€ Exp [id C]\n", + " โ”‚ โ””โ”€ x [id D]\n", + " โ””โ”€ 2 [id E]\n" + ] + } + ], + "execution_count": 16 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Besides {class}`WalkingGraphRewriter`s, there are:\n", + " - {class}`SequentialGraphRewriter`s, which apply a set of {class}`GraphRewriters` sequentially \n", + " - {class}`EquilibriumGraphRewriter`s which apply a set of {class}`GraphRewriters` (and {class}`NodeRewriters`) repeatedly until the graph stops changing.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Registering graph rewriters in a database" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, at the top of the rewrite mountain, there are {class}`RewriteDatabase`s! These allow \"querying\" for subsets of rewrites registered in a database.\n", + "\n", + "Most users trigger this when they change the `mode` of a PyTensor function `mode=\"FAST_COMPILE\"` or `mode=\"FAST_RUN\"`, or `mode=\"JAX\"` will lead to a different rewrite database query to be applied to the function before compilation.\n", + "\n", + "The most relevant {class}`RewriteDatabase` is called `optdb` and contains all the standard rewrites in PyTensor. You can manually register your {class}`GraphRewriter` in it. \n", + "\n", + "More often than not, you will want to register your rewrite in a pre-existing sub-database, like {term}`canonicalize`, {term}`stabilize`, or {term}`specialize`." + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:50.979283Z", + "start_time": "2025-01-11T07:37:50.976168Z" + } + }, + "source": [ + "from pytensor.compile.mode import optdb" + ], + "outputs": [], + "execution_count": 17 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:51.032996Z", + "start_time": "2025-01-11T07:37:51.029510Z" + } + }, + "source": [ + "optdb[\"canonicalize\"].register(\n", + " \"local_log1p_node_rewriter\",\n", + " local_log1p_node_rewriter,\n", + " \"fast_compile\",\n", + " \"fast_run\",\n", + " \"custom\",\n", + ")" + ], + "outputs": [], + "execution_count": 18 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:51.156080Z", + "start_time": "2025-01-11T07:37:51.095154Z" + } + }, + "source": [ + "with pytensor.config.change_flags(optimizer_verbose = True):\n", + " fn = pytensor.function([x], out, mode=\"FAST_COMPILE\")\n", + " \n", + "print(\"\")\n", + "pytensor.dprint(fn);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)\n", + "\n", + "Abs [id A] 4\n", + " โ””โ”€ Pow [id B] 3\n", + " โ”œโ”€ Log1p [id C] 2\n", + " โ”‚ โ””โ”€ Abs [id D] 1\n", + " โ”‚ โ””โ”€ Exp [id E] 0\n", + " โ”‚ โ””โ”€ x [id F]\n", + " โ””โ”€ 2 [id G]\n" + ] + } + ], + "execution_count": 19 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "There's also a decorator, {func}`register_canonicalize`, that automatically registers a {class}`NodeRewriter` in one of these standard databases. (It's placed in a weird location)" + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:51.220260Z", + "start_time": "2025-01-11T07:37:51.216259Z" + } + }, + "source": [ + "from pytensor.tensor.rewriting.basic import register_canonicalize\n", + "\n", + "@register_canonicalize(\"custom\")\n", + "@node_rewriter(tracks=[pt.abs])\n", + "def local_useless_abs_exp(fgraph, node):\n", + " # Because of the tracks we don't need to check \n", + " # that `node` has a `Sign` Op.\n", + " # We still need to check whether it's input is an `Abs` Op\n", + " exp_node = node.inputs[0].owner\n", + " if exp_node is None or exp_node.op != pt.exp:\n", + " return None\n", + " return exp_node.outputs" + ], + "outputs": [], + "execution_count": 20 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And you can also use the decorator directly" + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:51.292003Z", + "start_time": "2025-01-11T07:37:51.286043Z" + } + }, + "source": [ + "register_canonicalize(local_useless_abs_square, \"custom\")" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "local_useless_abs_square" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 21 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:51.380138Z", + "start_time": "2025-01-11T07:37:51.362056Z" + } + }, + "source": [ + "with pytensor.config.change_flags(optimizer_verbose = True):\n", + " fn = pytensor.function([x], out, mode=\"FAST_COMPILE\")\n", + " \n", + "print(\"\")\n", + "pytensor.dprint(fn);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rewriting: rewrite local_useless_abs_square replaces Abs.0 of Abs(Pow.0) with Pow.0 of Pow(Log.0, 2)\n", + "rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)\n", + "rewriting: rewrite local_useless_abs_exp replaces Abs.0 of Abs(Exp.0) with Exp.0 of Exp(x)\n", + "\n", + "Pow [id A] 2\n", + " โ”œโ”€ Log1p [id B] 1\n", + " โ”‚ โ””โ”€ Exp [id C] 0\n", + " โ”‚ โ””โ”€ x [id D]\n", + " โ””โ”€ 2 [id E]\n" + ] + } + ], + "execution_count": 22 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And if you wanted to exclude your custom rewrites you can do it like this:" + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:51.487102Z", + "start_time": "2025-01-11T07:37:51.459955Z" + } + }, + "source": [ + "from pytensor.compile.mode import get_mode\n", + "\n", + "with pytensor.config.change_flags(optimizer_verbose = True):\n", + " fn = pytensor.function([x], out, mode=get_mode(\"FAST_COMPILE\").excluding(\"custom\"))\n", + " \n", + "print(\"\")\n", + "pytensor.dprint(fn);" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rewriting: rewrite local_upcast_elemwise_constant_inputs replaces Add.0 of Add(1.0, Abs.0) with Add.0 of Add(Cast{float64}.0, Abs.0)\n", + "rewriting: rewrite constant_folding replaces Cast{float64}.0 of Cast{float64}(1.0) with 1.0 of None\n", + "\n", + "Abs [id A] 5\n", + " โ””โ”€ Pow [id B] 4\n", + " โ”œโ”€ Log [id C] 3\n", + " โ”‚ โ””โ”€ Add [id D] 2\n", + " โ”‚ โ”œโ”€ 1.0 [id E]\n", + " โ”‚ โ””โ”€ Abs [id F] 1\n", + " โ”‚ โ””โ”€ Exp [id G] 0\n", + " โ”‚ โ””โ”€ x [id H]\n", + " โ””โ”€ 2 [id I]\n" + ] + } + ], + "execution_count": 23 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Authors\n", + "\n", + "- Authored by Ricardo Vieira in May 2023" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## References\n", + "\n", + ":::{bibliography} :filter: docname in docnames" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Watermark " + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:37:51.621272Z", + "start_time": "2025-01-11T07:37:51.580753Z" + } + }, + "cell_type": "code", + "source": [ + "%load_ext watermark\n", + "%watermark -n -u -v -iv -w -p pytensor" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last updated: Sat Jan 11 2025\n", + "\n", + "Python implementation: CPython\n", + "Python version : 3.12.0\n", + "IPython version : 8.31.0\n", + "\n", + "pytensor: 2.26.4+16.g8be5c5323.dirty\n", + "\n", + "sys : 3.12.0 | packaged by conda-forge | (main, Oct 3 2023, 08:43:22) [GCC 12.3.0]\n", + "pytensor: 2.26.4+16.g8be5c5323.dirty\n", + "\n", + "Watermark: 2.5.0\n", + "\n" + ] + } + ], + "execution_count": 24 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + ":::{include} ../page_footer.md \n", + ":::" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "" + } + ], + "metadata": { + "hide_input": false, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/gallery/scan/scan_tutorial.ipynb b/doc/gallery/scan/scan_tutorial.ipynb new file mode 100644 index 0000000000..3428698450 --- /dev/null +++ b/doc/gallery/scan/scan_tutorial.ipynb @@ -0,0 +1,852 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(Scan_tutorial)=\n", + "# Introduction to Scan\n", + ":::{post} Jan 11, 2025 \n", + ":tags: scan, worked examples, tutorial\n", + ":category: beginner, explanation \n", + ":author: Pascal Lamblin, Jesse Grabowski\n", + ":::\n", + "\n", + "A Pytensor function graph is composed of two types of nodes: Variable nodes which represent data, and Apply node which apply Ops (which represent some computation) to Variables to produce new Variables.\n", + "\n", + "From this point of view, a node that applies a Scan Op is just like any other. Internally, however, it is very different from most Ops.\n", + "\n", + "Inside a Scan op is yet another Pytensor graph which represents the computation to be performed at every iteration of the loop. During compilation, that graph is compiled into a function. During execution, the Scan Op will call that function repeatedly on its inputs to produce its outputs.\n", + "\n", + "## Examples\n", + "\n", + "Scan's interface is complex and, thus, best introduced by examples. \n" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Example 1: As Simple as it Gets\n", + "So, let's dive right in and start with a simple example; perform an element-wise multiplication between two vectors. \n", + "\n", + "This particular example is simple enough that Scan is not the best way to do things but we'll gradually work our way to more complex examples where Scan gets more interesting.\n", + "\n", + "Let's first setup our use case by defining Pytensor variables for the inputs :" + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-01-10T17:39:58.951346Z", + "start_time": "2025-01-10T17:39:53.088554Z" + } + }, + "source": [ + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import numpy as np\n", + "\n", + "vector1 = pt.dvector('vector1')\n", + "vector2 = pt.dvector('vector2')" + ], + "outputs": [], + "execution_count": 1 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we call the `scan` function. It has many parameters but, because our use case is simple, we only need two of them. We'll introduce other parameters in the next examples.\n", + "\n", + "The parameter `sequences` allows us to specify variables that Scan should iterate over as it loops. The first iteration will take as input the first element of every sequence, the second iteration will take as input the second element of every sequence, etc. These individual element have will have one less dimension than the original sequences. For example, for a matrix sequence, the individual elements will be vectors.\n", + "\n", + "The parameter `fn` receives a function or lambda expression that expresses the computation to do at every iteration. It operates on the symbolic inputs to produce symbolic outputs. It will **only ever be called once**, to assemble the Pytensor graph used by Scan at every the iterations.\n", + "\n", + "Since we wish to iterate over both `vector1` and `vector2` simultaneously, we provide them as sequences. This means that every iteration will operate on two inputs: an element from `vector1` and the corresponding element from `vector2`. \n", + "\n", + "Because what we want is the elementwise product between the vectors, we provide a lambda expression that takes an element `a` from `vector1` and an element `b` from `vector2` then computes and return the product." + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-01-10T17:39:59.004407Z", + "start_time": "2025-01-10T17:39:58.955818Z" + } + }, + "source": [ + "output, updates = pytensor.scan(fn=lambda a, b : a * b,\n", + " sequences=[vector1, vector2])" + ], + "outputs": [], + "execution_count": 2 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calling `scan`, we see that it returns two outputs.\n", + "\n", + "The first output contains the outputs of `fn` from every timestep concatenated into a tensor. In our case, the output of a single timestep is a scalar so output is a vector where `output[i]` is the output of the i-th iteration.\n", + "\n", + "The second output details if and how the execution of the `Scan` updates any shared variable in the graph. It should be provided as an argument when compiling the Pytensor function." + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "scrolled": true, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.081533Z", + "start_time": "2025-01-10T17:39:59.741663Z" + } + }, + "source": [ + "f = pytensor.function(inputs=[vector1, vector2],\n", + " outputs=output,\n", + " updates=updates)" + ], + "outputs": [], + "execution_count": 3 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If `updates` is omitted, the state of any shared variables modified by `Scan` will not be updated properly. Random number sampling, for instance, relies on shared variables. If `updates` is not provided, the state of the random number generator won't be updated properly and the same numbers might be sampled repeatedly. **Always** provide `updates` when compiling your Pytensor function, unless you are sure that you don't need it!\n", + "\n", + "Now that we've defined how to do elementwise multiplication with Scan, we can see that the result is as expected :" + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.128785Z", + "start_time": "2025-01-10T17:40:00.125260Z" + } + }, + "source": [ + "floatX = pytensor.config.floatX\n", + "\n", + "vector1_value = np.arange(0, 5).astype(floatX) # [0,1,2,3,4]\n", + "vector2_value = np.arange(1, 6).astype(floatX) # [1,2,3,4,5]\n", + "print(f(vector1_value, vector2_value))" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0. 2. 6. 12. 20.]\n" + ] + } + ], + "execution_count": 4 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "An interesting thing is that we never explicitly told Scan how many iteration it needed to run. It was automatically inferred; when given sequences, Scan will run as many iterations as the length of the shortest sequence. Here we just truncate one of the sequences to 4 elements, and we get only 4 outputs." + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.199150Z", + "start_time": "2025-01-10T17:40:00.195450Z" + } + }, + "source": [ + "print(f(vector1_value, vector2_value[:4]))" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0. 2. 6. 12.]\n" + ] + } + ], + "execution_count": 5 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 2: Non-sequences\n", + "\n", + "In this example, we introduce another of Scan's features; non-sequences. To demonstrate how to use them, we use Scan to compute the activations of a linear MLP layer over a minibatch.\n", + "\n", + "It is not yet a use case where Scan is truly useful but it introduces a requirement that sequences cannot fulfill; if we want to use Scan to iterate over the minibatch elements and compute the activations for each of them, then we need some variables (the parameters of the layer), to be available 'as is' at every iteration of the loop. We do *not* want Scan to iterate over them and give only part of them at every iteration.\n", + "\n", + "Once again, we begin by setting up our Pytensor variables :" + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.263086Z", + "start_time": "2025-01-10T17:40:00.259308Z" + } + }, + "source": [ + "X = pt.dmatrix('X') # Minibatch of data\n", + "W = pt.dmatrix('W') # Weights of the layer\n", + "b = pt.dvector('b') # Biases of the layer" + ], + "outputs": [], + "execution_count": 6 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the sake of variety, in this example we define the computation to be done at every iteration of the loop using a Python function, `step()`, instead of a lambda expression.\n", + "\n", + "To have the full weight matrix W and the full bias vector b available at every iteration, we use the argument `non_sequences`. Contrary to `sequences`, `non_sequences` are not iterated upon by Scan. Every non-sequence is passed as input to every iteration.\n", + "\n", + "This means that our `step()` function will need to operate on three symbolic inputs; one for our sequence X and one for each of our non-sequences W and b. \n", + "\n", + "The inputs that correspond to the non-sequences are **always** last and in the same order at the non-sequences are provided to Scan. This means that the correspondence between the inputs of the `step()` function and the arguments to `scan()` is the following : \n", + "\n", + "* `v` : individual element of the sequence `X` \n", + "* `W` and `b` : non-sequences `W` and `b`, respectively" + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.366395Z", + "start_time": "2025-01-10T17:40:00.316085Z" + } + }, + "source": [ + "def step(v, W, b):\n", + " return v @ W + b\n", + "\n", + "output, updates = pytensor.scan(fn=step,\n", + " sequences=[X],\n", + " non_sequences=[W, b])\n", + "print(updates)" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{}\n" + ] + } + ], + "execution_count": 7 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "We can now compile our Pytensor function and see that it gives the expected results." + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.666677Z", + "start_time": "2025-01-10T17:40:00.403399Z" + } + }, + "source": [ + "f = pytensor.function(inputs=[X, W, b],\n", + " outputs=output,\n", + " updates=updates)\n", + "\n", + "X_value = np.arange(-3, 3).reshape(3, 2).astype(floatX)\n", + "W_value = np.eye(2).astype(floatX)\n", + "b_value = np.arange(2).astype(floatX)\n", + "\n", + "print(f(X_value, W_value, b_value))" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-3. -1.]\n", + " [-1. 1.]\n", + " [ 1. 3.]]\n" + ] + } + ], + "execution_count": 8 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 3 : Reusing outputs from the previous iterations\n", + "\n", + "In this example, we will use Scan to compute a cumulative sum over the first dimension of a matrix $M$. This means that the output will be a matrix $S$ in which the first row will be equal to the first row of $M$, the second row will be equal to the sum of the two first rows of $M$, and so on.\n", + "\n", + "Another way to express this, which is the way we will implement here, is that $S_t = S_{t-1} + M_t$. Implementing this with Scan would involve iterating over the rows of the matrix $M$ and, at every iteration, reuse the cumulative row that was output at the previous iteration and return the sum of it and the current row of $M$.\n", + "\n", + "If we assume for a moment that we can get Scan to provide the output value from the previous iteration as an input for every iteration, implementing a step function is simple :" + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.698967Z", + "start_time": "2025-01-10T17:40:00.695951Z" + } + }, + "source": [ + "def step(m_row, cumulative_sum):\n", + " return m_row + cumulative_sum" + ], + "outputs": [], + "execution_count": 9 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The trick part is informing Scan that our step function expects as input the output of a previous iteration. To achieve this, we need to use a new parameter of the `scan()` function: `outputs_info`. This parameter is used to tell Scan how we intend to use each of the outputs that are computed at each iteration.\n", + "\n", + "This parameter can be omitted (like we did so far) when the step function doesn't depend on any output of a previous iteration. However, now that we wish to have recurrent outputs, we need to start using it.\n", + "\n", + "`outputs_info` takes a sequence with one element for every output of the `step()` function :\n", + "* For a **non-recurrent output** (like in every example before this one), the element should be `None`.\n", + "* For a **simple recurrent output** (iteration $t$ depends on the value at iteration $t-1$), the element must be a tensor. Scan will interpret it as being an initial state for a recurrent output and give it as input to the first iteration, pretending it is the output value from a previous iteration. For subsequent iterations, Scan will automatically handle giving the previous output value as an input.\n", + "\n", + "The `step()` function needs to expect one additional input for each simple recurrent output. These inputs correspond to outputs from previous iteration and are **always** after the inputs that correspond to sequences but before those that correspond to non-sequences. The are received by the `step()` function in the order in which the recurrent outputs are declared in the outputs_info sequence." + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.767156Z", + "start_time": "2025-01-10T17:40:00.740203Z" + } + }, + "source": [ + "M = pt.dmatrix('X')\n", + "s = pt.dvector('s') # Initial value for the cumulative sum\n", + "\n", + "output, updates = pytensor.scan(fn=step,\n", + " sequences=[M],\n", + " outputs_info=[s])" + ], + "outputs": [], + "execution_count": 10 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "We can now compile and test the Pytensor function :" + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.933590Z", + "start_time": "2025-01-10T17:40:00.814705Z" + } + }, + "source": [ + "f = pytensor.function(inputs=[M, s],\n", + " outputs=output,\n", + " updates=updates)\n", + "\n", + "M_value = np.arange(9).reshape(3, 3).astype(floatX)\n", + "s_value = np.zeros((3, ), dtype=floatX)\n", + "\n", + "print(f(M_value, s_value))" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0. 1. 2.]\n", + " [ 3. 5. 7.]\n", + " [ 9. 12. 15.]]\n" + ] + } + ], + "execution_count": 11 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "An important thing to notice here, is that the output computed by the Scan does **not** include the initial state that we provided. It only outputs the states that it has computed itself.\n", + "\n", + "If we want to have both the initial state and the computed states in the same Pytensor variable, we have to join them ourselves." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 4 : Reusing outputs from multiple past iterations\n", + "\n", + "The Fibonacci sequence is a sequence of numbers F where the two first numbers both 1 and every subsequence number is defined as such : $F_n = F_{n-1} + F_{n-2}$. Thus, the Fibonacci sequence goes : 1, 1, 2, 3, 5, 8, 13, ...\n", + "\n", + "In this example, we will cover how to compute part of the Fibonacci sequence using Scan. Most of the tools required to achieve this have been introduced in the previous examples. The only one missing is the ability to use, at iteration $i$, outputs from iterations older than $i-1$.\n", + "\n", + "Also, since every example so far had only one output at every iteration of the loop, we will also compute, at each timestep, the ratio between the new term of the Fibonacci sequence and the previous term.\n", + "\n", + "Writing an appropriate step function given two inputs, representing the two previous terms of the Fibonacci sequence, is easy:" + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:00.960658Z", + "start_time": "2025-01-10T17:40:00.956657Z" + } + }, + "source": [ + "def step(f_minus2, f_minus1):\n", + " new_f = f_minus2 + f_minus1\n", + " ratio = new_f / f_minus1\n", + " return new_f, ratio" + ], + "outputs": [], + "execution_count": 12 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next step is defining the value of `outputs_info`.\n", + "\n", + "Recall that, for **non-recurrent outputs**, the value is `None` and, for **simple recurrent outputs**, the value is a single initial state. For **general recurrent outputs**, where iteration $t$ may depend on multiple past values, the value is a dictionary. That dictionary has two values:\n", + "* taps : list declaring which previous values of that output every iteration will need. `[-3, -2, -1]` would mean every iteration should take as input the last 3 values of that output. `[-2]` would mean every iteration should take as input the value of that output from two iterations ago.\n", + "* initial : tensor of initial values. If every initial value has $n$ dimensions, `initial` will be a single tensor of $n+1$ dimensions with as many initial values as the oldest requested tap. In the case of the Fibonacci sequence, the individual initial values are scalars so the `initial` will be a vector. \n", + "\n", + "In our example, we have two outputs. The first output is the next computed term of the Fibonacci sequence so every iteration should take as input the two last values of that output. The second output is the ratio between successive terms and we don't reuse its value so this output is non-recurrent. We define the value of `outputs_info` as such :" + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:01.023497Z", + "start_time": "2025-01-10T17:40:01.019867Z" + } + }, + "source": [ + "f_init = pt.fvector()\n", + "outputs_info = [dict(initial=f_init, taps=[-2, -1]),\n", + " None]" + ], + "outputs": [], + "execution_count": 13 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we've defined the step function and the properties of our outputs, we can call the `scan()` function. Because the `step()` function has multiple outputs, the first output of `scan()` function will be a list of tensors: the first tensor containing all the states of the first output and the second tensor containing all the states of the second input.\n", + "\n", + "In every previous example, we used sequences and Scan automatically inferred the number of iterations it needed to run from the length of these\n", + "sequences. Now that we have no sequence, we need to explicitly tell Scan how many iterations to run using the `n_step` parameter. The value can be real or symbolic." + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:01.080129Z", + "start_time": "2025-01-10T17:40:01.069348Z" + } + }, + "source": [ + "output, updates = pytensor.scan(fn=step,\n", + " outputs_info=outputs_info,\n", + " n_steps=10)\n", + "\n", + "next_fibonacci_terms = output[0]\n", + "ratios_between_terms = output[1]" + ], + "outputs": [], + "execution_count": 14 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "Let's compile our Pytensor function which will take a vector of consecutive values from the Fibonacci sequence and compute the next 10 values :" + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:01.254196Z", + "start_time": "2025-01-10T17:40:01.134565Z" + } + }, + "source": [ + "f = pytensor.function(inputs=[f_init],\n", + " outputs=[next_fibonacci_terms, ratios_between_terms],\n", + " updates=updates)\n", + "\n", + "out = f([1, 1])\n", + "print(out[0])\n", + "print(out[1])" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 2. 3. 5. 8. 13. 21. 34. 55. 89. 144.]\n", + "[2. 1.5 1.6666666 1.6 1.625 1.6153846 1.6190476\n", + " 1.617647 1.6181818 1.6179775]\n" + ] + } + ], + "execution_count": 15 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Order of inputs \n", + "\n", + "When we start using many sequences, recurrent outputs and non-sequences, it's easy to get confused regarding the order in which the step function receives the corresponding inputs. Below is the full order:\n", + "\n", + "* Element from the first sequence\n", + "* ...\n", + "* Element from the last sequence\n", + "* First requested tap from first recurrent output\n", + "* ...\n", + "* Last requested tap from first recurrent output\n", + "* ...\n", + "* First requested tap from last recurrent output\n", + "* ...\n", + "* Last requested tap from last recurrent output\n", + "* First non-sequence\n", + "* ...\n", + "* Last non-sequence" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## When to use Scan \n", + "\n", + "Scan is not appropriate for every problem. Here's some information to help you figure out if Scan is the best solution for a given use case.\n", + "\n", + "### Execution speed\n", + "\n", + "Using Scan in a Pytensor function typically makes it slightly slower compared to the equivalent Pytensor graph in which the loop is unrolled. Both of these approaches tend to be much slower than a vectorized implementation in which large chunks of the computation can be done in parallel.\n", + "\n", + "### Compilation speed\n", + "\n", + "Scan also adds an overhead to the compilation, potentially making it slower, but using it can also dramatically reduce the size of your graph, making compilation much faster. In the end, the effect of Scan on compilation speed will heavily depend on the size of the graph with and without Scan.\n", + "\n", + "The compilation speed of a Pytensor function using Scan will usually be comparable to one in which the loop is unrolled if the number of iterations is small. It the number of iterations is large, however, the compilation will usually be much faster with Scan.\n", + "\n", + "### In summary\n", + "\n", + "If you have one of the following cases, Scan can help :\n", + "* A vectorized implementation is not possible (due to the nature of the computation and/or memory usage)\n", + "* You want to do a large or variable number of iterations\n", + "\n", + "If you have one of the following cases, you should consider other options :\n", + "* A vectorized implementation could perform the same computation => Use the vectorized approach. It will often be faster during both compilation and execution.\n", + "* You want to do a small, fixed, number of iterations (ex: 2 or 3) => It's probably better to simply unroll the computation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exercises\n", + "\n", + "### Exercise 1 - Computing a polynomial\n", + "\n", + "In this exercise, the initial version already works. It computes the value of a polynomial ($n_0 + n_1 x + n_2 x^2 + ... $) of at most 10000 degrees given the coefficients of the various terms and the value of x.\n", + "\n", + "You must modify it such that the reduction (the sum() call) is done by Scan." + ] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:01.466495Z", + "start_time": "2025-01-10T17:40:01.288716Z" + } + }, + "source": [ + "coefficients = pt.dvector(\"coefficients\")\n", + "x = pt.dscalar(\"x\")\n", + "max_coefficients_supported = 10000\n", + "\n", + "def step(coeff, power, free_var):\n", + " return coeff * free_var ** power\n", + "\n", + "# Generate the components of the polynomial\n", + "full_range = pt.arange(max_coefficients_supported)\n", + "components, updates = pytensor.scan(fn=step,\n", + " outputs_info=None,\n", + " sequences=[coefficients, full_range],\n", + " non_sequences=x)\n", + "\n", + "polynomial = components.sum()\n", + "calculate_polynomial = pytensor.function(inputs=[coefficients, x],\n", + " outputs=polynomial,\n", + " updates=updates)\n", + "\n", + "test_coeff = np.asarray([1, 0, 2], dtype=floatX)\n", + "print(calculate_polynomial(test_coeff, 3))\n", + "# 19.0" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19.0\n" + ] + } + ], + "execution_count": 16 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Solution** : run the cell below to display the solution to this exercise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Exercise 2 - Sampling without replacement\n", + "\n", + "In this exercise, the goal is to implement a Pytensor function that :\n", + "* takes as input a vector of probabilities and a scalar\n", + "* performs sampling without replacements from those probabilities as many times as the value of the scalar\n", + "* returns a vector containing the indices of the sampled elements.\n", + "\n", + "Partial code is provided to help with the sampling of random numbers since this is not something that was covered in this tutorial." + ] + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-10T17:40:01.513298Z", + "start_time": "2025-01-10T17:40:01.482238Z" + } + }, + "cell_type": "code", + "source": [ + "rng = pytensor.shared(np.random.default_rng(1234))\n", + "p_vec = pt.dvector(\"p_vec\")\n", + "next_rng, onehot_sample = pt.random.multinomial(n=1, p=p_vec, rng=rng).owner.outputs\n", + "f = pytensor.function([p_vec], onehot_sample, updates={rng:next_rng})" + ], + "outputs": [], + "execution_count": 17 + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-01-10T17:40:01.703547Z", + "start_time": "2025-01-10T17:40:01.536499Z" + } + }, + "source": [ + "def sample_from_pvect(p, rng):\n", + " \"\"\" Provided utility function: given a symbolic vector of\n", + " probabilities (which MUST sum to 1), sample one element\n", + " and return its index.\n", + " \"\"\"\n", + " next_rng, onehot_sample = pt.random.multinomial(n=1, p=p, rng=rng).owner.outputs\n", + " idx = onehot_sample.argmax()\n", + " \n", + " return idx, {rng: next_rng}\n", + "\n", + "def set_p_to_zero(p, i):\n", + " \"\"\" Provided utility function: given a symbolic vector of\n", + " probabilities and an index 'i', set the probability of the\n", + " i-th element to 0 and renormalize the probabilities so they\n", + " sum to 1.\n", + " \"\"\"\n", + " new_p = p[i].set(0.)\n", + " new_p = new_p / new_p.sum()\n", + " return new_p\n", + "\n", + "def sample(p, rng):\n", + " idx, updates = sample_from_pvect(p, rng)\n", + " p = set_p_to_zero(p, idx)\n", + " return (p, idx), updates\n", + "\n", + "probabilities = pt.dvector()\n", + "nb_samples = pt.iscalar()\n", + "\n", + "SEED = sum(map(ord, 'PyTensor Scan'))\n", + "rng = pytensor.shared(np.random.default_rng(SEED))\n", + "\n", + "\n", + "# TODO use Scan to sample from the vector of probabilities and\n", + "# symbolically obtain 'samples' the vector of sampled indices.\n", + "[probs, samples], updates = pytensor.scan(fn=sample,\n", + " outputs_info=[probabilities, None],\n", + " non_sequences=[rng],\n", + " n_steps=nb_samples)\n", + "\n", + "# Compiling the function\n", + "f = pytensor.function(inputs=[probabilities, nb_samples], outputs=samples, updates=updates)\n", + "\n", + "# Testing the function\n", + "test_probs = np.asarray([0.6, 0.3, 0.1], dtype=floatX)\n", + "\n", + "for i in range(10):\n", + " print(f(test_probs, 2))" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0 1]\n", + "[0 1]\n", + "[2 1]\n", + "[2 0]\n", + "[0 1]\n", + "[0 1]\n", + "[0 1]\n", + "[0 1]\n", + "[0 1]\n", + "[0 1]\n" + ] + } + ], + "execution_count": 18 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Authors\n", + "\n", + "- Authored by Pascal Lamblin in Feburary 2016\n", + "- Updated by Jesse Grabowski in January 2025" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## References\n", + "\n", + ":::{bibliography} :filter: docname in docnames" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Watermark " + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-11T07:50:45.845462Z", + "start_time": "2025-01-11T07:50:45.809393Z" + } + }, + "cell_type": "code", + "source": [ + "%load_ext watermark\n", + "%watermark -n -u -v -iv -w -p pytensor" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The watermark extension is already loaded. To reload it, use:\n", + " %reload_ext watermark\n", + "Last updated: Sat Jan 11 2025\n", + "\n", + "Python implementation: CPython\n", + "Python version : 3.12.0\n", + "IPython version : 8.31.0\n", + "\n", + "pytensor: 2.26.4+16.g8be5c5323.dirty\n", + "\n", + "numpy : 1.26.4\n", + "pytensor: 2.26.4+16.g8be5c5323.dirty\n", + "sys : 3.12.0 | packaged by conda-forge | (main, Oct 3 2023, 08:43:22) [GCC 12.3.0]\n", + "\n", + "Watermark: 2.5.0\n", + "\n" + ] + } + ], + "execution_count": 20 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + ":::{include} ../page_footer.md \n", + ":::" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/doc/images/PyTensor.png b/doc/images/PyTensor.png new file mode 100644 index 0000000000..e6097693af Binary files /dev/null and b/doc/images/PyTensor.png differ diff --git a/doc/images/PyTensor_logo.png b/doc/images/PyTensor_logo.png new file mode 100644 index 0000000000..c8947735de Binary files /dev/null and b/doc/images/PyTensor_logo.png differ diff --git a/doc/images/binder.svg b/doc/images/binder.svg new file mode 100644 index 0000000000..327f6b639a --- /dev/null +++ b/doc/images/binder.svg @@ -0,0 +1 @@ + launchlaunchbinderbinder \ No newline at end of file diff --git a/doc/images/colab.svg b/doc/images/colab.svg new file mode 100644 index 0000000000..c08066ee33 --- /dev/null +++ b/doc/images/colab.svg @@ -0,0 +1 @@ + Open in ColabOpen in Colab diff --git a/doc/images/github.svg b/doc/images/github.svg new file mode 100644 index 0000000000..e02d8ed55b --- /dev/null +++ b/doc/images/github.svg @@ -0,0 +1 @@ + View On GitHubView On GitHub \ No newline at end of file diff --git a/doc/index.rst b/doc/index.rst index ac5bc0876c..a70a28df82 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -80,6 +80,7 @@ Community introduction user_guide API + Examples Contributing .. _Theano: https://github.com/Theano/Theano diff --git a/doc/library/index.rst b/doc/library/index.rst index 6a05a5a7bf..08a5b51c34 100644 --- a/doc/library/index.rst +++ b/doc/library/index.rst @@ -22,7 +22,6 @@ Modules gradient misc/pkl_utils printing - sandbox/index scalar/index scan sparse/index diff --git a/doc/library/misc/pkl_utils.rst b/doc/library/misc/pkl_utils.rst index 0299d15204..f22e5e8bd7 100644 --- a/doc/library/misc/pkl_utils.rst +++ b/doc/library/misc/pkl_utils.rst @@ -9,10 +9,6 @@ from pytensor.misc.pkl_utils import * -.. autofunction:: pytensor.misc.pkl_utils.dump - -.. autofunction:: pytensor.misc.pkl_utils.load - .. autoclass:: pytensor.misc.pkl_utils.StripPickler .. seealso:: diff --git a/doc/library/sandbox/index.rst b/doc/library/sandbox/index.rst deleted file mode 100644 index b4012cd9df..0000000000 --- a/doc/library/sandbox/index.rst +++ /dev/null @@ -1,16 +0,0 @@ - -.. _libdoc_sandbox: - -============================================================== -:mod:`sandbox` -- Experimental Code -============================================================== - -.. module:: sandbox - :platform: Unix, Windows - :synopsis: Experimental code -.. moduleauthor:: LISA - -.. toctree:: - :maxdepth: 1 - - linalg diff --git a/doc/library/sandbox/linalg.rst b/doc/library/sandbox/linalg.rst deleted file mode 100644 index 9ee5fe9f51..0000000000 --- a/doc/library/sandbox/linalg.rst +++ /dev/null @@ -1,19 +0,0 @@ -.. ../../../../pytensor/sandbox/linalg/ops.py -.. ../../../../pytensor/sandbox/linalg - -.. _libdoc_sandbox_linalg: - -=================================================================== -:mod:`sandbox.linalg` -- Linear Algebra Ops -=================================================================== - -.. module:: sandbox.linalg - :platform: Unix, Windows - :synopsis: Linear Algebra Ops -.. moduleauthor:: LISA - -API -=== - -.. automodule:: pytensor.sandbox.linalg.ops - :members: diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst index 50da46449a..8d22c1e577 100644 --- a/doc/library/tensor/basic.rst +++ b/doc/library/tensor/basic.rst @@ -477,7 +477,7 @@ them perfectly, but a `dscalar` otherwise. you'll want to call. -.. autoclass:: pytensor.tensor.var._tensor_py_operators +.. autoclass:: pytensor.tensor.variable._tensor_py_operators :members: This mix-in class adds convenient attributes, methods, and support diff --git a/doc/library/tensor/random/index.rst b/doc/library/tensor/random/index.rst index 210b77d5c1..d1f87af77b 100644 --- a/doc/library/tensor/random/index.rst +++ b/doc/library/tensor/random/index.rst @@ -30,6 +30,7 @@ sophisticated `Op`\s like `Scan`, which makes it a user-friendly random variable interface in PyTensor. For an example of how to use random numbers, see :ref:`Using Random Numbers `. +For a technical explanation of how PyTensor implements random variables see :ref:`prng`. .. class:: RandomStream() diff --git a/doc/robots.txt b/doc/robots.txt new file mode 100644 index 0000000000..73cf5dba3b --- /dev/null +++ b/doc/robots.txt @@ -0,0 +1,3 @@ +User-agent: * + +Sitemap: https://pytensor.readthedocs.io/en/latest/sitemap.xml diff --git a/doc/tutorial/examples.rst b/doc/tutorial/examples.rst index 51ea8496b2..e74d604f63 100644 --- a/doc/tutorial/examples.rst +++ b/doc/tutorial/examples.rst @@ -357,6 +357,9 @@ hold here as well. PyTensor's random objects are defined and implemented in :ref:`RandomStream` and, at a lower level, in :ref:`RandomVariable`. +For a more technical explanation of how PyTensor implements random variables see :ref:`prng`. + + Brief Example ------------- diff --git a/doc/tutorial/loading_and_saving.rst b/doc/tutorial/loading_and_saving.rst index dc6eb9b097..d099ecb026 100644 --- a/doc/tutorial/loading_and_saving.rst +++ b/doc/tutorial/loading_and_saving.rst @@ -145,7 +145,7 @@ might not have PyTensor installed, who are using a different Python version, or you are planning to save your model for a long time (in which case version mismatches might make it difficult to unpickle objects). -See :func:`pytensor.misc.pkl_utils.dump` and :func:`pytensor.misc.pkl_utils.load`. +See :meth:`pytensor.misc.pkl_utils.StripPickler.dump` and :meth:`pytensor.misc.pkl_utils.StripPickler.load`. Long-Term Serialization diff --git a/doc/tutorial/prng.rst b/doc/tutorial/prng.rst index fe541ab71e..65f0e43479 100644 --- a/doc/tutorial/prng.rst +++ b/doc/tutorial/prng.rst @@ -5,7 +5,9 @@ Pseudo random number generation in PyTensor =========================================== PyTensor has native support for `pseudo random number generation (PRNG) `_. -This document describes how PRNGs are implemented in PyTensor, via the RandomVariable Operator. + +This document describes the details of how PRNGs are implemented in PyTensor, via the RandomVariable Operator. +For a more applied example see :ref:`using_random_numbers` We also discuss how initial seeding and seeding updates are implemented, and some harder cases such as using RandomVariables inside Scan, or with other backends like JAX. diff --git a/environment-osx-arm64.yml b/environment-osx-arm64.yml index 13a68faaaa..3a83cde5e6 100644 --- a/environment-osx-arm64.yml +++ b/environment-osx-arm64.yml @@ -9,7 +9,7 @@ channels: dependencies: - python=>3.10 - compilers - - numpy>=1.17.0,<2 + - numpy>=1.17.0,<2.1 - scipy>=1,<2 - filelock>=3.15 - etuples diff --git a/environment.yml b/environment.yml index 4b213fd851..0c7454f8ac 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ channels: dependencies: - python>=3.10 - compilers - - numpy>=1.17.0,<2 + - numpy>=1.17.0,<2.1 - scipy>=1,<2 - filelock>=3.15 - etuples @@ -43,6 +43,10 @@ dependencies: - ipython - pymc-sphinx-theme - sphinx-design + - myst-nb + - matplotlib + - watermark + # code style - ruff # developer tools diff --git a/pyproject.toml b/pyproject.toml index 42c2289dde..7688a39720 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ keywords = [ dependencies = [ "setuptools>=59.0.0", "scipy>=1,<2", - "numpy>=1.17.0,<2", + "numpy>=1.17.0,<2.1", "filelock>=3.15", "etuples", "logical-unification", @@ -129,8 +129,12 @@ exclude = ["doc/", "pytensor/_version.py"] docstring-code-format = true [tool.ruff.lint] -select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"] +select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "NPY201"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] +unfixable = [ + # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead + "B905", +] [tool.ruff.lint.isort] diff --git a/pytensor/__init__.py b/pytensor/__init__.py index dd6117c527..3c925ac2f2 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -24,6 +24,7 @@ # pytensor code, since this code may want to log some messages. import logging import sys +import warnings from functools import singledispatch from pathlib import Path from typing import Any, NoReturn, Optional @@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v): If `v` is not some view of constant data, then raise a `NotScalarConstantError`. """ - # Is it necessary to test for presence of pytensor.sparse at runtime? - sparse = globals().get("sparse") - if sparse and isinstance(v.type, sparse.SparseTensorType): - if v.owner is not None and isinstance(v.owner.op, sparse.CSM): - data = v.owner.inputs[0] - return tensor.get_underlying_scalar_constant_value(data) - return tensor.get_underlying_scalar_constant_value(v) + warnings.warn( + "get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.", + FutureWarning, + ) + from pytensor.tensor.basic import get_underlying_scalar_constant_value + + return get_underlying_scalar_constant_value(v) # isort: off @@ -164,6 +165,7 @@ def get_underlying_scalar_constant(v): from pytensor.scan import checkpoints from pytensor.scan.basic import scan from pytensor.scan.views import foldl, foldr, map, reduce +from pytensor.compile.builders import OpFromGraph # isort: on diff --git a/pytensor/compile/__init__.py b/pytensor/compile/__init__.py index 9bd140d746..f6a95fe163 100644 --- a/pytensor/compile/__init__.py +++ b/pytensor/compile/__init__.py @@ -37,7 +37,6 @@ PrintCurrentFunctionGraph, get_default_mode, get_mode, - instantiated_default_mode, local_useless, optdb, predefined_linkers, diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index ff0b742975..49baa3bb26 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -43,7 +43,7 @@ def infer_shape(outs, inputs, input_shapes): # TODO: ShapeFeature should live elsewhere from pytensor.tensor.rewriting.shape import ShapeFeature - for inp, inp_shp in zip(inputs, input_shapes): + for inp, inp_shp in zip(inputs, input_shapes, strict=True): if inp_shp is not None and len(inp_shp) != inp.type.ndim: assert len(inp_shp) == inp.type.ndim @@ -51,7 +51,7 @@ def infer_shape(outs, inputs, input_shapes): shape_feature.on_attach(FunctionGraph([], [])) # Initialize shape_of with the input shapes - for inp, inp_shp in zip(inputs, input_shapes): + for inp, inp_shp in zip(inputs, input_shapes, strict=True): shape_feature.set_shape(inp, inp_shp) def local_traverse(out): @@ -108,7 +108,9 @@ def construct_nominal_fgraph( replacements = dict( zip( - inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs + inputs + implicit_shared_inputs, + dummy_inputs + dummy_implicit_shared_inputs, + strict=True, ) ) @@ -138,7 +140,7 @@ def construct_nominal_fgraph( NominalVariable(n, var.type) for n, var in enumerate(local_inputs) ) - fgraph.replace_all(zip(local_inputs, nominal_local_inputs)) + fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True)) for i, inp in enumerate(fgraph.inputs): nom_inp = nominal_local_inputs[i] @@ -562,7 +564,9 @@ def lop_overrides(inps, grads): # compute non-overriding downsteam grads from upstreams grads # it's normal some input may be disconnected, thus the 'ignore' wrt = [ - lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None + lin + for lin, gov in zip(inner_inputs, custom_input_grads, strict=True) + if gov is None ] default_input_grads = fn_grad(wrt=wrt) if wrt else [] input_grads = self._combine_list_overrides( @@ -653,7 +657,7 @@ def _build_and_cache_rop_op(self): f = [ output for output, custom_output_grad in zip( - inner_outputs, custom_output_grads + inner_outputs, custom_output_grads, strict=True ) if custom_output_grad is None ] @@ -733,18 +737,24 @@ def make_node(self, *inputs): non_shared_inputs = [ inp_t.filter_variable(inp) - for inp, inp_t in zip(non_shared_inputs, self.input_types) + for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True) ] new_shared_inputs = inputs[num_expected_inps:] - inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs)) + inner_and_input_shareds = list( + zip(self.shared_inputs, new_shared_inputs, strict=True) + ) if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds): # The shared variables are not equal to the original shared # variables, so we construct a new `Op` that uses the new shared # variables instead. replace = dict( - zip(self.inner_inputs[num_expected_inps:], new_shared_inputs) + zip( + self.inner_inputs[num_expected_inps:], + new_shared_inputs, + strict=True, + ) ) # If the new shared variables are inconsistent with the inner-graph, @@ -811,7 +821,7 @@ def infer_shape(self, fgraph, node, shapes): # each shape call. PyTensor optimizer will clean this up later, but this # will make extra work for the optimizer. - repl = dict(zip(self.inner_inputs, node.inputs)) + repl = dict(zip(self.inner_inputs, node.inputs, strict=True)) clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)] cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl) ret = [] @@ -853,5 +863,6 @@ def clone(self): def perform(self, node, inputs, outputs): variables = self.fn(*inputs) assert len(variables) == len(outputs) - for output, variable in zip(outputs, variables): + # strict=False because asserted above + for output, variable in zip(outputs, variables, strict=False): output[0] = variable diff --git a/pytensor/compile/compilelock.py b/pytensor/compile/compilelock.py index 83bf42866d..a1697e43d1 100644 --- a/pytensor/compile/compilelock.py +++ b/pytensor/compile/compilelock.py @@ -8,8 +8,6 @@ from contextlib import contextmanager from pathlib import Path -import filelock - from pytensor.configdefaults import config @@ -35,8 +33,9 @@ def force_unlock(lock_dir: os.PathLike): lock_dir : os.PathLike Path to a directory that was locked with `lock_ctx`. """ + from filelock import FileLock - fl = filelock.FileLock(Path(lock_dir) / ".lock") + fl = FileLock(Path(lock_dir) / ".lock") fl.release(force=True) dir_key = f"{lock_dir}-{os.getpid()}" @@ -62,6 +61,8 @@ def lock_ctx( Timeout in seconds for waiting in lock acquisition. Defaults to `pytensor.config.compile__timeout`. """ + from filelock import FileLock + if lock_dir is None: lock_dir = config.compiledir @@ -73,7 +74,7 @@ def lock_ctx( if dir_key not in local_mem._locks: local_mem._locks[dir_key] = True - fl = filelock.FileLock(Path(lock_dir) / ".lock") + fl = FileLock(Path(lock_dir) / ".lock") fl.acquire(timeout=timeout) try: yield diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index bfcaf1ecf0..cc1a5b225a 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -865,7 +865,7 @@ def _get_preallocated_maps( # except if broadcastable, or for dimensions above # config.DebugMode__check_preallocated_output_ndim buf_shape = [] - for s, b in zip(r_vals[r].shape, r.broadcastable): + for s, b in zip(r_vals[r].shape, r.broadcastable, strict=True): if b or ((r.ndim - len(buf_shape)) > check_ndim): buf_shape.append(s) else: @@ -943,7 +943,7 @@ def _get_preallocated_maps( r_shape_diff = shape_diff[: r.ndim] new_buf_shape = [ max((s + sd), 0) - for s, sd in zip(r_vals[r].shape, r_shape_diff) + for s, sd in zip(r_vals[r].shape, r_shape_diff, strict=True) ] new_buf = np.empty(new_buf_shape, dtype=r.type.dtype) new_buf[...] = np.asarray(def_val).astype(r.type.dtype) @@ -1575,7 +1575,7 @@ def f(): # try: # compute the value of all variables for i, (thunk_py, thunk_c, node) in enumerate( - zip(thunks_py, thunks_c, order) + zip(thunks_py, thunks_c, order, strict=True) ): _logger.debug(f"{i} - starting node {i} {node}") @@ -1855,7 +1855,7 @@ def thunk(): assert s[0] is None # store our output variables to their respective storage lists - for output, storage in zip(fgraph.outputs, output_storage): + for output, storage in zip(fgraph.outputs, output_storage, strict=True): storage[0] = r_vals[output] # transfer all inputs back to their respective storage lists @@ -1931,11 +1931,11 @@ def deco(): f, [ Container(input, storage, readonly=False) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks_py, order, @@ -2122,7 +2122,9 @@ def __init__( no_borrow = [ output - for output, spec in zip(fgraph.outputs, outputs + additional_outputs) + for output, spec in zip( + fgraph.outputs, outputs + additional_outputs, strict=True + ) if not spec.borrow ] if no_borrow: diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 49a6840719..935c77219a 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -603,7 +603,7 @@ def construct_pfunc_ins_and_outs( new_inputs = [] - for i, iv in zip(inputs, input_variables): + for i, iv in zip(inputs, input_variables, strict=True): new_i = copy(i) new_i.variable = iv @@ -637,13 +637,13 @@ def construct_pfunc_ins_and_outs( assert len(fgraph.inputs) == len(inputs) assert len(fgraph.outputs) == len(outputs) - for fg_inp, inp in zip(fgraph.inputs, inputs): + for fg_inp, inp in zip(fgraph.inputs, inputs, strict=True): if fg_inp != getattr(inp, "variable", inp): raise ValueError( f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}" ) - for fg_out, out in zip(fgraph.outputs, outputs): + for fg_out, out in zip(fgraph.outputs, outputs, strict=True): if fg_out != getattr(out, "variable", out): raise ValueError( f"`fgraph`'s output does not match the provided output: {fg_out}, {out}" diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index b7caff1bf4..e2e612ac93 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -241,7 +241,7 @@ def std_fgraph( fgraph.attach_feature( Supervisor( input - for spec, input in zip(input_specs, fgraph.inputs) + for spec, input in zip(input_specs, fgraph.inputs, strict=True) if not ( spec.mutable or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input])) @@ -393,6 +393,8 @@ def __init__( assert len(self.input_storage) == len(self.maker.fgraph.inputs) assert len(self.output_storage) == len(self.maker.fgraph.outputs) + self.has_defaults = any(refeed for _, refeed, _ in self.defaults) + # Group indexes of inputs that are potentially aliased to each other # Note: Historically, we only worried about aliasing inputs if they belonged to the same type, # even though there could be two distinct types that use the same kinds of underlying objects. @@ -442,7 +444,7 @@ def __init__( # this loop works by modifying the elements (as variable c) of # self.input_storage inplace. for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate( - zip(self.indices, defaults) + zip(self.indices, defaults, strict=True) ): if indices is None: # containers is being used as a stack. Here we pop off @@ -540,14 +542,40 @@ def __contains__(self, item): self._value = ValueAttribute() self._container = ContainerAttribute() - # TODO: Get rid of all this `expanded_inputs` nonsense - assert len(self.maker.expanded_inputs) == len(self.input_storage) + update_storage = [ + container + for inp, container in zip( + self.maker.expanded_inputs, input_storage, strict=True + ) + if inp.update is not None + ] + # Updates are the last inner outputs that are not returned by Function.__call__ + self.n_returned_outputs = len(self.output_storage) - len(update_storage) + + # Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself + self.update_input_storage: tuple[int, Container] = () + if getattr(vm, "need_update_inputs", True): + self.update_input_storage = tuple( + zip( + range(self.n_returned_outputs, len(output_storage)), + update_storage, + strict=True, + ) + ) - # This is used only when `vm.need_update_inputs` is `False`, because - # we're using one of the VM objects and it is putting updates back into - # the input containers all by itself. - self.n_returned_outputs = len(self.output_storage) - sum( - inp.update is not None for inp in self.maker.expanded_inputs + # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage + # After the call, we want to erase (some of) these references, to allow Python to GC them if unused + # Required input containers are the non-default inputs, must always be provided again, so we GC them + self.clear_input_storage_data = tuple( + container.storage for container in input_storage if container.required + ) + # This is only done when `vm.allow_gc` is True, which can change at runtime. + self.clear_output_storage_data = tuple( + container.storage + for container, variable in zip( + self.output_storage, self.maker.fgraph.outputs, strict=True + ) + if variable.owner is not None # Not a constant output ) for node in self.maker.fgraph.apply_nodes: @@ -671,7 +699,7 @@ def checkSV(sv_ori, sv_rpl): else: outs = list(map(SymbolicOutput, fg_cpy.outputs)) - for out_ori, out_cpy in zip(maker.outputs, outs): + for out_ori, out_cpy in zip(maker.outputs, outs, strict=False): out_cpy.borrow = out_ori.borrow # swap SharedVariable @@ -684,7 +712,7 @@ def checkSV(sv_ori, sv_rpl): raise ValueError(f"SharedVariable: {sv.name} not found") # Swap SharedVariable in fgraph and In instances - for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)): + for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs, strict=True)): # Variables in maker.inputs are defined by user, therefore we # use them to make comparison and do the mapping. # Otherwise we don't touch them. @@ -708,7 +736,7 @@ def checkSV(sv_ori, sv_rpl): # Delete update if needed rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()} - for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)): + for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs, strict=True)): inp.variable = in_var if not delete_updates and inp.update is not None: out_idx = rev_update_mapping[n] @@ -747,7 +775,7 @@ def checkSV(sv_ori, sv_rpl): elif isinstance(profile, str): profile = pytensor.compile.profiling.ProfileStats(message=profile) - f_cpy = maker.__class__( + f_cpy = type(maker)( inputs=ins, outputs=outs, fgraph=fg_cpy, @@ -765,10 +793,16 @@ def checkSV(sv_ori, sv_rpl): # check that. accept_inplace=True, no_fgraph_prep=True, + output_keys=maker.output_keys, + name=name, ).create(input_storage, storage_map=new_storage_map) for in_ori, in_cpy, ori, cpy in zip( - maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage + maker.inputs, + f_cpy.maker.inputs, + self.input_storage, + f_cpy.input_storage, + strict=True, ): # Share immutable ShareVariable and constant input's storage swapped = swap is not None and in_ori.variable in swap @@ -793,8 +827,6 @@ def checkSV(sv_ori, sv_rpl): f_cpy.trust_input = self.trust_input f_cpy.unpack_single = self.unpack_single - f_cpy.name = name - f_cpy.maker.fgraph.name = name return f_cpy def _restore_defaults(self): @@ -804,7 +836,7 @@ def _restore_defaults(self): value = value.storage[0] self[i] = value - def __call__(self, *args, **kwargs): + def __call__(self, *args, output_subset=None, **kwargs): """ Evaluates value of a function on given arguments. @@ -832,20 +864,21 @@ def __call__(self, *args, **kwargs): List of outputs on indices/keys from ``output_subset`` or all of them, if ``output_subset`` is not passed. """ + trust_input = self.trust_input input_storage = self.input_storage + vm = self.vm profile = self.profile if profile: t0 = time.perf_counter() - output_subset = kwargs.pop("output_subset", None) if output_subset is not None: warnings.warn("output_subset is deprecated.", FutureWarning) if self.output_keys is not None: output_subset = [self.output_keys.index(key) for key in output_subset] # Reinitialize each container's 'provided' counter - if self.trust_input: + if trust_input: for arg_container, arg in zip(input_storage, args, strict=False): arg_container.storage[0] = arg else: @@ -904,7 +937,7 @@ def __call__(self, *args, **kwargs): for k, arg in kwargs.items(): self[k] = arg - if not self.trust_input: + if not trust_input: # Collect aliased inputs among the storage space for potential_group in self._potential_aliased_input_groups: args_share_memory: list[list[int]] = [] @@ -956,11 +989,7 @@ def __call__(self, *args, **kwargs): if profile: t0_fn = time.perf_counter() try: - outputs = ( - self.vm() - if output_subset is None - else self.vm(output_subset=output_subset) - ) + outputs = vm() if output_subset is None else vm(output_subset=output_subset) except Exception: self._restore_defaults() if hasattr(self.vm, "position_of_error"): @@ -987,37 +1016,23 @@ def __call__(self, *args, **kwargs): # Retrieve the values that were computed if outputs is None: - outputs = [x.data for x in self.output_storage] - - # Remove internal references to required inputs. - # These cannot be re-used anyway. - for arg_container in input_storage: - if arg_container.required: - arg_container.storage[0] = None - - # if we are allowing garbage collection, remove the - # output reference from the internal storage cells - if getattr(self.vm, "allow_gc", False): - for o_container, o_variable in zip( - self.output_storage, self.maker.fgraph.outputs - ): - if o_variable.owner is not None: - # this node is the variable of computation - # WARNING: This circumvents the 'readonly' attribute in x - o_container.storage[0] = None - - if getattr(self.vm, "need_update_inputs", True): - # Update the inputs that have an update function - for input, storage in reversed( - list(zip(self.maker.expanded_inputs, input_storage)) - ): - if input.update is not None: - storage.data = outputs.pop() - else: - outputs = outputs[: self.n_returned_outputs] + outputs = [x.storage[0] for x in self.output_storage] + + # Set updates and filter them out from the returned outputs + for i, input_storage in self.update_input_storage: + input_storage.storage[0] = outputs[i] + outputs = outputs[: self.n_returned_outputs] + + # Remove input and output values from storage data + for storage_data in self.clear_input_storage_data: + storage_data[0] = None + if getattr(vm, "allow_gc", False): + for storage_data in self.clear_output_storage_data: + storage_data[0] = None # Put default values back in the storage - self._restore_defaults() + if self.has_defaults: + self._restore_defaults() if profile: dt_call = time.perf_counter() - t0 @@ -1025,32 +1040,29 @@ def __call__(self, *args, **kwargs): self.maker.mode.call_time += dt_call profile.fct_callcount += 1 profile.fct_call_time += dt_call - if hasattr(self.vm, "update_profile"): - self.vm.update_profile(profile) + if hasattr(vm, "update_profile"): + vm.update_profile(profile) if profile.ignore_first_call: profile.reset() profile.ignore_first_call = False if self.return_none: return None - elif self.unpack_single and len(outputs) == 1 and output_subset is None: - return outputs[0] - else: - if self.output_keys is not None: - assert len(self.output_keys) == len(outputs) - if output_subset is None: - return dict(zip(self.output_keys, outputs)) - else: - return { - self.output_keys[index]: outputs[index] - for index in output_subset - } + if output_subset is not None: + outputs = [outputs[i] for i in output_subset] - if output_subset is None: - return outputs + if self.output_keys is None: + if self.unpack_single: + [out] = outputs + return out else: - return [outputs[i] for i in output_subset] + return outputs + else: + output_keys = self.output_keys + if output_subset is not None: + output_keys = [output_keys[i] for i in output_subset] + return dict(zip(output_keys, outputs, strict=True)) value = property( lambda self: self._value, @@ -1070,9 +1082,10 @@ def free(self): # 1.no allow_gc return False # 2.has allow_gc, if allow_gc is False, return True if not getattr(self.vm, "allow_gc", True): - for key in self.vm.storage_map: - if not isinstance(key, Constant): - self.vm.storage_map[key][0] = None + storage_map = self.vm.storage_map + for key, value in storage_map.items(): + if key.owner is not None: # Not a constant + value[0] = None for node in self.nodes_with_inner_function: if hasattr(node.fn, "free"): @@ -1084,10 +1097,6 @@ def get_shared(self): """ return [i.variable for i in self.maker.inputs if i.implicit] - def sync_shared(self): - # NOTE: sync was needed on old gpu backend - pass - def dprint(self, **kwargs): """Debug print itself @@ -1107,8 +1116,9 @@ def _pickle_Function(f): ins = list(f.input_storage) input_storage = [] + # strict=False because we are in a hot loop for (input, indices, inputs), (required, refeed, default) in zip( - f.indices, f.defaults + f.indices, f.defaults, strict=False ): input_storage.append(ins[0]) del ins[0] @@ -1150,7 +1160,7 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False): f = maker.create(input_storage) assert len(f.input_storage) == len(inputs_data) - for container, x in zip(f.input_storage, inputs_data): + for container, x in zip(f.input_storage, inputs_data, strict=True): assert ( (container.data is x) or (isinstance(x, np.ndarray) and (container.data == x).all()) @@ -1184,7 +1194,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): reason = "insert_deepcopy" updated_fgraph_inputs = { fgraph_i - for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs) + for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True) if getattr(i, "update", False) } @@ -1521,7 +1531,9 @@ def __init__( # return the internal storage pointer. no_borrow = [ output - for output, spec in zip(fgraph.outputs, outputs + found_updates) + for output, spec in zip( + fgraph.outputs, outputs + found_updates, strict=True + ) if not spec.borrow ] @@ -1590,7 +1602,7 @@ def create(self, input_storage=None, storage_map=None): # defaults lists. assert len(self.indices) == len(input_storage) for i, ((input, indices, subinputs), input_storage_i) in enumerate( - zip(self.indices, input_storage) + zip(self.indices, input_storage, strict=True) ): # Replace any default value given as a variable by its # container. Note that this makes sense only in the diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 152ad3554d..ae905089b5 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -492,7 +492,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "PYTORCH": PYTORCH, } -instantiated_default_mode = None +_CACHED_RUNTIME_MODES: dict[str, Mode] = {} def get_mode(orig_string): @@ -500,50 +500,46 @@ def get_mode(orig_string): string = config.mode else: string = orig_string + if not isinstance(string, str): return string # it is hopefully already a mode... - global instantiated_default_mode - # The default mode is cached. However, config.mode can change - # If instantiated_default_mode has the right class, use it. - if orig_string is None and instantiated_default_mode: - if string in predefined_modes: - default_mode_class = predefined_modes[string].__class__.__name__ - else: - default_mode_class = string - if instantiated_default_mode.__class__.__name__ == default_mode_class: - return instantiated_default_mode - - if string in ("Mode", "DebugMode", "NanGuardMode"): - if string == "DebugMode": - # need to import later to break circular dependency. - from .debugmode import DebugMode - - # DebugMode use its own linker. - ret = DebugMode(optimizer=config.optimizer) - elif string == "NanGuardMode": - # need to import later to break circular dependency. - from .nanguardmode import NanGuardMode - - # NanGuardMode use its own linker. - ret = NanGuardMode(True, True, True, optimizer=config.optimizer) - else: - # TODO: Can't we look up the name and invoke it rather than using eval here? - ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)") - elif string in predefined_modes: - ret = predefined_modes[string] - else: - raise Exception(f"No predefined mode exist for string: {string}") + # Keep the original string for error messages + upper_string = string.upper() - if orig_string is None: - # Build and cache the default mode - if config.optimizer_excluding: - ret = ret.excluding(*config.optimizer_excluding.split(":")) - if config.optimizer_including: - ret = ret.including(*config.optimizer_including.split(":")) - if config.optimizer_requiring: - ret = ret.requiring(*config.optimizer_requiring.split(":")) - instantiated_default_mode = ret + if upper_string in predefined_modes: + return predefined_modes[upper_string] + + global _CACHED_RUNTIME_MODES + + if upper_string in _CACHED_RUNTIME_MODES: + return _CACHED_RUNTIME_MODES[upper_string] + + # Need to define the mode for the first time + if upper_string == "MODE": + ret = Mode(linker=config.linker, optimizer=config.optimizer) + elif upper_string in ("DEBUGMODE", "DEBUG_MODE"): + from pytensor.compile.debugmode import DebugMode + + # DebugMode use its own linker. + ret = DebugMode(optimizer=config.optimizer) + elif upper_string == "NANGUARDMODE": + from pytensor.compile.nanguardmode import NanGuardMode + + # NanGuardMode use its own linker. + ret = NanGuardMode(True, True, True, optimizer=config.optimizer) + + else: + raise ValueError(f"No predefined mode exist for string: {string}") + + if config.optimizer_excluding: + ret = ret.excluding(*config.optimizer_excluding.split(":")) + if config.optimizer_including: + ret = ret.including(*config.optimizer_including.split(":")) + if config.optimizer_requiring: + ret = ret.requiring(*config.optimizer_requiring.split(":")) + # Cache the mode for next time + _CACHED_RUNTIME_MODES[upper_string] = ret return ret diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index a81fd63905..6000311df7 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -3,11 +3,10 @@ import os import platform import re -import shutil -import socket import sys import textwrap from pathlib import Path +from shutil import which import numpy as np @@ -349,7 +348,7 @@ def add_compile_configvars(): # Try to find the full compiler path from the name if param != "": - newp = shutil.which(param) + newp = which(param) if newp is not None: param = newp del newp @@ -388,7 +387,8 @@ def add_compile_configvars(): config.add( "linker", "Default linker used if the pytensor flags mode is Mode", - EnumStr("cvm", linker_options), + # Not mutable because the default mode is cached after the first use. + EnumStr("cvm", linker_options, mutable=False), in_c_key=False, ) @@ -411,6 +411,7 @@ def add_compile_configvars(): EnumStr( "o4", ["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"], + mutable=False, # Not mutable because the default mode is cached after the first use. ), in_c_key=False, ) @@ -1190,7 +1191,7 @@ def _get_home_dir() -> Path: "pytensor_version": pytensor.__version__, "numpy_version": np.__version__, "gxx_version": "xxx", - "hostname": socket.gethostname(), + "hostname": platform.node(), } diff --git a/pytensor/configparser.py b/pytensor/configparser.py index 8c6da4a144..4f71e85240 100644 --- a/pytensor/configparser.py +++ b/pytensor/configparser.py @@ -1,6 +1,5 @@ import logging import os -import shlex import sys import warnings from collections.abc import Callable, Sequence @@ -14,6 +13,7 @@ from functools import wraps from io import StringIO from pathlib import Path +from shlex import shlex from pytensor.utils import hash_from_code @@ -541,7 +541,7 @@ def parse_config_string( Parses a config string (comma-separated key=value components) into a dict. """ config_dict = {} - my_splitter = shlex.shlex(config_string, posix=True) + my_splitter = shlex(config_string, posix=True) my_splitter.whitespace = "," my_splitter.whitespace_split = True for kv_pair in my_splitter: diff --git a/pytensor/d3viz/formatting.py b/pytensor/d3viz/formatting.py index 80936a513d..df39335c19 100644 --- a/pytensor/d3viz/formatting.py +++ b/pytensor/d3viz/formatting.py @@ -12,13 +12,7 @@ from pytensor.compile import Function, builders from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs from pytensor.graph.fg import FunctionGraph -from pytensor.printing import pydot_imported, pydot_imported_msg - - -try: - from pytensor.printing import pd -except ImportError: - pass +from pytensor.printing import _try_pydot_import class PyDotFormatter: @@ -41,8 +35,7 @@ class PyDotFormatter: def __init__(self, compact=True): """Construct PyDotFormatter object.""" - if not pydot_imported: - raise ImportError("Failed to import pydot. " + pydot_imported_msg) + _try_pydot_import() self.compact = compact self.node_colors = { @@ -115,6 +108,8 @@ def __call__(self, fct, graph=None): pydot.Dot Pydot graph of `fct` """ + pd = _try_pydot_import() + if graph is None: graph = pd.Dot() @@ -244,14 +239,14 @@ def format_map(m): ext_inputs = [self.__node_id(x) for x in node.inputs] int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs] assert len(ext_inputs) == len(int_inputs) - h = format_map(zip(ext_inputs, int_inputs)) + h = format_map(zip(ext_inputs, int_inputs, strict=True)) pd_node.get_attributes()["subg_map_inputs"] = h # Outputs mapping ext_outputs = [self.__node_id(x) for x in node.outputs] int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs] assert len(ext_outputs) == len(int_outputs) - h = format_map(zip(int_outputs, ext_outputs)) + h = format_map(zip(int_outputs, ext_outputs, strict=True)) pd_node.get_attributes()["subg_map_outputs"] = h return graph @@ -356,6 +351,8 @@ def type_to_str(t): def dict_to_pdnode(d): """Create pydot node from dict.""" + pd = _try_pydot_import() + e = dict() for k, v in d.items(): if v is not None: diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 5946a20dd4..13ca943383 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -210,7 +210,7 @@ def Rop( # Check that each element of wrt corresponds to an element # of eval_points with the same dimensionality. - for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points)): + for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)): try: if wrt_elem.type.ndim != eval_point.type.ndim: raise ValueError( @@ -259,7 +259,7 @@ def _traverse(node): seen_nodes[inp.owner][inp.owner.outputs.index(inp)] ) same_type_eval_points = [] - for x, y in zip(inputs, local_eval_points): + for x, y in zip(inputs, local_eval_points, strict=True): if y is not None: if not isinstance(x, Variable): x = pytensor.tensor.as_tensor_variable(x) @@ -396,7 +396,7 @@ def Lop( _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] assert len(_f) == len(grads) - known = dict(zip(_f, grads)) + known = dict(zip(_f, grads, strict=True)) ret = grad( cost=None, @@ -800,7 +800,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False): for i in range(len(grads)): grads[i] += cost_grads[i] - pgrads = dict(zip(params, grads)) + pgrads = dict(zip(params, grads, strict=True)) # separate wrt from end grads: wrt_grads = [pgrads[k] for k in wrt] end_grads = [pgrads[k] for k in end] @@ -1066,7 +1066,7 @@ def access_term_cache(node): any( input_to_output and output_to_cost for input_to_output, output_to_cost in zip( - input_to_outputs, outputs_connected + input_to_outputs, outputs_connected, strict=True ) ) ) @@ -1091,7 +1091,7 @@ def access_term_cache(node): not any( in_to_out and out_to_cost and not out_nan for in_to_out, out_to_cost, out_nan in zip( - in_to_outs, outputs_connected, ograd_is_nan + in_to_outs, outputs_connected, ograd_is_nan, strict=True ) ) ) @@ -1151,7 +1151,7 @@ def try_to_copy_if_needed(var): # DO NOT force integer variables to have integer dtype. # This is a violation of the op contract. new_output_grads = [] - for o, og in zip(node.outputs, output_grads): + for o, og in zip(node.outputs, output_grads, strict=True): o_dt = getattr(o.type, "dtype", None) og_dt = getattr(og.type, "dtype", None) if ( @@ -1165,7 +1165,7 @@ def try_to_copy_if_needed(var): # Make sure that, if new_output_grads[i] has a floating point # dtype, it is the same dtype as outputs[i] - for o, ng in zip(node.outputs, new_output_grads): + for o, ng in zip(node.outputs, new_output_grads, strict=True): o_dt = getattr(o.type, "dtype", None) ng_dt = getattr(ng.type, "dtype", None) if ( @@ -1187,7 +1187,9 @@ def try_to_copy_if_needed(var): # by the user, not computed by Op.grad, and some gradients are # only computed and returned, but never passed as another # node's output grads. - for idx, packed in enumerate(zip(node.outputs, new_output_grads)): + for idx, packed in enumerate( + zip(node.outputs, new_output_grads, strict=True) + ): orig_output, new_output_grad = packed if not hasattr(orig_output, "shape"): continue @@ -1253,7 +1255,7 @@ def try_to_copy_if_needed(var): not in [ in_to_out and out_to_cost and not out_int for in_to_out, out_to_cost, out_int in zip( - in_to_outs, outputs_connected, output_is_int + in_to_outs, outputs_connected, output_is_int, strict=True ) ] ) @@ -1327,14 +1329,14 @@ def try_to_copy_if_needed(var): f" {i}. Since this input is only connected " "to integer-valued outputs, it should " "evaluate to zeros, but it evaluates to" - f"{pytensor.get_underlying_scalar_constant(term)}." + f"{pytensor.get_underlying_scalar_constant_value(term)}." ) raise ValueError(msg) # Check that op.connection_pattern matches the connectivity # logic driving the op.grad method for i, (ipt, ig, connected) in enumerate( - zip(inputs, input_grads, inputs_connected) + zip(inputs, input_grads, inputs_connected, strict=True) ): actually_connected = not isinstance(ig.type, DisconnectedType) @@ -1621,7 +1623,7 @@ def abs_rel_errors(self, g_pt): if len(g_pt) != len(self.gf): raise ValueError("argument has wrong number of elements", len(g_pt)) errs = [] - for i, (a, b) in enumerate(zip(g_pt, self.gf)): + for i, (a, b) in enumerate(zip(g_pt, self.gf, strict=True)): if a.shape != b.shape: raise ValueError( f"argument element {i} has wrong shapes {a.shape}, {b.shape}" @@ -1770,14 +1772,9 @@ def verify_grad( if rel_tol is None: rel_tol = max(_type_tol[str(p.dtype)] for p in pt) + # Initialize RNG if not provided if rng is None: - raise TypeError( - "rng should be a valid instance of " - "numpy.random.RandomState. You may " - "want to use tests.unittest" - "_tools.verify_grad instead of " - "pytensor.gradient.verify_grad." - ) + rng = np.random.default_rng() # We allow input downcast in `function`, because `numeric_grad` works in # the most precise dtype used among the inputs, so we may need to cast @@ -2160,6 +2157,9 @@ def _is_zero(x): 'maybe' means that x is an expression that is complicated enough that we can't tell that it simplifies to 0. """ + from pytensor.tensor import get_underlying_scalar_constant_value + from pytensor.tensor.exceptions import NotScalarConstantError + if not hasattr(x, "type"): return np.all(x == 0.0) if isinstance(x.type, NullType): @@ -2169,9 +2169,9 @@ def _is_zero(x): no_constant_value = True try: - constant_value = pytensor.get_underlying_scalar_constant(x) + constant_value = get_underlying_scalar_constant_value(x) no_constant_value = False - except pytensor.tensor.exceptions.NotScalarConstantError: + except NotScalarConstantError: pass if no_constant_value: diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 6b4ca7570d..512f0ef3ab 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -272,7 +272,7 @@ def clone_with_new_inputs( # as the output type depends on the input values and not just their types output_type_depends_on_input_value = self.op._output_type_depends_on_input_value - for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): + for i, (curr, new) in enumerate(zip(self.inputs, new_inputs, strict=True)): # Check if the input type changed or if the Op has output types that depend on input values if (curr.type != new.type) or output_type_depends_on_input_value: # In strict mode, the cloned graph is assumed to be mathematically equivalent to the original one. @@ -616,16 +616,20 @@ def eval( """ from pytensor.compile.function import function + ignore_unused_input = kwargs.get("on_unused_input", None) in ("ignore", "warn") + def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]: new_input_to_values = {} for key, value in inputs_to_values.items(): if isinstance(key, str): matching_vars = get_var_by_name([self], key) if not matching_vars: - raise Exception(f"{key} not found in graph") + if not ignore_unused_input: + raise ValueError(f"{key} not found in graph") elif len(matching_vars) > 1: - raise Exception(f"Found multiple variables with name {key}") - new_input_to_values[matching_vars[0]] = value + raise ValueError(f"Found multiple variables with name {key}") + else: + new_input_to_values[matching_vars[0]] = value else: new_input_to_values[key] = value return new_input_to_values @@ -1308,7 +1312,7 @@ def clone_node_and_cache( if new_node.op is not node.op: clone_d.setdefault(node.op, new_node.op) - for old_o, new_o in zip(node.outputs, new_node.outputs): + for old_o, new_o in zip(node.outputs, new_node.outputs, strict=True): clone_d.setdefault(old_o, new_o) return new_node @@ -1898,7 +1902,7 @@ def equal_computations( if in_ys is None: in_ys = [] - for x, y in zip(xs, ys): + for x, y in zip(xs, ys, strict=True): if not isinstance(x, Variable) and not isinstance(y, Variable): return np.array_equal(x, y) if not isinstance(x, Variable): @@ -1921,13 +1925,13 @@ def equal_computations( if len(in_xs) != len(in_ys): return False - for _x, _y in zip(in_xs, in_ys): + for _x, _y in zip(in_xs, in_ys, strict=True): if not (_y.type.in_same_class(_x.type)): return False - common = set(zip(in_xs, in_ys)) + common = set(zip(in_xs, in_ys, strict=True)) different: set[tuple[Variable, Variable]] = set() - for dx, dy in zip(xs, ys): + for dx, dy in zip(xs, ys, strict=True): assert isinstance(dx, Variable) # We checked above that both dx and dy have an owner or not if dx.owner is None: @@ -1963,7 +1967,7 @@ def compare_nodes(nd_x, nd_y, common, different): return False else: all_in_common = True - for dx, dy in zip(nd_x.outputs, nd_y.outputs): + for dx, dy in zip(nd_x.outputs, nd_y.outputs, strict=True): if (dx, dy) in different: return False if (dx, dy) not in common: @@ -1973,7 +1977,7 @@ def compare_nodes(nd_x, nd_y, common, different): return True # Compare the individual inputs for equality - for dx, dy in zip(nd_x.inputs, nd_y.inputs): + for dx, dy in zip(nd_x.inputs, nd_y.inputs, strict=True): if (dx, dy) not in common: # Equality between the variables is unknown, compare # their respective owners, if they have some @@ -2008,7 +2012,7 @@ def compare_nodes(nd_x, nd_y, common, different): # If the code reaches this statement then the inputs are pair-wise # equivalent so the outputs of the current nodes are also # pair-wise equivalents - for dx, dy in zip(nd_x.outputs, nd_y.outputs): + for dx, dy in zip(nd_x.outputs, nd_y.outputs, strict=True): common.add((dx, dy)) return True diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 519abe49d8..690bb44df5 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -231,14 +231,14 @@ def make_node(self, *inputs: Variable) -> Apply: ) if not all( expected_type.is_super(var.type) - for var, expected_type in zip(inputs, self.itypes) + for var, expected_type in zip(inputs, self.itypes, strict=True) ): raise TypeError( f"Invalid input types for Op {self}:\n" + "\n".join( f"Input {i}/{len(inputs)}: Expected {inp}, got {out}" for i, (inp, out) in enumerate( - zip(self.itypes, (inp.type for inp in inputs)), + zip(self.itypes, (inp.type for inp in inputs), strict=True), start=1, ) if inp != out diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index 9b12192452..5092d55e6b 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -78,7 +78,7 @@ def clone_replace( items = list(_format_replace(replace).items()) tmp_replace = [(x, x.type()) for x, y in items] - new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)] + new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items, strict=True)] _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds) # TODO Explain why we call it twice ?! @@ -295,11 +295,11 @@ def vectorize_graph( inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys()) new_inputs = [replace.get(inp, inp) for inp in inputs] - vect_vars = dict(zip(inputs, new_inputs)) + vect_vars = dict(zip(inputs, new_inputs, strict=True)) for node in io_toposort(inputs, seq_outputs): vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] vect_node = vectorize_node(node, *vect_inputs) - for output, vect_output in zip(node.outputs, vect_node.outputs): + for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True): if output in vect_vars: # This can happen when some outputs of a multi-output node are given a replacement, # while some of the remaining outputs are still needed in the graph. diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 2bc0508f7d..344d6a1940 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -5,7 +5,6 @@ import functools import inspect import logging -import pdb import sys import time import traceback @@ -237,6 +236,8 @@ def warn(cls, exc, self, rewriter): if config.on_opt_error == "raise": raise exc elif config.on_opt_error == "pdb": + import pdb + pdb.post_mortem(sys.exc_info()[2]) def __init__(self, *rewrites, failure_callback=None): @@ -399,14 +400,14 @@ def print_profile(cls, stream, prof, level=0): file=stream, ) ll = [] - for rewrite, nb_n in zip(rewrites, nb_nodes): + for rewrite, nb_n in zip(rewrites, nb_nodes, strict=True): if hasattr(rewrite, "__name__"): name = rewrite.__name__ else: name = rewrite.name idx = rewrites.index(rewrite) ll.append((name, rewrite.__class__.__name__, idx, *nb_n)) - lll = sorted(zip(prof, ll), key=lambda a: a[0]) + lll = sorted(zip(prof, ll, strict=True), key=lambda a: a[0]) for t, rewrite in lll[::-1]: i = rewrite[2] @@ -480,7 +481,8 @@ def merge_profile(prof1, prof2): new_rewrite = SequentialGraphRewriter(*new_l) new_nb_nodes = [ - (p1[0] + p2[0], p1[1] + p2[1]) for p1, p2 in zip(prof1[8], prof2[8]) + (p1[0] + p2[0], p1[1] + p2[1]) + for p1, p2 in zip(prof1[8], prof2[8], strict=True) ] new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :]) new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :]) @@ -635,7 +637,7 @@ def process_node(self, fgraph, node): inputs_match = all( node_in is cand_in - for node_in, cand_in in zip(node.inputs, candidate.inputs) + for node_in, cand_in in zip(node.inputs, candidate.inputs, strict=True) ) if inputs_match and node.op == candidate.op: @@ -649,6 +651,7 @@ def process_node(self, fgraph, node): node.outputs, candidate.outputs, ["merge"] * len(node.outputs), + strict=True, ) ) @@ -721,7 +724,9 @@ def apply(self, fgraph): inputs_match = all( node_in is cand_in for node_in, cand_in in zip( - var.owner.inputs, candidate_var.owner.inputs + var.owner.inputs, + candidate_var.owner.inputs, + strict=True, ) ) @@ -1434,7 +1439,7 @@ def transform(self, fgraph, node): repl = self.op2.make_node(*node.inputs) if self.transfer_tags: repl.tag = copy.copy(node.tag) - for output, new_output in zip(node.outputs, repl.outputs): + for output, new_output in zip(node.outputs, repl.outputs, strict=True): new_output.tag = copy.copy(output.tag) return repl.outputs @@ -1614,7 +1619,7 @@ def transform(self, fgraph, node, get_nodes=True): for real_node in self.get_nodes(fgraph, node): ret = self.transform(fgraph, real_node, get_nodes=False) if ret is not False and ret is not None: - return dict(zip(real_node.outputs, ret)) + return dict(zip(real_node.outputs, ret, strict=True)) if node.op != self.op: return False @@ -1646,7 +1651,7 @@ def transform(self, fgraph, node, get_nodes=True): len(node.outputs) == len(ret.owner.outputs) and all( o.type.is_super(new_o.type) - for o, new_o in zip(node.outputs, ret.owner.outputs) + for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True) ) ): return False @@ -1748,6 +1753,8 @@ def warn(cls, exc, nav, repl_pairs, node_rewriter, node): _logger.error("TRACEBACK:") _logger.error(traceback.format_exc()) if config.on_opt_error == "pdb": + import pdb + pdb.post_mortem(sys.exc_info()[2]) elif isinstance(exc, AssertionError) or config.on_opt_error == "raise": # We always crash on AssertionError because something may be @@ -1935,7 +1942,7 @@ def process_node( ) # None in the replacement mean that this variable isn't used # and we want to remove it - for r, rnew in zip(old_vars, replacements): + for r, rnew in zip(old_vars, replacements, strict=True): if rnew is None and len(fgraph.clients[r]) > 0: raise ValueError( f"Node rewriter {node_rewriter} tried to remove a variable" @@ -1945,7 +1952,7 @@ def process_node( # the replacement repl_pairs = [ (r, rnew) - for r, rnew in zip(old_vars, replacements) + for r, rnew in zip(old_vars, replacements, strict=True) if rnew is not r and rnew is not None ] @@ -2628,17 +2635,23 @@ def print_profile(cls, stream, prof, level=0): print(blanc, "Global, final, and clean up rewriters", file=stream) for i in range(len(loop_timing)): print(blanc, f"Iter {int(i)}", file=stream) - for o, prof in zip(rewrite.global_rewriters, global_sub_profs[i]): + for o, prof in zip( + rewrite.global_rewriters, global_sub_profs[i], strict=True + ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: print(blanc, "merge not implemented for ", o) - for o, prof in zip(rewrite.final_rewriters, final_sub_profs[i]): + for o, prof in zip( + rewrite.final_rewriters, final_sub_profs[i], strict=True + ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: print(blanc, "merge not implemented for ", o) - for o, prof in zip(rewrite.cleanup_rewriters, cleanup_sub_profs[i]): + for o, prof in zip( + rewrite.cleanup_rewriters, cleanup_sub_profs[i], strict=True + ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: @@ -2856,7 +2869,7 @@ def local_recursive_function( outs, rewritten_vars = local_recursive_function( rewrite_list, inp, rewritten_vars, depth + 1 ) - for k, v in zip(inp.owner.outputs, outs): + for k, v in zip(inp.owner.outputs, outs, strict=True): rewritten_vars[k] = v nw_in = outs[inp.owner.outputs.index(inp)] @@ -2874,7 +2887,7 @@ def local_recursive_function( if ret is not False and ret is not None: assert isinstance(ret, Sequence) assert len(ret) == len(node.outputs), rewrite - for k, v in zip(node.outputs, ret): + for k, v in zip(node.outputs, ret, strict=True): rewritten_vars[k] = v results = ret if ret[0].owner: diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py index d797504ae6..9c2eef5049 100644 --- a/pytensor/graph/utils.py +++ b/pytensor/graph/utils.py @@ -107,8 +107,6 @@ def add_tag_trace(thing: T, user_line: int | None = None) -> T: "pytensor\\graph\\", "pytensor/scalar/basic.py", "pytensor\\scalar\\basic.py", - "pytensor/sandbox/", - "pytensor\\sandbox\\", "pytensor/scan/", "pytensor\\scan\\", "pytensor/sparse/", diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index b7c2c52ee4..c458e5b296 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -170,7 +170,9 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any): output_vars = [] new_inputs_true_branch = [] new_inputs_false_branch = [] - for input_t, input_f in zip(inputs_true_branch, inputs_false_branch): + for input_t, input_f in zip( + inputs_true_branch, inputs_false_branch, strict=True + ): if not isinstance(input_t, Variable): input_t = as_symbolic(input_t) if not isinstance(input_f, Variable): @@ -207,7 +209,9 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any): # allowed to have distinct shapes from either branch new_shape = tuple( s_t if s_t == s_f else None - for s_t, s_f in zip(input_t.type.shape, input_f.type.shape) + for s_t, s_f in zip( + input_t.type.shape, input_f.type.shape, strict=True + ) ) # TODO FIXME: The presence of this keyword is a strong # assumption. Find something that's guaranteed by the/a @@ -301,7 +305,8 @@ def thunk(): if len(ls) > 0: return ls else: - for out, t in zip(outputs, input_true_branch): + # strict=False because we are in a hot loop + for out, t in zip(outputs, input_true_branch, strict=False): compute_map[out][0] = 1 val = storage_map[t][0] if self.as_view: @@ -321,7 +326,8 @@ def thunk(): if len(ls) > 0: return ls else: - for out, f in zip(outputs, inputs_false_branch): + # strict=False because we are in a hot loop + for out, f in zip(outputs, inputs_false_branch, strict=False): compute_map[out][0] = 1 # can't view both outputs unless destroyhandler # improves @@ -637,7 +643,7 @@ def apply(self, fgraph): old_outs += [proposal.outputs] else: old_outs += proposal.outputs - pairs = list(zip(old_outs, new_outs)) + pairs = list(zip(old_outs, new_outs, strict=True)) fgraph.replace_all_validate(pairs, reason="cond_merge") @@ -736,7 +742,7 @@ def cond_merge_random_op(fgraph, main_node): old_outs += [proposal.outputs] else: old_outs += proposal.outputs - pairs = list(zip(old_outs, new_outs)) + pairs = list(zip(old_outs, new_outs, strict=True)) main_outs = clone_replace(main_node.outputs, replace=pairs) return main_outs diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index 30154a98ce..9cf34983f2 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -385,11 +385,11 @@ def make_all( f, [ Container(input, storage) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks, order, @@ -509,7 +509,9 @@ def make_thunk(self, **kwargs): kwargs.pop("input_storage", None) make_all += [x.make_all(**kwargs) for x in self.linkers[1:]] - fns, input_lists, output_lists, thunk_lists, order_lists = zip(*make_all) + fns, input_lists, output_lists, thunk_lists, order_lists = zip( + *make_all, strict=True + ) order_list0 = order_lists[0] for order_list in order_lists[1:]: @@ -521,12 +523,12 @@ def make_thunk(self, **kwargs): inputs0 = input_lists[0] outputs0 = output_lists[0] - thunk_groups = list(zip(*thunk_lists)) - order = [x[0] for x in zip(*order_lists)] + thunk_groups = list(zip(*thunk_lists, strict=True)) + order = [x[0] for x in zip(*order_lists, strict=True)] to_reset = [ thunk.outputs[j] - for thunks, node in zip(thunk_groups, order) + for thunks, node in zip(thunk_groups, order, strict=True) for j, output in enumerate(node.outputs) if output in no_recycling for thunk in thunks @@ -537,12 +539,14 @@ def make_thunk(self, **kwargs): def f(): for inputs in input_lists[1:]: - for input1, input2 in zip(inputs0, inputs): + # strict=False because we are in a hot loop + for input1, input2 in zip(inputs0, inputs, strict=False): input2.storage[0] = copy(input1.storage[0]) for x in to_reset: x[0] = None pre(self, [input.data for input in input_lists[0]], order, thunk_groups) - for i, (thunks, node) in enumerate(zip(thunk_groups, order)): + # strict=False because we are in a hot loop + for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)): try: wrapper(self.fgraph, i, node, *thunks) except Exception: @@ -649,38 +653,36 @@ def create_jitable_thunk( ) thunk_inputs = self.create_thunk_inputs(storage_map) - - thunks = [] - thunk_outputs = [storage_map[n] for n in self.fgraph.outputs] - fgraph_jit = self.jit_compile(converted_fgraph) def thunk( - fgraph=self.fgraph, fgraph_jit=fgraph_jit, thunk_inputs=thunk_inputs, thunk_outputs=thunk_outputs, ): - outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs]) + try: + outputs = fgraph_jit(*(x[0] for x in thunk_inputs)) + except Exception: + # TODO: Should we add a fake node that combines all outputs, + # since the error may come from any of them? + raise_with_op(self.fgraph, output_nodes[0], thunk) - for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): - compute_map[o_var][0] = True - o_storage[0] = self.output_filter(o_var, o_val) - return outputs + # strict=False because we are in a hot loop + for o_storage, o_val in zip(thunk_outputs, outputs, strict=False): + o_storage[0] = o_val thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunk.lazy = False - thunks.append(thunk) + thunks = [thunk] return thunks, output_nodes, fgraph_jit def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph nodes = self.schedule(fgraph) - no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, nodes, input_storage, output_storage, storage_map @@ -694,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): compute_map, nodes, input_storage, output_storage, storage_map ) - computed, last_user = gc_helper(nodes) - - if self.allow_gc: - post_thunk_old_storage = [ - [ - storage_map[input] - for input in node.inputs - if (input in computed) - and (input not in fgraph.outputs) - and (node == last_user[input]) - ] - for node in nodes - ] - else: - post_thunk_old_storage = None - - if no_recycling is True: - no_recycling = list(storage_map.values()) - no_recycling = difference(no_recycling, input_storage) - else: - no_recycling = [ - storage_map[r] for r in no_recycling if r not in fgraph.inputs - ] - - fn = streamline( - fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling - ) - + [fn] = thunks fn.jit_fn = jit_fn fn.allow_gc = self.allow_gc fn.storage_map = storage_map @@ -730,11 +705,11 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): fn, [ Container(input, storage) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks, nodes, diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index 417580e09c..fa540bd9e6 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -1112,11 +1112,15 @@ def __compile__( module, [ Container(input, storage) - for input, storage in zip(self.fgraph.inputs, input_storage) + for input, storage in zip( + self.fgraph.inputs, input_storage, strict=True + ) ], [ Container(output, storage, readonly=True) - for output, storage in zip(self.fgraph.outputs, output_storage) + for output, storage in zip( + self.fgraph.outputs, output_storage, strict=True + ) ], error_storage, ) @@ -1363,8 +1367,8 @@ def cmodule_key_( # We must always add the numpy ABI version here as # DynamicModule always add the include - if np.lib.NumpyVersion(np.__version__) < "1.16.0a": - ndarray_c_version = np.core.multiarray._get_ndarray_c_version() + if np.lib.NumpyVersion(np.__version__) > "1.27.0": + ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() else: ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}") @@ -1887,11 +1891,11 @@ def make_all( f, [ Container(input, storage) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks, order, @@ -1989,22 +1993,27 @@ def make_thunk(self, **kwargs): ) def f(): - for input1, input2 in zip(i1, i2): + # strict=False because we are in a hot loop + for input1, input2 in zip(i1, i2, strict=False): # Set the inputs to be the same in both branches. # The copy is necessary in order for inplace ops not to # interfere. input2.storage[0] = copy(input1.storage[0]) - for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2): - for output, storage in zip(node1.outputs, thunk1.outputs): + for thunk1, thunk2, node1, node2 in zip( + thunks1, thunks2, order1, order2, strict=False + ): + for output, storage in zip(node1.outputs, thunk1.outputs, strict=False): if output in no_recycling: storage[0] = None - for output, storage in zip(node2.outputs, thunk2.outputs): + for output, storage in zip(node2.outputs, thunk2.outputs, strict=False): if output in no_recycling: storage[0] = None try: thunk1() thunk2() - for output1, output2 in zip(thunk1.outputs, thunk2.outputs): + for output1, output2 in zip( + thunk1.outputs, thunk2.outputs, strict=False + ): self.checker(output1, output2) except Exception: raise_with_op(fgraph, node1) diff --git a/pytensor/link/c/cmodule.py b/pytensor/link/c/cmodule.py index 62f5adea01..c992d0506e 100644 --- a/pytensor/link/c/cmodule.py +++ b/pytensor/link/c/cmodule.py @@ -26,19 +26,12 @@ from typing import TYPE_CHECKING, Protocol, cast import numpy as np -from setuptools._distutils.sysconfig import ( - get_config_h_filename, - get_config_var, - get_python_inc, - get_python_lib, -) # we will abuse the lockfile mechanism when reading and writing the registry from pytensor.compile.compilelock import lock_ctx from pytensor.configdefaults import config, gcc_version_str from pytensor.configparser import BoolParam, StrParam from pytensor.graph.op import Op -from pytensor.link.c.exceptions import CompileError, MissingGXX from pytensor.utils import ( LOCAL_BITWIDTH, flatten, @@ -266,6 +259,8 @@ def list_code(self, ofile=sys.stdout): def _get_ext_suffix(): """Get the suffix for compiled extensions""" + from setuptools._distutils.sysconfig import get_config_var + dist_suffix = get_config_var("EXT_SUFFIX") if dist_suffix is None: dist_suffix = get_config_var("SO") @@ -1697,6 +1692,8 @@ def get_gcc_shared_library_arg(): def std_include_dirs(): + from setuptools._distutils.sysconfig import get_python_inc + numpy_inc_dirs = [np.get_include()] py_inc = get_python_inc() py_plat_spec_inc = get_python_inc(plat_specific=True) @@ -1709,6 +1706,12 @@ def std_include_dirs(): @is_StdLibDirsAndLibsType def std_lib_dirs_and_libs() -> tuple[list[str], ...] | None: + from setuptools._distutils.sysconfig import ( + get_config_var, + get_python_inc, + get_python_lib, + ) + # We cache the results as on Windows, this trigger file access and # this method is called many times. if std_lib_dirs_and_libs.data is not None: @@ -2379,23 +2382,14 @@ def join_options(init_part): if sys.platform == "darwin": # Use the already-loaded python symbols. cxxflags.extend(["-undefined", "dynamic_lookup"]) - - if sys.platform == "win32": - # Workaround for https://github.com/Theano/Theano/issues/4926. - # https://github.com/python/cpython/pull/11283/ removed the "hypot" - # redefinition for recent CPython versions (>=2.7.16 and >=3.7.3). - # The following nullifies that redefinition, if it is found. - python_version = sys.version_info[:3] - if (3,) <= python_version < (3, 7, 3): - config_h_filename = get_config_h_filename() - try: - with open(config_h_filename) as config_h: - if any( - line.startswith("#define hypot _hypot") for line in config_h - ): - cxxflags.append("-D_hypot=hypot") - except OSError: - pass + # XCode15 introduced ld_prime linker. At the time of writing, this linker + # leads to multiple issues, so we supply a flag to use the older dynamic + # linker: ld64 + if int(platform.mac_ver()[0].split(".")[0]) >= 15: + # This might be incorrect. We know that ld_prime was introduced in + # XCode15, but we don't know if the platform version is aligned with + # xcode's version. + cxxflags.append("-ld64") return cxxflags @@ -2451,7 +2445,7 @@ def patch_ldflags(flag_list: list[str]) -> list[str]: if not libs: return flag_list libs = GCC_compiler.linking_patch(lib_dirs, libs) - for flag_idx, lib in zip(flag_idxs, libs): + for flag_idx, lib in zip(flag_idxs, libs, strict=True): flag_list[flag_idx] = lib return flag_list @@ -2547,8 +2541,9 @@ def compile_str( """ # TODO: Do not do the dlimport in this function - if not config.cxx: + from pytensor.link.c.exceptions import MissingGXX + raise MissingGXX("g++ not available! We can't compile c code.") if include_dirs is None: @@ -2578,6 +2573,8 @@ def compile_str( cppfile.write("\n") if platform.python_implementation() == "PyPy": + from setuptools._distutils.sysconfig import get_config_var + suffix = "." + get_lib_extension() dist_suffix = get_config_var("SO") @@ -2634,6 +2631,8 @@ def print_command_line_error(): status = p_out[2] if status: + from pytensor.link.c.exceptions import CompileError + tf = tempfile.NamedTemporaryFile( mode="w", prefix="pytensor_compilation_error_", delete=False ) diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 61c90d2b10..74905d686f 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -59,7 +59,7 @@ def make_c_thunk( e = FunctionGraph(node.inputs, node.outputs) e_no_recycling = [ new_o - for (new_o, old_o) in zip(e.outputs, node.outputs) + for (new_o, old_o) in zip(e.outputs, node.outputs, strict=True) if old_o in no_recycling ] cl = pytensor.link.c.basic.CLinker().accept(e, no_recycling=e_no_recycling) @@ -352,7 +352,7 @@ def load_c_code(self, func_files: Iterable[Path]) -> None: "be used at the same time." ) - for func_file, code in zip(func_files, self.func_codes): + for func_file, code in zip(func_files, self.func_codes, strict=True): if self.backward_re.search(code): # This is backward compat code that will go away in a while diff --git a/pytensor/link/c/params_type.py b/pytensor/link/c/params_type.py index e81efc8647..457983ce03 100644 --- a/pytensor/link/c/params_type.py +++ b/pytensor/link/c/params_type.py @@ -725,7 +725,7 @@ def c_support_code(self, **kwargs): c_init_list = [] c_cleanup_list = [] c_extract_list = [] - for attribute_name, type_instance in zip(self.fields, self.types): + for attribute_name, type_instance in zip(self.fields, self.types, strict=True): try: # c_support_code() may return a code string or a list of code strings. support_code = type_instance.c_support_code() diff --git a/pytensor/link/jax/dispatch/extra_ops.py b/pytensor/link/jax/dispatch/extra_ops.py index a9e36667ef..87e55f1007 100644 --- a/pytensor/link/jax/dispatch/extra_ops.py +++ b/pytensor/link/jax/dispatch/extra_ops.py @@ -10,6 +10,7 @@ FillDiagonalOffset, RavelMultiIndex, Repeat, + SearchsortedOp, Unique, UnravelIndex, ) @@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs): # return filldiagonaloffset raise NotImplementedError("flatiter not implemented in JAX") + + +@jax_funcify.register(SearchsortedOp) +def jax_funcify_SearchsortedOp(op, **kwargs): + side = op.side + + def searchsorted(a, v, side=side, sorter=None): + return jnp.searchsorted(a=a, v=v, side=side, sorter=sorter) + + return searchsorted diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 9a89bf1406..3767946455 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -56,7 +56,7 @@ def assert_size_argument_jax_compatible(node): @jax_typify.register(Generator) def jax_typify_Generator(rng, **kwargs): - state = rng.__getstate__() + state = rng.bit_generator.state state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] # XXX: Is this a reasonable approach? diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 71ea40de0f..d3e5ac11f7 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -31,6 +31,7 @@ GammaIncInv, Iv, Ive, + Kve, Log1mexp, Psi, TriGamma, @@ -288,9 +289,12 @@ def iv(v, x): @jax_funcify.register(Ive) def jax_funcify_Ive(op, **kwargs): - ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive") + return try_import_tfp_jax_op(op, jax_op_name="bessel_ive") + - return ive +@jax_funcify.register(Kve) +def jax_funcify_Kve(op, **kwargs): + return try_import_tfp_jax_op(op, jax_op_name="bessel_kve") @jax_funcify.register(Log1mexp) diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index b82fd67e3f..d98328f0cf 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -30,7 +30,9 @@ def scan(*outer_inputs): seqs = op.outer_seqs(outer_inputs) # JAX `xs` mit_sot_init = [] - for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)): + for tap, seq in zip( + op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True + ): init_slice = seq[: abs(min(tap))] mit_sot_init.append(init_slice) @@ -61,7 +63,9 @@ def jax_args_to_inner_func_args(carry, x): inner_seqs = x mit_sot_flatten = [] - for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices): + for array, index in zip( + inner_mit_sot, op.info.mit_sot_in_slices, strict=True + ): mit_sot_flatten.extend(array[jnp.array(index)]) inner_scan_inputs = [ @@ -98,8 +102,7 @@ def inner_func_outs_to_jax_outs( inner_mit_sot_new = [ jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0) for old_mit_sot, new_val in zip( - inner_mit_sot, - inner_mit_sot_outs, + inner_mit_sot, inner_mit_sot_outs, strict=True ) ] @@ -152,7 +155,9 @@ def get_partial_traces(traces): + op.outer_nitsot(outer_inputs) ) partial_traces = [] - for init_state, trace, buffer in zip(init_states, traces, buffers): + for init_state, trace, buffer in zip( + init_states, traces, buffers, strict=True + ): if init_state is not None: # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer trace = jnp.atleast_1d(trace) diff --git a/pytensor/link/jax/dispatch/shape.py b/pytensor/link/jax/dispatch/shape.py index 6d75b7ae6f..6d809252a7 100644 --- a/pytensor/link/jax/dispatch/shape.py +++ b/pytensor/link/jax/dispatch/shape.py @@ -96,7 +96,7 @@ def shape_i(x): def jax_funcify_SpecifyShape(op, node, **kwargs): def specifyshape(x, *shape): assert x.ndim == len(shape) - for actual, expected in zip(x.shape, shape): + for actual, expected in zip(x.shape, shape, strict=True): if expected is None: continue if actual != expected: diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index bf1a93ce5b..2956afad02 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -18,7 +18,7 @@ Split, TensorFromScalar, Tri, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i @@ -103,7 +103,7 @@ def join(axis, *tensors): def jax_funcify_Split(op: Split, node, **kwargs): _, axis, splits = node.inputs try: - constant_axis = get_underlying_scalar_constant_value(axis) + constant_axis = get_scalar_constant_value(axis) except NotScalarConstantError: constant_axis = None warnings.warn( @@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs): try: constant_splits = np.array( [ - get_underlying_scalar_constant_value(splits[i]) + get_scalar_constant_value(splits[i]) for i in range(get_vector_length(splits)) ] ) @@ -200,7 +200,8 @@ def jax_funcify_Tri(op, node, **kwargs): def tri(*args): # args is N, M, k args = [ - x if const_x is None else const_x for x, const_x in zip(args, const_args) + x if const_x is None else const_x + for x, const_x in zip(args, const_args, strict=True) ] return jnp.tri(*args, dtype=op.dtype) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 667806a80f..06370b4514 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -3,7 +3,6 @@ from numpy.random import Generator, RandomState from pytensor.compile.sharedvalue import SharedVariable, shared -from pytensor.graph.basic import Constant from pytensor.link.basic import JITLinker @@ -35,12 +34,14 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): ] fgraph.replace_all( - zip(shared_rng_inputs, new_shared_rng_inputs), + zip(shared_rng_inputs, new_shared_rng_inputs, strict=True), import_missing=True, reason="JAXLinker.fgraph_convert", ) - for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs): + for old_inp, new_inp in zip( + shared_rng_inputs, new_shared_rng_inputs, strict=True + ): new_inp_storage = [new_inp.get_value(borrow=True)] storage_map[new_inp] = new_inp_storage old_inp_storage = storage_map.pop(old_inp) @@ -70,12 +71,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): def jit_compile(self, fn): import jax - # I suppose we can consider `Constant`s to be "static" according to - # JAX. - static_argnums = [ - n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant) - ] - return jax.jit(fn, static_argnums=static_argnums) + return jax.jit(fn) def create_thunk_inputs(self, storage_map): from pytensor.link.jax.dispatch import jax_typify diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 6dd0e8211b..56a3e2c9b2 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -2,15 +2,16 @@ from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify # Load dispatch specializations -import pytensor.link.numba.dispatch.scalar -import pytensor.link.numba.dispatch.tensor_basic +import pytensor.link.numba.dispatch.blockwise +import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.extra_ops import pytensor.link.numba.dispatch.nlinalg import pytensor.link.numba.dispatch.random -import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.scan -import pytensor.link.numba.dispatch.sparse +import pytensor.link.numba.dispatch.scalar import pytensor.link.numba.dispatch.slinalg +import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.subtensor +import pytensor.link.numba.dispatch.tensor_basic # isort: on diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 2b934d049c..843a4dbf1f 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -49,10 +49,23 @@ def global_numba_func(func): return func -def numba_njit(*args, **kwargs): +def numba_njit(*args, fastmath=None, **kwargs): kwargs.setdefault("cache", config.numba__cache) kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True) + if fastmath is None: + if config.numba__fastmath: + # Opinionated default on fastmath flags + # https://llvm.org/docs/LangRef.html#fast-math-flags + fastmath = { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # no-signed zeros + } + else: + fastmath = False # Suppress cache warning for internal functions # We have to add an ansi escape code for optional bold text by numba @@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs): ) if len(args) > 0 and callable(args[0]): - return numba.njit(*args[1:], **kwargs)(args[0]) + return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) - return numba.njit(*args, **kwargs) + return numba.njit(*args, fastmath=fastmath, **kwargs) def numba_vectorize(*args, **kwargs): @@ -401,9 +414,10 @@ def py_perform_return(inputs): else: def py_perform_return(inputs): + # strict=False because we are in a hot loop return tuple( out_type.filter(out[0]) - for out_type, out in zip(output_types, py_perform(inputs)) + for out_type, out in zip(output_types, py_perform(inputs), strict=False) ) @numba_njit @@ -566,7 +580,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): func_conditions = [ f"assert x.shape[{i}] == {shape_input_names}" for i, (shape_input, shape_input_names) in enumerate( - zip(shape_inputs, shape_input_names) + zip(shape_inputs, shape_input_names, strict=True) ) if shape_input is not NoneConst ] diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py new file mode 100644 index 0000000000..b7481bd5a3 --- /dev/null +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -0,0 +1,91 @@ +import sys +from typing import cast + +from numba.core.extending import overload +from numba.np.unsafe.ndarray import to_fixed_tuple + +from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.link.numba.dispatch.vectorize_codegen import ( + _jit_options, + _vectorized, + encode_literals, + store_core_outputs, +) +from pytensor.link.utils import compile_function_src +from pytensor.tensor import TensorVariable, get_vector_length +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape + + +@numba_funcify.register +def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): + [blockwise_node] = op.fgraph.apply_nodes + blockwise_op: Blockwise = blockwise_node.op + core_op = blockwise_op.core_op + nin = len(blockwise_node.inputs) + nout = len(blockwise_node.outputs) + core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:]) + + core_node = blockwise_op._create_dummy_core_node( + cast(tuple[TensorVariable], blockwise_node.inputs) + ) + core_op_fn = numba_funcify( + core_op, + node=core_node, + parent_node=node, + **kwargs, + ) + core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) + + batch_ndim = blockwise_op.batch_ndim(node) + + # numba doesn't support nested literals right now... + input_bc_patterns = encode_literals( + tuple(inp.type.broadcastable[:batch_ndim] for inp in node.inputs[:nin]) + ) + output_bc_patterns = encode_literals( + tuple(out.type.broadcastable[:batch_ndim] for out in node.outputs) + ) + output_dtypes = encode_literals(tuple(out.type.dtype for out in node.outputs)) + inplace_pattern = encode_literals(()) + + # Numba does not allow a tuple generator in the Jitted function so we have to compile a helper to convert core_shapes into tuples + # Alternatively, add an Op that converts shape vectors into tuples, like we did for JAX + src = "def to_tuple(core_shapes): return (" + for i in range(nout): + src += f"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]})," + src += ")" + + to_tuple = numba_njit( + compile_function_src( + src, + "to_tuple", + global_env={"to_fixed_tuple": to_fixed_tuple}, + ), + # cache=True leads to a numba.cloudpickle dump failure in Python 3.10 + # May be fine in Python 3.11, but I didn't test. It was fine in 3.12 + cache=sys.version_info >= (3, 12), + ) + + def blockwise_wrapper(*inputs_and_core_shapes): + inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:] + tuple_core_shapes = to_tuple(core_shapes) + return _vectorized( + core_op_fn, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + (), # constant_inputs + inputs, + tuple_core_shapes, + None, # size + ) + + def blockwise(*inputs_and_core_shapes): + raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented") + + @overload(blockwise, jit_options=_jit_options) + def ov_blockwise(*inputs_and_core_shapes): + return blockwise_wrapper + + return blockwise diff --git a/pytensor/link/numba/dispatch/cython_support.py b/pytensor/link/numba/dispatch/cython_support.py index 36b3e80850..8dccf98836 100644 --- a/pytensor/link/numba/dispatch/cython_support.py +++ b/pytensor/link/numba/dispatch/cython_support.py @@ -45,7 +45,7 @@ def arg_numba_types(self) -> list[DTypeLike]: def can_cast_args(self, args: list[DTypeLike]) -> bool: ok = True count = 0 - for name, dtype in zip(self.arg_names, self.arg_dtypes): + for name, dtype in zip(self.arg_names, self.arg_dtypes, strict=True): if name == "__pyx_skip_dispatch": continue if len(args) <= count: @@ -164,7 +164,12 @@ def __wrapper_address__(self): return self._func_ptr def __call__(self, *args, **kwargs): - args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)] + # no strict argument because of the JIT + # TODO: check + args = [ + dtype(arg) + for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905 + ] if self.has_pyx_skip_dispatch(): output = self._pyfunc(*args[:-1], **kwargs) else: diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 35f23b4aa2..b8a982f2b7 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,21 +1,24 @@ -from collections.abc import Callable +import warnings from functools import singledispatch -from numbers import Number -from textwrap import indent -from typing import Any +from textwrap import dedent, indent import numba import numpy as np from numba.core.extending import overload -from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple -from pytensor import config -from pytensor.graph.basic import Apply + +try: + from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.multiarray import normalize_axis_index + from numpy.core.numeric import normalize_axis_tuple + from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( create_numba_signature, - create_tuple_creator, numba_funcify, numba_njit, use_optimized_cheap_pass, @@ -26,27 +29,25 @@ encode_literals, store_core_outputs, ) -from pytensor.link.utils import compile_function_src, get_name_for_object +from pytensor.link.utils import compile_function_src from pytensor.scalar.basic import ( AND, OR, XOR, Add, - Composite, IntDiv, - Mean, Mul, ScalarMaximum, ScalarMinimum, Sub, TrueDiv, + get_scalar_type, scalar_maximum, ) from pytensor.scalar.basic import add as add_as from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad -from pytensor.tensor.type import scalar @singledispatch @@ -77,11 +78,6 @@ def scalar_in_place_fn_Sub(op, idx, res, arr): return f"{res}[{idx}] -= {arr}" -@scalar_in_place_fn.register(Mean) -def scalar_in_place_fn_Mean(op, idx, res, arr): - return f"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)" - - @scalar_in_place_fn.register(Mul) def scalar_in_place_fn_Mul(op, idx, res, arr): return f"{res}[{idx}] *= {arr}" @@ -133,76 +129,32 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): """ -def create_vectorize_func( - scalar_op_fn: Callable, - node: Apply, - use_signature: bool = False, - identity: Any | None = None, - **kwargs, -) -> Callable: - r"""Create a vectorized Numba function from a `Apply`\s Python function.""" - - if len(node.outputs) > 1: - raise NotImplementedError( - "Multi-output Elemwise Ops are not supported by the Numba backend" - ) - - if use_signature: - signature = [create_numba_signature(node, force_scalar=True)] - else: - signature = [] - - target = ( - getattr(node.tag, "numba__vectorize_target", None) - or config.numba__vectorize_target - ) - - numba_vectorized_fn = numba_basic.numba_vectorize( - signature, identity=identity, target=target, fastmath=config.numba__fastmath - ) - - py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn) - - elemwise_fn = numba_vectorized_fn(scalar_op_fn) - elemwise_fn.py_scalar_func = py_scalar_func - - return elemwise_fn - - -def create_axis_reducer( - scalar_op: Op, - identity: np.ndarray | Number, - axis: int, - ndim: int, - dtype: numba.types.Type, +def create_multiaxis_reducer( + scalar_op, + identity, + axes, + ndim, + dtype, keepdims: bool = False, - return_scalar=False, -) -> numba.core.dispatcher.Dispatcher: - r"""Create Python function that performs a NumPy-like reduction on a given axis. +): + r"""Construct a function that reduces multiple axes. The functions generated by this function take the following form: .. code-block:: python - def careduce_axis(x): - res_shape = tuple( - shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1) - ) - res = np.full(res_shape, identity, dtype=dtype) - - x_axis_first = x.transpose(reaxis_first) - - for m in range(x.shape[axis]): - reduce_fn(res, x_axis_first[m], res) - - if keepdims: - return np.expand_dims(res, axis) - else: - return res + def careduce_add(x): + # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add" + x_shape = x.shape + res_shape = x_shape[2] + res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype) + for i0 in range(x_shape[0]): + for i1 in range(x_shape[1]): + for i2 in range(x_shape[2]): + res[i2] += x[i0, i1, i2] - This can be removed/replaced when - https://github.com/numba/numba/issues/4504 is implemented. + return res Parameters ========== @@ -210,25 +162,29 @@ def careduce_axis(x): The scalar :class:`Op` that performs the desired reduction. identity: The identity value for the reduction. - axis: - The axis to reduce. + axes: + The axes to reduce. ndim: - The number of dimensions of the result. + The number of dimensions of the input variable. dtype: The data type of the result. - keepdims: - Determines whether or not the reduced dimension is retained. - - + keepdims: boolean, default False + Whether to keep the reduced dimensions. Returns ======= A Python function that can be JITed. """ + # if len(axes) == 1: + # return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) - axis = normalize_axis_index(axis, ndim) + axes = normalize_axis_tuple(axes, ndim) + if keepdims and len(axes) > 1: + raise NotImplementedError( + "Cannot keep multiple dimensions when reducing multiple axes" + ) - reduce_elemwise_fn_name = "careduce_axis" + careduce_fn_name = f"careduce_{scalar_op}" identity = str(identity) if identity == "inf": @@ -241,162 +197,55 @@ def careduce_axis(x): "numba_basic": numba_basic, "out_dtype": dtype, } + complete_reduction = len(axes) == ndim + kept_axis = tuple(i for i in range(ndim) if i not in axes) + + res_indices = [] + arr_indices = [] + for i in range(ndim): + index_label = f"i{i}" + arr_indices.append(index_label) + if i not in axes: + res_indices.append(index_label) + res_indices = ", ".join(res_indices) if res_indices else () + arr_indices = ", ".join(arr_indices) if arr_indices else () + + inplace_update_stmt = scalar_in_place_fn( + scalar_op, res_indices, "res", f"x[{arr_indices}]" + ) - if ndim > 1: - res_shape_tuple_ctor = create_tuple_creator( - lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1 - ) - global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor - - res_indices = [] - arr_indices = [] - count = 0 - - for i in range(ndim): - if i == axis: - arr_indices.append("i") - else: - res_indices.append(f"idx_arr[{count}]") - arr_indices.append(f"idx_arr[{count}]") - count = count + 1 - - res_indices = ", ".join(res_indices) - arr_indices = ", ".join(arr_indices) - - inplace_update_statement = scalar_in_place_fn( - scalar_op, res_indices, "res", f"x[{arr_indices}]" - ) - inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3) - - return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res" - reduce_elemwise_def_src = f""" -def {reduce_elemwise_fn_name}(x): - - x_shape = np.shape(x) - res_shape = res_shape_tuple_ctor(x_shape) - res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype) - - axis_shape = x.shape[{axis}] - - for idx_arr in np.ndindex(res_shape): - for i in range(axis_shape): -{inplace_update_statement} - - return {return_expr} - """ + res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})" + if complete_reduction and ndim > 0: + # We accumulate on a scalar, not an array + res_creator = f"np.asarray({identity}).astype(out_dtype).item()" + inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res") + return_obj = "np.asarray(res)" else: - inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]") - inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2) - - return_expr = "res" if keepdims else "res.item()" - if not return_scalar: - return_expr = f"np.asarray({return_expr})" - reduce_elemwise_def_src = f""" -def {reduce_elemwise_fn_name}(x): - - res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype) - - axis_shape = x.shape[{axis}] - - for i in range(axis_shape): -{inplace_update_statement} - - return {return_expr} + res_creator = ( + f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)" + ) + return_obj = "res" + + if keepdims: + [axis] = axes + return_obj = f"np.expand_dims({return_obj}, {axis})" + + careduce_def_src = dedent( + f""" + def {careduce_fn_name}(x): + x_shape = x.shape + res_shape = {res_shape} + res = {res_creator} """ - - reduce_elemwise_fn_py = compile_function_src( - reduce_elemwise_def_src, reduce_elemwise_fn_name, {**globals(), **global_env} ) - - return reduce_elemwise_fn_py - - -def create_multiaxis_reducer( - scalar_op, - identity, - axes, - ndim, - dtype, - input_name="input", - return_scalar=False, -): - r"""Construct a function that reduces multiple axes. - - The functions generated by this function take the following form: - - .. code-block:: python - - def careduce_maximum(input): - axis_0_res = careduce_axes_fn_0(input) - axis_1_res = careduce_axes_fn_1(axis_0_res) - ... - axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res) - return axis_N_res - - The range 0-N is determined by the `axes` argument (i.e. the - axes to be reduced). - - - Parameters - ========== - scalar_op: - The scalar :class:`Op` that performs the desired reduction. - identity: - The identity value for the reduction. - axes: - The axes to reduce. - ndim: - The number of dimensions of the result. - dtype: - The data type of the result. - return_scalar: - If True, return a scalar, otherwise an array. - - Returns - ======= - A Python function that can be JITed. - - """ - if len(axes) == 1: - return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) - - axes = normalize_axis_tuple(axes, ndim) - - careduce_fn_name = f"careduce_{scalar_op}" - global_env = {} - to_reduce = sorted(axes, reverse=True) - careduce_lines_src = [] - var_name = input_name - - for i, axis in enumerate(to_reduce): - careducer_axes_fn_name = f"careduce_axes_fn_{i}" - reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype) - reducer_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - )(reducer_py_fn) - - global_env[careducer_axes_fn_name] = reducer_fn - - ndim -= 1 - last_var_name = var_name - var_name = f"axis_{i}_res" - careduce_lines_src.append( - f"{var_name} = {careducer_axes_fn_name}({last_var_name})" + for axis in range(ndim): + careduce_def_src += indent( + f"for i{axis} in range(x_shape[{axis}]):\n", + " " * (4 + 4 * axis), ) - - careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) - if not return_scalar: - pre_result = "np.asarray" - post_result = "" - else: - pre_result = "np.asarray" - post_result = ".item()" - - careduce_def_src = f""" -def {careduce_fn_name}({input_name}): -{careduce_assign_lines} - return {pre_result}({var_name}){post_result} - """ + careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim)) + careduce_def_src += "\n\n" + careduce_def_src += indent(f"return {return_obj}", " " * 4) careduce_fn = compile_function_src( careduce_def_src, careduce_fn_name, {**globals(), **global_env} @@ -440,7 +289,6 @@ def jit_compile_reducer( res = numba_basic.numba_njit( *args, boundscheck=False, - fastmath=config.numba__fastmath, **kwds, )(fn) @@ -467,19 +315,13 @@ def axis_apply_fn(x): @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): - # Creating a new scalar node is more involved and unnecessary - # if the scalar_op is composite, as the fgraph already contains - # all the necessary information. - scalar_node = None - if not isinstance(op.scalar_op, Composite): - scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs] - scalar_node = op.scalar_op.make_node(*scalar_inputs) + scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] + scalar_node = op.scalar_op.make_node(*scalar_inputs) scalar_op_fn = numba_funcify( op.scalar_op, node=scalar_node, parent_node=node, - fastmath=_jit_options["fastmath"], **kwargs, ) @@ -517,8 +359,10 @@ def elemwise(*inputs): inputs = [np.asarray(input) for input in inputs] inputs_bc = np.broadcast_arrays(*inputs) shape = inputs[0].shape - for input, bc in zip(inputs, input_bc_patterns): - for length, allow_bc, iter_length in zip(input.shape, bc, shape): + for input, bc in zip(inputs, input_bc_patterns, strict=True): + for length, allow_bc, iter_length in zip( + input.shape, bc, shape, strict=True + ): if length == 1 and shape and iter_length != 1 and not allow_bc: raise ValueError("Broadcast not allowed.") @@ -529,11 +373,11 @@ def elemwise(*inputs): outs = scalar_op_fn(*vals) if not isinstance(outs, tuple): outs = (outs,) - for out, out_val in zip(outputs, outs): + for out, out_val in zip(outputs, outs, strict=True): out[idx] = out_val outputs_summed = [] - for output, bc in zip(outputs, output_bc_patterns): + for output, bc in zip(outputs, output_bc_patterns, strict=True): axes = tuple(np.nonzero(bc)[0]) outputs_summed.append(output.sum(axes, keepdims=True)) if len(outputs_summed) != 1: @@ -549,32 +393,29 @@ def ov_elemwise(*inputs): @numba_funcify.register(Sum) def numba_funcify_Sum(op, node, **kwargs): + ndim_input = node.inputs[0].ndim axes = op.axis if axes is None: axes = list(range(node.inputs[0].ndim)) - - axes = tuple(axes) - - ndim_input = node.inputs[0].ndim + else: + axes = normalize_axis_tuple(axes, ndim_input) if hasattr(op, "acc_dtype") and op.acc_dtype is not None: acc_dtype = op.acc_dtype else: acc_dtype = node.outputs[0].type.dtype - np_acc_dtype = np.dtype(acc_dtype) - out_dtype = np.dtype(node.outputs[0].dtype) if ndim_input == len(axes): - - @numba_njit(fastmath=True) + # Slightly faster than `numba_funcify_CAReduce` for this case + @numba_njit def impl_sum(array): return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) elif len(axes) == 0: - - @numba_njit(fastmath=True) + # These cases should be removed by rewrites! + @numba_njit def impl_sum(array): return np.asarray(array, dtype=out_dtype) @@ -607,7 +448,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): # Make sure it has the correct dtype scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype) - input_name = get_name_for_object(node.inputs[0]) ndim = node.inputs[0].ndim careduce_py_fn = create_multiaxis_reducer( op.scalar_op, @@ -615,7 +455,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): axes, ndim, np.dtype(node.outputs[0].type.dtype), - input_name=input_name, ) careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) @@ -728,16 +567,14 @@ def numba_funcify_Softmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) - reduce_max_py = create_axis_reducer( + reduce_max_py = create_multiaxis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True ) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: @@ -765,13 +602,11 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): axis = op.axis if axis is not None: axis = normalize_axis_index(axis, sm_at.ndim) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_sum = jit_fn(reduce_sum_py) else: reduce_sum = np.sum @@ -797,21 +632,19 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) - reduce_max_py = create_axis_reducer( + reduce_max_py = create_multiaxis_reducer( scalar_maximum, -np.inf, - axis, + (axis,), x_at.ndim, x_dtype, keepdims=True, ) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index e2a4668242..1f0a33e595 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -4,7 +4,6 @@ import numba import numpy as np -from pytensor import config from pytensor.graph import Apply from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify @@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -74,13 +73,13 @@ def cumop(x): else: if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -186,7 +185,8 @@ def ravelmultiindex(*inp): new_arr = arr.T.astype(np.float64).copy() for i, b in enumerate(new_arr): - for j, (d, v) in enumerate(zip(shape, b)): + # no strict argument to this zip because numba doesn't support it + for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905 if v < 0 or v >= d: mode_fn(new_arr, i, j, v, d) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 29584daa5f..4bd5c2fc28 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from copy import copy +from copy import copy, deepcopy from functools import singledispatch from textwrap import dedent @@ -34,7 +34,7 @@ def copy_NumPyRandomGenerator(rng): def impl(rng): # TODO: Open issue on Numba? with numba.objmode(new_rng=types.npy_rng): - new_rng = copy(rng) + new_rng = deepcopy(rng) return new_rng @@ -388,7 +388,7 @@ def random_wrapper(core_shape, rng, size, *dist_params): return rng, draws def random(core_shape, rng, size, *dist_params): - pass + raise NotImplementedError("Non-jitted random variable not implemented") @overload(random, jit_options=_jit_options) def ov_random(core_shape, rng, size, *dist_params): diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index f2c1bbc185..e9b637b00f 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -2,7 +2,6 @@ import numpy as np -from pytensor import config from pytensor.compile.ops import ViewOp from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic @@ -114,7 +113,9 @@ def {scalar_op_fn_name}({input_names}): input_names = [unique_names(v, force_unique=True) for v in node.inputs] converted_call_args = ", ".join( f"direct_cast({i_name}, {i_tmp_dtype_name})" - for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names) + for i_name, i_tmp_dtype_name in zip( + input_names, input_tmp_dtype_names, strict=False + ) ) if not has_pyx_skip_dispatch: scalar_op_src = f""" @@ -135,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}): return numba_basic.numba_njit( signature, - fastmath=config.numba__fastmath, # Functions that call a function pointer can't be cached cache=False, )(scalar_op_fn) @@ -175,9 +175,7 @@ def numba_funcify_Add(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( - nary_add_fn - ) + return numba_basic.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Mul) @@ -185,9 +183,7 @@ def numba_funcify_Mul(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( - nary_add_fn - ) + return numba_basic.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Cast) @@ -237,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) - composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( + composite_fn = numba_basic.numba_njit(signature)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) return composite_fn @@ -265,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs): return numba_basic.global_numba_func(reciprocal) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -275,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs): return numba_basic.global_numba_func(sigmoid) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def gammaln(x): return math.lgamma(x) @@ -285,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs): return numba_basic.global_numba_func(gammaln) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def logp1mexp(x): if x < np.log(0.5): return np.log1p(-np.exp(x)) @@ -298,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs): return numba_basic.global_numba_func(logp1mexp) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def erf(x): return math.erf(x) @@ -308,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs): return numba_basic.global_numba_func(erf) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def erfc(x): return math.erfc(x) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 92566a7f78..cc75fc3742 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -163,10 +163,11 @@ def add_inner_in_expr( op.info.mit_mot_in_slices + op.info.mit_sot_in_slices + op.info.sit_sot_in_slices, + strict=True, ) ) inner_in_names_to_output_taps: dict[str, tuple[int, ...] | None] = dict( - zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices) + zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices, strict=True) ) # Inner-outputs consist of: @@ -373,7 +374,8 @@ def add_output_storage_post_proc_stmt( inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) inner_out_to_outer_out_stmts = "\n".join( - f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names) + f"{s} = {d}" + for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names, strict=True) ) scan_op_src = f""" diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 1bf5a6c8fa..96a8da282e 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -420,7 +420,8 @@ def block_diag(*arrs): out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype) r, c = 0, 0 - for arr, shape in zip(arrs, shapes): + # no strict argument because it is incompatible with numba + for arr, shape in zip(arrs, shapes): # noqa: B905 rr, cc = shape out[r : r + rr, c : c + cc] = arr r += rr diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 178ce0b857..6dc4d4c294 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -5,6 +5,7 @@ from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.tensor import TensorType +from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -13,6 +14,7 @@ IncSubtensor, Subtensor, ) +from pytensor.tensor.type_other import NoneTypeT, SliceType @numba_funcify.register(Subtensor) @@ -104,18 +106,73 @@ def {function_name}({", ".join(input_names)}): @numba_funcify.register(AdvancedSubtensor) @numba_funcify.register(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): - idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:] - adv_idxs_dims = [ - idx.type.ndim + if isinstance(op, AdvancedSubtensor): + x, y, idxs = node.inputs[0], None, node.inputs[1:] + else: + x, y, *idxs = node.inputs + + basic_idxs = [ + idx for idx in idxs - if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) + if ( + isinstance(idx.type, NoneTypeT) + or (isinstance(idx.type, SliceType) and not is_full_slice(idx)) + ) + ] + adv_idxs = [ + { + "axis": i, + "dtype": idx.type.dtype, + "bcast": idx.type.broadcastable, + "ndim": idx.type.ndim, + } + for i, idx in enumerate(idxs) + if isinstance(idx.type, TensorType) ] + # Special case for consecutive consecutive vector indices + def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): + # Check that x is not broadcasted to y based on broadcastable info + if len(x_bcast) < len(to_bcast): + return True + for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True): + if x_bcast_dim and not to_bcast_dim: + return True + return False + + # Special implementation for consecutive integer vector indices + if ( + not basic_idxs + and len(adv_idxs) >= 2 + # Must be integer vectors + # Todo: we could allow shape=(1,) if this is the shape of x + and all( + (adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool") + for adv_idx in adv_idxs + ) + # Must be consecutive + and not op.non_contiguous_adv_indexing(node) + # y in set/inc_subtensor cannot be broadcasted + and ( + y is None + or not broadcasted_to( + y.type.broadcastable, + ( + x.type.broadcastable[: adv_idxs[0]["axis"]] + + x.type.broadcastable[adv_idxs[-1]["axis"] :] + ), + ) + ) + ): + return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs) + + # Other cases not natively supported by Numba (fallback to obj-mode) if ( # Numba does not support indexes with more than one dimension + any(idx["ndim"] > 1 for idx in adv_idxs) # Nor multiple vector indexes - (len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1) - # The default index implementation does not handle duplicate indices correctly + or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1 + # The default PyTensor implementation does not handle duplicate indices correctly or ( isinstance(op, AdvancedIncSubtensor) and not op.set_instead_of_inc @@ -124,9 +181,91 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ): return generate_fallback_impl(op, node, **kwargs) + # What's left should all be supported natively by numba return numba_funcify_default_subtensor(op, node, **kwargs) +def numba_funcify_multiple_integer_vector_indexing( + op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs +): + # Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor) + if isinstance(op, AdvancedSubtensor): + y, idxs = None, node.inputs[1:] + else: + y, *idxs = node.inputs[1:] + + first_axis = next( + i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType) + ) + try: + after_last_axis = next( + i + for i, idx in enumerate(idxs[first_axis:], start=first_axis) + if not isinstance(idx.type, TensorType) + ) + except StopIteration: + after_last_axis = len(idxs) + + if isinstance(op, AdvancedSubtensor): + + @numba_njit + def advanced_subtensor_multiple_vector(x, *idxs): + none_slices = idxs[:first_axis] + vec_idxs = idxs[first_axis:after_last_axis] + + x_shape = x.shape + idx_shape = vec_idxs[0].shape + shape_bef = x_shape[:first_axis] + shape_aft = x_shape[after_last_axis:] + out_shape = (*shape_bef, *idx_shape, *shape_aft) + out_buffer = np.empty(out_shape, dtype=x.dtype) + for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)] + return out_buffer + + return advanced_subtensor_multiple_vector + + elif op.set_instead_of_inc: + inplace = op.inplace + + @numba_njit + def advanced_set_subtensor_multiple_vector(x, y, *idxs): + vec_idxs = idxs[first_axis:after_last_axis] + x_shape = x.shape + + if inplace: + out = x + else: + out = x.copy() + + for outer in np.ndindex(x_shape[:first_axis]): + for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + out[(*outer, *scalar_idxs)] = y[(*outer, i)] + return out + + return advanced_set_subtensor_multiple_vector + + else: + inplace = op.inplace + + @numba_njit + def advanced_inc_subtensor_multiple_vector(x, y, *idxs): + vec_idxs = idxs[first_axis:after_last_axis] + x_shape = x.shape + + if inplace: + out = x + else: + out = x.copy() + + for outer in np.ndindex(x_shape[:first_axis]): + for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + out[(*outer, *scalar_idxs)] += y[(*outer, i)] + return out + + return advanced_inc_subtensor_multiple_vector + + @numba_funcify.register(AdvancedIncSubtensor1) def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): inplace = op.inplace @@ -158,7 +297,8 @@ def advancedincsubtensor1_inplace(x, val, idxs): def advancedincsubtensor1_inplace(x, vals, idxs): if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") - for idx, val in zip(idxs, vals): + # no strict argument because incompatible with numba + for idx, val in zip(idxs, vals): # noqa: B905 x[idx] = val return x else: @@ -184,7 +324,9 @@ def advancedincsubtensor1_inplace(x, val, idxs): def advancedincsubtensor1_inplace(x, vals, idxs): if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") - for idx, val in zip(idxs, vals): + # no strict argument because unsupported by numba + # TODO: this doesn't come up in tests + for idx, val in zip(idxs, vals): # noqa: B905 x[idx] += val return x diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 09421adeb6..80b05d4e81 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -36,7 +36,9 @@ def numba_funcify_AllocEmpty(op, node, **kwargs): shapes_to_items_src = indent( "\n".join( f"{item_name} = to_scalar({shape_name})" - for item_name, shape_name in zip(shape_var_item_names, shape_var_names) + for item_name, shape_name in zip( + shape_var_item_names, shape_var_names, strict=True + ) ), " " * 4, ) @@ -68,7 +70,9 @@ def numba_funcify_Alloc(op, node, **kwargs): shapes_to_items_src = indent( "\n".join( f"{item_name} = to_scalar({shape_name})" - for item_name, shape_name in zip(shape_var_item_names, shape_var_names) + for item_name, shape_name in zip( + shape_var_item_names, shape_var_names, strict=True + ) ), " " * 4, ) diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index a680f9747d..74870e29bd 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -44,7 +44,7 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): inner_out_signature = ", ".join(inner_outputs) store_outputs = "\n".join( f"{output}[...] = {inner_output}" - for output, inner_output in zip(outputs, inner_outputs) + for output, inner_output in zip(outputs, inner_outputs, strict=True) ) func_src = f""" def store_core_outputs({inp_signature}, {out_signature}): @@ -137,7 +137,7 @@ def _vectorized( ) core_input_types = [] - for input_type, bc_pattern in zip(input_types, input_bc_patterns): + for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True): core_ndim = input_type.ndim - len(bc_pattern) # TODO: Reconsider this if core_ndim == 0: @@ -150,14 +150,18 @@ def _vectorized( core_out_types = [ types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C") - for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types) + for dtype, output_core_shape in zip( + output_dtypes, output_core_shape_types, strict=True + ) ] out_types = [ types.Array( numba.from_dtype(np.dtype(dtype)), batch_ndim + len(output_core_shape), "C" ) - for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types) + for dtype, output_core_shape in zip( + output_dtypes, output_core_shape_types, strict=True + ) ] for output_idx, input_idx in inplace_pattern: @@ -211,7 +215,7 @@ def codegen( inputs = [ arrayobj.make_array(ty)(ctx, builder, val) - for ty, val in zip(input_types, inputs) + for ty, val in zip(input_types, inputs, strict=True) ] in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] @@ -283,7 +287,9 @@ def compute_itershape( if size is not None: shape = size for i in range(batch_ndim): - for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + for j, (bc, in_shape) in enumerate( + zip(broadcast_pattern, in_shapes, strict=True) + ): length = in_shape[i] if bc[i]: with builder.if_then( @@ -318,7 +324,9 @@ def compute_itershape( else: # Size is implied by the broadcast pattern for i in range(batch_ndim): - for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + for j, (bc, in_shape) in enumerate( + zip(broadcast_pattern, in_shapes, strict=True) + ): length = in_shape[i] if bc[i]: with builder.if_then( @@ -374,7 +382,7 @@ def make_outputs( one = ir.IntType(64)(1) inplace_dict = dict(inplace) for i, (core_shape, bc, dtype) in enumerate( - zip(output_core_shapes, out_bc, dtypes) + zip(output_core_shapes, out_bc, dtypes, strict=True) ): if i in inplace_dict: output_arrays.append(inputs[inplace_dict[i]]) @@ -388,7 +396,8 @@ def make_outputs( # This is actually an internal numba function, I guess we could # call `numba.nd.unsafe.ndarray` instead? batch_shape = [ - length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc) + length if not bc_dim else one + for length, bc_dim in zip(iter_shape, bc, strict=True) ] shape = batch_shape + core_shape array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) @@ -458,10 +467,10 @@ def make_loop_call( # Load values from input arrays input_vals = [] - for input, input_type, bc in zip(inputs, input_types, input_bc): + for input, input_type, bc in zip(inputs, input_types, input_bc, strict=True): core_ndim = input_type.ndim - len(bc) - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [ + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ zero ] * core_ndim ptr = cgutils.get_item_pointer2( @@ -506,13 +515,13 @@ def make_loop_call( # Create output slices to pass to inner func output_slices = [] - for output, output_type, bc in zip(outputs, output_types, output_bc): + for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True): core_ndim = output_type.ndim - len(bc) size_type = output.shape.type.element # type: ignore output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [ + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ zero ] * core_ndim ptr = cgutils.get_item_pointer2( diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index f120706f3b..553c5ef217 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -1,26 +1,9 @@ -from typing import TYPE_CHECKING, Any - -import numpy as np - -import pytensor from pytensor.link.basic import JITLinker -if TYPE_CHECKING: - from pytensor.graph.basic import Variable - - class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" - def output_filter(self, var: "Variable", out: Any) -> Any: - if not isinstance(var, np.ndarray) and isinstance( - var.type, pytensor.tensor.TensorType - ): - return var.type.filter(out, allow_downcast=True) - - return out - def fgraph_convert(self, fgraph, **kwargs): from pytensor.link.numba.dispatch import numba_funcify diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index e0aa80e18b..11e1d6c63a 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -54,14 +54,16 @@ def pytorch_funcify_FunctionGraph( fgraph, node=None, fgraph_name="pytorch_funcified_fgraph", + conversion_func=pytorch_funcify, **kwargs, ): + built_kwargs = {"conversion_func": conversion_func, **kwargs} return fgraph_to_python( fgraph, - pytorch_funcify, + conversion_func, type_conversion_fn=pytorch_typify, fgraph_name=fgraph_name, - **kwargs, + **built_kwargs, ) @@ -173,11 +175,8 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs): # Apply inner rewrites PYTORCH.optimizer(op.fgraph) - fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) - # Disable one step inlining to prevent torch from trying to import local functions - # defined in `pytorch_funcify` - return torch.compiler.disable(fgraph_fn, recursive=False) + return fgraph_fn @pytorch_funcify.register(TensorFromScalar) diff --git a/pytensor/link/pytorch/dispatch/blockwise.py b/pytensor/link/pytorch/dispatch/blockwise.py index 524e706633..0681d32a8e 100644 --- a/pytensor/link/pytorch/dispatch/blockwise.py +++ b/pytensor/link/pytorch/dispatch/blockwise.py @@ -1,5 +1,4 @@ import torch -import torch.compiler from pytensor.graph import FunctionGraph from pytensor.link.pytorch.dispatch import pytorch_funcify @@ -11,12 +10,13 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): batched_dims = op.batch_ndim(node) core_node = op._create_dummy_core_node(node.inputs) core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) - inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1) + inner_func = pytorch_funcify( + core_fgraph, squeeze_output=len(node.outputs) == 1, **kwargs + ) for _ in range(batched_dims): inner_func = torch.vmap(inner_func) - @torch.compiler.disable(recursive=False) def batcher(*inputs): op._check_runtime_broadcast(node, inputs) # broadcast on batched_dims diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index b1ad5582c5..c22945d914 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -1,6 +1,9 @@ +import importlib + import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.scalar import ScalarLoop from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -9,11 +12,41 @@ @pytorch_funcify.register(Elemwise) def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op + base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) - def elemwise_fn(*inputs): - Elemwise._check_runtime_broadcast(node, inputs) - return base_fn(*inputs) + def check_special_scipy(func_name): + if "scipy." not in func_name: + return False + loc = func_name.split(".")[1:] + try: + mod = importlib.import_module(".".join(loc[:-1]), "torch") + return getattr(mod, loc[-1], False) + except ImportError: + return False + + if hasattr(scalar_op, "nfunc_spec") and ( + hasattr(torch, scalar_op.nfunc_spec[0]) + or check_special_scipy(scalar_op.nfunc_spec[0]) + ): + # torch can handle this scalar + # broadcast, we'll let it. + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) + return base_fn(*inputs) + + elif isinstance(scalar_op, ScalarLoop): + return elemwise_ravel_fn(base_fn, op, node, **kwargs) + + else: + + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) + broadcast_inputs = torch.broadcast_tensors(*inputs) + ufunc = base_fn + for _ in range(broadcast_inputs[0].dim()): + ufunc = torch.vmap(ufunc) + return ufunc(*broadcast_inputs) return elemwise_fn @@ -148,3 +181,37 @@ def softmax_grad(dy, sm): return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm return softmax_grad + + +def elemwise_ravel_fn(base_fn, op, node, **kwargs): + """ + Dispatch methods using `.item()` (ScalarLoop + Elemwise) is common, but vmap + in torch has a limitation: https://github.com/pymc-devs/pytensor/issues/1031, + Instead, we can ravel all the inputs, broadcasted according to torch + """ + + n_outputs = len(node.outputs) + + def elemwise_fn(*inputs): + bcasted_inputs = torch.broadcast_tensors(*inputs) + raveled_inputs = [inp.ravel() for inp in bcasted_inputs] + + out_shape = bcasted_inputs[0].size() + out_size = out_shape.numel() + raveled_outputs = [torch.empty(out_size) for out in node.outputs] + + for i in range(out_size): + core_outs = base_fn(*(inp[i] for inp in raveled_inputs)) + if n_outputs == 1: + raveled_outputs[0][i] = core_outs + else: + for o in range(n_outputs): + raveled_outputs[o][i] = core_outs[o] + + outputs = tuple(out.view(out_shape) for out in raveled_outputs) + if n_outputs == 1: + return outputs[0] + else: + return outputs + + return elemwise_fn diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index a977c6d4b2..65170b1f53 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -1,3 +1,5 @@ +import importlib + import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify @@ -5,6 +7,8 @@ Cast, ScalarOp, ) +from pytensor.scalar.loop import ScalarLoop +from pytensor.scalar.math import Softplus @pytorch_funcify.register(ScalarOp) @@ -19,9 +23,14 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs): if nfunc_spec is None: raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") - func_name = nfunc_spec[0] + func_name = nfunc_spec[0].replace("scipy.", "") - pytorch_func = getattr(torch, func_name) + if "." in func_name: + loc = func_name.split(".") + mod = importlib.import_module(".".join(["torch", *loc[:-1]])) + pytorch_func = getattr(mod, loc[-1]) + else: + pytorch_func = getattr(torch, func_name) if len(node.inputs) > op.nfunc_spec[1]: # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, @@ -49,3 +58,42 @@ def cast(x): return x.to(dtype=dtype) return cast + + +@pytorch_funcify.register(Softplus) +def pytorch_funcify_Softplus(op, node, **kwargs): + return torch.nn.Softplus() + + +@pytorch_funcify.register(ScalarLoop) +def pytorch_funicify_ScalarLoop(op, node, **kwargs): + update = pytorch_funcify(op.fgraph, **kwargs) + state_length = op.nout + if op.is_while: + + def scalar_loop(steps, *start_and_constants): + carry, constants = ( + start_and_constants[:state_length], + start_and_constants[state_length:], + ) + done = True + for _ in range(steps): + *carry, done = update(*carry, *constants) + if torch.any(done): + break + return *carry, done + else: + + def scalar_loop(steps, *start_and_constants): + carry, constants = ( + start_and_constants[:state_length], + start_and_constants[state_length:], + ) + for _ in range(steps): + carry = update(*carry, *constants) + if len(node.outputs) == 1: + return carry[0] + else: + return carry + + return scalar_loop diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index e249a81a70..f771ac7211 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -34,7 +34,8 @@ def shape_i(x): def pytorch_funcify_SpecifyShape(op, node, **kwargs): def specifyshape(x, *shape): assert x.ndim == len(shape) - for actual, expected in zip(x.shape, shape): + # strict=False because asserted above + for actual, expected in zip(x.shape, shape, strict=False): if expected is None: continue if actual != expected: diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 035d654c83..d47aa43dda 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,31 +1,83 @@ -from typing import Any - -from pytensor.graph.basic import Variable from pytensor.link.basic import JITLinker +from pytensor.link.utils import unique_name_generator class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" - def input_filter(self, inp: Any) -> Any: - from pytensor.link.pytorch.dispatch import pytorch_typify - - return pytorch_typify(inp) - - def output_filter(self, var: Variable, out: Any) -> Any: - return out.cpu() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gen_functors = [] def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.pytorch.dispatch import pytorch_funcify + # We want to have globally unique names + # across the entire pytensor graph, not + # just the subgraph + generator = unique_name_generator(["torch_linker"]) + + # Ensure that torch is aware of the generated + # code so we can compile without graph breaks + def conversion_func_register(*args, **kwargs): + functor = pytorch_funcify(*args, **kwargs) + name = kwargs["unique_name"](functor) + self.gen_functors.append((f"_{name}", functor)) + return functor + + built_kwargs = { + "unique_name": generator, + "conversion_func": conversion_func_register, + **kwargs, + } return pytorch_funcify( - fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs + fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs ) def jit_compile(self, fn): import torch - return torch.compile(fn) + from pytensor.link.pytorch.dispatch import pytorch_typify + + class wrapper: + """ + Pytorch would fail compiling our method when trying + to resolve some of the methods returned from dispatch + calls. We want to be careful to not leak the methods, + so this class just holds them and provisions the expected + location accordingly + + https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319 + """ + + def __init__(self, fn, gen_functors): + self.fn = torch.compile(fn) + self.gen_functors = gen_functors.copy() + + def __call__(self, *inputs, **kwargs): + import pytensor.link.utils + + # set attrs + for n, fn in self.gen_functors: + setattr(pytensor.link.utils, n[1:], fn) + + # Torch does not accept numpy inputs and may return GPU objects + outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) + + # unset attrs + for n, _ in self.gen_functors: + if getattr(pytensor.link.utils, n[1:], False): + delattr(pytensor.link.utils, n[1:]) + + return tuple(out.cpu().numpy() for out in outs) + + def __del__(self): + del self.gen_functors + + inner_fn = wrapper(fn, self.gen_functors) + self.gen_functors = [] + + return inner_fn def create_thunk_inputs(self, storage_map): thunk_inputs = [] diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index c51b13c427..9cbc3838dd 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -88,7 +88,7 @@ def map_storage( assert len(fgraph.inputs) == len(input_storage) # add input storage into storage_map - for r, storage in zip(fgraph.inputs, input_storage): + for r, storage in zip(fgraph.inputs, input_storage, strict=True): if r in storage_map: assert storage_map[r] is storage, ( "Given input_storage conflicts " @@ -108,7 +108,7 @@ def map_storage( # allocate output storage if output_storage is not None: assert len(fgraph.outputs) == len(output_storage) - for r, storage in zip(fgraph.outputs, output_storage): + for r, storage in zip(fgraph.outputs, output_storage, strict=True): if r in storage_map: assert storage_map[r] is storage, ( "Given output_storage confl" @@ -190,8 +190,9 @@ def streamline_default_f(): for x in no_recycling: x[0] = None try: + # strict=False because we are in a hot loop for thunk, node, old_storage in zip( - thunks, order, post_thunk_old_storage + thunks, order, post_thunk_old_storage, strict=False ): thunk() for old_s in old_storage: @@ -206,7 +207,8 @@ def streamline_nice_errors_f(): for x in no_recycling: x[0] = None try: - for thunk, node in zip(thunks, order): + # strict=False because we are in a hot loop + for thunk, node in zip(thunks, order, strict=False): thunk() except Exception: raise_with_op(fgraph, node, thunk) @@ -673,6 +675,7 @@ def fgraph_to_python( local_env: dict[Any, Any] | None = None, get_name_for_object: Callable[[Any], str] = get_name_for_object, squeeze_output: bool = False, + unique_name: Callable | None = None, **kwargs, ) -> Callable: """Convert a `FunctionGraph` into a regular Python function. @@ -704,6 +707,8 @@ def fgraph_to_python( get_name_for_object A function used to provide names for the objects referenced within the generated function. + unique_name + A function to make random function names for generated code squeeze_output If the `FunctionGraph` has only one output and this option is ``True``, return the single output instead of a tuple with the output. @@ -717,7 +722,11 @@ def fgraph_to_python( if storage_map is None: storage_map = {} - unique_name = unique_name_generator([fgraph_name]) + if not unique_name: + unique_name = unique_name_generator([fgraph_name]) + + # make sure we plumb this through + kwargs["unique_name"] = unique_name if global_env is None: global_env = {} diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 587b379cf0..af44af3254 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -19,7 +19,6 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Constant, Variable from pytensor.link.basic import Container, LocalLinker -from pytensor.link.c.exceptions import MissingGXX from pytensor.link.utils import ( gc_helper, get_destroy_dependencies, @@ -244,7 +243,7 @@ def clear_storage(self): def update_profile(self, profile): """Update a profile object.""" for node, thunk, t, c in zip( - self.nodes, self.thunks, self.call_times, self.call_counts + self.nodes, self.thunks, self.call_times, self.call_counts, strict=True ): profile.apply_time[(self.fgraph, node)] += t @@ -310,7 +309,9 @@ def __init__( self.output_storage = output_storage self.inp_storage_and_out_idx = tuple( (inp_storage, self.fgraph.outputs.index(update_vars[inp])) - for inp, inp_storage in zip(self.fgraph.inputs, self.input_storage) + for inp, inp_storage in zip( + self.fgraph.inputs, self.input_storage, strict=True + ) if inp in update_vars ) @@ -1004,6 +1005,8 @@ def make_vm( compute_map, updated_vars, ): + from pytensor.link.c.exceptions import MissingGXX + pre_call_clear = [storage_map[v] for v in self.no_recycling] try: @@ -1241,7 +1244,7 @@ def make_all( self.profile.linker_node_make_thunks += t1 - t0 self.profile.linker_make_thunk_time = linker_make_thunk_time - for node, thunk in zip(order, thunks): + for node, thunk in zip(order, thunks, strict=True): thunk.inputs = [storage_map[v] for v in node.inputs] thunk.outputs = [storage_map[v] for v in node.outputs] @@ -1298,11 +1301,11 @@ def make_all( vm, [ Container(input, storage) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks, order, diff --git a/pytensor/misc/check_blas.py b/pytensor/misc/check_blas.py index 8ee4482f0e..fc2fe02377 100644 --- a/pytensor/misc/check_blas.py +++ b/pytensor/misc/check_blas.py @@ -59,7 +59,7 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, iters=10, order= if any(x.op.__class__.__name__ == "Gemm" for x in f.maker.fgraph.toposort()): c_impl = [ hasattr(thunk, "cthunk") - for node, thunk in zip(f.vm.nodes, f.vm.thunks) + for node, thunk in zip(f.vm.nodes, f.vm.thunks, strict=True) if node.op.__class__.__name__ == "Gemm" ] assert len(c_impl) == 1 diff --git a/pytensor/npy_2_compat.py b/pytensor/npy_2_compat.py new file mode 100644 index 0000000000..bfdd567eae --- /dev/null +++ b/pytensor/npy_2_compat.py @@ -0,0 +1,223 @@ +from textwrap import dedent + + +def npy_2_compat_header() -> str: + return dedent(""" + #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ + #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ + + + /* + * This header is meant to be included by downstream directly for 1.x compat. + * In that case we need to ensure that users first included the full headers + * and not just `ndarraytypes.h`. + */ + + #ifndef NPY_FEATURE_VERSION + #error "The NumPy 2 compat header requires `import_array()` for which " \\ + "the `ndarraytypes.h` header include is not sufficient. Please " \\ + "include it after `numpy/ndarrayobject.h` or similar." \\ + "" \\ + "To simplify inclusion, you may use `PyArray_ImportNumPy()` " \\ + "which is defined in the compat header and is lightweight (can be)." + #endif + + #if NPY_ABI_VERSION < 0x02000000 + /* + * Define 2.0 feature version as it is needed below to decide whether we + * compile for both 1.x and 2.x (defining it gaurantees 1.x only). + */ + #define NPY_2_0_API_VERSION 0x00000012 + /* + * If we are compiling with NumPy 1.x, PyArray_RUNTIME_VERSION so we + * pretend the `PyArray_RUNTIME_VERSION` is `NPY_FEATURE_VERSION`. + * This allows downstream to use `PyArray_RUNTIME_VERSION` if they need to. + */ + #define PyArray_RUNTIME_VERSION NPY_FEATURE_VERSION + /* Compiling on NumPy 1.x where these are the same: */ + #define PyArray_DescrProto PyArray_Descr + #endif + + + /* + * Define a better way to call `_import_array()` to simplify backporting as + * we now require imports more often (necessary to make ABI flexible). + */ + #ifdef import_array1 + + static inline int + PyArray_ImportNumPyAPI() + { + if (NPY_UNLIKELY(PyArray_API == NULL)) { + import_array1(-1); + } + return 0; + } + + #endif /* import_array1 */ + + + /* + * NPY_DEFAULT_INT + * + * The default integer has changed, `NPY_DEFAULT_INT` is available at runtime + * for use as type number, e.g. `PyArray_DescrFromType(NPY_DEFAULT_INT)`. + * + * NPY_RAVEL_AXIS + * + * This was introduced in NumPy 2.0 to allow indicating that an axis should be + * raveled in an operation. Before NumPy 2.0, NPY_MAXDIMS was used for this purpose. + * + * NPY_MAXDIMS + * + * A constant indicating the maximum number dimensions allowed when creating + * an ndarray. + * + * NPY_NTYPES_LEGACY + * + * The number of built-in NumPy dtypes. + */ + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + #define NPY_DEFAULT_INT NPY_INTP + #define NPY_RAVEL_AXIS NPY_MIN_INT + #define NPY_MAXARGS 64 + + #elif NPY_ABI_VERSION < 0x02000000 + #define NPY_DEFAULT_INT NPY_LONG + #define NPY_RAVEL_AXIS 32 + #define NPY_MAXARGS 32 + + /* Aliases of 2.x names to 1.x only equivalent names */ + #define NPY_NTYPES NPY_NTYPES_LEGACY + #define PyArray_DescrProto PyArray_Descr + #define _PyArray_LegacyDescr PyArray_Descr + /* NumPy 2 definition always works, but add it for 1.x only */ + #define PyDataType_ISLEGACY(dtype) (1) + #else + #define NPY_DEFAULT_INT \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? NPY_INTP : NPY_LONG) + #define NPY_RAVEL_AXIS \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? -1 : 32) + #define NPY_MAXARGS \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? 64 : 32) + #endif + + + /* + * Access inline functions for descriptor fields. Except for the first + * few fields, these needed to be moved (elsize, alignment) for + * additional space. Or they are descriptor specific and are not generally + * available anymore (metadata, c_metadata, subarray, names, fields). + * + * Most of these are defined via the `DESCR_ACCESSOR` macro helper. + */ + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION || NPY_ABI_VERSION < 0x02000000 + /* Compiling for 1.x or 2.x only, direct field access is OK: */ + + static inline void + PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size) + { + dtype->elsize = size; + } + + static inline npy_uint64 + PyDataType_FLAGS(const PyArray_Descr *dtype) + { + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + return dtype->flags; + #else + return (unsigned char)dtype->flags; /* Need unsigned cast on 1.x */ + #endif + } + + #define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\ + static inline type \\ + PyDataType_##FIELD(const PyArray_Descr *dtype) { \\ + if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\ + return (type)0; \\ + } \\ + return ((_PyArray_LegacyDescr *)dtype)->field; \\ + } + #else /* compiling for both 1.x and 2.x */ + + static inline void + PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + ((_PyArray_DescrNumPy2 *)dtype)->elsize = size; + } + else { + ((PyArray_DescrProto *)dtype)->elsize = (int)size; + } + } + + static inline npy_uint64 + PyDataType_FLAGS(const PyArray_Descr *dtype) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + return ((_PyArray_DescrNumPy2 *)dtype)->flags; + } + else { + return (unsigned char)((PyArray_DescrProto *)dtype)->flags; + } + } + + /* Cast to LegacyDescr always fine but needed when `legacy_only` */ + #define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\ + static inline type \\ + PyDataType_##FIELD(const PyArray_Descr *dtype) { \\ + if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\ + return (type)0; \\ + } \\ + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { \\ + return ((_PyArray_LegacyDescr *)dtype)->field; \\ + } \\ + else { \\ + return ((PyArray_DescrProto *)dtype)->field; \\ + } \\ + } + #endif + + DESCR_ACCESSOR(ELSIZE, elsize, npy_intp, 0) + DESCR_ACCESSOR(ALIGNMENT, alignment, npy_intp, 0) + DESCR_ACCESSOR(METADATA, metadata, PyObject *, 1) + DESCR_ACCESSOR(SUBARRAY, subarray, PyArray_ArrayDescr *, 1) + DESCR_ACCESSOR(NAMES, names, PyObject *, 1) + DESCR_ACCESSOR(FIELDS, fields, PyObject *, 1) + DESCR_ACCESSOR(C_METADATA, c_metadata, NpyAuxData *, 1) + + #undef DESCR_ACCESSOR + + + #if !(defined(NPY_INTERNAL_BUILD) && NPY_INTERNAL_BUILD) + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + return _PyDataType_GetArrFuncs(descr); + } + #elif NPY_ABI_VERSION < 0x02000000 + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + return descr->f; + } + #else + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + return _PyDataType_GetArrFuncs(descr); + } + else { + return ((PyArray_DescrProto *)descr)->f; + } + } + #endif + + + #endif /* not internal build */ + + #endif /* NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ */ + + """) diff --git a/pytensor/printing.py b/pytensor/printing.py index 92bcf5ff23..6a18f6e8e5 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -26,39 +26,6 @@ IDTypesType = Literal["id", "int", "CHAR", "auto", ""] -pydot_imported = False -pydot_imported_msg = "" -try: - # pydot-ng is a fork of pydot that is better maintained - import pydot_ng as pd - - if pd.find_graphviz(): - pydot_imported = True - else: - pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz." -except ImportError: - try: - # fall back on pydot if necessary - import pydot as pd - - if hasattr(pd, "find_graphviz"): - if pd.find_graphviz(): - pydot_imported = True - else: - pydot_imported_msg = "pydot can't find graphviz" - else: - pd.Dot.create(pd.Dot()) - pydot_imported = True - except ImportError: - # tests should not fail on optional dependency - pydot_imported_msg = ( - "Install the python package pydot or pydot-ng. Install graphviz." - ) - except Exception as e: - pydot_imported_msg = "An error happened while importing/trying pydot: " - pydot_imported_msg += str(e.args) - - _logger = logging.getLogger("pytensor.printing") VALID_ASSOC = {"left", "right", "either"} @@ -311,7 +278,7 @@ def debugprint( ) for var, profile, storage_map, topo_order in zip( - outputs_to_print, profile_list, storage_maps, topo_orders + outputs_to_print, profile_list, storage_maps, topo_orders, strict=True ): if hasattr(var.owner, "op"): if ( @@ -930,7 +897,7 @@ def process(self, output, pstate): ) idx = node.outputs.index(output) pattern, precedences = self.patterns[idx] - precedences += (1000,) * len(node.inputs) + precedences += (1000,) * (len(node.inputs) - len(precedences)) def pp_process(input, new_precedence): with set_precedence(pstate, new_precedence): @@ -938,10 +905,9 @@ def pp_process(input, new_precedence): return r d = { - str(i): x - for i, x in enumerate( - pp_process(input, precedence) - for input, precedence in zip(node.inputs, precedences) + str(i): pp_process(input, precedence) + for i, (input, precedence) in enumerate( + zip(node.inputs, precedences, strict=True) ) } r = pattern % d @@ -1197,6 +1163,48 @@ def __call__(self, *args): } +def _try_pydot_import(): + pydot_imported = False + pydot_imported_msg = "" + try: + # pydot-ng is a fork of pydot that is better maintained + import pydot_ng as pd + + if pd.find_graphviz(): + pydot_imported = True + else: + pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz." + except ImportError: + try: + # fall back on pydot if necessary + import pydot as pd + + if hasattr(pd, "find_graphviz"): + if pd.find_graphviz(): + pydot_imported = True + else: + pydot_imported_msg = "pydot can't find graphviz" + else: + pd.Dot.create(pd.Dot()) + pydot_imported = True + except ImportError: + # tests should not fail on optional dependency + pydot_imported_msg = ( + "Install the python package pydot or pydot-ng. Install graphviz." + ) + except Exception as e: + pydot_imported_msg = "An error happened while importing/trying pydot: " + pydot_imported_msg += str(e.args) + + if not pydot_imported: + raise ImportError( + "Failed to import pydot. You must install graphviz " + "and either pydot or pydot-ng for " + f"`pydotprint` to work:\n {pydot_imported_msg}", + ) + return pd + + def pydotprint( fct, outfile: Path | str | None = None, @@ -1289,6 +1297,8 @@ def pydotprint( scan separately after the top level debugprint output. """ + pd = _try_pydot_import() + from pytensor.scan.op import Scan if colorCodes is None: @@ -1321,12 +1331,6 @@ def pydotprint( outputs = fct.outputs topo = fct.toposort() fgraph = fct - if not pydot_imported: - raise RuntimeError( - "Failed to import pydot. You must install graphviz " - "and either pydot or pydot-ng for " - f"`pydotprint` to work:\n {pydot_imported_msg}", - ) g = pd.Dot() @@ -1449,7 +1453,7 @@ def apply_name(node): if isinstance(fct, Function): # TODO: Get rid of all this `expanded_inputs` nonsense and use # `fgraph.update_mapping` - function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs) + function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs, strict=True) for i, fg_ii in reversed(list(function_inputs)): if i.update is not None: k = outputs.pop() diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 55414a94d0..54936a9720 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -184,7 +184,9 @@ def __call__(self, x): for dtype in try_dtypes: x_ = np.asarray(x).astype(dtype=dtype) - if np.all(x == x_): + if np.all( + np.asarray(x) == x_ + ): # use np.asarray(x) to match TensorType.filter break # returns either an exact x_==x, or the last cast x_ return x_ @@ -348,8 +350,10 @@ def c_headers(self, c_compiler=None, **kwargs): l = [""] # These includes are needed by ScalarType and TensorType, # we declare them here and they will be re-used by TensorType + l.append("") l.append("") + l.append("") if config.lib__amdlibm and c_compiler.supports_amdlibm: l += [""] return l @@ -502,7 +506,9 @@ def c_cleanup(self, name, sub): def c_support_code(self, **kwargs): if self.dtype.startswith("complex"): - cplx_types = ["pytensor_complex64", "pytensor_complex128"] + # complex types are: "pytensor_complex64", "pytensor_complex128" + # but it is more convenient to have their bit widths: + cplx_types_bit_widths = ["64", "128"] real_types = [ "npy_int8", "npy_int16", @@ -518,83 +524,135 @@ def c_support_code(self, **kwargs): # In that case we add the 'int' type to the real types. real_types.append("int") + def _make_get_set_real_imag(scalar_type: str) -> str: + """Make overloaded getter/setter functions for real/imag parts of numpy complex types. + + The functions called by these getter/setter functions are defining in npy_math.h + + Args: + scalar_type: float, double, or longdouble + + Returns: + C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the + given type. + """ + complex_type = "npy_c" + scalar_type + suffix = "" if scalar_type == "double" else scalar_type[0] + return_type = scalar_type + + if scalar_type == "longdouble": + scalar_type += "_t" + return_type = "npy_" + return_type + + template = f""" + static inline {return_type} get_real(const {complex_type} z) + {{ + return npy_creal{suffix}(z); + }} + + static inline void set_real({complex_type} *z, const {scalar_type} r) + {{ + npy_csetreal{suffix}(z, r); + }} + + static inline {return_type} get_imag(const {complex_type} z) + {{ + return npy_cimag{suffix}(z); + }} + + static inline void set_imag({complex_type} *z, const {scalar_type} i) + {{ + npy_csetimag{suffix}(z, i); + }} + """ + return template + + # TODO: add guard code to prevent this from being defining twice, in case we need to add it somewhere else + get_set_aliases = "\n".join( + _make_get_set_real_imag(stype) + for stype in ["float", "double", "longdouble"] + ) + template = """ - struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s - { - typedef pytensor_complex%(nbits)s complex_type; - typedef npy_float%(half_nbits)s scalar_type; - - complex_type operator +(const complex_type &y) const { - complex_type ret; - ret.real = this->real + y.real; - ret.imag = this->imag + y.imag; - return ret; - } - - complex_type operator -() const { - complex_type ret; - ret.real = -this->real; - ret.imag = -this->imag; - return ret; - } - bool operator ==(const complex_type &y) const { - return (this->real == y.real) && (this->imag == y.imag); - } - bool operator ==(const scalar_type &y) const { - return (this->real == y) && (this->imag == 0); - } - complex_type operator -(const complex_type &y) const { - complex_type ret; - ret.real = this->real - y.real; - ret.imag = this->imag - y.imag; - return ret; - } - complex_type operator *(const complex_type &y) const { - complex_type ret; - ret.real = this->real * y.real - this->imag * y.imag; - ret.imag = this->real * y.imag + this->imag * y.real; - return ret; - } - complex_type operator /(const complex_type &y) const { - complex_type ret; - scalar_type y_norm_square = y.real * y.real + y.imag * y.imag; - ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square; - ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; - return ret; - } - template - complex_type& operator =(const T& y); - - pytensor_complex%(nbits)s() {} - - template - pytensor_complex%(nbits)s(const T& y) { *this = y; } - - template - pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; } + struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s { + typedef pytensor_complex%(nbits)s complex_type; + typedef npy_float32 scalar_type; + + complex_type operator+(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) + get_real(y)); + set_imag(&ret, get_imag(*this) + get_imag(y)); + return ret; + } + + complex_type operator-() const { + complex_type ret; + set_real(&ret, -get_real(*this)); + set_imag(&ret, -get_imag(*this)); + return ret; + } + bool operator==(const complex_type &y) const { + return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y)); + } + bool operator==(const scalar_type &y) const { + return (get_real(*this) == y) && (get_real(*this) == 0); + } + complex_type operator-(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) - get_real(y)); + set_imag(&ret, get_imag(*this) - get_imag(y)); + return ret; + } + complex_type operator*(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y)); + set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y)); + return ret; + } + complex_type operator/(const complex_type &y) const { + complex_type ret; + scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y); + set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square); + set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square); + return ret; + } + template complex_type &operator=(const T &y); + + pytensor_complex%(nbits)s() {} + + template pytensor_complex%(nbits)s(const T &y) { *this = y; } + + template + pytensor_complex%(nbits)s(const TR &r, const TI &i) { + set_real(this, r); + set_imag(this, i); + } }; """ - def operator_eq_real(mytype, othertype): + def operator_eq_real(bit_width, othertype): + mytype = f"pytensor_complex{bit_width}" return f""" template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y; this->imag=0; return *this; }} + {{ set_real(this, y); set_imag(this, 0); return *this; }} """ - def operator_eq_cplx(mytype, othertype): + def operator_eq_cplx(bit_width1, bit_width2): + mytype = f"pytensor_complex{bit_width1}" + othertype = f"pytensor_complex{bit_width2}" return f""" template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y.real; this->imag=y.imag; return *this; }} + {{ set_real(this, get_real(y)); set_imag(this, get_imag(y)); return *this; }} """ operator_eq = "".join( - operator_eq_real(ctype, rtype) - for ctype in cplx_types + operator_eq_real(bit_width, rtype) + for bit_width in cplx_types_bit_widths for rtype in real_types ) + "".join( - operator_eq_cplx(ctype1, ctype2) - for ctype1 in cplx_types - for ctype2 in cplx_types + operator_eq_cplx(bit_width1, bit_width2) + for bit_width1 in cplx_types_bit_widths + for bit_width2 in cplx_types_bit_widths ) # We are not using C++ generic templating here, because this would @@ -603,53 +661,57 @@ def operator_eq_cplx(mytype, othertype): # and the compiler complains it is ambiguous. # Instead, we generate code for known and safe types only. - def operator_plus_real(mytype, othertype): + def operator_plus_real(bit_width, othertype): + mytype = f"pytensor_complex{bit_width}" return f""" const {mytype} operator+(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real+y, x.imag); }} + {{ return {mytype}(get_real(x) + y, get_imag(x)); }} const {mytype} operator+(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real+y, x.imag); }} + {{ return {mytype}(get_real(x) + y, get_imag(x)); }} """ operator_plus = "".join( - operator_plus_real(ctype, rtype) - for ctype in cplx_types + operator_plus_real(bit_width, rtype) + for bit_width in cplx_types_bit_widths for rtype in real_types ) - def operator_minus_real(mytype, othertype): + def operator_minus_real(bit_width, othertype): + mytype = f"pytensor_complex{bit_width}" return f""" const {mytype} operator-(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real-y, x.imag); }} + {{ return {mytype}(get_real(x) - y, get_imag(x)); }} const {mytype} operator-(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(y-x.real, -x.imag); }} + {{ return {mytype}(y - get_real(x), -get_imag(x)); }} """ operator_minus = "".join( - operator_minus_real(ctype, rtype) - for ctype in cplx_types + operator_minus_real(bit_width, rtype) + for bit_width in cplx_types_bit_widths for rtype in real_types ) - def operator_mul_real(mytype, othertype): + def operator_mul_real(bit_width, othertype): + mytype = f"pytensor_complex{bit_width}" return f""" const {mytype} operator*(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real*y, x.imag*y); }} + {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }} const {mytype} operator*(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real*y, x.imag*y); }} + {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }} """ operator_mul = "".join( - operator_mul_real(ctype, rtype) - for ctype in cplx_types + operator_mul_real(bit_width, rtype) + for bit_width in cplx_types_bit_widths for rtype in real_types ) return ( - template % dict(nbits=64, half_nbits=32) + get_set_aliases + + template % dict(nbits=64, half_nbits=32) + template % dict(nbits=128, half_nbits=64) + operator_eq + operator_plus @@ -664,7 +726,7 @@ def c_init_code(self, **kwargs): return ["import_array();"] def c_code_cache_version(self): - return (13, np.__version__) + return (14, np.version.git_revision) def get_shape_info(self, obj): return obj.itemsize @@ -1150,7 +1212,10 @@ def perform(self, node, inputs, output_storage): else: variables = from_return_values(self.impl(*inputs)) assert len(variables) == len(output_storage) - for out, storage, variable in zip(node.outputs, output_storage, variables): + # strict=False because we are in a hot loop + for out, storage, variable in zip( + node.outputs, output_storage, variables, strict=False + ): dtype = out.dtype storage[0] = self._cast_scalar(variable, dtype) @@ -1868,32 +1933,6 @@ def L_op(self, inputs, outputs, gout): add = Add(upcast_out, name="add") -class Mean(ScalarOp): - identity = 0 - commutative = True - associative = False - nfunc_spec = ("mean", 2, 1) - nfunc_variadic = "mean" - - def impl(self, *inputs): - return sum(inputs) / len(inputs) - - def c_code(self, node, name, inputs, outputs, sub): - (z,) = outputs - if not inputs: - return f"{z} = 0;" - else: - return f"{z} = ({' + '.join(inputs)}) / ((double) {len(inputs)});" - - def L_op(self, inputs, outputs, gout): - (gz,) = gout - retval = [gz / len(inputs)] * len(inputs) - return retval - - -mean = Mean(float_out, name="mean") - - class Mul(ScalarOp): identity = 1 commutative = True @@ -2591,7 +2630,7 @@ def c_code(self, node, name, inputs, outputs, sub): if type in float_types: return f"{z} = fabs({x});" if type in complex_types: - return f"{z} = sqrt({x}.real*{x}.real + {x}.imag*{x}.imag);" + return f"{z} = sqrt(get_real({x}) * get_real({x}) + get_imag({x}) * get_imag({x}));" if node.outputs[0].type == bool: return f"{z} = ({x}) ? 1 : 0;" if type in uint_types: @@ -3148,7 +3187,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * exp2(x) * log(np.cast[x.type](2)),) + return (gz * exp2(x) * log(np.asarray(2, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3395,7 +3434,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (-gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (-gz / sqrt(np.asarray(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3469,7 +3508,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (gz / sqrt(np.asarray(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3541,7 +3580,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) + sqr(x)),) + return (gz / (np.asarray(1, dtype=x.type) + sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3664,7 +3703,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) - np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) - np.asarray(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3741,7 +3780,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) + np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) + np.asarray(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3819,7 +3858,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) - sqr(x)),) + return (gz / (np.asarray(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -4115,7 +4154,9 @@ def c_support_code(self, **kwargs): def c_support_code_apply(self, node, name): rval = [] - for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): + for subnode, subnodename in zip( + self.fgraph.toposort(), self.nodenames, strict=True + ): subnode_support_code = subnode.op.c_support_code_apply( subnode, subnodename % dict(nodename=name) ) @@ -4221,7 +4262,7 @@ def __init__(self, inputs, outputs, name="Composite"): res2 = pytensor.compile.rebuild_collect_shared( inputs=outputs[0].owner.op.inputs, outputs=outputs[0].owner.op.outputs, - replace=dict(zip(outputs[0].owner.op.inputs, res[1])), + replace=dict(zip(outputs[0].owner.op.inputs, res[1], strict=True)), ) assert len(res2[1]) == len(outputs) assert len(res[0]) == len(inputs) @@ -4311,7 +4352,7 @@ def make_node(self, *inputs): assert len(inputs) == self.nin res = pytensor.compile.rebuild_collect_shared( self.outputs, - replace=dict(zip(self.inputs, inputs)), + replace=dict(zip(self.inputs, inputs, strict=True)), rebuild_strict=False, ) # After rebuild_collect_shared, the Variable in inputs @@ -4324,7 +4365,8 @@ def make_node(self, *inputs): def perform(self, node, inputs, output_storage): outputs = self.py_perform_fn(*inputs) - for storage, out_val in zip(output_storage, outputs): + # strict=False because we are in a hot loop + for storage, out_val in zip(output_storage, outputs, strict=False): storage[0] = out_val def grad(self, inputs, output_grads): @@ -4394,8 +4436,8 @@ def c_code_template(self): def c_code(self, node, nodename, inames, onames, sub): d = dict( chain( - zip((f"i{int(i)}" for i in range(len(inames))), inames), - zip((f"o{int(i)}" for i in range(len(onames))), onames), + zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True), + zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True), ), **sub, ) @@ -4443,7 +4485,7 @@ def apply(self, fgraph): ) # make sure we don't produce any float16. assert not any(o.dtype == "float16" for o in new_node.outputs) - mapping.update(zip(node.outputs, new_node.outputs)) + mapping.update(zip(node.outputs, new_node.outputs, strict=True)) new_ins = [mapping[inp] for inp in fgraph.inputs] new_outs = [mapping[out] for out in fgraph.outputs] @@ -4486,7 +4528,7 @@ def handle_composite(node, mapping): new_op = node.op.clone_float32() new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True) assert len(new_outs) == len(node.outputs) - for o, no in zip(node.outputs, new_outs): + for o, no in zip(node.outputs, new_outs, strict=True): mapping[o] = no diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 59664374f9..0b59195722 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -93,7 +93,7 @@ def _validate_updates( ) else: update = outputs - for i, u in zip(init, update): + for i, u in zip(init, update, strict=False): if i.type != u.type: raise TypeError( "Init and update types must be the same: " @@ -166,7 +166,7 @@ def make_node(self, n_steps, *inputs): # Make a new op with the right input types. res = rebuild_collect_shared( self.outputs, - replace=dict(zip(self.inputs, inputs)), + replace=dict(zip(self.inputs, inputs, strict=True)), rebuild_strict=False, ) if self.is_while: @@ -207,7 +207,8 @@ def perform(self, node, inputs, output_storage): for i in range(n_steps): carry = inner_fn(*carry, *constant) - for storage, out_val in zip(output_storage, carry): + # strict=False because we are in a hot loop + for storage, out_val in zip(output_storage, carry, strict=False): storage[0] = out_val @property @@ -295,7 +296,7 @@ def c_code_template(self): # Set the carry variables to the output variables _c_code += "\n" - for init, update in zip(carry_subd.values(), update_subd.values()): + for init, update in zip(carry_subd.values(), update_subd.values(), strict=True): _c_code += f"{init} = {update};\n" # _c_code += 'printf("%%ld\\n", i);\n' @@ -321,8 +322,8 @@ def c_code_template(self): def c_code(self, node, nodename, inames, onames, sub): d = dict( chain( - zip((f"i{int(i)}" for i in range(len(inames))), inames), - zip((f"o{int(i)}" for i in range(len(onames))), onames), + zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True), + zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True), ), **sub, ) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index e3379492fa..a5512c6564 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1281,6 +1281,38 @@ def c_code(self, *args, **kwargs): ive = Ive(upgrade_to_float, name="ive") +class Kve(BinaryScalarOp): + """Exponentially scaled modified Bessel function of the second kind of real order v.""" + + nfunc_spec = ("scipy.special.kve", 2, 1) + + @staticmethod + def st_impl(v, x): + return scipy.special.kve(v, x) + + def impl(self, v, x): + return self.st_impl(v, x) + + def L_op(self, inputs, outputs, output_grads): + v, x = inputs + [kve_vx] = outputs + [g_out] = output_grads + # (1 -v/x) * kve(v, x) - kve(v - 1, x) + kve_vm1x = self(v - 1, x) + dx = (1 - v / x) * kve_vx - kve_vm1x + + return [ + grad_not_implemented(self, 0, v), + g_out * dx, + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +kve = Kve(upgrade_to_float, name="kve") + + class Sigmoid(UnaryScalarOp): """ Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 931e105597..dcae273aef 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -484,7 +484,7 @@ def wrap_into_list(x): n_fixed_steps = int(n_steps) else: try: - n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps) + n_fixed_steps = pt.get_scalar_constant_value(n_steps) except NotScalarConstantError: n_fixed_steps = None @@ -892,7 +892,9 @@ def wrap_into_list(x): if condition is not None: outputs.append(condition) fake_nonseqs = [x.type() for x in non_seqs] - fake_outputs = clone_replace(outputs, replace=dict(zip(non_seqs, fake_nonseqs))) + fake_outputs = clone_replace( + outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True)) + ) all_inputs = filter( lambda x: ( isinstance(x, Variable) @@ -1055,7 +1057,7 @@ def wrap_into_list(x): if not isinstance(arg, SharedVariable | Constant) ] - inner_replacements.update(dict(zip(other_scan_args, other_inner_args))) + inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) if strict: non_seqs_set = set(non_sequences if non_sequences is not None else []) @@ -1077,7 +1079,7 @@ def wrap_into_list(x): ] inner_replacements.update( - dict(zip(other_shared_scan_args, other_shared_inner_args)) + dict(zip(other_shared_scan_args, other_shared_inner_args, strict=True)) ) ## diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 4f6dc7e0be..a01347ef9c 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -74,7 +74,6 @@ from pytensor.graph.replace import clone_replace from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.link.c.basic import CLinker -from pytensor.link.c.exceptions import MissingGXX from pytensor.printing import op_debug_information from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new from pytensor.tensor.basic import as_tensor_variable @@ -170,7 +169,7 @@ def check_broadcast(v1, v2): ) size = min(v1.type.ndim, v2.type.ndim) for n, (b1, b2) in enumerate( - zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:]) + zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:], strict=False) ): if b1 != b2: a1 = n + size - v1.type.ndim + 1 @@ -577,6 +576,7 @@ def get_oinp_iinp_iout_oout_mappings(self): inner_input_indices, inner_output_indices, outer_output_indices, + strict=True, ): if oout != -1: mappings["outer_inp_from_outer_out"][oout] = oinp @@ -958,7 +958,7 @@ def make_node(self, *inputs): # them have the same dtype argoffset = 0 for inner_seq, outer_seq in zip( - self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs) + self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs), strict=True ): check_broadcast(outer_seq, inner_seq) new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq)) @@ -977,6 +977,7 @@ def make_node(self, *inputs): self.info.mit_mot_in_slices, self.info.mit_mot_out_slices[: self.info.n_mit_mot], self.outer_mitmot(inputs), + strict=True, ) ): outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos]) @@ -1031,6 +1032,7 @@ def make_node(self, *inputs): self.info.mit_sot_in_slices, self.outer_mitsot(inputs), self.inner_mitsot_outs(self.inner_outputs), + strict=True, ) ): outer_mitsot = copy_var_format(_outer_mitsot, as_var=inner_mitsots[ipos]) @@ -1083,6 +1085,7 @@ def make_node(self, *inputs): self.inner_sitsot(self.inner_inputs), self.outer_sitsot(inputs), self.inner_sitsot_outs(self.inner_outputs), + strict=True, ) ): outer_sitsot = copy_var_format(_outer_sitsot, as_var=inner_sitsot) @@ -1130,6 +1133,7 @@ def make_node(self, *inputs): self.inner_shared(self.inner_inputs), self.inner_shared_outs(self.inner_outputs), self.outer_shared(inputs), + strict=True, ) ): outer_shared = copy_var_format(_outer_shared, as_var=inner_shared) @@ -1188,7 +1192,9 @@ def make_node(self, *inputs): # type of tensor as the output, it is always a scalar int. new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)] for inner_nonseq, _outer_nonseq in zip( - self.inner_non_seqs(self.inner_inputs), self.outer_non_seqs(inputs) + self.inner_non_seqs(self.inner_inputs), + self.outer_non_seqs(inputs), + strict=True, ): outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq) new_inputs.append(outer_nonseq) @@ -1271,7 +1277,10 @@ def __eq__(self, other): if len(self.inner_outputs) != len(other.inner_outputs): return False - for self_in, other_in in zip(self.inner_inputs, other.inner_inputs): + # strict=False because length already compared above + for self_in, other_in in zip( + self.inner_inputs, other.inner_inputs, strict=False + ): if self_in.type != other_in.type: return False @@ -1406,7 +1415,7 @@ def prepare_fgraph(self, fgraph): fgraph.attach_feature( Supervisor( inp - for spec, inp in zip(wrapped_inputs, fgraph.inputs) + for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True) if not ( getattr(spec, "mutable", None) or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp])) @@ -1489,6 +1498,7 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): then it must not do so for variables in the no_recycling list. """ + from pytensor.link.c.exceptions import MissingGXX # Before building the thunk, validate that the inner graph is # coherent @@ -2086,7 +2096,9 @@ def perform(self, node, inputs, output_storage): jout = j + offset_out output_storage[j][0] = inner_output_storage[jout].storage[0] - pos = [(idx + 1) % store for idx, store in zip(pos, store_steps)] + pos = [ + (idx + 1) % store for idx, store in zip(pos, store_steps, strict=True) + ] i = i + 1 # 6. Check if you need to re-order output buffers @@ -2171,7 +2183,7 @@ def perform(self, node, inputs, output_storage): def infer_shape(self, fgraph, node, input_shapes): # input_shapes correspond to the shapes of node.inputs - for inp, inp_shp in zip(node.inputs, input_shapes): + for inp, inp_shp in zip(node.inputs, input_shapes, strict=True): assert inp_shp is None or len(inp_shp) == inp.type.ndim # Here we build 2 variables; @@ -2240,7 +2252,9 @@ def infer_shape(self, fgraph, node, input_shapes): # Non-sequences have a direct equivalent from self.inner_inputs in # node.inputs inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :] - out_equivalent.update(zip(inner_non_sequences, node.inputs[offset:])) + out_equivalent.update( + zip(inner_non_sequences, node.inputs[offset:], strict=True) + ) if info.as_while: self_outs = self.inner_outputs[:-1] @@ -2274,7 +2288,7 @@ def infer_shape(self, fgraph, node, input_shapes): r = node.outputs[n_outs + x] assert r.ndim == 1 + len(out_shape_x) shp = [node.inputs[offset + info.n_shared_outs + x]] - for i, shp_i in zip(range(1, r.ndim), out_shape_x): + for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True): # Validate shp_i. v_shape_i is either None (if invalid), # or a (variable, Boolean) tuple. The Boolean indicates # whether variable is shp_i (if True), or an valid @@ -2296,7 +2310,7 @@ def infer_shape(self, fgraph, node, input_shapes): if info.as_while: scan_outs_init = scan_outs scan_outs = [] - for o, x in zip(node.outputs, scan_outs_init): + for o, x in zip(node.outputs, scan_outs_init, strict=True): if x is None: scan_outs.append(None) else: @@ -2572,7 +2586,9 @@ def compute_all_gradients(known_grads): dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx]) else: disconnected_dC_dinps_t[dx] = False - for Xt, Xt_placeholder in zip(diff_outputs[info.n_mit_mot_outs :], Xts): + for Xt, Xt_placeholder in zip( + diff_outputs[info.n_mit_mot_outs :], Xts, strict=True + ): tmp = forced_replace(dC_dinps_t[dx], Xt, Xt_placeholder) dC_dinps_t[dx] = tmp @@ -2652,7 +2668,9 @@ def compute_all_gradients(known_grads): n = n_steps.tag.test_value else: n = inputs[0].tag.test_value - for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)): + for taps, x in zip( + info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True + ): mintap = np.min(taps) if hasattr(x[::-1][:mintap], "test_value"): assert x[::-1][:mintap].tag.test_value.shape[0] == n @@ -2667,7 +2685,9 @@ def compute_all_gradients(known_grads): assert x[::-1].tag.test_value.shape[0] == n outer_inp_seqs += [ x[::-1][: np.min(taps)] - for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)) + for taps, x in zip( + info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True + ) ] outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] @@ -2998,6 +3018,7 @@ def compute_all_gradients(known_grads): zip( outputs[offset : offset + info.n_seqs], type_outs[offset : offset + info.n_seqs], + strict=True, ) ): if t == "connected": @@ -3027,7 +3048,7 @@ def compute_all_gradients(known_grads): gradients.append(NullType(t)()) end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot - for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])): + for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end], strict=True)): if t == "connected": # If the forward scan is in as_while mode, we need to pad # the gradients, so that they match the size of the input @@ -3062,7 +3083,7 @@ def compute_all_gradients(known_grads): for idx in range(info.n_shared_outs): disconnected = True connected_flags = self.connection_pattern(node)[idx + start] - for dC_dout, connected in zip(dC_douts, connected_flags): + for dC_dout, connected in zip(dC_douts, connected_flags, strict=True): if not isinstance(dC_dout.type, DisconnectedType) and connected: disconnected = False if disconnected: @@ -3079,7 +3100,9 @@ def compute_all_gradients(known_grads): begin = end end = begin + n_sitsot_outs - for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])): + for p, (x, t) in enumerate( + zip(outputs[begin:end], type_outs[begin:end], strict=True) + ): if t == "connected": gradients.append(x[-1]) elif t == "disconnected": @@ -3156,7 +3179,7 @@ def R_op(self, inputs, eval_points): e = 1 + info.n_seqs ie = info.n_seqs clean_eval_points = [] - for inp, evp in zip(inputs[b:e], eval_points[b:e]): + for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True): if evp is not None: clean_eval_points.append(evp) else: @@ -3171,7 +3194,7 @@ def R_op(self, inputs, eval_points): ib = ie ie = ie + int(sum(len(x) for x in info.mit_mot_in_slices)) clean_eval_points = [] - for inp, evp in zip(inputs[b:e], eval_points[b:e]): + for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True): if evp is not None: clean_eval_points.append(evp) else: @@ -3186,7 +3209,7 @@ def R_op(self, inputs, eval_points): ib = ie ie = ie + int(sum(len(x) for x in info.mit_sot_in_slices)) clean_eval_points = [] - for inp, evp in zip(inputs[b:e], eval_points[b:e]): + for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True): if evp is not None: clean_eval_points.append(evp) else: @@ -3201,7 +3224,7 @@ def R_op(self, inputs, eval_points): ib = ie ie = ie + info.n_sit_sot clean_eval_points = [] - for inp, evp in zip(inputs[b:e], eval_points[b:e]): + for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True): if evp is not None: clean_eval_points.append(evp) else: @@ -3225,7 +3248,7 @@ def R_op(self, inputs, eval_points): # All other arguments clean_eval_points = [] - for inp, evp in zip(inputs[e:], eval_points[e:]): + for inp, evp in zip(inputs[e:], eval_points[e:], strict=True): if evp is not None: clean_eval_points.append(evp) else: diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index ab4f5b6a77..2ba282d8d6 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -54,7 +54,7 @@ from pytensor.tensor.basic import ( Alloc, AllocEmpty, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -71,7 +71,7 @@ get_slice_elements, set_subtensor, ) -from pytensor.tensor.variable import TensorConstant, get_unique_constant_value +from pytensor.tensor.variable import TensorConstant list_opt_slice = [ @@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): all_ins = list(graph_inputs(op_outs)) for idx in range(op_info.n_seqs): node_inp = node.inputs[idx + 1] - if ( - isinstance(node_inp, TensorConstant) - and get_unique_constant_value(node_inp) is not None - ): + if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None: try: # This works if input is a constant that has all entries # equal @@ -166,7 +163,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): # Look through non sequences nw_inner_nonseq = [] nw_outer_nonseq = [] - for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)): + for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs, strict=True)): if isinstance(nw_out, Constant): givens[nw_in] = nw_out elif nw_in in all_ins: @@ -203,7 +200,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): allow_gc=op.allow_gc, ) nw_outs = nwScan(*nw_outer, return_list=True) - return dict([("remove", [node]), *zip(node.outputs, nw_outs)]) + return dict([("remove", [node]), *zip(node.outputs, nw_outs, strict=True)]) else: return False @@ -348,7 +345,7 @@ def add_to_replace(y): nw_outer = [] nw_inner = [] for to_repl, repl_in, repl_out in zip( - clean_to_replace, clean_replace_with_in, clean_replace_with_out + clean_to_replace, clean_replace_with_in, clean_replace_with_out, strict=True ): if isinstance(repl_out, Constant): repl_in = repl_out @@ -380,7 +377,7 @@ def add_to_replace(y): # Do not call make_node for test_value nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner - replacements = dict(zip(node.outputs, nw_node.outputs)) + replacements = dict(zip(node.outputs, nw_node.outputs, strict=True)) replacements["remove"] = [node] return replacements elif not to_keep_set: @@ -584,7 +581,7 @@ def add_to_replace(y): nw_outer = [] nw_inner = [] for to_repl, repl_in, repl_out in zip( - clean_to_replace, clean_replace_with_in, clean_replace_with_out + clean_to_replace, clean_replace_with_in, clean_replace_with_out, strict=True ): if isinstance(repl_out, Constant): repl_in = repl_out @@ -616,7 +613,7 @@ def add_to_replace(y): return_list=True, )[0].owner - replacements = dict(zip(node.outputs, nw_node.outputs)) + replacements = dict(zip(node.outputs, nw_node.outputs, strict=True)) replacements["remove"] = [node] return replacements @@ -668,8 +665,10 @@ def inner_sitsot_only_last_step_used( client = fgraph.clients[outer_var][0][0] if isinstance(client, Apply) and isinstance(client.op, Subtensor): lst = get_idx_list(client.inputs, client.op.idx_list) - if len(lst) == 1 and pt.extract_constant(lst[0]) == -1: - return True + return ( + len(lst) == 1 + and get_scalar_constant_value(lst[0], raise_not_constant=False) == -1 + ) return False @@ -814,7 +813,7 @@ def add_nitsot_outputs( # replacements["remove"] = [old_scan_node] # return new_scan_node, replacements fgraph.replace_all_validate_remove( # type: ignore - list(zip(old_scan_node.outputs, new_node_old_outputs)), + list(zip(old_scan_node.outputs, new_node_old_outputs, strict=True)), remove=[old_scan_node], reason="scan_pushout_add", ) @@ -1020,7 +1019,7 @@ def attempt_scan_inplace( # This whole rewrite should be a simple local rewrite, but, because # of this awful approach, it can't be. fgraph.replace_all_validate_remove( # type: ignore - list(zip(node.outputs, new_outs)), + list(zip(node.outputs, new_outs, strict=True)), remove=[node], reason="scan_make_inplace", ) @@ -1344,10 +1343,17 @@ def scan_save_mem(fgraph, node): if isinstance(this_slice[0], slice) and this_slice[0].stop is None: global_nsteps = None if isinstance(cf_slice[0], slice): - stop = pt.extract_constant(cf_slice[0].stop) + stop = get_scalar_constant_value( + cf_slice[0].stop, raise_not_constant=False + ) else: - stop = pt.extract_constant(cf_slice[0]) + 1 - if stop == maxsize or stop == pt.extract_constant(length): + stop = ( + get_scalar_constant_value(cf_slice[0], raise_not_constant=False) + + 1 + ) + if stop == maxsize or stop == get_scalar_constant_value( + length, raise_not_constant=False + ): stop = None else: # there is a **gotcha** here ! Namely, scan returns an @@ -1451,9 +1457,13 @@ def scan_save_mem(fgraph, node): cf_slice = get_canonical_form_slice(this_slice[0], length) if isinstance(cf_slice[0], slice): - start = pt.extract_constant(cf_slice[0].start) + start = pt.get_scalar_constant_value( + cf_slice[0].start, raise_not_constant=False + ) else: - start = pt.extract_constant(cf_slice[0]) + start = pt.get_scalar_constant_value( + cf_slice[0], raise_not_constant=False + ) if start == 0 or store_steps[i] == 0: store_steps[i] = 0 @@ -1628,7 +1638,7 @@ def scan_save_mem(fgraph, node): # 3.6 Compose the new scan # TODO: currently we don't support scan with 0 step. So # don't create one. - if pt.extract_constant(node_ins[0]) == 0: + if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0: return False # Do not call make_node for test_value @@ -1941,7 +1951,7 @@ def merge(self, nodes): if not isinstance(new_outs, list | tuple): new_outs = [new_outs] - return list(zip(outer_outs, new_outs)) + return list(zip(outer_outs, new_outs, strict=True)) def belongs_to_set(self, node, set_nodes): """ @@ -1965,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes): nsteps = node.inputs[0] try: - nsteps = int(get_underlying_scalar_constant_value(nsteps)) + nsteps = int(get_scalar_constant_value(nsteps)) except NotScalarConstantError: pass rep_nsteps = rep_node.inputs[0] try: - rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps)) + rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) except NotScalarConstantError: pass @@ -2010,7 +2020,9 @@ def belongs_to_set(self, node, set_nodes): ] inner_inputs = op.inner_inputs rep_inner_inputs = rep_op.inner_inputs - for nominal_input, rep_nominal_input in zip(nominal_inputs, rep_nominal_inputs): + for nominal_input, rep_nominal_input in zip( + nominal_inputs, rep_nominal_inputs, strict=True + ): conds.append(node.inputs[mapping[inner_inputs.index(nominal_input)]]) rep_conds.append( rep_node.inputs[rep_mapping[rep_inner_inputs.index(rep_nominal_input)]] @@ -2067,7 +2079,7 @@ def make_equiv(lo, li): seeno = {} left = [] right = [] - for o, i in zip(lo, li): + for o, i in zip(lo, li, strict=True): if o in seeno: left += [i] right += [o] @@ -2104,7 +2116,7 @@ def scan_merge_inouts(fgraph, node): if has_duplicates(a.outer_in_seqs): new_outer_seqs = [] new_inner_seqs = [] - for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs): + for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs, strict=True): if out_seq in new_outer_seqs: i = new_outer_seqs.index(out_seq) inp_equiv[in_seq] = new_inner_seqs[i] @@ -2117,7 +2129,9 @@ def scan_merge_inouts(fgraph, node): if has_duplicates(a.outer_in_non_seqs): new_outer_nseqs = [] new_inner_nseqs = [] - for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs): + for out_nseq, in_nseq in zip( + a.outer_in_non_seqs, a.inner_in_non_seqs, strict=True + ): if out_nseq in new_outer_nseqs: i = new_outer_nseqs.index(out_nseq) inp_equiv[in_nseq] = new_inner_nseqs[i] @@ -2180,7 +2194,7 @@ def scan_merge_inouts(fgraph, node): if has_duplicates(na.outer_in_mit_mot): seen = {} for omm, imm, _sl in zip( - na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices + na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices, strict=True ): sl = tuple(_sl) if (omm, sl) in seen: @@ -2193,7 +2207,7 @@ def scan_merge_inouts(fgraph, node): if has_duplicates(na.outer_in_mit_sot): seen = {} for oms, ims, _sl in zip( - na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices + na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices, strict=True ): sl = tuple(_sl) if (oms, sl) in seen: @@ -2227,7 +2241,7 @@ def map_out(outer_i, inner_o, outer_o, seen): na.outer_out_nit_sot = [ map_out(outer_i, inner_o, outer_o, seen) for outer_i, inner_o, outer_o in zip( - na.outer_in_nit_sot, na.inner_out_nit_sot, na.outer_out_nit_sot + na.outer_in_nit_sot, na.inner_out_nit_sot, na.outer_out_nit_sot, strict=True ) ] @@ -2237,7 +2251,7 @@ def map_out(outer_i, inner_o, outer_o, seen): na.outer_out_sit_sot = [ map_out(outer_i, inner_o, outer_o, seen) for outer_i, inner_o, outer_o in zip( - na.outer_in_sit_sot, na.inner_out_sit_sot, na.outer_out_sit_sot + na.outer_in_sit_sot, na.inner_out_sit_sot, na.outer_out_sit_sot, strict=True ) ] @@ -2247,7 +2261,7 @@ def map_out(outer_i, inner_o, outer_o, seen): na.outer_out_mit_sot = [ map_out(outer_i, inner_o, outer_o, seen) for outer_i, inner_o, outer_o in zip( - na.outer_in_mit_sot, na.inner_out_mit_sot, na.outer_out_mit_sot + na.outer_in_mit_sot, na.inner_out_mit_sot, na.outer_out_mit_sot, strict=True ) ] @@ -2261,6 +2275,7 @@ def map_out(outer_i, inner_o, outer_o, seen): na.inner_out_mit_mot, na.outer_out_mit_mot, na.mit_mot_out_slices, + strict=True, ): for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: if ( @@ -2275,7 +2290,9 @@ def map_out(outer_i, inner_o, outer_o, seen): new_outer_out_mit_mot.append(outer_omm) na.outer_out_mit_mot = new_outer_out_mit_mot if remove: - return dict([("remove", remove), *zip(node.outputs, na.outer_outputs)]) + return dict( + [("remove", remove), *zip(node.outputs, na.outer_outputs, strict=True)] + ) return na.outer_outputs @@ -2300,7 +2317,7 @@ def scan_push_out_dot1(fgraph, node): sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) outer_sitsot = op.outer_sitsot_outs(node.outputs) seqs = op.inner_seqs(op.inner_inputs) - for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot): + for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot, strict=True): if ( out.owner and isinstance(out.owner.op, Elemwise) @@ -2453,10 +2470,12 @@ def scan_push_out_dot1(fgraph, node): new_out = dot(val, out_seq) pos = node.outputs.index(outer_out) - old_new = list(zip(node.outputs[:pos], new_outs[:pos])) + old_new = list(zip(node.outputs[:pos], new_outs[:pos], strict=True)) old = fgraph.clients[node.outputs[pos]][0][0].outputs[0] old_new.append((old, new_out)) - old_new += list(zip(node.outputs[pos + 1 :], new_outs[pos:])) + old_new += list( + zip(node.outputs[pos + 1 :], new_outs[pos:], strict=True) + ) replacements = dict(old_new) replacements["remove"] = [node] return replacements diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index c55820eb68..611012b97e 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -559,7 +559,7 @@ def reconstruct_graph(inputs, outputs, tag=None): tag = "" nw_inputs = [safe_new(x, tag) for x in inputs] - givens = {x: nw_x for nw_x, x in zip(nw_inputs, inputs)} + givens = {x: nw_x for nw_x, x in zip(nw_inputs, inputs, strict=True)} nw_outputs = clone_replace(outputs, replace=givens) return (nw_inputs, nw_outputs) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 1a3ca4ffdf..7f200b2a7c 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -491,6 +491,10 @@ def __str__(self): def __repr__(self): return str(self) + @property + def unique_value(self): + return None + SparseTensorType.variable_type = SparseVariable SparseTensorType.constant_type = SparseConstant @@ -2848,7 +2852,7 @@ def choose(continuous, derivative): else: return None - return [choose(c, d) for c, d in zip(is_continuous, derivative)] + return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] def infer_shape(self, fgraph, node, ins_shapes): def _get(l): @@ -2927,7 +2931,7 @@ def choose(continuous, derivative): else: return None - return [choose(c, d) for c, d in zip(is_continuous, derivative)] + return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] def infer_shape(self, fgraph, node, ins_shapes): def _get(l): @@ -3606,7 +3610,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3643,11 +3647,11 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp nnz = PyArray_DIMS({_indices})[0]; npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; @@ -3740,7 +3744,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3778,11 +3782,11 @@ def c_code(self, node, name, inputs, outputs, sub): // extract number of rows npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index bf6d6f0bc6..13735d2aca 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -158,8 +158,8 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{y}* ydata = (dtype_{y}*)PyArray_DATA({y}); dtype_{z}* zdata = (dtype_{z}*)PyArray_DATA({z}); - npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_DESCR({y})->elsize; - npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; + npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_ITEMSIZE({y}); + npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); npy_intp pos; if ({format} == 0){{ @@ -186,7 +186,7 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[3]] def c_code_cache_version(self): - return (2,) + return (3,) @node_rewriter([sparse.AddSD]) @@ -361,13 +361,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -436,7 +436,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3,) + return (4,) sd_csc = StructuredDotCSC() @@ -555,13 +555,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -614,7 +614,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ def c_code_cache_version(self): - return (2,) + return (3,) sd_csr = StructuredDotCSR() @@ -845,12 +845,12 @@ def c_code(self, node, name, inputs, outputs, sub): const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA({x_ptr}); const dtype_{alpha} alpha = ((dtype_{alpha}*)PyArray_DATA({alpha}))[0]; - npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_DESCR({zn})->elsize; - npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_DESCR({x_val})->elsize; - npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_DESCR({x_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_DESCR({x_ptr})->elsize; - npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_DESCR({y})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_ITEMSIZE({zn}); + npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_ITEMSIZE({x_val}); + npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_ITEMSIZE({x_ind}); + npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_ITEMSIZE({x_ptr}); + npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_ITEMSIZE({y}); // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction if ((N > 0x7fffffffL)||(Sy > 0x7fffffffL)||(Szn > 0x7fffffffL)||(Sy < -0x7fffffffL)||(Szn < -0x7fffffffL)) @@ -896,7 +896,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3, blas.blas_header_version()) + return (4, blas.blas_header_version()) usmm_csc_dense = UsmmCscDense(inplace=False) @@ -1035,13 +1035,13 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; - npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_DESCR({b_val})->elsize; - npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_DESCR({b_ind})->elsize; - npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_DESCR({b_ptr})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); + npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_ITEMSIZE({b_val}); + npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_ITEMSIZE({b_ind}); + npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_ITEMSIZE({b_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -1086,7 +1086,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ def c_code_cache_version(self): - return (3,) + return (4,) csm_grad_c = CSMGradC() @@ -1482,7 +1482,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (2,) + return (3,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1544,7 +1544,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over rows for (npy_intp j = 0; j < N; ++j) @@ -1655,7 +1655,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (3,) + return (4,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1723,7 +1723,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over columns for (npy_intp j = 0; j < N; ++j) @@ -1868,7 +1868,7 @@ def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols): ) def c_code_cache_version(self): - return (4, blas.blas_header_version()) + return (5, blas.blas_header_version()) def c_support_code(self, **kwargs): return blas.blas_header_text() @@ -1995,14 +1995,14 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{z_ind}* __restrict__ Dzi = (dtype_{z_ind}*)PyArray_DATA({z_ind}); dtype_{z_ptr}* __restrict__ Dzp = (dtype_{z_ptr}*)PyArray_DATA({z_ptr}); - const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_DESCR({x})->elsize; - const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; - const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_DESCR({p_data})->elsize; - const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_DESCR({p_ind})->elsize; - const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_DESCR({p_ptr})->elsize; - const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_DESCR({z_data})->elsize; - const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_DESCR({z_ind})->elsize; - const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_DESCR({z_ptr})->elsize; + const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_ITEMSIZE({x}); + const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); + const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_ITEMSIZE({p_data}); + const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_ITEMSIZE({p_ind}); + const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_ITEMSIZE({p_ptr}); + const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_ITEMSIZE({z_data}); + const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_ITEMSIZE({z_ind}); + const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_ITEMSIZE({z_ptr}); memcpy(Dzi, Dpi, PyArray_DIMS({p_ind})[0]*sizeof(dtype_{p_ind})); memcpy(Dzp, Dpp, PyArray_DIMS({p_ptr})[0]*sizeof(dtype_{p_ptr})); diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 7385f02478..88d3f33199 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -123,11 +123,12 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # isort: on # Allow accessing numpy constants from pytensor.tensor -from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi +from numpy import e, euler_gamma, inf, nan, newaxis, pi from pytensor.tensor.basic import * from pytensor.tensor.blas import batched_dot, batched_tensordot from pytensor.tensor.extra_ops import * +from pytensor.tensor.interpolate import interp, interpolate1d from pytensor.tensor.io import * from pytensor.tensor.math import * from pytensor.tensor.pad import pad diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d5236d04a..9700630b25 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -14,12 +14,20 @@ from typing import cast as type_cast import numpy as np -from numpy.core.multiarray import normalize_axis_index -from numpy.core.numeric import normalize_axis_tuple + + +try: + from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.multiarray import normalize_axis_index + from numpy.core.numeric import normalize_axis_tuple + import pytensor import pytensor.scalar.sharedvar -from pytensor import compile, config, printing +from pytensor import config, printing from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType, grad_undefined @@ -35,7 +43,7 @@ from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.raise_op import CheckAndRaise, assert_op from pytensor.scalar import int32 -from pytensor.scalar.basic import ScalarConstant, ScalarVariable +from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable from pytensor.tensor import ( _as_tensor_variable, _get_vector_length, @@ -71,10 +79,10 @@ uint_dtypes, values_eq_approx_always_true, ) +from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import ( TensorConstant, TensorVariable, - get_unique_constant_value, ) @@ -228,7 +236,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: elif x_.ndim > ndim: try: x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim))) - except np.AxisError: + except np.exceptions.AxisError: raise ValueError( f"ndarray could not be cast to constant with {int(ndim)} dimensions" ) @@ -268,27 +276,7 @@ def _obj_is_wrappable_as_tensor(x): ) -def get_scalar_constant_value( - v, elemwise=True, only_process_constants=False, max_recur=10 -): - """ - Checks whether 'v' is a scalar (ndim = 0). - - If 'v' is a scalar then this function fetches the underlying constant by calling - 'get_underlying_scalar_constant_value()'. - - If 'v' is not a scalar, it raises a NotScalarConstantError. - - """ - if isinstance(v, Variable | np.ndarray): - if v.ndim != 0: - raise NotScalarConstantError() - return get_underlying_scalar_constant_value( - v, elemwise, only_process_constants, max_recur - ) - - -def get_underlying_scalar_constant_value( +def _get_underlying_scalar_constant_value( orig_v, elemwise=True, only_process_constants=False, max_recur=10 ): """Return the constant scalar(0-D) value underlying variable `v`. @@ -319,6 +307,10 @@ def get_underlying_scalar_constant_value( but I'm not sure where it is. """ + from pytensor.compile.ops import DeepCopyOp, OutputGuard + from pytensor.sparse import CSM + from pytensor.tensor.subtensor import Subtensor + v = orig_v while True: if v is None: @@ -336,40 +328,28 @@ def get_underlying_scalar_constant_value( raise NotScalarConstantError() if isinstance(v, Constant): - unique_value = get_unique_constant_value(v) - if unique_value is not None: - data = unique_value - else: - data = v.data + if isinstance(v.type, TensorType) and v.unique_value is not None: + return v.unique_value - if isinstance(data, np.ndarray): - try: - return np.array(data.item(), dtype=v.dtype) - except ValueError: - raise NotScalarConstantError() - - from pytensor.sparse.type import SparseTensorType + elif isinstance(v.type, ScalarType): + return v.data - if isinstance(v.type, SparseTensorType): - raise NotScalarConstantError() + elif isinstance(v.type, NoneTypeT): + return None - return data + raise NotScalarConstantError() if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: + op = v.owner.op max_recur -= 1 if isinstance( - v.owner.op, - Alloc - | DimShuffle - | Unbroadcast - | compile.ops.OutputGuard - | compile.DeepCopyOp, + op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp ): # OutputGuard is only used in debugmode but we # keep it here to avoid problems with old pickles v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, Shape_i): + elif isinstance(op, Shape_i): i = v.owner.op.i inp = v.owner.inputs[0] if isinstance(inp, Constant): @@ -383,19 +363,19 @@ def get_underlying_scalar_constant_value( # mess with the stabilization optimization and be too slow. # We put all the scalar Ops used by get_canonical_form_slice() # to allow it to determine the broadcast pattern correctly. - elif isinstance(v.owner.op, ScalarFromTensor | TensorFromScalar): + elif isinstance(op, ScalarFromTensor | TensorFromScalar): v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, CheckAndRaise): + elif isinstance(op, CheckAndRaise): # check if all conditions are constant and true conds = [ - get_underlying_scalar_constant_value(c, max_recur=max_recur) + _get_underlying_scalar_constant_value(c, max_recur=max_recur) for c in v.owner.inputs[1:] ] if builtins.all(0 == c.ndim and c != 0 for c in conds): v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, ps.ScalarOp): + elif isinstance(op, ps.ScalarOp): if isinstance(v.owner.op, ps.Second): # We don't need both input to be constant for second shp, val = v.owner.inputs @@ -403,7 +383,7 @@ def get_underlying_scalar_constant_value( continue if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops): const = [ - get_underlying_scalar_constant_value(i, max_recur=max_recur) + _get_underlying_scalar_constant_value(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] @@ -412,7 +392,7 @@ def get_underlying_scalar_constant_value( # In fast_compile, we don't enable local_fill_to_alloc, so # we need to investigate Second as Alloc. So elemwise # don't disable the check for Second. - elif isinstance(v.owner.op, Elemwise): + elif isinstance(op, Elemwise): if isinstance(v.owner.op.scalar_op, ps.Second): # We don't need both input to be constant for second shp, val = v.owner.inputs @@ -422,16 +402,13 @@ def get_underlying_scalar_constant_value( v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops ): const = [ - get_underlying_scalar_constant_value(i, max_recur=max_recur) + _get_underlying_scalar_constant_value(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] v.owner.op.perform(v.owner, const, ret) return np.asarray(ret[0][0].copy()) - elif ( - isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor) - and v.ndim == 0 - ): + elif isinstance(op, Subtensor) and v.ndim == 0: if isinstance(v.owner.inputs[0], TensorConstant): from pytensor.tensor.subtensor import get_constant_idx @@ -468,7 +445,7 @@ def get_underlying_scalar_constant_value( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) try: @@ -502,14 +479,13 @@ def get_underlying_scalar_constant_value( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) - # Python 2.4 does not support indexing with numpy.integer - # So we cast it. - idx = int(idx) ret = v.owner.inputs[0].owner.inputs[idx] - ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur) + ret = _get_underlying_scalar_constant_value( + ret, max_recur=max_recur + ) # MakeVector can cast implicitly its input in some case. return np.asarray(ret, dtype=v.type.dtype) @@ -524,7 +500,7 @@ def get_underlying_scalar_constant_value( idx_list = op.idx_list idx = idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( owner.inputs[1], max_recur=max_recur ) grandparent = leftmost_parent.owner.inputs[0] @@ -534,7 +510,9 @@ def get_underlying_scalar_constant_value( grandparent.owner.op, Unbroadcast ): ggp_shape = grandparent.owner.inputs[0].type.shape - l = [get_underlying_scalar_constant_value(s) for s in ggp_shape] + l = [ + _get_underlying_scalar_constant_value(s) for s in ggp_shape + ] gp_shape = tuple(l) if not (idx < ndim): @@ -555,10 +533,105 @@ def get_underlying_scalar_constant_value( if isinstance(grandparent, Constant): return np.asarray(np.shape(grandparent.data)[idx]) + elif isinstance(op, CSM): + data = _get_underlying_scalar_constant_value( + v.owner.inputs, elemwise=elemwise, max_recur=max_recur + ) + # Sparse variable can only be constant if zero (or I guess if homogeneously dense) + if data == 0: + return data + break raise NotScalarConstantError() +def get_underlying_scalar_constant_value( + v, + *, + elemwise=True, + only_process_constants=False, + max_recur=10, + raise_not_constant=True, +): + """Return the unique constant scalar(0-D) value underlying variable `v`. + + If `v` is the output of dimshuffles, fills, allocs, etc, + cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise + and some pattern with Subtensor, this function digs through them. + + If `v` is not some view of constant scalar data, then raise a + NotScalarConstantError. + + This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by + constant folding the inputs of `v`. + + Parameters + ---------- + v: Variable + elemwise : bool + If False, we won't try to go into elemwise. So this call is faster. + But we still investigate in Second Elemwise (as this is a substitute + for Alloc) + only_process_constants : bool + If True, we only attempt to obtain the value of `orig_v` if it's + directly constant and don't try to dig through dimshuffles, fills, + allocs, and other to figure out its value. + max_recur : int + The maximum number of recursion. + raise_not_constant: bool, default True + If True, raise a NotScalarConstantError if `v` does not have an + underlying constant scalar value. If False, return `v` as is. + + + Raises + ------ + NotScalarConstantError + `v` does not have an underlying constant scalar value. + Only rasise if raise_not_constant is True. + + """ + try: + return _get_underlying_scalar_constant_value( + v, + elemwise=elemwise, + only_process_constants=only_process_constants, + max_recur=max_recur, + ) + except NotScalarConstantError: + if raise_not_constant: + raise + return v + + +def get_scalar_constant_value( + v, + elemwise=True, + only_process_constants=False, + max_recur=10, + raise_not_constant: bool = True, +): + """ + Checks whether 'v' is a scalar (ndim = 0). + + If 'v' is a scalar then this function fetches the underlying constant by calling + 'get_underlying_scalar_constant_value()'. + + If 'v' is not a scalar, it raises a NotScalarConstantError. + + """ + if isinstance(v, TensorVariable | np.ndarray): + if v.ndim != 0: + print(v, v.ndim) + raise NotScalarConstantError("Input ndim != 0") + return get_underlying_scalar_constant_value( + v, + elemwise=elemwise, + only_process_constants=only_process_constants, + max_recur=max_recur, + raise_not_constant=raise_not_constant, + ) + + class TensorFromScalar(COp): __props__ = () @@ -1544,6 +1617,7 @@ def make_node(self, value, *shape): extended_value_broadcastable, extended_value_static_shape, static_shape, + strict=True, ) ): # If value is not broadcastable and we don't know the target static shape: use value static shape @@ -1564,7 +1638,7 @@ def make_node(self, value, *shape): def _check_runtime_broadcast(node, value, shape): value_static_shape = node.inputs[0].type.shape for v_static_dim, value_dim, out_dim in zip( - value_static_shape[::-1], value.shape[::-1], shape[::-1] + value_static_shape[::-1], value.shape[::-1], shape[::-1], strict=False ): if v_static_dim is None and value_dim == 1 and out_dim != 1: raise ValueError(Alloc._runtime_broadcast_error_msg) @@ -1667,6 +1741,7 @@ def grad(self, inputs, grads): inputs[0].type.shape, # We need the dimensions corresponding to x grads[0].type.shape[-inputs[0].ndim :], + strict=False, ) ): if ib == 1 and gb != 1: @@ -1741,7 +1816,7 @@ def do_constant_folding(self, fgraph, node): @_get_vector_length.register(Alloc) def _get_vector_length_Alloc(var_inst, var): try: - return get_underlying_scalar_constant_value(var.owner.inputs[1]) + return get_scalar_constant_value(var.owner.inputs[1]) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -2013,16 +2088,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False): ScalarVariable, we convert it to a tensor with tensor_from_scalar. """ - try: - x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants) - except NotScalarConstantError: - pass - if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable): - if x.owner and isinstance(x.owner.op, ScalarFromTensor): - x = x.owner.inputs[0] - else: - x = tensor_from_scalar(x) - return x + warnings.warn( + "extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`", + FutureWarning, + ) + return get_underlying_scalar_constant_value( + x, + elemwise=elemwise, + only_process_constants=only_process_constants, + raise_not_constant=False, + ) def transpose(x, axes=None): @@ -2192,7 +2267,7 @@ def grad(self, inputs, g_outputs): ] # Else, we have to make them zeros before joining them new_g_outputs = [] - for o, g in zip(outputs, g_outputs): + for o, g in zip(outputs, g_outputs, strict=True): if isinstance(g.type, DisconnectedType): new_g_outputs.append(o.zeros_like()) else: @@ -2442,7 +2517,7 @@ def make_node(self, axis, *tensors): if not isinstance(axis, int): try: - axis = int(get_underlying_scalar_constant_value(axis)) + axis = int(get_scalar_constant_value(axis)) except NotScalarConstantError: pass @@ -2629,7 +2704,7 @@ def grad(self, axis_and_tensors, grads): else specify_broadcastable( g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1) ) - for t, g in zip(tens, split_gz) + for t, g in zip(tens, split_gz, strict=True) ] rval = rval + split_gz else: @@ -2686,7 +2761,7 @@ def infer_shape(self, fgraph, node, ishapes): def _get_vector_length_Join(op, var): axis, *arrays = var.owner.inputs try: - axis = get_underlying_scalar_constant_value(axis) + axis = get_scalar_constant_value(axis) assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) return builtins.sum(get_vector_length(a) for a in arrays) except NotScalarConstantError: @@ -2741,7 +2816,7 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs): ): batch_ndims = { batch_input.type.ndim - old_input.type.ndim - for batch_input, old_input in zip(batch_inputs, old_inputs) + for batch_input, old_input in zip(batch_inputs, old_inputs, strict=True) } if len(batch_ndims) == 1: [batch_ndim] = batch_ndims @@ -3324,7 +3399,7 @@ def __getitem__(self, *args): tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j)) for j, r in enumerate(ranges) ] - ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)] + ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes, strict=True)] if self.sparse: grids = ranges else: @@ -3396,7 +3471,7 @@ def make_node(self, x, y, inverse): out_shape = [ 1 if xb == 1 and yb == 1 else None - for xb, yb in zip(x.type.shape, y.type.shape) + for xb, yb in zip(x.type.shape, y.type.shape, strict=True) ] out_type = tensor(dtype=x.type.dtype, shape=out_shape) @@ -3461,7 +3536,8 @@ def perform(self, node, inp, out): # Make sure the output is big enough out_s = [] - for xdim, ydim in zip(x_s, y_s): + # strict=False because we are in a hot loop + for xdim, ydim in zip(x_s, y_s, strict=False): if xdim == ydim: outdim = xdim elif xdim == 1: @@ -3521,7 +3597,7 @@ def grad(self, inp, grads): assert gx.type.ndim == x.type.ndim assert all( s1 == s2 - for s1, s2 in zip(gx.type.shape, x.type.shape) + for s1, s2 in zip(gx.type.shape, x.type.shape, strict=True) if s1 == 1 or s2 == 1 ) @@ -3967,7 +4043,7 @@ def moveaxis( order = [n for n in range(a.ndim) if n not in source] - for dest, src in sorted(zip(destination, source)): + for dest, src in sorted(zip(destination, source, strict=True)): order.insert(dest, src) result = a.dimshuffle(order) @@ -4078,7 +4154,7 @@ def make_node(self, a, choices): static_out_shape = () for s in out_shape: try: - s_val = pytensor.get_underlying_scalar_constant(s) + s_val = get_scalar_constant_value(s) except (NotScalarConstantError, AttributeError): s_val = None @@ -4293,7 +4369,7 @@ def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVa axis = (axis,) out_ndim = len(axis) + a.ndim - axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim) + axis = normalize_axis_tuple(axis, out_ndim) if not axis: return a @@ -4315,7 +4391,7 @@ def _make_along_axis_idx(arr_shape, indices, axis): # build a fancy index, consisting of orthogonal aranges, with the # requested index inserted at the right location fancy_index = [] - for dim, n in zip(dest_dims, arr_shape): + for dim, n in zip(dest_dims, arr_shape, strict=True): if dim is None: fancy_index.append(indices) else: @@ -4401,7 +4477,6 @@ def ix_(*args): "split", "transpose", "matrix_transpose", - "extract_constant", "default", "tensor_copy", "transfer", diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 6170a02a98..592a4ba27c 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -79,7 +79,6 @@ import logging import os import shlex -import time from pathlib import Path import numpy as np @@ -103,58 +102,25 @@ from pytensor.tensor import basic as ptb from pytensor.tensor.basic import expand_dims from pytensor.tensor.blas_headers import blas_header_text, blas_header_version -from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import add, mul, neg, sub, variadic_add from pytensor.tensor.shape import shape_padright, specify_broadcastable -from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor +from pytensor.tensor.type import DenseTensorType, tensor _logger = logging.getLogger("pytensor.tensor.blas") -try: - import scipy.linalg.blas - - have_fblas = True - try: - fblas = scipy.linalg.blas.fblas - except AttributeError: - # A change merged in Scipy development version on 2012-12-02 replaced - # `scipy.linalg.blas.fblas` with `scipy.linalg.blas`. - # See http://github.com/scipy/scipy/pull/358 - fblas = scipy.linalg.blas - _blas_gemv_fns = { - np.dtype("float32"): fblas.sgemv, - np.dtype("float64"): fblas.dgemv, - np.dtype("complex64"): fblas.cgemv, - np.dtype("complex128"): fblas.zgemv, - } -except ImportError as e: - have_fblas = False - # This is used in Gemv and ScipyGer. We use CGemv and CGer - # when config.blas__ldflags is defined. So we don't need a - # warning in that case. - if not config.blas__ldflags: - _logger.warning( - "Failed to import scipy.linalg.blas, and " - "PyTensor flag blas__ldflags is empty. " - "Falling back on slower implementations for " - "dot(matrix, vector), dot(vector, matrix) and " - f"dot(vector, vector) ({e!s})" - ) - # If check_init_y() == True we need to initialize y when beta == 0. def check_init_y(): + # TODO: What is going on here? + from scipy.linalg.blas import get_blas_funcs + if check_init_y._result is None: - if not have_fblas: # pragma: no cover - check_init_y._result = False - else: - y = float("NaN") * np.ones((2,)) - x = np.ones((2,)) - A = np.ones((2, 2)) - gemv = _blas_gemv_fns[y.dtype] - gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) - check_init_y._result = np.isnan(y).any() + y = float("NaN") * np.ones((2,)) + x = np.ones((2,)) + A = np.ones((2, 2)) + gemv = get_blas_funcs("gemv", dtype=y.dtype) + gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) + check_init_y._result = np.isnan(y).any() return check_init_y._result @@ -211,14 +177,15 @@ def make_node(self, y, alpha, A, x, beta): return Apply(self, inputs, [y.type()]) def perform(self, node, inputs, out_storage): + from scipy.linalg.blas import get_blas_funcs + y, alpha, A, x, beta = inputs if ( - have_fblas - and y.shape[0] != 0 + y.shape[0] != 0 and x.shape[0] != 0 - and y.dtype in _blas_gemv_fns + and y.dtype in {"float32", "float64", "complex64", "complex128"} ): - gemv = _blas_gemv_fns[y.dtype] + gemv = get_blas_funcs("gemv", dtype=y.dtype) if A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]: raise ValueError( @@ -531,7 +498,7 @@ def c_header_dirs(self, **kwargs): int unit = 0; int type_num = PyArray_DESCR(%(_x)s)->type_num; - int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes npy_intp* Nx = PyArray_DIMS(%(_x)s); npy_intp* Ny = PyArray_DIMS(%(_y)s); @@ -822,7 +789,7 @@ def build_gemm_call(self): ) def build_gemm_version(self): - return (13, blas_header_version()) + return (14, blas_header_version()) class Gemm(GemmRelated): @@ -1063,7 +1030,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(x_new, %(_x)s) == -1) + if(PyArray_CopyInto(x_new, %(_x)s) == -1) { %(fail)s } @@ -1089,7 +1056,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(y_new, %(_y)s) == -1) + if(PyArray_CopyInto(y_new, %(_y)s) == -1) { %(fail)s } @@ -1135,7 +1102,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): gv = self.build_gemm_version() if gv: - return (7, *gv) + return (8, *gv) else: return gv @@ -1148,322 +1115,6 @@ def c_code_cache_version(self): pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"])) -def res_is_a(fgraph, var, op, maxclients=None): - if maxclients is not None and var in fgraph.clients: - retval = len(fgraph.get_clients(var)) <= maxclients - else: - retval = True - - return var.owner and var.owner.op == op and retval - - -def _as_scalar(res, dtype=None): - """Return ``None`` or a `TensorVariable` of float type""" - if dtype is None: - dtype = config.floatX - if all(s == 1 for s in res.type.shape): - while res.owner and isinstance(res.owner.op, DimShuffle): - res = res.owner.inputs[0] - # may still have some number of True's - if res.type.ndim > 0: - rval = res.dimshuffle() - else: - rval = res - if rval.type.dtype in integer_dtypes: - # We check that the upcast of res and dtype won't change dtype. - # If dtype is float64, we will cast int64 to float64. - # This is valid when res is a scalar used as input to a dot22 - # as the cast of the scalar can be done before or after the dot22 - # and this will give the same result. - if pytensor.scalar.upcast(res.dtype, dtype) == dtype: - return ptb.cast(rval, dtype) - else: - return None - - return rval - - -def _is_real_matrix(res): - return ( - res.type.dtype in ("float16", "float32", "float64") - and res.type.ndim == 2 - and res.type.shape[0] != 1 - and res.type.shape[1] != 1 - ) # cope with tuple vs. list - - -def _is_real_vector(res): - return ( - res.type.dtype in ("float16", "float32", "float64") - and res.type.ndim == 1 - and res.type.shape[0] != 1 - ) - - -def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True): - # print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip - # EXPRESSION: (beta * L) + (alpha * M) - - # we've already checked the client counts, now just make the type check. - # if res_is_a(M, _dot22, 1): - if M.owner and M.owner.op == _dot22: - Ml, Mr = M.owner.inputs - rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] - return rval, M - - # it also might be the case that there is a dimshuffle between the + - # and the dot22. local_dot_to_dot22 in particular will put in such things. - if ( - M.owner - and isinstance(M.owner.op, DimShuffle) - and M.owner.inputs[0].owner - and isinstance(M.owner.inputs[0].owner.op, Dot22) - ): - MM = M.owner.inputs[0] - if M.owner.op.new_order == (0,): - # it is making a column MM into a vector - MMl, MMr = MM.owner.inputs - g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta) - rval = [g.dimshuffle(0)] - return rval, MM - if M.owner.op.new_order == (1,): - # it is making a row MM into a vector - MMl, MMr = MM.owner.inputs - g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta) - rval = [g.dimshuffle(1)] - return rval, MM - if len(M.owner.op.new_order) == 0: - # it is making a row MM into a vector - MMl, MMr = MM.owner.inputs - g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta) - rval = [g.dimshuffle()] - return rval, MM - - if recurse_flip: - return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False) - else: - return False, False - - -def _gemm_canonicalize(fgraph, r, scale, rval, maxclients): - # Tries to interpret node as a sum of scalars * (vectors or matrices) - def scaled(thing): - if scale == 1: - return thing - if scale == -1 and thing.type.dtype != "bool": - return -thing - else: - return scale * thing - - if not isinstance(r.type, TensorType): - return None - - if (r.type.ndim not in (1, 2)) or r.type.dtype not in ( - "float16", - "float32", - "float64", - "complex64", - "complex128", - ): - rval.append(scaled(r)) - return rval - - if maxclients and len(fgraph.clients[r]) > maxclients: - rval.append((scale, r)) - return rval - - if r.owner and r.owner.op == sub: - _gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1) - _gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1) - - elif r.owner and r.owner.op == add: - for i in r.owner.inputs: - _gemm_canonicalize(fgraph, i, scale, rval, 1) - - elif r.owner and r.owner.op == neg: - _gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1) - - elif r.owner and r.owner.op == mul: - scalars = [] - vectors = [] - matrices = [] - for i in r.owner.inputs: - if all(s == 1 for s in i.type.shape): - while i.owner and isinstance(i.owner.op, DimShuffle): - i = i.owner.inputs[0] - if i.type.ndim > 0: - scalars.append(i.dimshuffle()) - else: - scalars.append(i) - elif _is_real_vector(i): - vectors.append(i) - elif _is_real_matrix(i): - matrices.append(i) - else: - # just put the original arguments as in the base case - rval.append((scale, r)) - return rval - if len(matrices) == 1: - assert len(vectors) == 0 - m = matrices[0] - if len(scalars) == 0: - _gemm_canonicalize(fgraph, m, scale, rval, 1) - elif len(scalars) == 1: - _gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1) - else: - _gemm_canonicalize( - fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 - ) - elif len(vectors) == 1: - assert len(matrices) == 0 - v = vectors[0] - if len(scalars) == 0: - _gemm_canonicalize(fgraph, v, scale, rval, 1) - elif len(scalars) == 1: - _gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1) - else: - _gemm_canonicalize( - fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 - ) - else: # lets not open this up - rval.append((scale, r)) - else: - rval.append((scale, r)) - return rval - - -def _factor_canonicalized(lst): - # remove duplicates from canonicalized list - - # we only delete out of the right end of the list, - # once i has touched a list element, it is permantent - lst = list(lst) - # print 'FACTOR', lst - # for t in lst: - # if not isinstance(t, (list, tuple)): - # t = (t,) - # for e in t: - # try: - # pytensor.printing.debugprint(e) - # except TypeError: - # print e, type(e) - i = 0 - while i < len(lst) - 1: - try: - s_i, M_i = lst[i] - except Exception: - i += 1 - continue - - j = i + 1 - while j < len(lst): - try: - s_j, M_j = lst[j] - except Exception: - j += 1 - continue - - if M_i is M_j: - s_i = s_i + s_j - lst[i] = (s_i, M_i) - del lst[j] - else: - j += 1 - i += 1 - return lst - - -def _gemm_from_factored_list(fgraph, lst): - """ - Returns None, or a list to replace node.outputs. - - """ - lst2 = [] - # Remove the tuple that can't be cast correctly. - # This can happen when we try to cast a complex to a real - for sM in lst: - # Make every pair in list have matching dtypes - # sM can be a tuple of 2 elements or an PyTensor variable. - if isinstance(sM, tuple): - sm0, sm1 = sM - sm0 = ptb.as_tensor_variable(sm0) - if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: - lst2.append((ptb.cast(sm0, sm1.dtype), sM[1])) - - lst = lst2 - - def item_to_var(t): - try: - s, M = t - except Exception: - return t - if s == 1: - return M - if s == -1: - return -M - return s * M - - # Try every pair in the sM_list, trying to turn it into a gemm operation - for i in range(len(lst) - 1): - s_i, M_i = lst[i] - - for j in range(i + 1, len(lst)): - s_j, M_j = lst[j] - - if not M_j.type.in_same_class(M_i.type): - continue - - # print 'TRYING', (s_i, M_i, s_j, M_j) - - gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M( - fgraph, s_i, M_i, s_j, M_j - ) - # print 'GOT IT', gemm_of_sM_list - if gemm_of_sM_list: - assert len(gemm_of_sM_list) == 1 - add_inputs = [ - item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) - ] - add_inputs.extend(gemm_of_sM_list) - rval = [variadic_add(*add_inputs)] - return rval, old_dot22 - - -def _gemm_from_node2(fgraph, node): - """ - - TODO: In many expressions, there are many ways to turn it into a - gemm. For example dot(a,b) + c + d. This function should return all - of them, so that if one version of gemm causes a cycle in the graph, then - another application of gemm can be tried. - - """ - lst = [] - t0 = time.perf_counter() - _gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0) - t1 = time.perf_counter() - - if len(lst) > 1: - lst = _factor_canonicalized(lst) - t2 = time.perf_counter() - rval = _gemm_from_factored_list(fgraph, lst) - t3 = time.perf_counter() - - # It can happen that _factor_canonicalized and - # _gemm_from_factored_list return a node with an incorrect - # type. This happens in particular when one of the scalar - # factors forces the upcast of the whole expression. In that - # case, we simply skip that candidate for Gemm. This was - # discussed in - # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, - # but never made it into a trac ticket. - - if rval and rval[0][0].type.in_same_class(node.outputs[0].type): - return rval, t1 - t0, t2 - t1, t3 - t2 - - return None, t1 - t0, 0, 0 - - class Dot22(GemmRelated): """Compute a matrix-matrix product. @@ -1887,7 +1538,7 @@ def contiguous(var, ndim): return f""" int type_num = PyArray_DESCR({_x})->type_num; - int type_size = PyArray_DESCR({_x})->elsize; // in bytes + int type_size = PyArray_ITEMSIZE({_x}); // in bytes if (PyArray_NDIM({_x}) != 3) {{ PyErr_Format(PyExc_NotImplementedError, @@ -1947,7 +1598,7 @@ def contiguous(var, ndim): def c_code_cache_version(self): from pytensor.tensor.blas_headers import blas_header_version - return (5, blas_header_version()) + return (6, blas_header_version()) def grad(self, inp, grads): x, y = inp diff --git a/pytensor/tensor/blas_headers.py b/pytensor/tensor/blas_headers.py index 2806bfc41d..39c92e975b 100644 --- a/pytensor/tensor/blas_headers.py +++ b/pytensor/tensor/blas_headers.py @@ -1052,7 +1052,7 @@ def openblas_threads_text(): def blas_header_version(): # Version for the base header - version = (9,) + version = (10,) if detect_macos_sdot_bug(): if detect_macos_sdot_bug.fix_works: # Version with fix @@ -1070,7 +1070,7 @@ def ____gemm_code(check_ab, a_init, b_init): const char * error_string = NULL; int type_num = PyArray_DESCR(_x)->type_num; - int type_size = PyArray_DESCR(_x)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(_x); // in bytes npy_intp* Nx = PyArray_DIMS(_x); npy_intp* Ny = PyArray_DIMS(_y); diff --git a/pytensor/tensor/blas_scipy.py b/pytensor/tensor/blas_scipy.py index 16fb90988b..bb3ccf9354 100644 --- a/pytensor/tensor/blas_scipy.py +++ b/pytensor/tensor/blas_scipy.py @@ -2,30 +2,19 @@ Implementations of BLAS Ops based on scipy's BLAS bindings. """ -import numpy as np - -from pytensor.tensor.blas import Ger, have_fblas - - -if have_fblas: - from pytensor.tensor.blas import fblas - - _blas_ger_fns = { - np.dtype("float32"): fblas.sger, - np.dtype("float64"): fblas.dger, - np.dtype("complex64"): fblas.cgeru, - np.dtype("complex128"): fblas.zgeru, - } +from pytensor.tensor.blas import Ger class ScipyGer(Ger): def perform(self, node, inputs, output_storage): + from scipy.linalg.blas import get_blas_funcs + cA, calpha, cx, cy = inputs (cZ,) = output_storage # N.B. some versions of scipy (e.g. mine) don't actually work # in-place on a, even when I tell it to. A = cA - local_ger = _blas_ger_fns[cA.dtype] + local_ger = get_blas_funcs("ger", dtype=cA.dtype) if A.size == 0: # We don't have to compute anything, A is empty. # We need this special case because Numpy considers it diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 662ddbcdd1..b3366f21af 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -6,7 +6,8 @@ from pytensor import config from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType -from pytensor.graph.basic import Apply, Constant +from pytensor.graph import FunctionGraph +from pytensor.graph.basic import Apply, Constant, ancestors from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.graph.replace import ( @@ -91,7 +92,7 @@ def __init__( def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: core_input_types = [] - for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): if inp.type.ndim < len(sig): raise ValueError( f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}" @@ -109,7 +110,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: raise ValueError( f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}" ) - for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)): + for i, (core_out, sig) in enumerate( + zip(core_node.outputs, self.outputs_sig, strict=True) + ): if core_out.type.ndim != len(sig): raise ValueError( f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}" @@ -123,12 +126,13 @@ def make_node(self, *inputs): core_node = self._create_dummy_core_node(inputs) batch_ndims = max( - inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig) + inp.type.ndim - len(sig) + for inp, sig in zip(inputs, self.inputs_sig, strict=True) ) batched_inputs = [] batch_shapes = [] - for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): # Append missing dims to the left missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig)) if missing_batch_ndims: @@ -143,7 +147,7 @@ def make_node(self, *inputs): try: batch_shape = tuple( broadcast_static_dim_lengths(batch_dims) - for batch_dims in zip(*batch_shapes) + for batch_dims in zip(*batch_shapes, strict=True) ) except ValueError: raise ValueError( @@ -169,10 +173,10 @@ def infer_shape( batch_ndims = self.batch_ndim(node) core_dims: dict[str, Any] = {} batch_shapes = [input_shape[:batch_ndims] for input_shape in input_shapes] - for input_shape, sig in zip(input_shapes, self.inputs_sig): + for input_shape, sig in zip(input_shapes, self.inputs_sig, strict=True): core_shape = input_shape[batch_ndims:] - for core_dim, dim_name in zip(core_shape, sig): + for core_dim, dim_name in zip(core_shape, sig, strict=True): prev_core_dim = core_dims.get(core_dim) if prev_core_dim is None: core_dims[dim_name] = core_dim @@ -182,15 +186,40 @@ def infer_shape( batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) + # Try to extract the core shapes from the core_op + core_op_infer_shape = getattr(self.core_op, "infer_shape", None) + if core_op_infer_shape is not None: + dummy_core_node = self._create_dummy_core_node(node.inputs) + dummy_core_inputs = dummy_core_node.inputs + dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) + core_input_shapes = [ + input_shape[batch_ndims:] for input_shape in input_shapes + ] + core_output_shapes = core_op_infer_shape( + dummy_fgraph, dummy_core_node, core_input_shapes + ) + out_shapes = [] - for output, sig in zip(node.outputs, self.outputs_sig): + for o, (output, sig) in enumerate( + zip(node.outputs, self.outputs_sig, strict=True) + ): core_out_shape = [] for i, dim_name in enumerate(sig): # The output dim is the same as another input dim if dim_name in core_dims: core_out_shape.append(core_dims[dim_name]) else: - # TODO: We could try to make use of infer_shape of core_op + if core_op_infer_shape is not None: + # If the input values are needed to compute the dimension length, we can't use the infer_shape + # of the core_node as the value is not constant across batch dims of the Blockwise + core_out_dim = core_output_shapes[o][i] + if not ( + set(dummy_core_inputs) & set(ancestors([core_out_dim])) + ): + core_out_shape.append(core_out_dim) + continue + + # Fallback shape requires evaluating the Blockwise Op core_out_shape.append(Shape_i(batch_ndims + i)(output)) out_shapes.append((*batch_shape, *core_out_shape)) @@ -214,17 +243,17 @@ def as_core(t, core_t): with config.change_flags(compute_test_value="off"): safe_inputs = [ tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) - for inp, sig in zip(inputs, self.inputs_sig) + for inp, sig in zip(inputs, self.inputs_sig, strict=True) ] core_node = self._create_dummy_core_node(safe_inputs) core_inputs = [ as_core(inp, core_inp) - for inp, core_inp in zip(inputs, core_node.inputs) + for inp, core_inp in zip(inputs, core_node.inputs, strict=True) ] core_ograds = [ as_core(ograd, core_ograd) - for ograd, core_ograd in zip(ograds, core_node.outputs) + for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True) ] core_outputs = core_node.outputs @@ -233,7 +262,11 @@ def as_core(t, core_t): igrads = vectorize_graph( [core_igrad for core_igrad in core_igrads if core_igrad is not None], replace=dict( - zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds) + zip( + core_inputs + core_outputs + core_ograds, + inputs + outputs + ograds, + strict=True, + ) ), ) @@ -259,7 +292,7 @@ def L_op(self, inputs, outs, ograds): # the return value obviously zero so that gradient.grad can tell # this op did the right thing. new_rval = [] - for elem, inp in zip(rval, inputs): + for elem, inp in zip(rval, inputs, strict=True): if isinstance(elem.type, NullType | DisconnectedType): new_rval.append(elem) else: @@ -273,7 +306,7 @@ def L_op(self, inputs, outs, ograds): # Sum out the broadcasted dimensions batch_ndims = self.batch_ndim(outs[0].owner) batch_shape = outs[0].type.shape[:batch_ndims] - for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): if isinstance(rval[i].type, NullType | DisconnectedType): continue @@ -281,7 +314,9 @@ def L_op(self, inputs, outs, ograds): to_sum = [ j - for j, (inp_s, out_s) in enumerate(zip(inp.type.shape, batch_shape)) + for j, (inp_s, out_s) in enumerate( + zip(inp.type.shape, batch_shape, strict=False) + ) if inp_s == 1 and out_s != 1 ] if to_sum: @@ -333,11 +368,17 @@ def core_func( def _check_runtime_broadcast(self, node, inputs): batch_ndim = self.batch_ndim(node) + # strict=False because we are in a hot loop for dims_and_bcast in zip( *[ - zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim]) - for input, sinput in zip(inputs, node.inputs) - ] + zip( + input.shape[:batch_ndim], + sinput.type.broadcastable[:batch_ndim], + strict=False, + ) + for input, sinput in zip(inputs, node.inputs, strict=False) + ], + strict=False, ): if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: raise ValueError( @@ -360,7 +401,10 @@ def perform(self, node, inputs, output_storage): if not isinstance(res, tuple): res = (res,) - for node_out, out_storage, r in zip(node.outputs, output_storage, res): + # strict=False because we are in a hot loop + for node_out, out_storage, r in zip( + node.outputs, output_storage, res, strict=False + ): out_dtype = getattr(node_out, "dtype", None) if out_dtype and out_dtype != r.dtype: r = np.asarray(r, dtype=out_dtype) @@ -398,3 +442,11 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: class OpWithCoreShape(OpFromGraph): """Generalizes an `Op` to include core shape as an additional input.""" + + +class BlockwiseWithCoreShape(OpWithCoreShape): + """Generalizes a Blockwise `Op` to include a core shape parameter.""" + + def __str__(self): + [blockwise_node] = self.fgraph.apply_nodes + return f"[{blockwise_node.op!s}]" diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 73d402cfca..fc937bf404 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -8,6 +8,7 @@ from math import gcd import numpy as np +from numpy.exceptions import ComplexWarning try: @@ -25,7 +26,7 @@ from pytensor.raise_op import Assert from pytensor.tensor.basic import ( as_tensor_variable, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -497,8 +498,8 @@ def check_dim(given, computed): if given is None or computed is None: return True try: - given = get_underlying_scalar_constant_value(given) - computed = get_underlying_scalar_constant_value(computed) + given = get_scalar_constant_value(given) + computed = get_scalar_constant_value(computed) return int(given) == int(computed) except NotScalarConstantError: # no answer possible, accept for now @@ -506,7 +507,7 @@ def check_dim(given, computed): return all( check_dim(given, computed) - for (given, computed) in zip(output_shape, computed_output_shape) + for (given, computed) in zip(output_shape, computed_output_shape, strict=True) ) @@ -534,7 +535,7 @@ def assert_conv_shape(shape): out_shape = [] for i, n in enumerate(shape): try: - const_n = get_underlying_scalar_constant_value(n) + const_n = get_scalar_constant_value(n) if i < 2: if const_n < 0: raise ValueError( @@ -2203,9 +2204,7 @@ def __init__( if imshp_i is not None: # Components of imshp should be constant or ints try: - get_underlying_scalar_constant_value( - imshp_i, only_process_constants=True - ) + get_scalar_constant_value(imshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "imshp should be None or a tuple of constant int values" @@ -2218,9 +2217,7 @@ def __init__( if kshp_i is not None: # Components of kshp should be constant or ints try: - get_underlying_scalar_constant_value( - kshp_i, only_process_constants=True - ) + get_scalar_constant_value(kshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "kshp should be None or a tuple of constant int values" @@ -2342,7 +2339,7 @@ def conv( bval = _bvalfromboundary("fill") with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) for b in range(img.shape[0]): for g in range(self.num_groups): for n in range(output_channel_offset): diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index fb93f378bf..f35817d0b8 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -6,11 +6,27 @@ from typing import cast import numpy as np -from numpy.core.einsumfunc import _find_contraction, _parse_einsum_input # type: ignore -from numpy.core.numeric import ( # type: ignore - normalize_axis_index, - normalize_axis_tuple, -) + + +try: + from numpy._core.einsumfunc import ( # type: ignore + _find_contraction, + _parse_einsum_input, + ) +except ModuleNotFoundError as e: + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.einsumfunc import ( # type: ignore + _find_contraction, + _parse_einsum_input, + ) + +try: + from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.multiarray import normalize_axis_index # type: ignore + from numpy.core.numeric import normalize_axis_tuple # type: ignore from pytensor.compile.builders import OpFromGraph from pytensor.tensor import TensorLike @@ -255,7 +271,7 @@ def _general_dot( .. testoutput:: - (3, 4, 2) + (np.int64(3), np.int64(4), np.int64(2)) """ # Shortcut for non batched case if not batch_axes[0] and not batch_axes[1]: @@ -303,7 +319,7 @@ def _general_dot( lhs_signature = [f"l{i}" for i in range(lhs.type.ndim)] rhs_signature = [f"r{i}" for i in range(rhs.type.ndim)] # Aligned axes get the same dimension name - for i, (lhs_axis, rhs_axis) in enumerate(zip(lhs_axes, rhs_axes)): + for i, (lhs_axis, rhs_axis) in enumerate(zip(lhs_axes, rhs_axes, strict=True)): lhs_signature[lhs_axis] = rhs_signature[rhs_axis] = f"a{i}" # Trim away the batch ndims lhs_signature = lhs_signature[lhs_n_batch_axes:] @@ -703,7 +719,10 @@ def filter_singleton_dims(operand, names, other_operand, other_names): if batch_names: lhs_batch, rhs_batch = tuple( - zip(*[(lhs_names.find(n), rhs_names.find(n)) for n in batch_names]) + zip( + *[(lhs_names.find(n), rhs_names.find(n)) for n in batch_names], + strict=True, + ) ) else: lhs_batch = rhs_batch = () @@ -716,7 +735,8 @@ def filter_singleton_dims(operand, names, other_operand, other_names): *[ (lhs_names.index(n), rhs_names.index(n)) for n in contracted_names - ] + ], + strict=True, ) ) else: diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a51c2034af..d3b566caeb 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,10 +1,18 @@ +import warnings from collections.abc import Sequence from copy import copy from textwrap import dedent from typing import Literal import numpy as np -from numpy.core.numeric import normalize_axis_tuple + + +try: + from numpy.lib.array_utils import normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.numeric import normalize_axis_tuple import pytensor.tensor.basic from pytensor.configdefaults import config @@ -418,7 +426,7 @@ def get_output_info(self, *inputs): out_shapes = [ [ broadcast_static_dim_lengths(shape) - for shape in zip(*[inp.type.shape for inp in inputs]) + for shape in zip(*[inp.type.shape for inp in inputs], strict=True) ] ] * shadow.nout except ValueError: @@ -431,8 +439,7 @@ def get_output_info(self, *inputs): if inplace_pattern: for overwriter, overwritten in inplace_pattern.items(): for out_s, in_s in zip( - out_shapes[overwriter], - inputs[overwritten].type.shape, + out_shapes[overwriter], inputs[overwritten].type.shape, strict=True ): if in_s == 1 and out_s != 1: raise ValueError( @@ -463,7 +470,7 @@ def make_node(self, *inputs): out_dtypes, out_shapes, inputs = self.get_output_info(*inputs) outputs = [ TensorType(dtype=dtype, shape=shape)() - for dtype, shape in zip(out_dtypes, out_shapes) + for dtype, shape in zip(out_dtypes, out_shapes, strict=True) ] return Apply(self, inputs, outputs) @@ -485,7 +492,9 @@ def R_op(self, inputs, eval_points): bgrads = self._bgrad(inputs, outs, ograds) rop_out = None - for jdx, (inp, eval_point) in enumerate(zip(inputs, eval_points)): + for jdx, (inp, eval_point) in enumerate( + zip(inputs, eval_points, strict=True) + ): # if None, then we can just ignore this branch .. # what we do is to assume that for any non-differentiable # branch, the gradient is actually 0, which I think is not @@ -528,7 +537,7 @@ def L_op(self, inputs, outs, ograds): # the return value obviously zero so that gradient.grad can tell # this op did the right thing. new_rval = [] - for elem, ipt in zip(rval, inputs): + for elem, ipt in zip(rval, inputs, strict=True): if isinstance(elem.type, NullType | DisconnectedType): new_rval.append(elem) else: @@ -614,7 +623,7 @@ def transform(r): return new_r ret = [] - for scalar_igrad, ipt in zip(scalar_igrads, inputs): + for scalar_igrad, ipt in zip(scalar_igrads, inputs, strict=True): if scalar_igrad is None: # undefined gradient ret.append(None) @@ -666,7 +675,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): and isinstance(self.nfunc, np.ufunc) and node.inputs[0].dtype in discrete_dtypes ): - char = np.sctype2char(out_dtype) + char = np.dtype(out_dtype).char sig = char * node.nin + "->" + char * node.nout node.tag.sig = sig node.tag.fake_node = Apply( @@ -736,8 +745,9 @@ def perform(self, node, inputs, output_storage): if nout == 1: variables = [variables] + # strict=False because we are in a hot loop for i, (variable, storage, nout) in enumerate( - zip(variables, output_storage, node.outputs) + zip(variables, output_storage, node.outputs, strict=False) ): storage[0] = variable = np.asarray(variable, dtype=nout.dtype) @@ -752,11 +762,13 @@ def perform(self, node, inputs, output_storage): @staticmethod def _check_runtime_broadcast(node, inputs): + # strict=False because we are in a hot loop for dims_and_bcast in zip( *[ - zip(input.shape, sinput.type.broadcastable) - for input, sinput in zip(inputs, node.inputs) - ] + zip(input.shape, sinput.type.broadcastable, strict=False) + for input, sinput in zip(inputs, node.inputs, strict=False) + ], + strict=False, ): if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: raise ValueError( @@ -785,9 +797,11 @@ def _c_all(self, node, nodename, inames, onames, sub): # assert that inames and inputs order stay consistent. # This is to protect again futur change of uniq. assert len(inames) == len(inputs) - ii, iii = list(zip(*uniq(list(zip(_inames, node.inputs))))) - assert all(x == y for x, y in zip(ii, inames)) - assert all(x == y for x, y in zip(iii, inputs)) + ii, iii = list( + zip(*uniq(list(zip(_inames, node.inputs, strict=True))), strict=True) + ) + assert all(x == y for x, y in zip(ii, inames, strict=True)) + assert all(x == y for x, y in zip(iii, inputs, strict=True)) defines = "" undefs = "" @@ -808,9 +822,10 @@ def _c_all(self, node, nodename, inames, onames, sub): zip( *[ (r, s, r.type.dtype_specs()[1]) - for r, s in zip(node.outputs, onames) + for r, s in zip(node.outputs, onames, strict=True) if r not in dmap - ] + ], + strict=True, ) ) if real: @@ -822,7 +837,14 @@ def _c_all(self, node, nodename, inames, onames, sub): # (output, name), transposed (c type name not needed since we don't # need to allocate. aliased = list( - zip(*[(r, s) for (r, s) in zip(node.outputs, onames) if r in dmap]) + zip( + *[ + (r, s) + for (r, s) in zip(node.outputs, onames, strict=True) + if r in dmap + ], + strict=True, + ) ) if aliased: aliased_outputs, aliased_onames = aliased @@ -840,7 +862,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # dimensionality) nnested = len(orders[0]) sub = dict(sub) - for i, (input, iname) in enumerate(zip(inputs, inames)): + for i, (input, iname) in enumerate(zip(inputs, inames, strict=True)): # the c generators will substitute the input names for # references to loop variables lv0, lv1, ... sub[f"lv{i}"] = iname @@ -850,7 +872,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # Check if all inputs (except broadcasted scalar) are fortran. # In that case, create a fortran output ndarray. - z = list(zip(inames, inputs)) + z = list(zip(inames, inputs, strict=True)) alloc_fortran = " && ".join( f"PyArray_ISFORTRAN({arr})" for arr, var in z @@ -865,7 +887,9 @@ def _c_all(self, node, nodename, inames, onames, sub): # We loop over the "real" outputs, i.e., those that are not # inplace (must be allocated) and we declare/allocate/check # them - for output, oname, odtype in zip(real_outputs, real_onames, real_odtypes): + for output, oname, odtype in zip( + real_outputs, real_onames, real_odtypes, strict=True + ): i += 1 # before this loop, i = number of inputs sub[f"lv{i}"] = oname sub["olv"] = oname @@ -882,7 +906,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # inplace (overwrite the contents of one of the inputs) and # make the output pointers point to their corresponding input # pointers. - for output, oname in zip(aliased_outputs, aliased_onames): + for output, oname in zip(aliased_outputs, aliased_onames, strict=True): olv_index = inputs.index(dmap[output][0]) iname = inames[olv_index] # We make the output point to the corresponding input and @@ -943,12 +967,16 @@ def _c_all(self, node, nodename, inames, onames, sub): task_decl = "".join( f"{dtype}& {name}_i = *{name}_iter;\n" for name, dtype in zip( - inames + list(real_onames), idtypes + list(real_odtypes) + inames + list(real_onames), + idtypes + list(real_odtypes), + strict=True, ) ) preloops = {} - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + for i, (loop_order, dtype) in enumerate( + zip(loop_orders, dtypes, strict=True) + ): for j, index in enumerate(loop_order): if index != "x": preloops.setdefault(j, "") @@ -1020,7 +1048,9 @@ def _c_all(self, node, nodename, inames, onames, sub): # assume they will have the same size or all( len(set(inp_shape)) == 1 and None not in inp_shape - for inp_shape in zip(*(inp.type.shape for inp in node.inputs)) + for inp_shape in zip( + *(inp.type.shape for inp in node.inputs), strict=True + ) ) ): z = onames[0] @@ -1029,7 +1059,9 @@ def _c_all(self, node, nodename, inames, onames, sub): npy_intp n = PyArray_SIZE({z}); """ index = "" - for x, var in zip(inames + onames, inputs + node.outputs): + for x, var in zip( + inames + onames, inputs + node.outputs, strict=True + ): if not all(s == 1 for s in var.type.shape): contig += f""" dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}); @@ -1051,7 +1083,7 @@ def _c_all(self, node, nodename, inames, onames, sub): }} """ if contig is not None: - z = list(zip(inames + onames, inputs + node.outputs)) + z = list(zip(inames + onames, inputs + node.outputs, strict=True)) all_broadcastable = all(s == 1 for s in var.type.shape) cond1 = " && ".join( f"PyArray_ISCONTIGUOUS({arr})" diff --git a/pytensor/tensor/elemwise_cgen.py b/pytensor/tensor/elemwise_cgen.py index 7eb422aa0a..5d50f02ad5 100644 --- a/pytensor/tensor/elemwise_cgen.py +++ b/pytensor/tensor/elemwise_cgen.py @@ -10,7 +10,7 @@ def make_declare(loop_orders, dtypes, sub, compute_stride_jump=True): """ decl = "" - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): var = sub[f"lv{i}"] # input name corresponding to ith loop variable # we declare an iteration variable # and an integer for the number of dimensions @@ -35,7 +35,7 @@ def make_declare(loop_orders, dtypes, sub, compute_stride_jump=True): def make_checks(loop_orders, dtypes, sub, compute_stride_jump=True): init = "" - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): var = sub[f"lv{i}"] # List of dimensions of var that are not broadcasted nonx = [x for x in loop_order if x != "x"] @@ -89,7 +89,7 @@ def make_checks(loop_orders, dtypes, sub, compute_stride_jump=True): "If broadcasting was intended, use `specify_broadcastable` on the relevant input." ) - for matches in zip(*loop_orders): + for matches in zip(*loop_orders, strict=True): to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"] # elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx ) @@ -139,7 +139,7 @@ def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str: Note: We could specialize C code even further with the known static output shapes """ dims_c_code = "" - for i, candidates in enumerate(zip(*loop_orders)): + for i, candidates in enumerate(zip(*loop_orders, strict=True)): # Borrow the length of the first non-broadcastable input dimension for j, candidate in enumerate(candidates): if candidate != "x": @@ -209,7 +209,13 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ) -def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): +def make_loop( + loop_orders: list[tuple[int | str, ...]], + dtypes: list, + loop_tasks: list, + sub: dict[str, str], + openmp: bool = False, +): """ Make a nested loop over several arrays and associate specific code to each level of nesting. @@ -227,7 +233,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): string is code to be executed before the ith loop starts, the second one contains code to be executed just before going to the next element of the ith dimension. - The last element if loop_tasks is a single string, containing code + The last element of loop_tasks is a single string, containing code to be executed at the very end. sub : dictionary Maps 'lv#' to a suitable variable name. @@ -260,8 +266,8 @@ def loop_over(preloop, code, indices, i): }} """ - preloops = {} - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + preloops: dict[int, str] = {} + for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): for j, index in enumerate(loop_order): if index != "x": preloops.setdefault(j, "") @@ -277,9 +283,8 @@ def loop_over(preloop, code, indices, i): s = "" - for i, (pre_task, task), indices in reversed( - list(zip(range(len(loop_tasks) - 1), loop_tasks, list(zip(*loop_orders)))) - ): + tasks_indices = zip(loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True) + for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))): s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i) s += loop_tasks[-1] diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 9de2b3f938..72cb5bcd05 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -2,7 +2,16 @@ from collections.abc import Collection, Iterable import numpy as np -from numpy.core.multiarray import normalize_axis_index +from numpy.exceptions import AxisError + + +try: + from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.multiarray import normalize_axis_index + from numpy.core.numeric import normalize_axis_tuple import pytensor import pytensor.scalar.basic as ps @@ -17,8 +26,9 @@ from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import EnumList, Generic +from pytensor.npy_2_compat import npy_2_compat_header from pytensor.raise_op import Assert -from pytensor.scalar import int32 as int_t +from pytensor.scalar import int64 as int_t from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb @@ -298,7 +308,14 @@ def __init__(self, axis: int | None = None, mode="add"): self.axis = axis self.mode = mode - c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) + @property + def c_axis(self) -> int: + if self.axis is None: + if np.__version__ < "2": + return 32 # value used to mark axis = None in Numpy C-API prior to version 2.0 + else: + return np.iinfo(np.int32).min # the value of "NPY_RAVEL_AXIS" + return self.axis def make_node(self, x): x = ptb.as_tensor_variable(x) @@ -355,24 +372,38 @@ def infer_shape(self, fgraph, node, shapes): return shapes + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inames, onames, sub): (x,) = inames (z,) = onames fail = sub["fail"] params = sub["params"] - code = f""" - int axis = {params}->c_axis; + if self.axis is None: + axis_code = "int axis = NPY_RAVEL_AXIS;\n" + else: + axis_code = "int axis = {params}->c_axis;\n" + + code = ( + axis_code + + """ + #undef NPY_UF_DBG_TRACING + #define NPY_UF_DBG_TRACING 1 + if (axis == 0 && PyArray_NDIM({x}) == 1) - axis = NPY_MAXDIMS; + axis = NPY_RAVEL_AXIS; npy_intp shape[1] = {{ PyArray_SIZE({x}) }}; - if(axis == NPY_MAXDIMS && !({z} && PyArray_DIMS({z})[0] == shape[0])) + if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0])) {{ Py_XDECREF({z}); - {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x})); + {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x})); + //{z} = (PyArrayObject*) PyArray_NewLikeArray((PyArrayObject*) PyArray_Ravel({x}, NPY_ANYORDER), NPY_ANYORDER, NULL, 0); }} - else if(axis != NPY_MAXDIMS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) + else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) {{ Py_XDECREF({z}); {z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x})); @@ -399,11 +430,12 @@ def c_code(self, node, name, inames, onames, sub): Py_XDECREF(t); }} """ + ).format(**locals()) return code def c_code_cache_version(self): - return (8,) + return () def __str__(self): return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}" @@ -596,9 +628,9 @@ def squeeze(x, axis=None): # scalar inputs are treated as 1D regarding axis in this `Op` try: - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=_x.ndim) + axis = normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) + except AxisError: + raise AxisError(axis, ndim=_x.ndim) if not axis: # Nothing to do @@ -678,7 +710,7 @@ def make_node(self, x, repeats): out_shape = [None] else: try: - const_reps = ptb.get_underlying_scalar_constant_value(repeats) + const_reps = ptb.get_scalar_constant_value(repeats) except NotScalarConstantError: const_reps = None if const_reps == 1: @@ -1176,7 +1208,8 @@ class Unique(Op): >>> y = pytensor.tensor.matrix() >>> g = pytensor.function([y], Unique(True, True, False)(y)) >>> g([[1, 1, 1.0], (2, 3, 3.0)]) - [array([1., 2., 3.]), array([0, 3, 4]), array([0, 0, 0, 1, 2, 2])] + [array([1., 2., 3.]), array([0, 3, 4]), array([[0, 0, 0], + [1, 2, 2]])] """ @@ -1211,7 +1244,11 @@ def make_node(self, x): if self.return_index: outputs.append(typ()) if self.return_inverse: - outputs.append(typ()) + if axis is None: + inverse_shape = TensorType(dtype="int64", shape=x.type.shape) + else: + inverse_shape = TensorType(dtype="int64", shape=(x.type.shape[axis],)) + outputs.append(inverse_shape()) if self.return_counts: outputs.append(typ()) return Apply(self, [x], outputs) @@ -1243,9 +1280,9 @@ def infer_shape(self, fgraph, node, i0_shapes): out_shapes[0] = tuple(shape) if self.return_inverse: - shape = prod(x_shape) if self.axis is None else x_shape[axis] + shape = x_shape if self.axis is None else (x_shape[axis],) return_index_out_idx = 2 if self.return_index else 1 - out_shapes[return_index_out_idx] = (shape,) + out_shapes[return_index_out_idx] = shape return out_shapes @@ -1501,13 +1538,16 @@ def broadcast_shape_iter( array_shapes = [ (one,) * (max_dims - a.ndim) - + tuple(one if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)) + + tuple( + one if t_sh == 1 else sh + for sh, t_sh in zip(a.shape, a.type.shape, strict=True) + ) for a in _arrays ] result_dims = [] - for dim_shapes in zip(*array_shapes): + for dim_shapes in zip(*array_shapes, strict=True): # Get the shapes in this dimension that are not broadcastable # (i.e. not symbolically known to be broadcastable) non_bcast_shapes = [shape for shape in dim_shapes if shape != one] diff --git a/pytensor/tensor/functional.py b/pytensor/tensor/functional.py index de35183d28..ad72fb7d52 100644 --- a/pytensor/tensor/functional.py +++ b/pytensor/tensor/functional.py @@ -89,7 +89,7 @@ def inner(*inputs): # Create dummy core inputs by stripping the batched dimensions of inputs core_inputs = [] - for input, input_sig in zip(inputs, inputs_sig): + for input, input_sig in zip(inputs, inputs_sig, strict=True): if not isinstance(input, TensorVariable): raise TypeError( f"Inputs to vectorize function must be TensorVariable, got {type(input)}" @@ -123,7 +123,9 @@ def inner(*inputs): ) # Vectorize graph by replacing dummy core inputs by original inputs - outputs = vectorize_graph(core_outputs, replace=dict(zip(core_inputs, inputs))) + outputs = vectorize_graph( + core_outputs, replace=dict(zip(core_inputs, inputs, strict=True)) + ) return outputs return inner diff --git a/pytensor/tensor/interpolate.py b/pytensor/tensor/interpolate.py new file mode 100644 index 0000000000..f598695784 --- /dev/null +++ b/pytensor/tensor/interpolate.py @@ -0,0 +1,200 @@ +from collections.abc import Callable +from difflib import get_close_matches +from typing import Literal, get_args + +from pytensor import Variable +from pytensor.tensor.basic import as_tensor_variable, switch +from pytensor.tensor.extra_ops import searchsorted +from pytensor.tensor.functional import vectorize +from pytensor.tensor.math import clip, eq, le +from pytensor.tensor.sort import argsort + + +InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"] +valid_methods = get_args(InterpolationMethod) + + +def pad_or_return(x, idx, output, left_pad, right_pad, extrapolate): + if extrapolate: + return output + + n = x.shape[0] + + return switch(eq(idx, 0), left_pad, switch(eq(idx, n), right_pad, output)) + + +def _linear_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx, 1, x.shape[0] - 1) + + slope = (x_hat - x[clip_idx - 1]) / (x[clip_idx] - x[clip_idx - 1]) + y_hat = y[clip_idx - 1] + slope * (y[clip_idx] - y[clip_idx - 1]) + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def _nearest_neighbor_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx, 1, x.shape[0] - 1) + + left_distance = x_hat - x[clip_idx - 1] + right_distance = x[clip_idx] - x_hat + y_hat = switch(le(left_distance, right_distance), y[clip_idx - 1], y[clip_idx]) + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def _stepwise_first_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx - 1, 0, x.shape[0] - 1) + y_hat = y[clip_idx] + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def _stepwise_last_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx, 0, x.shape[0] - 1) + y_hat = y[clip_idx] + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def _stepwise_mean_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx, 1, x.shape[0] - 1) + y_hat = (y[clip_idx - 1] + y[clip_idx]) / 2 + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def interpolate1d( + x: Variable, + y: Variable, + method: InterpolationMethod = "linear", + left_pad: Variable | None = None, + right_pad: Variable | None = None, + extrapolate: bool = True, +) -> Callable[[Variable], Variable]: + """ + Create a function to interpolate one-dimensional data. + + Parameters + ---------- + x : TensorLike + Input data used to create an interpolation function. Data will be sorted to be monotonically increasing. + y: TensorLike + Output data used to create an interpolation function. Must have the same shape as `x`. + method : InterpolationMethod, optional + Method for interpolation. The following methods are available: + - 'linear': Linear interpolation + - 'nearest': Nearest neighbor interpolation + - 'first': Stepwise interpolation using the closest value to the left of the query point + - 'last': Stepwise interpolation using the closest value to the right of the query point + - 'mean': Stepwise interpolation using the mean of the two closest values to the query point + left_pad: TensorLike, optional + Value to return inputs `x_hat < x[0]`. Default is `y[0]`. Ignored if ``extrapolate == True``; in this + case, values `x_hat < x[0]` will be extrapolated from the endpoints of `x` and `y`. + right_pad: TensorLike, optional + Value to return for inputs `x_hat > x[-1]`. Default is `y[-1]`. Ignored if ``extrapolate == True``; in this + case, values `x_hat > x[-1]` will be extrapolated from the endpoints of `x` and `y`. + extrapolate: bool + Whether to extend the request interpolation function beyond the range of the input-output pairs specified in + `x` and `y.` If False, constant values will be returned for such inputs. + + Returns + ------- + interpolation_func: OpFromGraph + A function that can be used to interpolate new data. The function takes a single input `x_hat` and returns + the interpolated value `y_hat`. The input `x_hat` must be a 1d array. + + """ + x = as_tensor_variable(x) + y = as_tensor_variable(y) + + sort_idx = argsort(x) + x = x[sort_idx] + y = y[sort_idx] + + if left_pad is None: + left_pad = y[0] # type: ignore + else: + left_pad = as_tensor_variable(left_pad) + if right_pad is None: + right_pad = y[-1] # type: ignore + else: + right_pad = as_tensor_variable(right_pad) + + def _scalar_interpolate1d(x_hat): + idx = searchsorted(x, x_hat) + + if x.ndim != 1 or y.ndim != 1: + raise ValueError("Inputs must be 1d") + + if method == "linear": + y_hat = _linear_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "nearest": + y_hat = _nearest_neighbor_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "first": + y_hat = _stepwise_first_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "mean": + y_hat = _stepwise_mean_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "last": + y_hat = _stepwise_last_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + else: + raise NotImplementedError( + f"Unknown interpolation method: {method}. " + f"Did you mean {get_close_matches(method, valid_methods)}?" + ) + + return y_hat + + return vectorize(_scalar_interpolate1d, signature="()->()") + + +def interp(x, xp, fp, left=None, right=None, period=None): + """ + One-dimensional linear interpolation. Similar to ``pytensor.interpolate.interpolate1d``, but with a signature that + matches ``np.interp`` + + Parameters + ---------- + x : TensorLike + The x-coordinates at which to evaluate the interpolated values. + + xp : TensorLike + The x-coordinates of the data points, must be increasing if argument `period` is not specified. Otherwise, + `xp` is internally sorted after normalizing the periodic boundaries with ``xp = xp % period``. + + fp : TensorLike + The y-coordinates of the data points, same length as `xp`. + + left : float, optional + Value to return for `x < xp[0]`. Default is `fp[0]`. + + right : float, optional + Value to return for `x > xp[-1]`. Default is `fp[-1]`. + + period : None + Not supported. Included to ensure the signature of this function matches ``numpy.interp``. + + Returns + ------- + y : Variable + The interpolated values, same shape as `x`. + """ + + xp = as_tensor_variable(xp) + fp = as_tensor_variable(fp) + x = as_tensor_variable(x) + + f = interpolate1d( + xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False + ) + + return f(x) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index d1e4dc6195..2251bbd968 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -5,7 +5,14 @@ from typing import TYPE_CHECKING, Optional import numpy as np -from numpy.core.numeric import normalize_axis_tuple + + +try: + from numpy.lib.array_utils import normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.numeric import normalize_axis_tuple from pytensor import config, printing from pytensor import scalar as ps @@ -14,6 +21,7 @@ from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import npy_2_compat_header from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.scalar.basic import BinaryScalarOp @@ -160,7 +168,10 @@ def get_params(self, node): c_axis = np.int64(self.axis[0]) else: # The value here doesn't matter, it won't be used - c_axis = np.int64(-1) + if np.__version__ < "2": + c_axis = np.int64(-1) + else: + c_axis = -2147483648 # the value of "NPY_RAVEL_AXIS" return self.params_type.get_params(c_axis=c_axis) def make_node(self, x): @@ -203,13 +214,17 @@ def perform(self, node, inp, outs): max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (argmax,) = out fail = sub["fail"] params = sub["params"] if self.axis is None: - axis_code = "axis = NPY_MAXDIMS;" + axis_code = "axis = NPY_RAVEL_AXIS;" else: if len(self.axis) != 1: raise NotImplementedError() @@ -1229,6 +1244,16 @@ def ive(v, x): """Exponentially scaled modified Bessel function of the first kind of order v (real).""" +@scalar_elemwise +def kve(v, x): + """Exponentially scaled modified Bessel function of the second kind of real order v.""" + + +def kv(v, x): + """Modified Bessel function of the second kind of real order v.""" + return kve(v, x) * exp(-x) + + @scalar_elemwise def sigmoid(x): """Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit""" @@ -1306,63 +1331,7 @@ def complex_from_polar(abs, angle): """Return complex-valued tensor from polar coordinate specification.""" -class Mean(FixedOpCAReduce): - __props__ = ("axis",) - nfunc_spec = ("mean", 1, 1) - - def __init__(self, axis=None): - super().__init__(ps.mean, axis) - assert self.axis is None or len(self.axis) == 1 - - def __str__(self): - if self.axis is not None: - args = ", ".join(str(x) for x in self.axis) - return f"Mean{{{args}}}" - else: - return "Mean" - - def _output_dtype(self, idtype): - # we want to protect against overflow - return "float64" - - def perform(self, node, inp, out): - (input,) = inp - (output,) = out - if self.axis is None: - axis = None - else: - axis = self.axis[0] - # numpy.asarray is needed as otherwise we can end up with a - # numpy scalar. - output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis)) - - def c_code(self, node, name, inames, onames, sub): - ret = super().c_code(node, name, inames, onames, sub) - - if self.axis is not None: - return ret - - # TODO: c_code perform support only axis is None - return ( - ret - + f""" - *((double *)PyArray_DATA({onames[0]})) /= PyArray_SIZE({inames[0]}); - """ - ) - - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) - return type(self)(axis=axis) - - -# TODO: implement the grad. When done and tested, you can make this the default -# version. -# def grad(self, (x,), (gout,)): -# import pdb;pdb.set_trace() -# return grad(mean(x, self.axis, op=False),[x]) - - -def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None): +def mean(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): """ Computes the mean value along the given axis(es) of a tensor `input`. @@ -1387,25 +1356,6 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None) be in a float type). If None, then we use the same rules as `sum()`. """ input = as_tensor_variable(input) - if op: - if dtype not in (None, "float64"): - raise NotImplementedError( - "The Mean op does not support the dtype argument, " - "and will always use float64. If you want to specify " - "the dtype, call tensor.mean(..., op=False).", - dtype, - ) - if acc_dtype not in (None, "float64"): - raise NotImplementedError( - "The Mean op does not support the acc_dtype argument, " - "and will always use float64. If you want to specify " - "acc_dtype, call tensor.mean(..., op=False).", - dtype, - ) - out = Mean(axis)(input) - if keepdims: - out = makeKeepDims(input, out, axis) - return out if dtype is not None: # The summation will be done with the specified dtype. @@ -3040,6 +2990,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): "i1", "iv", "ive", + "kv", + "kve", "sigmoid", "expit", "softplus", diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index e7093a82bd..727cc8f08a 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -4,13 +4,22 @@ from typing import Literal, cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore + + +try: + from numpy.lib.array_utils import normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.numeric import normalize_axis_tuple # type: ignore + from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm from pytensor.tensor.basic import as_tensor_variable, diagonal @@ -266,7 +275,33 @@ def __str__(self): return "SLogDet" -slogdet = Blockwise(SLogDet()) +def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: + """ + Compute the sign and (natural) logarithm of the determinant of an array. + + Returns a naive graph which is optimized later using rewrites with the det operation. + + Parameters + ---------- + x : (..., M, M) tensor or tensor_like + Input tensor, has to be square. + + Returns + ------- + A tuple with the following attributes: + + sign : (...) tensor_like + A number representing the sign of the determinant. For a real matrix, + this is 1, 0, or -1. + logabsdet : (...) tensor_like + The natural log of the absolute value of the determinant. + + If the determinant is zero, then `sign` will be 0 and `logabsdet` + will be -inf. In all cases, the determinant is equal to + ``sign * exp(logabsdet)``. + """ + det_val = det(x) + return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) class Eig(Op): @@ -362,7 +397,7 @@ def grad(self, inputs, g_outputs): def _zero_disconnected(outputs, grads): l = [] - for o, g in zip(outputs, grads): + for o, g in zip(outputs, grads, strict=True): if isinstance(g.type, DisconnectedType): l.append(o.zeros_like()) else: @@ -664,7 +699,7 @@ def s_grad_only( return s_grad_only(U, VT, ds) for disconnected, output_grad, output in zip( - is_disconnected, output_grads, [U, s, VT] + is_disconnected, output_grads, [U, s, VT], strict=True ): if disconnected: new_output_grads.append(output.zeros_like()) diff --git a/pytensor/tensor/pad.py b/pytensor/tensor/pad.py index 91aef44004..2a3b8b4588 100644 --- a/pytensor/tensor/pad.py +++ b/pytensor/tensor/pad.py @@ -263,7 +263,9 @@ def _linear_ramp_pad( dtype=padded.dtype, axis=axis, ) - for end_value, edge, width in zip(end_value_pair, edge_pair, width_pair) + for end_value, edge, width in zip( + end_value_pair, edge_pair, width_pair, strict=True + ) ) # Reverse the direction of the ramp for the "right" side diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 4a2c47b2af..bebcad55be 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1862,7 +1862,8 @@ def rng_fn(cls, rng, p, size): # to `p.shape[:-1]` in the call to `vsearchsorted` below. if len(size) < (p.ndim - 1): raise ValueError("`size` is incompatible with the shape of `p`") - for s, ps in zip(reversed(size), reversed(p.shape[:-1])): + # strict=False because we are in a hot loop + for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False): if s == 1 and ps != 1: raise ValueError("`size` is incompatible with the shape of `p`") diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 309a661c9a..a8b67dee4f 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Sequence -from copy import copy +from copy import deepcopy from typing import Any, cast import numpy as np @@ -151,11 +151,13 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): # Try to infer missing support dims from signature of params for param, param_sig, ndim_params in zip( - dist_params, self.inputs_sig, self.ndims_params + dist_params, self.inputs_sig, self.ndims_params, strict=True ): if ndim_params == 0: continue - for param_dim, dim in zip(param.shape[-ndim_params:], param_sig): + for param_dim, dim in zip( + param.shape[-ndim_params:], param_sig, strict=True + ): if dim in core_out_shape and core_out_shape[dim] is None: core_out_shape[dim] = param_dim @@ -230,7 +232,7 @@ def _infer_shape( # Fail early when size is incompatible with parameters for i, (param, param_ndim_supp) in enumerate( - zip(dist_params, self.ndims_params) + zip(dist_params, self.ndims_params, strict=True) ): param_batched_dims = getattr(param, "ndim", 0) - param_ndim_supp if param_batched_dims > size_len: @@ -254,7 +256,7 @@ def extract_batch_shape(p, ps, n): batch_shape = tuple( s if not b else constant(1, "int64") - for s, b in zip(shape[:-n], p.type.broadcastable[:-n]) + for s, b in zip(shape[:-n], p.type.broadcastable[:-n], strict=True) ) return batch_shape @@ -263,7 +265,9 @@ def extract_batch_shape(p, ps, n): # independent variate dimensions are left. params_batch_shape = tuple( extract_batch_shape(p, ps, n) - for p, ps, n in zip(dist_params, param_shapes, self.ndims_params) + for p, ps, n in zip( + dist_params, param_shapes, self.ndims_params, strict=False + ) ) if len(params_batch_shape) == 1: @@ -391,7 +395,7 @@ def perform(self, node, inputs, outputs): # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. if not self.inplace: - rng = copy(rng) + rng = deepcopy(rng) outputs[0][0] = rng outputs[1][0] = np.asarray( diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index 7ce17ade08..6de1a6b527 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -48,7 +48,7 @@ def random_make_inplace(fgraph, node): props["inplace"] = True new_op = type(op)(**props) new_outputs = new_op.make_node(*node.inputs).outputs - for old_out, new_out in zip(node.outputs, new_outputs): + for old_out, new_out in zip(node.outputs, new_outputs, strict=True): copy_stack_trace(old_out, new_out) return new_outputs @@ -171,7 +171,7 @@ def local_dimshuffle_rv_lift(fgraph, node): # Updates the params to reflect the Dimshuffled dimensions new_dist_params = [] - for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params, strict=True): # Add the parameter support dimension indexes to the batched dimensions Dimshuffle param_new_order = batched_dims_ds_order + tuple( range(batched_dims, batched_dims + param_ndim_supp) @@ -290,12 +290,12 @@ def is_nd_advanced_idx(idx, dtype) -> bool: # non-broadcastable (non-degenerate) parameter dims. These parameters and the new size # should still correctly broadcast any degenerate parameter dims. new_dist_params = [] - for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params, strict=True): # Check which dims are broadcasted by either size or other parameters bcast_param_dims = tuple( dim for dim, (param_dim_bcast, output_dim_bcast) in enumerate( - zip(param.type.broadcastable, rv.type.broadcastable) + zip(param.type.broadcastable, rv.type.broadcastable, strict=False) ) if param_dim_bcast and not output_dim_bcast ) diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index fe170f4718..b6dcf3b5e8 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -15,7 +15,7 @@ def introduce_explicit_core_shape_rv(fgraph, node): This core_shape is used by the numba backend to pre-allocate the output array. If available, the core shape is extracted from the shape feature of the graph, - which has a higher change of having been simplified, optimized, constant-folded. + which has a higher chance of having been simplified, optimized, constant-folded. If missing, we fall back to the op._supp_shape_from_params method. This rewrite is required for the numba backend implementation of RandomVariable. diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index 88d5e6197f..df8e3b691d 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -87,8 +87,8 @@ def filter(self, data, strict=False, allow_downcast=None): @staticmethod def values_eq(a, b): - sa = a if isinstance(a, dict) else a.__getstate__() - sb = b if isinstance(b, dict) else b.__getstate__() + sa = a if isinstance(a, dict) else a.bit_generator.state + sb = b if isinstance(b, dict) else b.bit_generator.state def _eq(sa, sb): for key in sa: diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 075d09b053..23b4b50265 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -44,7 +44,8 @@ def params_broadcast_shapes( max_fn = maximum if use_pytensor else max rev_extra_dims: list[int] = [] - for ndim_param, param_shape in zip(ndims_params, param_shapes): + # strict=False because we are in a hot loop + for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False): # We need this in order to use `len` param_shape = tuple(param_shape) extras = tuple(param_shape[: (len(param_shape) - ndim_param)]) @@ -63,11 +64,12 @@ def max_bcast(x, y): extra_dims = tuple(reversed(rev_extra_dims)) + # strict=False because we are in a hot loop bcast_shapes = [ (extra_dims + tuple(param_shape)[-ndim_param:]) if ndim_param > 0 else extra_dims - for ndim_param, param_shape in zip(ndims_params, param_shapes) + for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False) ] return bcast_shapes @@ -110,9 +112,12 @@ def broadcast_params( use_pytensor = False param_shapes = [] for p in params: + # strict=False because we are in a hot loop param_shape = tuple( 1 if bcast else s - for s, bcast in zip(p.shape, getattr(p, "broadcastable", (False,) * p.ndim)) + for s, bcast in zip( + p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=False + ) ) use_pytensor |= isinstance(p, Variable) param_shapes.append(param_shape) @@ -122,8 +127,10 @@ def broadcast_params( ) broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to + # strict=False because we are in a hot loop bcast_params = [ - broadcast_to_fn(param, shape) for shape, param in zip(shapes, params) + broadcast_to_fn(param, shape) + for shape, param in zip(shapes, params, strict=False) ] return bcast_params @@ -137,7 +144,8 @@ def explicit_expand_dims( """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" batch_dims = [ - param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params) + param.type.ndim - ndim_param + for param, ndim_param in zip(params, ndim_params, strict=True) ] if size_length is not None: @@ -146,7 +154,7 @@ def explicit_expand_dims( max_batch_dims = max(batch_dims, default=0) new_params = [] - for new_param, batch_dim in zip(params, batch_dims): + for new_param, batch_dim in zip(params, batch_dims, strict=True): missing_dims = max_batch_dims - batch_dim if missing_dims: new_param = shape_padleft(new_param, missing_dims) @@ -161,7 +169,7 @@ def compute_batch_shape( params = explicit_expand_dims(params, ndims_params) batch_params = [ param[(..., *(0,) * core_ndim)] - for param, core_ndim in zip(params, ndims_params) + for param, core_ndim in zip(params, ndims_params, strict=True) ] return broadcast_arrays(*batch_params)[0].shape @@ -279,7 +287,9 @@ def seed(self, seed=None): self.gen_seedgen = np.random.SeedSequence(seed) old_r_seeds = self.gen_seedgen.spawn(len(self.state_updates)) - for (old_r, new_r), old_r_seed in zip(self.state_updates, old_r_seeds): + for (old_r, new_r), old_r_seed in zip( + self.state_updates, old_r_seeds, strict=True + ): old_r.set_value(self.rng_ctor(old_r_seed), borrow=True) def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable: diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index fc5c528f2d..4e75140ceb 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -9,6 +9,7 @@ import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math +import pytensor.tensor.rewriting.numba import pytensor.tensor.rewriting.ofg import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.special diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 78d00790ac..59148fae3b 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -30,8 +30,9 @@ from pytensor import compile, config from pytensor.compile.ops import ViewOp from pytensor.graph import FunctionGraph -from pytensor.graph.basic import Constant, Variable +from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import ( + NodeProcessingGraphRewriter, NodeRewriter, RemovalNodeRewriter, Rewriter, @@ -54,9 +55,8 @@ as_tensor_variable, atleast_Nd, cast, - extract_constant, fill, - get_underlying_scalar_constant_value, + get_scalar_constant_value, join, ones_like, register_infer_shape, @@ -98,11 +98,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool: if len(bx) < len(by): return True bx = bx[-len(by) :] - return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by)) + return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True)) def merge_broadcastables(broadcastables): - return [all(bcast) for bcast in zip(*broadcastables)] + return [all(bcast) for bcast in zip(*broadcastables, strict=True)] def alloc_like( @@ -477,7 +477,12 @@ def local_alloc_sink_dimshuffle(fgraph, node): output_shape = node.inputs[1:] num_dims_with_size_1_added_to_left = 0 for i in range(len(output_shape) - inp.ndim): - if extract_constant(output_shape[i], only_process_constants=True) == 1: + if ( + get_scalar_constant_value( + output_shape[i], only_process_constants=True, raise_not_constant=False + ) + == 1 + ): num_dims_with_size_1_added_to_left += 1 else: break @@ -537,93 +542,90 @@ def local_useless_elemwise(fgraph, node): xor(x, x) -> zeros_like(x) TODO: This implementation is painfully redundant. + TODO: Allow rewrite when useless input broadcasts output """ - if isinstance(node.op, Elemwise): - # We call zeros_like and one_like with opt=True to generate a - # cleaner graph. - dtype = node.outputs[0].dtype - - if node.op.scalar_op == ps.eq and len(node.inputs) == 2: - if node.inputs[0] == node.inputs[1]: - # it is the same var in the graph. That will always be true - ret = ones_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - elif node.op.scalar_op == ps.neq and len(node.inputs) == 2: - if node.inputs[0] == node.inputs[1]: - # it is the same var in the graph. That will always be false - ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - - elif node.op.scalar_op == ps.mul and len(node.inputs) == 1: - # No need to copy over any stack trace - return [node.inputs[0]] - - elif node.op.scalar_op == ps.add and len(node.inputs) == 1: - # No need to copy over any stack trace - return [node.inputs[0]] - elif node.op.scalar_op == ps.identity and len(node.inputs) == 1: - return [node.inputs[0]] - - elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: - if isinstance(node.inputs[0], TensorConstant): - const_val = extract_constant( - node.inputs[0], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this rewrite would be wrong - return [node.inputs[1].astype(node.outputs[0].dtype)] - - if isinstance(node.inputs[1], TensorConstant): - const_val = extract_constant( - node.inputs[1], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this rewrite would be wrong - return [node.inputs[0].astype(node.outputs[0].dtype)] - - elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: - if isinstance(node.inputs[0], TensorConstant): - const_val = extract_constant( - node.inputs[0], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [node.inputs[1].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this rewrite would be wrong - return [ones_like(node.inputs[1], dtype=dtype, opt=True)] - - if isinstance(node.inputs[1], TensorConstant): - const_val = extract_constant( - node.inputs[1], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [node.inputs[0].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this rewrite would be wrong - return [ones_like(node.inputs[0], dtype=dtype, opt=True)] - - elif isinstance(node.op.scalar_op, ps.XOR) and len(node.inputs) == 2: - if node.inputs[0] is node.inputs[1]: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] + out_bcast = node.outputs[0].type.broadcastable + dtype = node.outputs[0].type.dtype + scalar_op = node.op.scalar_op + + if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2: + if node.inputs[0] is node.inputs[1]: + # it is the same var in the graph. That will always be true + ret = ones_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy stack trace from input to constant output + copy_stack_trace(node.outputs[0], ret) + return [ret] + elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2: + if node.inputs[0] is node.inputs[1]: + # it is the same var in the graph. That will always be false + ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy stack trace from input to constant output + copy_stack_trace(node.outputs[0], ret) + return [ret] + + elif ( + isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity) + and len(node.inputs) == 1 + ): + # No need to copy over any stack trace + return [node.inputs[0]] + + elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: + if ( + isinstance(node.inputs[0], TensorConstant) + and node.inputs[1].type.broadcastable == out_bcast + ): + const_val = node.inputs[0].unique_value + if const_val is not None: + if const_val == 0: + return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise AND, + # and this rewrite would be wrong + return [node.inputs[1].astype(node.outputs[0].dtype)] + + if ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[0].type.broadcastable == out_bcast + ): + const_val = node.inputs[1].unique_value + if const_val is not None: + if const_val == 0: + return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise AND, + # and this rewrite would be wrong + return [node.inputs[0].astype(node.outputs[0].dtype)] + + elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: + if ( + isinstance(node.inputs[0], TensorConstant) + and node.inputs[1].type.broadcastable == out_bcast + ): + const_val = node.inputs[0].unique_value + if const_val is not None: + if const_val == 0: + return [node.inputs[1].astype(node.outputs[0].dtype)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise OR, + # and this rewrite would be wrong + return [ones_like(node.inputs[1], dtype=dtype, opt=True)] + + if ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[0].type.broadcastable == out_bcast + ): + const_val = node.inputs[1].unique_value + if const_val is not None: + if const_val == 0: + return [node.inputs[0].astype(node.outputs[0].dtype)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise OR, + # and this rewrite would be wrong + return [ones_like(node.inputs[0], dtype=dtype, opt=True)] @register_specialize @@ -736,7 +738,7 @@ def local_remove_useless_assert(fgraph, node): n_conds = len(node.inputs[1:]) for c in node.inputs[1:]: try: - const = get_underlying_scalar_constant_value(c) + const = get_scalar_constant_value(c) if 0 != const.ndim or const == 0: # Should we raise an error here? How to be sure it @@ -831,7 +833,7 @@ def local_join_empty(fgraph, node): return new_inputs = [] try: - join_idx = get_underlying_scalar_constant_value( + join_idx = get_scalar_constant_value( node.inputs[0], only_process_constants=True ) except NotScalarConstantError: @@ -987,13 +989,10 @@ def local_useless_switch(fgraph, node): left = node.inputs[1] right = node.inputs[2] cond_var = node.inputs[0] - cond = extract_constant(cond_var, only_process_constants=True) out_bcast = node.outputs[0].type.broadcastable - if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( - cond, np.number | np.bool_ - ): - if cond == 0: + if isinstance(cond_var, TensorConstant) and cond_var.unique_value is not None: + if cond_var.unique_value == 0: correct_out = right else: correct_out = left @@ -1013,7 +1012,7 @@ def local_useless_switch(fgraph, node): # if left is right -> left if equivalent_up_to_constant_casting(left, right): if left.type.broadcastable != out_bcast: - left, _ = broadcast_arrays(left, cond) + left, _ = broadcast_arrays(left, cond_var) out_dtype = node.outputs[0].type.dtype if left.type.dtype != out_dtype: @@ -1025,13 +1024,22 @@ def local_useless_switch(fgraph, node): # This case happens with scan. # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) if ( - cond_var.owner + node.outputs[0].type.ndim == 0 + and cond_var.owner and isinstance(cond_var.owner.op, Elemwise) and isinstance(cond_var.owner.op.scalar_op, ps.LE) and cond_var.owner.inputs[0].owner and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) - and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 - and extract_constant(left, only_process_constants=True) == 0 + and get_scalar_constant_value( + cond_var.owner.inputs[1], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + and get_scalar_constant_value( + left, only_process_constants=True, raise_not_constant=False + ) + == 0 and right == cond_var.owner.inputs[0] ): assert node.outputs[0].type.is_super(right.type) @@ -1101,10 +1109,7 @@ def local_useless_split(fgraph, node): @node_rewriter(None) -def constant_folding(fgraph, node): - if not node.op.do_constant_folding(fgraph, node): - return False - +def unconditional_constant_folding(fgraph, node): if not all(isinstance(inp, Constant) for inp in node.inputs): return False @@ -1151,6 +1156,23 @@ def constant_folding(fgraph, node): return rval +topo_unconditional_constant_folding = in2out( + unconditional_constant_folding, + ignore_newtrees=True, + name="topo_unconditional_constant_folding", + # Not all Ops have a perform method, so we ignore failures to constant_fold + failure_callback=NodeProcessingGraphRewriter.warn_ignore, +) + + +@node_rewriter(None) +def constant_folding(fgraph, node): + if not node.op.do_constant_folding(fgraph, node): + return False + + return unconditional_constant_folding.transform(fgraph, node) + + topo_constant_folding = in2out( constant_folding, ignore_newtrees=True, name="topo_constant_folding" ) @@ -1192,25 +1214,23 @@ def local_merge_alloc(fgraph, node): inputs_inner = node.inputs[0].owner.inputs dims_outer = inputs_outer[1:] dims_inner = inputs_inner[1:] - dims_outer_rev = dims_outer[::-1] - dims_inner_rev = dims_inner[::-1] + assert len(dims_inner) <= len(dims_outer) # check if the pattern of broadcasting is matched, in the reversed ordering. # The reverse ordering is needed when an Alloc add an implicit new # broadcasted dimensions to its inputs[0]. Eg: # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - i = 0 - for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev): - if dim_inner != dim_outer: - if isinstance(dim_inner, Constant) and dim_inner.data == 1: - pass - else: - dims_outer[-1 - i] = Assert( - "You have a shape error in your graph. To see a better" - " error message and a stack trace of where in your code" - " the error is created, use the PyTensor flags" - " optimizer=None or optimizer=fast_compile." - )(dim_outer, eq(dim_outer, dim_inner)) - i += 1 + for i, dim_inner in enumerate(reversed(dims_inner)): + dim_outer = dims_outer[-1 - i] + if dim_inner == dim_outer: + continue + if isinstance(dim_inner, Constant) and dim_inner.data == 1: + continue + dims_outer[-1 - i] = Assert( + "You have a shape error in your graph. To see a better" + " error message and a stack trace of where in your code" + " the error is created, use the PyTensor flags" + " optimizer=None or optimizer=fast_compile." + )(dim_outer, eq(dim_outer, dim_inner)) return [alloc(inputs_inner[0], *dims_outer)] @@ -1292,7 +1312,8 @@ def local_join_of_alloc(fgraph, node): ) ] for core_tensor, tensor in zip(core_tensors, tensors, strict=True) - ) + ), + strict=True, ) ) @@ -1307,7 +1328,7 @@ def local_join_of_alloc(fgraph, node): # Lift the allocated dimensions new_tensors = [] - for core_tensor, alloc_shape in zip(core_tensors, alloc_shapes): + for core_tensor, alloc_shape in zip(core_tensors, alloc_shapes, strict=True): pre_join_shape = [ 1 if i in lifteable_alloc_dims else alloc_dim for i, alloc_dim in enumerate(alloc_shape) @@ -1321,7 +1342,7 @@ def local_join_of_alloc(fgraph, node): # Reintroduce the lifted dims post_join_shape = [] - for i, alloc_dims in enumerate(zip(*alloc_shapes)): + for i, alloc_dims in enumerate(zip(*alloc_shapes, strict=True)): if i == axis: # The alloc dim along the axis is the sum of all the pre-join alloc dims post_join_shape.append(add(*alloc_dims)) diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 094becd98b..d3fc0398c4 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -507,7 +507,7 @@ def on_import(new_node): ].tag.values_eq_approx = values_eq_approx_remove_inf_nan try: fgraph.replace_all_validate_remove( - list(zip(node.outputs, new_outputs)), + list(zip(node.outputs, new_outputs, strict=True)), [old_dot22], reason="GemmOptimizer", # For now we disable the warning as we know case diff --git a/pytensor/tensor/rewriting/blas_scipy.py b/pytensor/tensor/rewriting/blas_scipy.py index 610ef9b82f..2ed0279e45 100644 --- a/pytensor/tensor/rewriting/blas_scipy.py +++ b/pytensor/tensor/rewriting/blas_scipy.py @@ -1,5 +1,5 @@ from pytensor.graph.rewriting.basic import in2out -from pytensor.tensor.blas import ger, ger_destructive, have_fblas +from pytensor.tensor.blas import ger, ger_destructive from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb @@ -19,19 +19,19 @@ def make_ger_destructive(fgraph, node): use_scipy_blas = in2out(use_scipy_ger) make_scipy_blas_destructive = in2out(make_ger_destructive) -if have_fblas: - # scipy_blas is scheduled in the blas_optdb very late, because scipy sortof - # sucks, but it is almost always present. - # C implementations should be scheduled earlier than this, so that they take - # precedence. Once the original Ger is replaced, then these optimizations - # have no effect. - blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100) - - # this matches the InplaceBlasOpt defined in blas.py - optdb.register( - "make_scipy_blas_destructive", - make_scipy_blas_destructive, - "fast_run", - "inplace", - position=50.2, - ) + +# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof +# sucks [citation needed], but it is almost always present. +# C implementations should be scheduled earlier than this, so that they take +# precedence. Once the original Ger is replaced, then these optimizations +# have no effect. +blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100) + +# this matches the InplaceBlasOpt defined in blas.py +optdb.register( + "make_scipy_blas_destructive", + make_scipy_blas_destructive, + "fast_run", + "inplace", + position=50.2, +) diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 97046bffe2..49bd5510ae 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -120,7 +120,7 @@ def local_blockwise_alloc(fgraph, node): new_inputs = [] batch_shapes = [] can_push_any_alloc = False - for inp, inp_sig in zip(node.inputs, op.inputs_sig): + for inp, inp_sig in zip(node.inputs, op.inputs_sig, strict=True): if not all(inp.type.broadcastable[:batch_ndim]): if inp.owner and isinstance(inp.owner.op, Alloc): # Push batch dims from Alloc @@ -146,6 +146,7 @@ def local_blockwise_alloc(fgraph, node): :squeezed_value_batch_ndim ], tuple(squeezed_value.shape)[:squeezed_value_batch_ndim], + strict=True, ) ] squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape) @@ -159,7 +160,7 @@ def local_blockwise_alloc(fgraph, node): tuple( 1 if broadcastable else dim for broadcastable, dim in zip( - inp.type.broadcastable, shape[:batch_ndim] + inp.type.broadcastable, shape[:batch_ndim], strict=False ) ) ) @@ -182,7 +183,9 @@ def local_blockwise_alloc(fgraph, node): # We pick the most parsimonious batch dim from the pushed Alloc missing_ndim = old_out_type.ndim - new_out_type.ndim batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim] - for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples + for i, batch_dims in enumerate( + zip(*batch_shapes, strict=True) + ): # Transpose shape tuples if old_out_type.broadcastable[i]: continue for batch_dim in batch_dims: diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 277b8bdb55..3226f9b5a7 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -41,7 +41,7 @@ register_specialize, ) from pytensor.tensor.shape import shape_padleft -from pytensor.tensor.variable import TensorConstant, get_unique_constant_value +from pytensor.tensor.variable import TensorConstant class InplaceElemwiseOptimizer(GraphRewriter): @@ -299,7 +299,7 @@ def apply(self, fgraph): ) new_node = new_outputs[0].owner - for r, new_r in zip(node.outputs, new_outputs): + for r, new_r in zip(node.outputs, new_outputs, strict=True): prof["nb_call_replace"] += 1 fgraph.replace( r, new_r, reason="inplace_elemwise_optimizer" @@ -513,7 +513,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): new_inputs.append(i) else: try: - # works only for scalars cval_i = get_underlying_scalar_constant_value( i, only_process_constants=True ) @@ -1033,12 +1032,12 @@ def update_fuseable_mappings_after_fg_replace( ) if not isinstance(composite_outputs, list): composite_outputs = [composite_outputs] - for old_out, composite_out in zip(outputs, composite_outputs): + for old_out, composite_out in zip(outputs, composite_outputs, strict=True): if old_out.name: composite_out.name = old_out.name fgraph.replace_all_validate( - list(zip(outputs, composite_outputs)), + list(zip(outputs, composite_outputs, strict=True)), reason=self.__class__.__name__, ) nb_replacement += 1 @@ -1114,7 +1113,7 @@ def local_useless_composite_outputs(fgraph, node): used_inputs = [node.inputs[i] for i in used_inputs_idxs] c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) - return dict(zip([node.outputs[i] for i in used_outputs_idxs], e)) + return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True)) @node_rewriter([CAReduce]) @@ -1214,13 +1213,17 @@ def local_inline_composite_constants(fgraph, node): new_outer_inputs = [] new_inner_inputs = [] inner_replacements = {} - for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs): + for outer_inp, inner_inp in zip( + node.inputs, composite_op.fgraph.inputs, strict=True + ): # Complex variables don't have a `c_literal` that can be inlined - if "complex" not in outer_inp.type.dtype: - unique_value = get_unique_constant_value(outer_inp) - if unique_value is not None: + if ( + isinstance(outer_inp, TensorConstant) + and "complex" not in outer_inp.type.dtype + ): + if outer_inp.unique_value is not None: inner_replacements[inner_inp] = ps.constant( - unique_value, dtype=inner_inp.dtype + outer_inp.unique_value, dtype=inner_inp.dtype ) continue new_outer_inputs.append(outer_inp) @@ -1351,7 +1354,7 @@ def local_useless_2f1grad_loop(fgraph, node): replacements = {converges: new_converges} i = 0 - for grad_var, is_used in zip(grad_vars, grad_var_is_used): + for grad_var, is_used in zip(grad_vars, grad_var_is_used, strict=True): if not is_used: continue replacements[grad_var] = new_outs[i] diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a2418147cf..cd202fe3ed 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,6 +2,8 @@ from collections.abc import Callable from typing import cast +import numpy as np + from pytensor import Variable from pytensor import tensor as pt from pytensor.compile import optdb @@ -11,7 +13,7 @@ in2out, node_rewriter, ) -from pytensor.scalar.basic import Mul +from pytensor.scalar.basic import Abs, Log, Mul, Sign from pytensor.tensor.basic import ( AllocDiag, ExtractDiag, @@ -30,11 +32,11 @@ KroneckerProduct, MatrixInverse, MatrixPinv, + SLogDet, det, inv, kron, pinv, - slogdet, svd, ) from pytensor.tensor.rewriting.basic import ( @@ -785,45 +787,6 @@ def rewrite_det_blockdiag(fgraph, node): return [prod(det_sub_matrices)] -@register_canonicalize -@register_stabilize -@node_rewriter([slogdet]) -def rewrite_slogdet_blockdiag(fgraph, node): - """ - This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those - - slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check for inner block_diag operation - potential_block_diag = node.inputs[0].owner - if not ( - potential_block_diag - and isinstance(potential_block_diag.op, Blockwise) - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) - ): - return None - - # Find the composing sub_matrices - sub_matrices = potential_block_diag.inputs - sign_sub_matrices, logdet_sub_matrices = zip( - *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] - ) - - return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] - - @register_canonicalize @register_stabilize @node_rewriter([ExtractDiag]) @@ -860,10 +823,10 @@ def rewrite_diag_kronecker(fgraph, node): @register_canonicalize @register_stabilize -@node_rewriter([slogdet]) -def rewrite_slogdet_kronecker(fgraph, node): +@node_rewriter([det]) +def rewrite_det_kronecker(fgraph, node): """ - This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those Parameters ---------- @@ -884,13 +847,12 @@ def rewrite_slogdet_kronecker(fgraph, node): # Find the matrices a, b = potential_kron.inputs - signs, logdets = zip(*[slogdet(a), slogdet(b)]) + dets = [det(a), det(b)] sizes = [a.shape[-1], b.shape[-1]] prod_sizes = prod(sizes, no_zeros_in_input=True) - signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] - logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] + det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)]) - return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] + return [det_final] @register_canonicalize @@ -989,3 +951,65 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): "jax", position=0.9, # Run before canonicalization ) + + +@register_specialize +@node_rewriter([det]) +def slogdet_specialization(fgraph, node): + """ + This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + dictionary of Variables, optional + Dictionary of nodes and what they should be replaced with, or None if no optimization was performed + """ + dummy_replacements = {} + for client, _ in fgraph.clients[node.outputs[0]]: + # Check for sign(det) + if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign): + dummy_replacements[client.outputs[0]] = "sign" + + # Check for log(abs(det)) + elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs): + potential_log = None + for client_2, _ in fgraph.clients[client.outputs[0]]: + if isinstance(client_2.op, Elemwise) and isinstance( + client_2.op.scalar_op, Log + ): + potential_log = client_2 + if potential_log: + dummy_replacements[potential_log.outputs[0]] = "log_abs_det" + else: + return None + + # Check for log(det) + elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log): + dummy_replacements[client.outputs[0]] = "log_det" + + # Det is used directly for something else, don't rewrite to avoid computing two dets + else: + return None + + if not dummy_replacements: + return None + else: + [x] = node.inputs + sign_det_x, log_abs_det_x = SLogDet()(x) + log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x) + slogdet_specialization_map = { + "sign": sign_det_x, + "log_abs_det": log_abs_det_x, + "log_det": log_det_x, + } + replacements = { + k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() + } + return replacements diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b230f035cc..aa2d279f43 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -28,7 +28,6 @@ as_tensor_variable, cast, constant, - extract_constant, get_underlying_scalar_constant_value, moveaxis, ones_like, @@ -56,6 +55,7 @@ ge, int_div, isinf, + kve, le, log, log1mexp, @@ -105,7 +105,6 @@ from pytensor.tensor.variable import ( TensorConstant, TensorVariable, - get_unique_constant_value, ) @@ -127,32 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): return consts, origconsts, nonconsts -def get_constant(v): - """ - - Returns - ------- - object - A numeric constant if v is a Constant or, well, a - numeric constant. If v is a plain Variable, returns None. - - """ - if isinstance(v, Constant): - unique_value = get_unique_constant_value(v) - if unique_value is not None: - data = unique_value - else: - data = v.data - if data.ndim == 0: - return data - else: - return None - elif isinstance(v, Variable): - return None - else: - return v - - @register_canonicalize @register_stabilize @node_rewriter([Dot]) @@ -162,18 +135,16 @@ def local_0_dot_x(fgraph, node): x = node.inputs[0] y = node.inputs[1] - replace = False - try: - if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass - - try: - if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass + replace = ( + get_underlying_scalar_constant_value( + x, only_process_constants=True, raise_not_constant=False + ) + == 0 + or get_underlying_scalar_constant_value( + y, only_process_constants=True, raise_not_constant=False + ) + == 0 + ) if replace: constant_zero = constant(0, dtype=node.outputs[0].type.dtype) @@ -564,27 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node): @register_stabilize @register_specialize @register_canonicalize -@node_rewriter([sub]) +@node_rewriter([add, sub]) def local_expm1(fgraph, node): - """Detect ``exp(a) - 1`` and convert them to ``expm1(a)``.""" - in1, in2 = node.inputs - out = node.outputs[0] + """Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``.""" + if len(node.inputs) != 2: + # TODO: handle more than two inputs in add + return None - if ( - in1.owner - and isinstance(in1.owner.op, Elemwise) - and isinstance(in1.owner.op.scalar_op, ps.Exp) - and extract_constant(in2, only_process_constants=False) == 1 - ): - in11 = in1.owner.inputs[0] - new_out = expm1(in11) + if isinstance(node.op.scalar_op, ps.Sub): + exp_x, other_inp = node.inputs + if not ( + exp_x.owner + and isinstance(exp_x.owner.op, Elemwise) + and isinstance(exp_x.owner.op.scalar_op, ps.Exp) + and get_underlying_scalar_constant_value( + other_inp, raise_not_constant=False + ) + == 1 + ): + return None + else: + # Try both orders + other_inp, exp_x = node.inputs + for i in range(2): + if i == 1: + other_inp, exp_x = exp_x, other_inp + if ( + exp_x.owner + and isinstance(exp_x.owner.op, Elemwise) + and isinstance(exp_x.owner.op.scalar_op, ps.Exp) + and get_underlying_scalar_constant_value( + other_inp, raise_not_constant=False + ) + == -1 + ): + break + else: # no break + return None - if new_out.dtype != out.dtype: - new_out = cast(new_out, dtype=out.dtype) + [old_out] = node.outputs - if not out.type.is_super(new_out.type): - return - return [new_out] + [x] = exp_x.owner.inputs + if x.type.broadcastable != old_out.type.broadcastable: + x = broadcast_arrays(x, other_inp)[0] + + new_out = expm1(x) + + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, dtype=old_out.dtype) + + if not old_out.type.is_super(new_out.type): + return None + + return [new_out] @register_specialize @@ -627,7 +630,14 @@ def local_mul_switch_sink(fgraph, node): # Look for a zero as the first or second branch of the switch for branch in range(2): zero_switch_input = switch_node.inputs[1 + branch] - if not get_unique_constant_value(zero_switch_input) == 0.0: + if ( + not get_underlying_scalar_constant_value( + zero_switch_input, + only_process_constants=True, + raise_not_constant=False, + ) + == 0.0 + ): continue switch_cond = switch_node.inputs[0] @@ -684,7 +694,14 @@ def local_div_switch_sink(fgraph, node): # Look for a zero as the first or second branch of the switch for branch in range(2): zero_switch_input = switch_node.inputs[1 + branch] - if not get_unique_constant_value(zero_switch_input) == 0.0: + if ( + not get_underlying_scalar_constant_value( + zero_switch_input, + only_process_constants=True, + raise_not_constant=False, + ) + == 0.0 + ): continue switch_cond = switch_node.inputs[0] @@ -699,7 +716,10 @@ def local_div_switch_sink(fgraph, node): # will point to the new division op. copy_stack_trace(node.outputs, fdiv) - fct = switch(switch_cond, zero_switch_input, fdiv) + if branch == 0: + fct = switch(switch_cond, zero_switch_input, fdiv) + else: + fct = switch(switch_cond, fdiv, zero_switch_input) # Tell debug_mode than the output is correct, even if nan disappear fct.tag.values_eq_approx = values_eq_approx_remove_nan @@ -985,8 +1005,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): """ Find all constants and put them together into a single constant. - Finds all constants in orig_num and orig_denum (using - get_constant) and puts them together into a single + Finds all constants in orig_num and orig_denum + and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator. @@ -1007,17 +1027,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): numct, denumct = [], [] for v in orig_num: - ct = get_constant(v) - if ct is not None: + if isinstance(v, TensorConstant) and v.unique_value is not None: # We found a constant in the numerator! # We add it to numct - numct.append(ct) + numct.append(v.unique_value) else: num.append(v) for v in orig_denum: - ct = get_constant(v) - if ct is not None: - denumct.append(ct) + if isinstance(v, TensorConstant) and v.unique_value is not None: + denumct.append(v.unique_value) else: denum.append(v) @@ -1041,10 +1059,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: # In that case we should only have one constant in `ct`. - assert len(ct) == 1 - first_num_ct = get_constant(orig_num[0]) - if first_num_ct is not None and ct[0].type.values_eq( - ct[0].data, first_num_ct + [var_ct] = ct + first_num_var = orig_num[0] + first_num_ct = ( + first_num_var.unique_value + if isinstance(first_num_var, TensorConstant) + else None + ) + if first_num_ct is not None and var_ct.type.values_eq( + var_ct.data, first_num_ct ): # This is an important trick :( if it so happens that: # * there's exactly one constant on the numerator and none on @@ -1095,7 +1118,9 @@ def transform(self, fgraph, node): num, denum = self.simplify(list(orig_num), list(orig_denum), out.type) def same(x, y): - return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in zip(x, y)) + return len(x) == len(y) and all( + np.all(xe == ye) for xe, ye in zip(x, y, strict=True) + ) if ( same(orig_num, num) @@ -1334,12 +1359,13 @@ def local_useless_elemwise_comparison(fgraph, node): the graph easier to read. """ + # TODO: Refactor this function. So much repeated code! + if node.op.scalar_op.nin != 2: return - # We call zeros_like and one_like with opt=True to generate a - # cleaner graph. - dtype = node.outputs[0].dtype + dtype = node.outputs[0].type.dtype + out_bcast = node.outputs[0].type.broadcastable # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) if ( @@ -1350,6 +1376,7 @@ def local_useless_elemwise_comparison(fgraph, node): # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.LE | ps.GE) @@ -1360,6 +1387,7 @@ def local_useless_elemwise_comparison(fgraph, node): # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[{minimum,maximum}](X, X) -> X if ( isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum) @@ -1375,64 +1403,72 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(node.op.scalar_op, ps.LT) and node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.GE) and node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = ones_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[maximum](X.shape[i], 0) -> X.shape[i] - if ( - isinstance(node.op.scalar_op, ps.ScalarMaximum) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - # No need to copy over stacktrace. - return [node.inputs[0]] - # Elemwise[maximum](0, X.shape[i]) -> X.shape[i] - if ( - isinstance(node.op.scalar_op, ps.ScalarMaximum) - and extract_constant(node.inputs[0], only_process_constants=True) == 0 - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Shape_i) - ): - # No need to copy over stacktrace. - return [node.inputs[1]] - # Elemwise[minimum](X.shape[i], 0) -> 0 - if ( - isinstance(node.op.scalar_op, ps.ScalarMinimum) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] + if isinstance(node.op.scalar_op, ps.ScalarMaximum): + for idx in range(2): + if ( + node.inputs[idx].owner + and isinstance(node.inputs[idx].owner.op, Shape_i) + and get_underlying_scalar_constant_value( + node.inputs[1 - idx], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + ): + res = node.inputs[idx] + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1 - idx])[0] + # No need to copy over stacktrace. + return [res] - # Elemwise[minimum](0, X.shape[i]) -> 0 - if ( - isinstance(node.op.scalar_op, ps.ScalarMinimum) - and extract_constant(node.inputs[0], only_process_constants=True) == 0 - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Shape_i) - ): - res = zeros_like(node.inputs[1], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] + # Elemwise[minimum](X.shape[i], 0) -> 0 + if isinstance(node.op.scalar_op, ps.ScalarMinimum): + for idx in range(2): + if ( + node.inputs[idx].owner + and isinstance(node.inputs[idx].owner.op, Shape_i) + and get_underlying_scalar_constant_value( + node.inputs[1 - idx], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + ): + res = zeros_like(node.inputs[idx], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1 - idx])[0] + # No need to copy over stacktrace. + return [res] # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) if ( @@ -1444,12 +1480,18 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(var.owner and var.owner.op, Shape_i) for var in node.inputs[0].owner.inputs ) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.GE) @@ -1460,57 +1502,61 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(var.owner and var.owner.op, Shape_i) for var in node.inputs[0].owner.inputs ) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = ones_like(node.inputs[0], dtype=dtype, opt=True) - + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] - # Elemwise[EQ](Subtensor(Shape(x)), -N) - # Elemwise[EQ](somegraph that only depend of shape, -N) - # TODO: handle the case where the -N is on either side - """ - |Elemwise{eq,no_inplace} [id B] '' - | |Subtensor{int64} [id C] '' - | | |Join [id D] '' - | | | |TensorConstant{0} [id E] - | | | |Subtensor{int64:int64:} [id F] '' - | | | | |Shape [id G] '' - """ + # Elemwise[EQ](Subtensor(Shape(x)), -N) + # Elemwise[EQ](somegraph that only depend of shape, -N) + # TODO: handle the case where the -N is on either side + """ +|Elemwise{eq,no_inplace} [id B] '' +| |Subtensor{int64} [id C] '' +| | |Join [id D] '' +| | | |TensorConstant{0} [id E] +| | | |Subtensor{int64:int64:} [id F] '' +| | | | |Shape [id G] '' + """ - def investigate(node): + def investigate_if_shape(node) -> bool: "Return True if values will be shapes, so >= 0" if isinstance(node.op, Shape | Shape_i): return True elif isinstance(node.op, Subtensor) and node.inputs[0].owner: - return investigate(node.inputs[0].owner) + return investigate_if_shape(node.inputs[0].owner) elif isinstance(node.op, Join): - return all(v.owner and investigate(v.owner) for v in node.inputs[1:]) + return all( + v.owner and investigate_if_shape(v.owner) for v in node.inputs[1:] + ) elif isinstance(node.op, MakeVector): - return all(v.owner and investigate(v.owner) for v in node.inputs) + return all(v.owner and investigate_if_shape(v.owner) for v in node.inputs) + return False if ( isinstance(node.op.scalar_op, ps.EQ) and node.inputs[0].owner - and investigate(node.inputs[0].owner) + and investigate_if_shape(node.inputs[0].owner) + and ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[1].unique_value is not None + and node.inputs[1].unique_value < 0 + ) ): - try: - cst = get_underlying_scalar_constant_value( - node.inputs[1], only_process_constants=True - ) - - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - - if cst < 0: - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - - return [res] + res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] - except NotScalarConstantError: - pass return @@ -1807,12 +1853,6 @@ def local_add_neg_to_sub(fgraph, node): new_out = sub(first, pre_neg) return [new_out] - # Check if it is a negative constant - const = get_constant(second) - if const is not None and const < 0: - new_out = sub(first, np.abs(const)) - return [new_out] - @register_canonicalize @node_rewriter([mul]) @@ -1839,7 +1879,12 @@ def local_mul_zero(fgraph, node): @register_specialize @node_rewriter([true_div]) def local_div_to_reciprocal(fgraph, node): - if np.all(get_constant(node.inputs[0]) == 1.0): + if ( + get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + == 1.0 + ): out = node.outputs[0] new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) # The ones could have forced upcasting @@ -1860,7 +1905,9 @@ def local_reciprocal_canon(fgraph, node): @register_canonicalize @node_rewriter([pt_pow]) def local_pow_canonicalize(fgraph, node): - cst = get_constant(node.inputs[1]) + cst = get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) if cst == 0: return [alloc_like(1, node.outputs[0], fgraph)] if cst == 1: @@ -1891,7 +1938,12 @@ def local_intdiv_by_one(fgraph, node): @node_rewriter([int_div, true_div]) def local_zero_div(fgraph, node): """0 / x -> 0""" - if get_constant(node.inputs[0]) == 0: + if ( + get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + == 0 + ): ret = alloc_like(0, node.outputs[0], fgraph) ret.tag.values_eq_approx = values_eq_approx_remove_nan return [ret] @@ -1904,8 +1956,12 @@ def local_pow_specialize(fgraph, node): odtype = node.outputs[0].dtype xsym = node.inputs[0] ysym = node.inputs[1] - y = get_constant(ysym) - if (y is not None) and not broadcasted_by(xsym, ysym): + try: + y = get_underlying_scalar_constant_value(ysym, only_process_constants=True) + except NotScalarConstantError: + return + + if not broadcasted_by(xsym, ysym): rval = None if np.all(y == 2): @@ -1939,10 +1995,14 @@ def local_pow_to_nested_squaring(fgraph, node): """ # the idea here is that we have pow(x, y) + xsym, ysym = node.inputs + + try: + y = get_underlying_scalar_constant_value(ysym, only_process_constants=True) + except NotScalarConstantError: + return + odtype = node.outputs[0].dtype - xsym = node.inputs[0] - ysym = node.inputs[1] - y = get_constant(ysym) # the next line is needed to fix a strange case that I don't # know how to make a separate test. @@ -1958,7 +2018,7 @@ def local_pow_to_nested_squaring(fgraph, node): y = y[0] except IndexError: pass - if (y is not None) and not broadcasted_by(xsym, ysym): + if not broadcasted_by(xsym, ysym): rval = None # 512 is too small for the cpu and too big for some gpu! if abs(y) == int(abs(y)) and abs(y) <= 512: @@ -2025,7 +2085,9 @@ def local_mul_specialize(fgraph, node): nb_neg_node += 1 # remove special case arguments of 1, -1 or 0 - y = get_constant(inp) + y = get_underlying_scalar_constant_value( + inp, only_process_constants=True, raise_not_constant=False + ) if y == 1.0: nb_cst += 1 elif y == -1.0: @@ -2077,7 +2139,7 @@ def local_add_remove_zeros(fgraph, node): y = get_underlying_scalar_constant_value(inp) except NotScalarConstantError: y = inp - if np.all(y == 0.0): + if y == 0.0: continue new_inputs.append(inp) @@ -2175,7 +2237,7 @@ def local_abs_merge(fgraph, node): ) except NotScalarConstantError: return False - if not (const >= 0).all(): + if not const >= 0: return False inputs.append(i) else: @@ -2212,12 +2274,21 @@ def local_log1p(fgraph, node): return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] elif log_arg.owner and log_arg.owner.op == sub: - one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) + one, other = log_arg.owner.inputs + try: + one = get_underlying_scalar_constant_value(one, only_process_constants=True) + except NotScalarConstantError: + return + if one != 1: return - other = log_arg.owner.inputs[1] - if other.dtype != log_arg.dtype: + + if other.type.broadcastable != log_arg.type.broadcastable: + other = broadcast_arrays(other, one)[0] + + if other.type.dtype != log_arg.type.dtype: other = other.astype(log_arg.dtype) + return [log1p(neg(other))] @@ -2368,7 +2439,9 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0): [(n + num, d + denum, out_type) for (n, d) in neg_pairs], ) ) - for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs): + for (n, d), (nn, dd) in zip( + pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs, strict=True + ): # We calculate how many operations we are saving with the new # num and denum score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd) @@ -2553,9 +2626,9 @@ def local_greedy_distributor(fgraph, node): register_stabilize(local_one_minus_erfc) register_specialize(local_one_minus_erfc) -# erfc(-x)-1=>erf(x) +# -1 + erfc(-x)=>erf(x) local_erf_neg_minus_one = PatternNodeRewriter( - (sub, (erfc, (neg, "x")), 1), + (add, -1, (erfc, (neg, "x"))), (erf, "x"), allow_multiple_clients=True, name="local_erf_neg_minus_one", @@ -2816,7 +2889,7 @@ def _is_1(expr): """ try: v = get_underlying_scalar_constant_value(expr) - return np.allclose(v, 1) + return np.isclose(v, 1) except NotScalarConstantError: return False @@ -2984,7 +3057,7 @@ def is_neg(var): for idx, mul_input in enumerate(var_node.inputs): try: constant = get_underlying_scalar_constant_value(mul_input) - is_minus_1 = np.allclose(constant, -1) + is_minus_1 = np.isclose(constant, -1) except NotScalarConstantError: is_minus_1 = False if is_minus_1: @@ -3491,3 +3564,18 @@ def local_useless_conj(fgraph, node): ) register_specialize(local_polygamma_to_tri_gamma) + + +local_log_kv = PatternNodeRewriter( + # Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x + # During stabilize -x is converted to -1.0 * x + (log, (mul, (kve, "v", "x"), (exp, (mul, -1.0, "x")))), + (sub, (log, (kve, "v", "x")), "x"), + allow_multiple_clients=True, + name="local_log_kv", + # Start the rewrite from the less likely kve node + tracks=[kve], + get_nodes=get_clients_at_depth2, +) + +register_stabilize(local_log_kv) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py new file mode 100644 index 0000000000..91ab131424 --- /dev/null +++ b/pytensor/tensor/rewriting/numba.py @@ -0,0 +1,108 @@ +from pytensor.compile import optdb +from pytensor.graph import node_rewriter +from pytensor.graph.basic import applys_between +from pytensor.graph.rewriting.basic import out2in +from pytensor.tensor.basic import as_tensor, constant +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape +from pytensor.tensor.rewriting.shape import ShapeFeature + + +@node_rewriter([Blockwise]) +def introduce_explicit_core_shape_blockwise(fgraph, node): + """Introduce the core shape of a Blockwise. + + We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph + that has an extra "non-functional" input that represents the core shape of the Blockwise variable. + This core_shape is used by the numba backend to pre-allocate the output array. + + If available, the core shape is extracted from the shape feature of the graph, + which has a higher change of having been simplified, optimized, constant-folded. + If missing, we fall back to the op._supp_shape_from_params method. + + This rewrite is required for the numba backend implementation of Blockwise. + + Example + ------- + + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(5, None, None)) + outs = pt.linalg.svd(x, compute_uv=True) + pytensor.dprint(outs) + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.0 [id A] + # โ””โ”€ x [id B] + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.1 [id A] + # โ””โ”€ ยทยทยท + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.2 [id A] + # โ””โ”€ ยทยทยท + + # After the rewrite, note the new 3 core shape inputs + fn = pytensor.function([x], outs, mode="NUMBA") + fn.dprint(print_type=False) + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].0 [id A] 6 + # โ”œโ”€ x [id B] + # โ”œโ”€ MakeVector{dtype='int64'} [id C] 5 + # โ”‚ โ”œโ”€ Shape_i{1} [id D] 2 + # โ”‚ โ”‚ โ””โ”€ x [id B] + # โ”‚ โ””โ”€ Shape_i{1} [id D] 2 + # โ”‚ โ””โ”€ ยทยทยท + # โ”œโ”€ MakeVector{dtype='int64'} [id E] 4 + # โ”‚ โ””โ”€ Minimum [id F] 3 + # โ”‚ โ”œโ”€ Shape_i{1} [id D] 2 + # โ”‚ โ”‚ โ””โ”€ ยทยทยท + # โ”‚ โ””โ”€ Shape_i{2} [id G] 0 + # โ”‚ โ””โ”€ x [id B] + # โ””โ”€ MakeVector{dtype='int64'} [id H] 1 + # โ”œโ”€ Shape_i{2} [id G] 0 + # โ”‚ โ””โ”€ ยทยทยท + # โ””โ”€ Shape_i{2} [id G] 0 + # โ””โ”€ ยทยทยท + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].1 [id A] 6 + # โ””โ”€ ยทยทยท + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6 + # โ””โ”€ ยทยทยท + """ + op: Blockwise = node.op # type: ignore[annotation-unchecked] + batch_ndim = op.batch_ndim(node) + + shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked] + if shape_feature: + core_shapes = [ + [shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] + for out in node.outputs + ] + else: + input_shapes = [tuple(inp.shape) for inp in node.inputs] + core_shapes = [ + out_shape[batch_ndim:] + for out_shape in op.infer_shape(None, node, input_shapes) + ] + + core_shapes = [ + as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64") + for core_shape in core_shapes + ] + + if any( + isinstance(node.op, Blockwise) + for node in applys_between(node.inputs, core_shapes) + ): + # If Blockwise shows up in the shape graph we can't introduce the core shape + return None + + return BlockwiseWithCoreShape( + [*node.inputs, *core_shapes], + node.outputs, + destroy_map=op.destroy_map, + )(*node.inputs, *core_shapes, return_list=True) + + +optdb.register( + introduce_explicit_core_shape_blockwise.__name__, + out2in(introduce_explicit_core_shape_blockwise), + "numba", + position=100, +) diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 2c4dfc4f70..52472de47b 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -13,7 +13,7 @@ def inline_ofg_node(node: Apply) -> list[Variable]: op = node.op assert isinstance(op, OpFromGraph) inlined_outs = clone_replace( - op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)) + op.inner_outputs, dict(zip(op.inner_inputs, node.inputs, strict=True)) ) copy_stack_trace(op.inner_outputs, inlined_outs) return cast(list[Variable], inlined_outs) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 91c731a4ff..e277772ad4 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -22,8 +22,7 @@ as_tensor_variable, cast, constant, - extract_constant, - get_underlying_scalar_constant_value, + get_scalar_constant_value, register_infer_shape, stack, ) @@ -185,7 +184,7 @@ def get_shape(self, var, idx): # Only change the variables and dimensions that would introduce # extra computation - for new_shps, out in zip(o_shapes, node.outputs): + for new_shps, out in zip(o_shapes, node.outputs, strict=True): if not hasattr(out.type, "ndim"): continue @@ -213,7 +212,7 @@ def shape_ir(self, i, r): # Do not call make_node for test_value s = Shape_i(i)(r) try: - s = get_underlying_scalar_constant_value(s) + s = get_scalar_constant_value(s) except NotScalarConstantError: pass return s @@ -297,7 +296,7 @@ def unpack(self, s_i, var): assert len(idx) == 1 idx = idx[0] try: - i = get_underlying_scalar_constant_value(idx) + i = get_scalar_constant_value(idx) except NotScalarConstantError: pass else: @@ -354,7 +353,9 @@ def set_shape(self, r, s, override=False): not hasattr(r.type, "shape") or r.type.shape[i] != 1 or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals(extract_constant(shape_vars[i])) + or self.lscalar_one.equals( + get_scalar_constant_value(shape_vars[i], raise_not_constant=False) + ) for i in range(r.type.ndim) ) self.shape_of[r] = tuple(shape_vars) @@ -450,7 +451,11 @@ def update_shape(self, r, other_r): ) or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals( - extract_constant(merged_shape[i], only_process_constants=True) + get_scalar_constant_value( + merged_shape[i], + only_process_constants=True, + raise_not_constant=False, + ) ) for i in range(r.type.ndim) ) @@ -474,7 +479,9 @@ def set_shape_i(self, r, i, s_i): not hasattr(r.type, "shape") or r.type.shape[idx] != 1 or self.lscalar_one.equals(new_shape[idx]) - or self.lscalar_one.equals(extract_constant(new_shape[idx])) + or self.lscalar_one.equals( + get_scalar_constant_value(new_shape[idx], raise_not_constant=False) + ) for idx in range(r.type.ndim) ) self.shape_of[r] = tuple(new_shape) @@ -577,7 +584,7 @@ def on_import(self, fgraph, node, reason): new_shape += sh[len(new_shape) :] o_shapes[sh_idx] = tuple(new_shape) - for r, s in zip(node.outputs, o_shapes): + for r, s in zip(node.outputs, o_shapes, strict=True): self.set_shape(r, s) def on_change_input(self, fgraph, node, i, r, new_r, reason): @@ -708,7 +715,7 @@ def same_shape( sx = canon_shapes[: len(sx)] sy = canon_shapes[len(sx) :] - for dx, dy in zip(sx, sy): + for dx, dy in zip(sx, sy, strict=True): if not equal_computations([dx], [dy]): return False @@ -776,7 +783,7 @@ def local_reshape_chain(fgraph, node): # rewrite. if rval.type.ndim == node.outputs[0].type.ndim and all( s1 == s2 - for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape) + for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape, strict=True) if s1 == 1 or s2 == 1 ): return [rval] @@ -847,7 +854,10 @@ def local_useless_reshape(fgraph, node): outshp_i.owner and isinstance(outshp_i.owner.op, Subtensor) and len(outshp_i.owner.inputs) == 2 - and extract_constant(outshp_i.owner.inputs[1]) == dim + and get_scalar_constant_value( + outshp_i.owner.inputs[1], raise_not_constant=False + ) + == dim ): subtensor_inp = outshp_i.owner.inputs[0] if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): @@ -857,7 +867,9 @@ def local_useless_reshape(fgraph, node): continue # Match constant if input.type.shape[dim] == constant - cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) + cst_outshp_i = get_scalar_constant_value( + outshp_i, only_process_constants=True, raise_not_constant=False + ) if inp.type.shape[dim] == cst_outshp_i: shape_match[dim] = True continue @@ -872,8 +884,12 @@ def local_useless_reshape(fgraph, node): if shape_feature: inpshp_i = shape_feature.get_shape(inp, dim) if inpshp_i == outshp_i or ( - extract_constant(inpshp_i, only_process_constants=True) - == extract_constant(outshp_i, only_process_constants=True) + get_scalar_constant_value( + inpshp_i, only_process_constants=True, raise_not_constant=False + ) + == get_scalar_constant_value( + outshp_i, only_process_constants=True, raise_not_constant=False + ) ): shape_match[dim] = True continue @@ -909,11 +925,14 @@ def local_reshape_to_dimshuffle(fgraph, node): new_output_shape = [] index = 0 # index over the output of the new reshape for i in range(output.ndim): - # Since output_shape is a symbolic vector, we trust extract_constant + # Since output_shape is a symbolic vector, we trust get_scalar_constant_value # to go through however it is formed to see if its i-th element is 1. # We need only_process_constants=False for that. - dim = extract_constant( - output_shape[i], only_process_constants=False, elemwise=False + dim = get_scalar_constant_value( + output_shape[i], + only_process_constants=False, + elemwise=False, + raise_not_constant=False, ) if dim == 1: dimshuffle_new_order.append("x") @@ -1087,7 +1106,9 @@ def local_specify_shape_lift(fgraph, node): nonbcast_dims = { i - for i, (dim, bcast) in enumerate(zip(shape, out_broadcastable)) + for i, (dim, bcast) in enumerate( + zip(shape, out_broadcastable, strict=True) + ) if (not bcast and not NoneConst.equals(dim)) } new_elem_inps = elem_inps.copy() @@ -1189,7 +1210,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): new_order = node.inputs[0].owner.op.new_order inp = node.inputs[0].owner.inputs[0] new_order_of_nonbroadcast = [] - for i, s in zip(new_order, node.inputs[0].type.shape): + for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): if s != 1: new_order_of_nonbroadcast.append(i) no_change_in_order = all( @@ -1213,7 +1234,7 @@ def local_useless_unbroadcast(fgraph, node): x = node.inputs[0] if x.type.ndim == node.outputs[0].type.ndim and all( s1 == s2 - for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape) + for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True) if s1 == 1 or s2 == 1 ): # No broadcastable flag was modified diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index cb453a44e4..4b824e46cf 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -26,7 +26,7 @@ as_tensor, cast, concatenate, - extract_constant, + get_scalar_constant_value, get_underlying_scalar_constant_value, register_infer_shape, switch, @@ -85,7 +85,7 @@ inc_subtensor, indices_from_subtensor, ) -from pytensor.tensor.type import TensorType +from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -249,7 +249,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): This is only done when there's a single vector index. """ - if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates: + if node.op.ignore_duplicates: # `AdvancedIncSubtensor1` does not ignore duplicate index values return @@ -390,8 +390,8 @@ def local_useless_slice(fgraph, node): start = s.start stop = s.stop - if start is not None and extract_constant( - start, only_process_constants=True + if start is not None and get_scalar_constant_value( + start, only_process_constants=True, raise_not_constant=False ) == (0 if positive_step else -1): change_flag = True start = None @@ -399,7 +399,9 @@ def local_useless_slice(fgraph, node): if ( stop is not None and x.type.shape[dim] is not None - and extract_constant(stop, only_process_constants=True) + and get_scalar_constant_value( + stop, only_process_constants=True, raise_not_constant=False + ) == (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1) ): change_flag = True @@ -683,7 +685,7 @@ def local_subtensor_of_alloc(fgraph, node): # Slices to take from val val_slices = [] - for i, (sl, dim) in enumerate(zip(slices, dims)): + for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)): # If val was not copied over that dim, # we need to take the appropriate subtensor on it. if i >= n_added_dims: @@ -889,7 +891,10 @@ def local_useless_inc_subtensor(fgraph, node): and e.stop is None and ( e.step is None - or extract_constant(e.step, only_process_constants=True) == -1 + or get_scalar_constant_value( + e.step, only_process_constants=True, raise_not_constant=False + ) + == -1 ) for e in idx_cst ): @@ -994,7 +999,7 @@ def local_useless_subtensor(fgraph, node): if isinstance(idx.stop, int | np.integer): length_pos_data = sys.maxsize try: - length_pos_data = get_underlying_scalar_constant_value( + length_pos_data = get_scalar_constant_value( length_pos, only_process_constants=True ) except NotScalarConstantError: @@ -1059,7 +1064,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node): # get length of the indexed tensor along the first axis try: - length = get_underlying_scalar_constant_value( + length = get_scalar_constant_value( shape_of[node.inputs[0]][0], only_process_constants=True ) except NotScalarConstantError: @@ -1490,7 +1495,10 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node): and # Don't use only_process_constants=True. We need to # investigate Alloc of 0s but with non constant shape. - extract_constant(x, elemwise=False) != 0 + get_underlying_scalar_constant_value( + x, elemwise=False, raise_not_constant=False + ) + != 0 ): return @@ -1728,7 +1736,7 @@ def local_join_subtensors(fgraph, node): axis, tensors = node.inputs[0], node.inputs[1:] try: - axis = get_underlying_scalar_constant_value(axis) + axis = get_scalar_constant_value(axis) except NotScalarConstantError: return @@ -1789,12 +1797,7 @@ def local_join_subtensors(fgraph, node): if step is None: continue try: - if ( - get_underlying_scalar_constant_value( - step, only_process_constants=True - ) - != 1 - ): + if get_scalar_constant_value(step, only_process_constants=True) != 1: return None except NotScalarConstantError: return None @@ -1803,7 +1806,7 @@ def local_join_subtensors(fgraph, node): if all( idxs_nonaxis_subtensor1 == idxs_nonaxis_subtensor2 for i, (idxs_nonaxis_subtensor1, idxs_nonaxis_subtensor2) in enumerate( - zip(idxs_subtensor1, idxs_subtensor2) + zip(idxs_subtensor1, idxs_subtensor2, strict=True) ) if i != axis ): @@ -1945,7 +1948,7 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): x_batch_bcast = x.type.broadcastable[:batch_ndim] y_batch_bcast = y.type.broadcastable[:batch_ndim] - if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)): + if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)): # Need to broadcast batch x dims batch_shape = tuple( x_dim if (not xb or yb) else y_dim @@ -1954,6 +1957,7 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): tuple(x.shape)[:batch_ndim], y_batch_bcast, tuple(y.shape)[:batch_ndim], + strict=True, ) ) core_shape = tuple(x.shape)[batch_ndim:] @@ -1966,19 +1970,26 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): return new_out -@node_rewriter(tracks=[AdvancedSubtensor]) +@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) def ravel_multidimensional_bool_idx(fgraph, node): """Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()] + x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape) """ - x, *idxs = node.inputs + if isinstance(node.op, AdvancedSubtensor): + x, *idxs = node.inputs + else: + x, y, *idxs = node.inputs if any( - isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int") + ( + (isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes) + or isinstance(idx.type, NoneTypeT) + ) for idx in idxs ): - # Get out if there are any other advanced indexes + # Get out if there are any other advanced indexes or np.newaxis return None bool_idxs = [ @@ -2006,7 +2017,16 @@ def ravel_multidimensional_bool_idx(fgraph, node): new_idxs = list(idxs) new_idxs[bool_idx_pos] = raveled_bool_idx - return [raveled_x[tuple(new_idxs)]] + if isinstance(node.op, AdvancedSubtensor): + new_out = node.op(raveled_x, *new_idxs) + else: + # The dimensions of y that correspond to the boolean indices + # must already be raveled in the original graph, so we don't need to do anything to it + new_out = node.op(raveled_x, y, *new_idxs) + # But we must reshape the output to math the original shape + new_out = new_out.reshape(x_shape) + + return [copy_stack_trace(node.outputs[0], new_out)] @node_rewriter(tracks=[AdvancedSubtensor]) @@ -2023,16 +2043,19 @@ def ravel_multidimensional_int_idx(fgraph, node): x, *idxs = node.inputs if any( - isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool") + ( + (isinstance(idx.type, TensorType) and idx.type.dtype == "bool") + or isinstance(idx.type, NoneTypeT) + ) for idx in idxs ): - # Get out if there are any other advanced indexes + # Get out if there are any other advanced indexes or np.newaxis return None int_idxs = [ (i, idx) for i, idx in enumerate(idxs) - if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int")) + if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes) ] if len(int_idxs) != 1: @@ -2058,7 +2081,8 @@ def ravel_multidimensional_int_idx(fgraph, node): *int_idx.shape, *raveled_shape[int_idx_pos + 1 :], ) - return [raveled_subtensor.reshape(unraveled_shape)] + new_out = raveled_subtensor.reshape(unraveled_shape) + return [copy_stack_trace(node.outputs[0], new_out)] optdb["specialize"].register( diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 2193c11575..82ac260085 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -4,7 +4,14 @@ from typing import cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore + + +try: + from numpy.lib.array_utils import normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.gradient import DisconnectedType @@ -20,7 +27,7 @@ from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor -from pytensor.tensor.type_other import NoneConst +from pytensor.tensor.type_other import NoneConst, NoneTypeT from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -401,8 +408,6 @@ class SpecifyShape(COp): _output_type_depends_on_input_value = True def make_node(self, x, *shape): - from pytensor.tensor.basic import get_underlying_scalar_constant_value - x = ptb.as_tensor_variable(x) shape = tuple( @@ -425,14 +430,12 @@ def make_node(self, x, *shape): ) type_shape = [None] * x.ndim - for i, (xts, s) in enumerate(zip(x.type.shape, shape)): + for i, (xts, s) in enumerate(zip(x.type.shape, shape, strict=True)): if xts is not None: type_shape[i] = xts - else: + elif not isinstance(s.type, NoneTypeT): try: - type_s = get_underlying_scalar_constant_value(s) - if type_s is not None: - type_shape[i] = int(type_s) + type_shape[i] = int(ptb.get_scalar_constant_value(s)) except NotScalarConstantError: pass @@ -448,7 +451,10 @@ def perform(self, node, inp, out_): raise AssertionError( f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." ) - if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None): + # strict=False because we are in a hot loop + if not all( + xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None + ): raise AssertionError( f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}." ) @@ -457,22 +463,13 @@ def perform(self, node, inp, out_): def infer_shape(self, fgraph, node, shapes): xshape, *_ = shapes shape = node.inputs[1:] - new_shape = [] - for dim in range(node.inputs[0].type.ndim): - s = shape[dim] - try: - s = ptb.get_underlying_scalar_constant_value(s) - # We assume that `None` shapes are always retrieved by - # `get_underlying_scalar_constant_value`, and only in that case do we default to - # the shape of the input variable - if s is None: - s = xshape[dim] - except NotScalarConstantError: - pass - new_shape.append(ptb.as_tensor_variable(s)) - - assert len(new_shape) == len(xshape) - return [new_shape] + # Use x shape if specified dim is None, otherwise the specified shape + return [ + [ + xshape[i] if isinstance(dim.type, NoneTypeT) else dim + for i, dim in enumerate(shape) + ] + ] def connection_pattern(self, node): return [[True], *[[False]] * len(node.inputs[1:])] @@ -512,7 +509,9 @@ def c_code(self, node, name, i_names, o_names, sub): """ ) - for i, (shp_name, shp) in enumerate(zip(shape_names, node.inputs[1:])): + for i, (shp_name, shp) in enumerate( + zip(shape_names, node.inputs[1:], strict=True) + ): if NoneConst.equals(shp): continue code += dedent( @@ -575,8 +574,9 @@ def specify_shape( # The above is a type error in Python 3.9 but not 3.12. # Thus we need to ignore unused-ignore on 3.12. new_shape_info = any( - s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None + s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None ) + # If shape does not match x.ndim, we rely on the `Op` to raise a ValueError if not new_shape_info and len(shape) == x.type.ndim: return x @@ -587,7 +587,7 @@ def specify_shape( @_get_vector_length.register(SpecifyShape) # type: ignore def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int: try: - return int(ptb.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()) + return int(ptb.get_scalar_constant_value(var.owner.inputs[1]).item()) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -668,7 +668,7 @@ def make_node(self, x, shp): y = shp_list[index] y = ptb.as_tensor_variable(y) try: - s_val = ptb.get_underlying_scalar_constant_value(y).item() + s_val = ptb.get_scalar_constant_value(y).item() if s_val >= 0: out_shape[index] = s_val except NotScalarConstantError: diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 802ca6e543..98ec43ba0c 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -6,6 +6,7 @@ import numpy as np import scipy.linalg +from numpy.exceptions import ComplexWarning import pytensor import pytensor.tensor as pt @@ -259,9 +260,10 @@ def make_node(self, A, b): raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") # Infer dtype by solving the most simple case with 1x1 matrices - o_dtype = scipy.linalg.solve( - np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype) - ).dtype + inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)] + out_arr = [[None]] + self.perform(None, inp_arr, out_arr) + o_dtype = out_arr[0][0].dtype x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) @@ -766,7 +768,7 @@ def perform(self, node, inputs, outputs): Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T) with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) out[0] = Y.astype(A.dtype) @@ -1092,7 +1094,7 @@ def grad(self, inputs, gout): return [gout[0][slc] for slc in slices] def infer_shape(self, fgraph, nodes, shapes): - first, second = zip(*shapes) + first, second = zip(*shapes, strict=True) return [(pt.add(*first), pt.add(*second))] def _validate_and_prepare_inputs(self, matrices, as_tensor_func): diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index a2f02fabd8..5b05ad03f4 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -6,6 +6,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp +from pytensor.npy_2_compat import npy_2_compat_header from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.math import gamma, gammaln, log, neg, sum @@ -60,12 +61,16 @@ def infer_shape(self, fgraph, node, shape): return [shape[1]] def c_code_cache_version(self): - return (4,) + return (5,) + + def c_support_code_apply(self, node: Apply, name: str) -> str: + # return super().c_support_code_apply(node, name) + return npy_2_compat_header() def c_code(self, node, name, inp, out, sub): dy, sm = inp (dx,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] return dedent( @@ -79,7 +84,7 @@ def c_code(self, node, name, inp, out, sub): int sm_ndim = PyArray_NDIM({sm}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || sm_ndim == 1); // Validate inputs if ((PyArray_TYPE({dy}) != NPY_DOUBLE) && @@ -95,13 +100,15 @@ def c_code(self, node, name, inp, out, sub): {fail}; }} - if (axis < 0) axis = sm_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > sm_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad"); - {fail}; + if (axis < 0) axis = sm_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > sm_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad"); + {fail}; + }} }} - if (({dx} == NULL) || !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim))) {{ @@ -289,10 +296,14 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return ["", ""] + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] # dtype = node.inputs[0].type.dtype_specs()[1] # TODO: put this into a templated function, in the support code @@ -309,7 +320,7 @@ def c_code(self, node, name, inp, out, sub): int x_ndim = PyArray_NDIM({x}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && @@ -319,11 +330,14 @@ def c_code(self, node, name, inp, out, sub): {fail} }} - if (axis < 0) axis = x_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax"); - {fail} + if (axis < 0) axis = x_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax"); + {fail} + }} }} // Allocate Output Array @@ -481,7 +495,7 @@ def c_code(self, node, name, inp, out, sub): @staticmethod def c_code_cache_version(): - return (4,) + return (5,) def softmax(c, axis=None): @@ -541,10 +555,14 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return [""] + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] return dedent( @@ -558,7 +576,7 @@ def c_code(self, node, name, inp, out, sub): int x_ndim = PyArray_NDIM({x}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && @@ -568,13 +586,15 @@ def c_code(self, node, name, inp, out, sub): {fail} }} - if (axis < 0) axis = x_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); - {fail} + if (axis < 0) axis = x_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); + {fail} + }} }} - // Allocate Output Array if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim))) {{ @@ -730,7 +750,7 @@ def c_code(self, node, name, inp, out, sub): @staticmethod def c_code_cache_version(): - return (1,) + return (2,) def log_softmax(c, axis=None): diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 87a62cad81..882daea9f3 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -18,6 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import npy_2_compat_header from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -29,7 +30,7 @@ from pytensor.tensor.basic import ( ScalarFromTensor, alloc, - get_underlying_scalar_constant_value, + get_scalar_constant_value, nonzero, scalar_from_tensor, ) @@ -522,7 +523,7 @@ def basic_shape(shape, indices): """ res_shape = () - for idx, n in zip(indices, shape): + for n, idx in zip(shape[: len(indices)], indices, strict=True): if isinstance(idx, slice): res_shape += (slice_len(idx, n),) elif isinstance(getattr(idx, "type", None), SliceType): @@ -610,7 +611,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ) for basic, grp_dim_indices in idx_groups: - dim_nums, grp_indices = zip(*grp_dim_indices) + dim_nums, grp_indices = zip(*grp_dim_indices, strict=True) remaining_dims = tuple(dim for dim in remaining_dims if dim not in dim_nums) if basic: @@ -756,13 +757,15 @@ def get_constant_idx( Example usage where `v` and `a` are appropriately typed PyTensor variables : >>> from pytensor.scalar import int64 >>> from pytensor.tensor import matrix + >>> import numpy as np + >>> >>> v = int64("v") >>> a = matrix("a") >>> b = a[v, 1:3] >>> b.owner.op.idx_list (ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None)) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True) - [v, slice(1, 3, None)] + [v, slice(np.int64(1), np.int64(3), None)] >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs) Traceback (most recent call last): pytensor.tensor.exceptions.NotScalarConstantError @@ -778,7 +781,7 @@ def conv(val): return slice(conv(val.start), conv(val.stop), conv(val.step)) else: try: - return get_underlying_scalar_constant_value( + return get_scalar_constant_value( val, only_process_constants=only_process_constants, elemwise=elemwise, @@ -838,7 +841,7 @@ def make_node(self, x, *inputs): assert len(inputs) == len(input_types) - for input, expected_type in zip(inputs, input_types): + for input, expected_type in zip(inputs, input_types, strict=True): if not expected_type.is_super(input.type): raise TypeError( f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}." @@ -855,12 +858,12 @@ def extract_const(value): if value is None: return value, True try: - value = get_underlying_scalar_constant_value(value) + value = get_scalar_constant_value(value) return value, True except NotScalarConstantError: return value, False - for the_slice, length in zip(padded, x.type.shape): + for the_slice, length in zip(padded, x.type.shape, strict=True): if not isinstance(the_slice, slice): continue @@ -915,7 +918,7 @@ def infer_shape(self, fgraph, node, shapes): len(xshp) - len(self.idx_list) ) i = 0 - for idx, xl in zip(padded, xshp): + for idx, xl in zip(padded, xshp, strict=True): if isinstance(idx, slice): # If it is the default (None, None, None) slice, or a variant, # the shape will be xl @@ -1456,11 +1459,8 @@ def inc_subtensor( views; if they overlap, the result of this `Op` will generally be incorrect. This value has no effect if ``inplace=False``. ignore_duplicates - This determines whether or not ``x[indices] += y`` is used or - ``np.add.at(x, indices, y)``. When the special duplicates handling of - ``np.add.at`` isn't required, setting this option to ``True`` - (i.e. using ``x[indices] += y``) can resulting in faster compiled - graphs. + This determines whether ``x[indices] += y`` is used or + ``np.add.at(x, indices, y)``. Examples -------- @@ -1687,7 +1687,7 @@ def make_node(self, x, y, *inputs): raise IndexError( "Not enough inputs to fill in the Subtensor template.", inputs, idx_list ) - for input, expected_type in zip(inputs, input_types): + for input, expected_type in zip(inputs, input_types, strict=True): if not expected_type.is_super(input.type): raise TypeError( f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}." @@ -2151,7 +2151,7 @@ def infer_shape(self, fgraph, node, ishapes): def c_support_code(self, **kwargs): # In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG, # which is not defined. It should be NPY_MIN_LONG instead in that case. - return dedent( + return npy_2_compat_header() + dedent( """\ #ifndef MIN_LONG #define MIN_LONG NPY_MIN_LONG @@ -2176,7 +2176,7 @@ def c_code(self, node, name, input_names, output_names, sub): if (!PyArray_CanCastSafely(i_type, NPY_INTP) && PyArray_SIZE({i_name}) > 0) {{ npy_int64 min_val, max_val; - PyObject* py_min_val = PyArray_Min({i_name}, NPY_MAXDIMS, + PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS, NULL); if (py_min_val == NULL) {{ {fail}; @@ -2186,7 +2186,7 @@ def c_code(self, node, name, input_names, output_names, sub): if (min_val == -1 && PyErr_Occurred()) {{ {fail}; }} - PyObject* py_max_val = PyArray_Max({i_name}, NPY_MAXDIMS, + PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS, NULL); if (py_max_val == NULL) {{ {fail}; @@ -2245,7 +2245,7 @@ def c_code(self, node, name, input_names, output_names, sub): """ def c_code_cache_version(self): - return (0, 1, 2) + return (0, 1, 2, 3) advanced_subtensor1 = AdvancedSubtensor1() @@ -2525,6 +2525,9 @@ def c_code(self, node, name, input_names, output_names, sub): numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] if bool(numpy_ver < [1, 8]): raise NotImplementedError + if bool(numpy_ver >= [2, 0]): + raise NotImplementedError + x, y, idx = input_names out = output_names[0] copy_of_x = self.copy_of_x(x) @@ -2713,7 +2716,7 @@ def is_bool_index(idx): indices = node.inputs[1:] index_shapes = [] - for idx, ishape in zip(indices, ishapes[1:]): + for idx, ishape in zip(indices, ishapes[1:], strict=True): # Mixed bool indexes are converted to nonzero entries shape0_op = Shape_i(0) if is_bool_index(idx): @@ -2816,7 +2819,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): x_is_batched = x.type.ndim < batch_x.type.ndim idxs_are_batched = any( batch_idx.type.ndim > idx.type.ndim - for batch_idx, idx in zip(batch_idxs, idxs) + for batch_idx, idx in zip(batch_idxs, idxs, strict=True) if isinstance(batch_idx, TensorVariable) ) @@ -2937,6 +2940,31 @@ def grad(self, inpt, output_gradients): gy = _sum_grad_over_bcasted_dims(y, gy) return [gx, gy] + [DisconnectedType()() for _ in idxs] + @staticmethod + def non_contiguous_adv_indexing(node: Apply) -> bool: + """ + Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing). + + This function checks if the advanced indexing is non-contiguous, + in which case the advanced index dimensions are placed on the left of the + output array, regardless of their opriginal position. + + See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + + + Parameters + ---------- + node : Apply + The node of the AdvancedSubtensor operation. + + Returns + ------- + bool + True if the advanced indexing is non-contiguous, False otherwise. + """ + _, _, *idxs = node.inputs + return _non_contiguous_adv_indexing(idxs) + advanced_inc_subtensor = AdvancedIncSubtensor() advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True) @@ -3000,17 +3028,17 @@ def _get_vector_length_Subtensor(op, var): start = ( None if indices[0].start is None - else get_underlying_scalar_constant_value(indices[0].start) + else get_scalar_constant_value(indices[0].start) ) stop = ( None if indices[0].stop is None - else get_underlying_scalar_constant_value(indices[0].stop) + else get_scalar_constant_value(indices[0].stop) ) step = ( None if indices[0].step is None - else get_underlying_scalar_constant_value(indices[0].step) + else get_scalar_constant_value(indices[0].step) ) if start == stop: diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 5fdaba8fd8..0dc0b5cce2 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -101,7 +101,7 @@ def __init__( if str(dtype) == "floatX": self.dtype = config.floatX else: - if np.obj2sctype(dtype) is None: + if np.dtype(dtype).type is None: raise TypeError(f"Invalid dtype: {dtype}") self.dtype = np.dtype(dtype).name @@ -177,7 +177,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: else: if allow_downcast: # Convert to self.dtype, regardless of the type of data - data = np.asarray(data, dtype=self.dtype) + data = np.asarray(data).astype(self.dtype) # TODO: consider to pad shape with ones to make it consistent # with self.broadcastable... like vector->row type thing else: @@ -248,9 +248,10 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: " PyTensor C code does not support that.", ) + # strict=False because we are in a hot loop if not all( ds == ts if ts is not None else True - for ds, ts in zip(data.shape, self.shape) + for ds, ts in zip(data.shape, self.shape, strict=False) ): raise TypeError( f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})" @@ -319,13 +320,17 @@ def in_same_class(self, otype): return False def is_super(self, otype): + # strict=False because we are in a hot loop if ( isinstance(otype, type(self)) and otype.dtype == self.dtype and otype.ndim == self.ndim # `otype` is allowed to be as or more shape-specific than `self`, # but not less - and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape)) + and all( + sb == ob or sb is None + for sb, ob in zip(self.shape, otype.shape, strict=False) + ) ): return True @@ -784,14 +789,16 @@ def tensor( **kwargs, ) -> "TensorVariable": if name is not None: - # Help catching errors with the new tensor API - # Many single letter strings are valid sctypes - if str(name) == "floatX" or (len(str(name)) > 1 and np.obj2sctype(name)): - np.obj2sctype(name) - raise ValueError( - f"The first and only positional argument of tensor is now `name`. Got {name}.\n" - "This name looks like a dtype, which you should pass as a keyword argument only." - ) + try: + # Help catching errors with the new tensor API + # Many single letter strings are valid sctypes + if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): + raise ValueError( + f"The first and only positional argument of tensor is now `name`. Got {name}.\n" + "This name looks like a dtype, which you should pass as a keyword argument only." + ) + except TypeError: + pass if dtype is None: dtype = config.floatX diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 8f4d0738f8..52cb8c25f6 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -1,9 +1,17 @@ import re +import warnings from collections.abc import Sequence from typing import cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore + + +try: + from numpy.lib.array_utils import normalize_axis_tuple +except ModuleNotFoundError as e: + # numpy < 2.0 + warnings.warn(f"Importing from numpy version < 2.0.0 location: {e}") + from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.graph import FunctionGraph, Variable @@ -99,7 +107,7 @@ def shape_of_variables( numeric_input_dims = [dim for inp in fgraph.inputs for dim in input_shapes[inp]] numeric_output_dims = compute_shapes(*numeric_input_dims) - sym_to_num_dict = dict(zip(output_dims, numeric_output_dims)) + sym_to_num_dict = dict(zip(output_dims, numeric_output_dims, strict=True)) l = {} for var in shape_feature.shape_of: @@ -236,8 +244,8 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: if axis is not None: try: axis = normalize_axis_tuple(axis, ndim=max(1, ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=ndim) + except np.exceptions.AxisError: + raise np.exceptions.AxisError(axis, ndim=ndim) # TODO: If axis tuple is equivalent to None, return None for more canonicalization? return cast(tuple, axis) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index a35404cdd5..ac89283bb6 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -11,7 +11,10 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Constant, OptionalApplyType, Variable from pytensor.graph.utils import MetaType -from pytensor.scalar import ComplexError, IntegerDivisionError +from pytensor.scalar import ( + ComplexError, + IntegerDivisionError, +) from pytensor.tensor import _get_vector_length from pytensor.tensor.exceptions import AdvancedIndexingError from pytensor.tensor.type import TensorType @@ -1042,17 +1045,9 @@ def no_nan(self): def get_unique_constant_value(x: TensorVariable) -> Number | None: """Return the unique value of a tensor, if there is one""" - if isinstance(x, Constant): - data = x.data - - if isinstance(data, np.ndarray) and data.size > 0: - if data.size == 1: - return data.squeeze() - - flat_data = data.ravel() - if (flat_data == flat_data[0]).all(): - return flat_data[0] - + warnings.warn("get_unique_constant_value is deprecated.", FutureWarning) + if isinstance(x, TensorConstant): + return x.unique_value return None @@ -1063,7 +1058,9 @@ def __init__(self, type: _TensorTypeType, data, name=None): data_shape = np.shape(data) if len(data_shape) != type.ndim or any( - ds != ts for ds, ts in zip(np.shape(data), type.shape) if ts is not None + ds != ts + for ds, ts in zip(np.shape(data), type.shape, strict=True) + if ts is not None ): raise ValueError( f"Shape of data ({data_shape}) does not match shape of type ({type.shape})" @@ -1079,6 +1076,30 @@ def __init__(self, type: _TensorTypeType, data, name=None): def signature(self): return TensorConstantSignature((self.type, self.data)) + @property + def unique_value(self) -> Number | None: + """Return the unique value of a tensor, if there is one""" + try: + return self._unique_value + except AttributeError: + data = self.data + unique_value = None + if data.size > 0: + if data.size == 1: + unique_value = data.squeeze() + else: + flat_data = data.ravel() + if (flat_data == flat_data[0]).all(): + unique_value = flat_data[0] + + if unique_value is not None: + # Don't allow the unique value to be changed + unique_value.setflags(write=False) + + self._unique_value = unique_value + + return self._unique_value + def equals(self, other): # Override Constant.equals to allow to compare with # numpy.ndarray, and python type. diff --git a/scripts/generate_gallery.py b/scripts/generate_gallery.py new file mode 100644 index 0000000000..5cd78d8494 --- /dev/null +++ b/scripts/generate_gallery.py @@ -0,0 +1,185 @@ +""" +Sphinx plugin to run generate a gallery for notebooks + +Modified from the pymc project, which modified the seaborn project, which modified the mpld3 project. +""" + +import base64 +import json +import os +import shutil +from pathlib import Path + +import matplotlib + + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import sphinx +from matplotlib import image + + +logger = sphinx.util.logging.getLogger(__name__) + +DOC_SRC = Path(__file__).resolve().parent.parent +DEFAULT_IMG_LOC = DOC_SRC / "doc" / "images" / "PyTensor_logo.png" + +external_nbs = {} + +HEAD = """ +Example Gallery +=============== + +.. toctree:: + :hidden: + +""" + +SECTION_TEMPLATE = """ +.. _{section_id}: + +{section_title} +{underlines} + +.. grid:: 1 2 3 3 + :gutter: 4 + +""" + +ITEM_TEMPLATE = """ + .. grid-item-card:: :doc:`{doc_name}` + :img-top: {image} + :link: {doc_reference} + :link-type: {link_type} + :shadow: none +""" + +folder_title_map = { + "introduction": "Introduction", + "rewrites": "Graph Rewriting", + "scan": "Looping in Pytensor", +} + + +def create_thumbnail(infile, width=275, height=275, cx=0.5, cy=0.5, border=4): + """Overwrites `infile` with a new file of the given size""" + im = image.imread(infile) + rows, cols = im.shape[:2] + size = min(rows, cols) + if size == cols: + xslice = slice(0, size) + ymin = min(max(0, int(cx * rows - size // 2)), rows - size) + yslice = slice(ymin, ymin + size) + else: + yslice = slice(0, size) + xmin = min(max(0, int(cx * cols - size // 2)), cols - size) + xslice = slice(xmin, xmin + size) + thumb = im[yslice, xslice] + thumb[:border, :, :3] = thumb[-border:, :, :3] = 0 + thumb[:, :border, :3] = thumb[:, -border:, :3] = 0 + + dpi = 100 + fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi) + + ax = fig.add_axes([0, 0, 1, 1], aspect="auto", frameon=False, xticks=[], yticks=[]) + ax.imshow(thumb, aspect="auto", resample=True, interpolation="bilinear") + fig.savefig(infile, dpi=dpi) + plt.close(fig) + return fig + + +class NotebookGenerator: + """Tools for generating an example page from a file""" + + def __init__(self, filename, root_dir, folder): + self.folder = folder + + self.basename = Path(filename).name + self.stripped_name = Path(filename).stem + self.image_dir = Path(root_dir) / "doc" / "_thumbnails" / folder + self.png_path = self.image_dir / f"{self.stripped_name}.png" + + with filename.open(encoding="utf-8") as fid: + self.json_source = json.load(fid) + self.default_image_loc = DEFAULT_IMG_LOC + + def extract_preview_pic(self): + """By default, just uses the last image in the notebook.""" + pic = None + for cell in self.json_source["cells"]: + for output in cell.get("outputs", []): + if "image/png" in output.get("data", []): + pic = output["data"]["image/png"] + if pic is not None: + return base64.b64decode(pic) + return None + + def gen_previews(self): + preview = self.extract_preview_pic() + if preview is not None: + with self.png_path.open("wb") as buff: + buff.write(preview) + else: + logger.warning( + f"Didn't find any pictures in {self.basename}", + type="thumbnail_extractor", + ) + shutil.copy(self.default_image_loc, self.png_path) + create_thumbnail(self.png_path) + + +def main(app): + logger.info("Starting thumbnail extractor.") + + working_dir = Path.cwd() + os.chdir(app.builder.srcdir) + + file = [HEAD] + + for folder, title in folder_title_map.items(): + file.append( + SECTION_TEMPLATE.format( + section_title=title, section_id=folder, underlines="-" * len(title) + ) + ) + + thumbnail_dir = Path("_thumbnails") / folder + if not thumbnail_dir.exists(): + Path.mkdir(thumbnail_dir, parents=True) + + if folder in external_nbs.keys(): + file += [ + ITEM_TEMPLATE.format( + doc_name=descr["doc_name"], + image=descr["image"], + doc_reference=descr["doc_reference"], + link_type=descr["link_type"], + ) + for descr in external_nbs[folder] + ] + + nb_paths = sorted(Path("gallery", folder).glob("*.ipynb")) + + for nb_path in nb_paths: + nbg = NotebookGenerator( + filename=nb_path, root_dir=Path(".."), folder=folder + ) + nbg.gen_previews() + + file.append( + ITEM_TEMPLATE.format( + doc_name=Path(folder) / nbg.stripped_name, + image="/" + str(nbg.png_path), + doc_reference=Path(folder) / nbg.stripped_name, + link_type="doc", + ) + ) + + with Path("gallery", "gallery.rst").open("w", encoding="utf-8") as f: + f.write("\n".join(file)) + + os.chdir(working_dir) + + +def setup(app): + app.connect("builder-inited", main) diff --git a/scripts/slowest_tests/extract-slow-tests.py b/scripts/slowest_tests/extract-slow-tests.py new file mode 100644 index 0000000000..3a06e4a68b --- /dev/null +++ b/scripts/slowest_tests/extract-slow-tests.py @@ -0,0 +1,80 @@ +"""This script parses the GitHub action log for test times. + +Taken from https://github.com/pymc-labs/pymc-marketing/tree/main/scripts/slowest_tests/extract-slow-tests.py + +""" + +import re +import sys +from pathlib import Path + + +start_pattern = re.compile(r"==== slow") +separator_pattern = re.compile(r"====") +time_pattern = re.compile(r"(\d+\.\d+)s ") + + +def extract_lines(lines: list[str]) -> list[str]: + times = [] + + in_section = False + for line in lines: + detect_start = start_pattern.search(line) + detect_end = separator_pattern.search(line) + + if detect_start: + in_section = True + + if in_section: + times.append(line) + + if not detect_start and in_section and detect_end: + break + + return times + + +def trim_up_to_match(pattern, string: str) -> str: + match = pattern.search(string) + if not match: + return "" + + return string[match.start() :] + + +def trim(pattern, lines: list[str]) -> list[str]: + return [trim_up_to_match(pattern, line) for line in lines] + + +def strip_ansi(text: str) -> str: + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + +def format_times(times: list[str]) -> list[str]: + return ( + trim(separator_pattern, times[:1]) + + trim(time_pattern, times[1:-1]) + + [strip_ansi(line) for line in trim(separator_pattern, times[-1:])] + ) + + +def read_lines_from_stdin(): + return sys.stdin.read().splitlines() + + +def read_from_file(file: Path): + """For testing purposes.""" + return file.read_text().splitlines() + + +def main(read_lines): + lines = read_lines() + times = extract_lines(lines) + parsed_times = format_times(times) + print("\n".join(parsed_times)) + + +if __name__ == "__main__": + read_lines = read_lines_from_stdin + main(read_lines) diff --git a/scripts/slowest_tests/update-slowest-times-issue.sh b/scripts/slowest_tests/update-slowest-times-issue.sh new file mode 100644 index 0000000000..b1c0c15789 --- /dev/null +++ b/scripts/slowest_tests/update-slowest-times-issue.sh @@ -0,0 +1,113 @@ +#!/bin/zsh + +DRY_RUN=false + +owner=pymc-devs +repo=pytensor +issue_number=1124 +title="Speed up test times :rocket:" +workflow=Tests +latest_id=$(gh run list --workflow $workflow --status success --limit 1 --json databaseId --jq '.[0].databaseId') +jobs=$(gh api /repos/$owner/$repo/actions/runs/$latest_id/jobs --jq '.jobs | map({name: .name, run_id: .run_id, id: .id, started_at: .started_at, completed_at: .completed_at})') + +# Skip 3.10, float32, and Benchmark tests +function skip_job() { + name=$1 + if [[ $name == *"py3.10"* ]]; then + return 0 + fi + + if [[ $name == *"float32 1"* ]]; then + return 0 + fi + + if [[ $name == *"Benchmark"* ]]; then + return 0 + fi + + return 1 +} + +# Remove common prefix from the name +function remove_prefix() { + name=$1 + echo $name | sed -e 's/^ubuntu-latest test py3.12 : fast-compile 0 : float32 0 : //' +} + +function human_readable_time() { + started_at=$1 + completed_at=$2 + + start_seconds=$(date -d "$started_at" +%s) + end_seconds=$(date -d "$completed_at" +%s) + + seconds=$(($end_seconds - $start_seconds)) + + if [ $seconds -lt 60 ]; then + echo "$seconds seconds" + else + echo "$(date -u -d @$seconds +'%-M minutes %-S seconds')" + fi +} + +all_times="" +echo "$jobs" | jq -c '.[]' | while read -r job; do + id=$(echo $job | jq -r '.id') + name=$(echo $job | jq -r '.name') + run_id=$(echo $job | jq -r '.run_id') + started_at=$(echo $job | jq -r '.started_at') + completed_at=$(echo $job | jq -r '.completed_at') + + if skip_job $name; then + echo "Skipping $name" + continue + fi + + echo "Processing job: $name (ID: $id, Run ID: $run_id)" + times=$(gh run view --job $id --log | python extract-slow-tests.py) + + if [ -z "$times" ]; then + # Some of the jobs are non-test jobs, so we skip them + echo "No tests found for '$name', skipping" + continue + fi + + echo $times + + human_readable=$(human_readable_time $started_at $completed_at) + name=$(remove_prefix $name) + + top="
($human_readable) $name\n\n\n\`\`\`" + bottom="\`\`\`\n\n
" + + formatted_times="$top\n$times\n$bottom" + + if [ -n "$all_times" ]; then + all_times="$all_times\n$formatted_times" + else + all_times="$formatted_times" + fi +done + +run_date=$(date +"%Y-%m-%d") +body=$(cat << EOF +If you are motivated to help speed up some tests, we would appreciate it! + +Here are some of the slowest test times: + +$all_times + +You can find more information on how to contribute [here](https://pytensor.readthedocs.io/en/latest/dev_start_guide.html) + +Automatically generated by [GitHub Action](https://github.com/pymc-devs/pytensor/blob/main/.github/workflows/slow-tests-issue.yml) +Latest run date: $run_date +EOF +) + +if [ "$DRY_RUN" = true ]; then + echo "Dry run, not updating issue" + echo $body + exit +fi +echo $body | gh issue edit $issue_number --body-file - --title "$title" +echo "Updated issue $issue_number with all times" diff --git a/tests/compile/function/test_function.py b/tests/compile/function/test_function.py index f835953b19..8bf54a50e4 100644 --- a/tests/compile/function/test_function.py +++ b/tests/compile/function/test_function.py @@ -166,12 +166,12 @@ def test_in_allow_downcast_int(self): # Value too big for a, silently ignored assert np.array_equal(f([2**20], np.ones(1, dtype="int8"), 1), [2]) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError + with pytest.raises(OverflowError): f([3], [312], 1) - # Value too big for c, raises TypeError - with pytest.raises(TypeError): + # Value too big for c, raises OverflowError + with pytest.raises(OverflowError): f([3], [6], 806) def test_in_allow_downcast_floatX(self): diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index 0a9bda9846..b2141b35a2 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -238,11 +238,11 @@ def test_param_allow_downcast_int(self): assert np.all(f([2**20], np.ones(1, dtype="int8"), 1) == 2) # Value too big for b, raises TypeError - with pytest.raises(TypeError): + with pytest.raises(OverflowError): f([3], [312], 1) # Value too big for c, raises TypeError - with pytest.raises(TypeError): + with pytest.raises(OverflowError): f([3], [6], 806) def test_param_allow_downcast_floatX(self): @@ -328,15 +328,17 @@ def test_allow_input_downcast_int(self): g([3], np.array([6], dtype="int16"), 0) # Value too big for b, raises TypeError - with pytest.raises(TypeError): + with pytest.raises(OverflowError): g([3], [312], 0) h = pfunc([a, b, c], (a + b + c)) # Default: allow_input_downcast=None # Everything here should behave like with False assert np.all(h([3], [6], 0) == 9) + with pytest.raises(TypeError): h([3], np.array([6], dtype="int16"), 0) - with pytest.raises(TypeError): + + with pytest.raises(OverflowError): h([3], [312], 0) def test_allow_downcast_floatX(self): diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 4b6537d328..0990dbeca0 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -388,7 +388,7 @@ def test_copy_share_memory(self): # Assert storages of SharedVariable without updates are shared for (input, _1, _2), here, there in zip( - ori.indices, ori.input_storage, cpy.input_storage + ori.indices, ori.input_storage, cpy.input_storage, strict=True ): assert here.data is there.data @@ -484,7 +484,7 @@ def test_swap_SharedVariable_with_given(self): swap={train_x: test_x, train_y: test_y}, delete_updates=True ) - for in1, in2 in zip(test_def.maker.inputs, test_cpy.maker.inputs): + for in1, in2 in zip(test_def.maker.inputs, test_cpy.maker.inputs, strict=True): assert in1.value is in2.value def test_copy_delete_updates(self): @@ -950,7 +950,7 @@ def test_deepcopy(self): # print(f"{f.defaults = }") # print(f"{g.defaults = }") for (f_req, f_feed, f_val), (g_req, g_feed, g_val) in zip( - f.defaults, g.defaults + f.defaults, g.defaults, strict=True ): assert f_req == g_req and f_feed == g_feed and f_val == g_val @@ -1105,14 +1105,10 @@ def test_optimizations_preserved(self): ((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s), ) old_default_mode = config.mode - old_default_opt = config.optimizer - old_default_link = config.linker try: try: str_f = pickle.dumps(f, protocol=-1) - config.mode = "Mode" - config.linker = "py" - config.optimizer = "None" + config.mode = "NUMBA" g = pickle.loads(str_f) # print g.maker.mode # print compile.mode.default_mode @@ -1121,8 +1117,6 @@ def test_optimizations_preserved(self): g = "ok" finally: config.mode = old_default_mode - config.optimizer = old_default_opt - config.linker = old_default_link if g == "ok": return @@ -1132,7 +1126,7 @@ def test_optimizations_preserved(self): tf = f.maker.fgraph.toposort() tg = f.maker.fgraph.toposort() assert len(tf) == len(tg) - for nf, ng in zip(tf, tg): + for nf, ng in zip(tf, tg, strict=True): assert nf.op == ng.op assert len(nf.inputs) == len(ng.inputs) assert len(nf.outputs) == len(ng.outputs) diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index d99b13edfc..8fc2a529df 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -722,5 +722,5 @@ def test_debugprint(): โ””โ”€ *2- [id I] """ - for truth, out in zip(exp_res.split("\n"), lines): + for truth, out in zip(exp_res.split("\n"), lines, strict=True): assert truth.strip() == out.strip() diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index 95e52d6b53..fae76fab0d 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -146,7 +146,7 @@ def dontuse_perform(self, node, inp, out_): raise ValueError(self.behaviour) def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inp, out, sub): (a,) = inp @@ -165,8 +165,8 @@ def c_code(self, node, name, inp, out, sub): prep_vars = f""" //the output array has size M x N npy_intp M = PyArray_DIMS({a})[0]; - npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_DESCR({a})->elsize; - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; + npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_ITEMSIZE({a}); + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); npy_double * Da = (npy_double*)PyArray_BYTES({a}); npy_double * Dz = (npy_double*)PyArray_BYTES({z}); diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index c965087ea2..291eac0782 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -13,6 +13,7 @@ from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.link.basic import LocalLinker +from pytensor.link.jax import JAXLinker from pytensor.tensor.math import dot, tanh from pytensor.tensor.type import matrix, vector @@ -142,3 +143,15 @@ class MyLinker(LocalLinker): test_mode = Mode(linker=MyLinker()) with pytest.raises(Exception): get_target_language(test_mode) + + +def test_predefined_modes_respected(): + default_mode = get_default_mode() + assert not isinstance(default_mode.linker, JAXLinker) + + with config.change_flags(mode="JAX"): + jax_mode = get_default_mode() + assert isinstance(jax_mode.linker, JAXLinker) + + default_mode_again = get_default_mode() + assert not isinstance(default_mode_again.linker, JAXLinker) diff --git a/tests/d3viz/test_d3viz.py b/tests/d3viz/test_d3viz.py index b6b6479a1b..7e4b0426a0 100644 --- a/tests/d3viz/test_d3viz.py +++ b/tests/d3viz/test_d3viz.py @@ -9,12 +9,14 @@ from pytensor import compile from pytensor.compile.function import function from pytensor.configdefaults import config -from pytensor.printing import pydot_imported, pydot_imported_msg +from pytensor.printing import _try_pydot_import from tests.d3viz import models -if not pydot_imported: - pytest.skip("pydot not available: " + pydot_imported_msg, allow_module_level=True) +try: + _try_pydot_import() +except Exception as e: + pytest.skip(f"pydot not available: {e!s}", allow_module_level=True) class TestD3Viz: diff --git a/tests/d3viz/test_formatting.py b/tests/d3viz/test_formatting.py index f0cbd3fdd7..7d1149be0e 100644 --- a/tests/d3viz/test_formatting.py +++ b/tests/d3viz/test_formatting.py @@ -3,11 +3,13 @@ from pytensor import config, function from pytensor.d3viz.formatting import PyDotFormatter -from pytensor.printing import pydot_imported, pydot_imported_msg +from pytensor.printing import _try_pydot_import -if not pydot_imported: - pytest.skip("pydot not available: " + pydot_imported_msg, allow_module_level=True) +try: + _try_pydot_import() +except Exception as e: + pytest.skip(f"pydot not available: {e!s}", allow_module_level=True) from tests.d3viz import models @@ -19,7 +21,7 @@ def setup_method(self): def node_counts(self, graph): node_types = [node.get_attributes()["node_type"] for node in graph.get_nodes()] a, b = np.unique(node_types, return_counts=True) - nc = dict(zip(a, b)) + nc = dict(zip(a, b, strict=True)) return nc @pytest.mark.parametrize("mode", ["FAST_RUN", "FAST_COMPILE"]) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 08c352ab71..84ffb365b5 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -367,6 +367,10 @@ def test_eval_kwargs(self): self.w.eval({self.z: 3, self.x: 2.5}) assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0 + # regression test for https://github.com/pymc-devs/pytensor/issues/1084 + q = self.x + 1 + assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0 + @pytest.mark.filterwarnings("error") def test_eval_unashable_kwargs(self): y_repl = constant(2.0, dtype="floatX") diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index f2550d348e..e82a59e790 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -32,13 +32,22 @@ def test_pickle(self): s = pickle.dumps(func) new_func = pickle.loads(s) - assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs)) - assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs)) + assert all( + type(a) is type(b) + for a, b in zip(func.inputs, new_func.inputs, strict=True) + ) + assert all( + type(a) is type(b) + for a, b in zip(func.outputs, new_func.outputs, strict=True) + ) assert all( type(a.op) is type(b.op) - for a, b in zip(func.apply_nodes, new_func.apply_nodes) + for a, b in zip(func.apply_nodes, new_func.apply_nodes, strict=True) + ) + assert all( + a.type == b.type + for a, b in zip(func.variables, new_func.variables, strict=True) ) - assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables)) def test_validate_inputs(self): var1 = op1() diff --git a/tests/graph/utils.py b/tests/graph/utils.py index d48e0b2a35..86b52a7ed1 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -137,7 +137,9 @@ def __init__(self, inner_inputs, inner_outputs): if not isinstance(v, Constant) ] outputs = clone_replace(inner_outputs, replace=input_replacements) - _, inputs = zip(*input_replacements) if input_replacements else (None, []) + _, inputs = ( + zip(*input_replacements, strict=True) if input_replacements else (None, []) + ) self.fgraph = FunctionGraph(inputs, outputs, clone=False) def make_node(self, *inputs): diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 5e783984e0..d0f748f3e7 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -76,13 +76,13 @@ def compare_jax_and_py( if isinstance(jax_res, list): assert all(isinstance(res, jax.Array) for res in jax_res) else: - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + assert isinstance(jax_res, jax.Array) pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: - for j, p in zip(jax_res, py_res): + for j, p in zip(jax_res, py_res, strict=True): assert_fn(j, p) else: assert_fn(jax_res, py_res) diff --git a/tests/link/jax/test_einsum.py b/tests/link/jax/test_einsum.py index 5761563066..4f1d25acfe 100644 --- a/tests/link/jax/test_einsum.py +++ b/tests/link/jax/test_einsum.py @@ -15,10 +15,12 @@ def test_jax_einsum(): y = np.random.rand(5, 2) z = np.random.rand(2, 4) - shapes = ((3, 5), (5, 2), (2, 4)) - x_pt, y_pt, z_pt = ( - pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes) - ) + shapes = { + "x": (3, 5), + "y": (5, 2), + "z": (2, 4), + } + x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) out = pt.einsum(subscripts, x_pt, y_pt, z_pt) fg = FunctionGraph([x_pt, y_pt, z_pt], [out]) compare_jax_and_py(fg, [x, y, z]) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index 1427413379..0c8fb92810 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -6,6 +6,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.tensor import extra_ops as pt_extra_ops +from pytensor.tensor.sort import argsort from pytensor.tensor.type import matrix, tensor from tests.link.jax.test_basic import compare_jax_and_py @@ -55,6 +56,13 @@ def test_extra_ops(): fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False ) + v = ptb.as_tensor_variable(6.0) + sorted_idx = argsort(a.ravel()) + + out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [a_test]) + @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") def test_bartlett_dynamic_shape(): diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index f9ae5d00c1..c5f2d29928 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -61,7 +61,11 @@ def test_random_updates(rng_ctor): # Check that original rng variable content was not overwritten when calling jax_typify assert all( a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) - for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__()) + for a, b in zip( + rng.get_value().bit_generator.state, + original_value.bit_generator.state, + strict=True, + ) ) @@ -92,7 +96,9 @@ def test_replaced_shared_rng_storage_order(noise_first): ), "Test may need to be tweaked" # Confirm that input_storage type and fgraph input order are aligned - for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs): + for storage, fgrapn_input in zip( + f.input_storage, f.maker.fgraph.inputs, strict=True + ): assert storage.type == fgrapn_input.type assert mu.get_value() == 1 diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 0469301791..475062e86c 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -21,6 +21,7 @@ gammainccinv, gammaincinv, iv, + kve, log, log1mexp, polygamma, @@ -157,6 +158,7 @@ def test_erfinv(): (erfcx, (0.7,)), (erfcinv, (0.7,)), (iv, (0.3, 0.7)), + (kve, (-2.5, 2.0)), ], ) @pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability") diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index cfbc61eaca..1b0fa8fd52 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -228,9 +228,11 @@ def compare_numba_and_py( fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]], inputs: Sequence["TensorLike"], assert_fn: Callable | None = None, + *, numba_mode=numba_mode, py_mode=py_mode, updates=None, + inplace: bool = False, eval_obj_mode: bool = True, ) -> tuple[Callable, Any]: """Function to compare python graph output and Numba compiled output for testing equality @@ -242,7 +244,7 @@ def compare_numba_and_py( Parameters ---------- fgraph - `FunctionGraph` or inputs to compare. + `FunctionGraph` or tuple(inputs, outputs) to compare. inputs Numeric inputs to be passed to the compiled graphs. assert_fn @@ -265,18 +267,25 @@ def assert_fn(x, y): x, y ) - if isinstance(fgraph, tuple): - fn_inputs, fn_outputs = fgraph - else: + if isinstance(fgraph, FunctionGraph): fn_inputs = fgraph.inputs fn_outputs = fgraph.outputs + else: + fn_inputs, fn_outputs = fgraph fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)] pytensor_py_fn = function( fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates ) - py_res = pytensor_py_fn(*inputs) + + test_inputs = (inp.copy() for inp in inputs) if inplace else inputs + py_res = pytensor_py_fn(*test_inputs) + + # Get some coverage (and catch errors in python mode before unreadable numba ones) + if eval_obj_mode: + test_inputs = (inp.copy() for inp in inputs) if inplace else inputs + eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode) pytensor_numba_fn = function( fn_inputs, @@ -285,14 +294,12 @@ def assert_fn(x, y): accept_inplace=True, updates=updates, ) - numba_res = pytensor_numba_fn(*inputs) - # Get some coverage - if eval_obj_mode: - eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) + test_inputs = (inp.copy() for inp in inputs) if inplace else inputs + numba_res = pytensor_numba_fn(*test_inputs) if len(fn_outputs) > 1: - for j, p in zip(numba_res, py_res): + for j, p in zip(numba_res, py_res, strict=True): assert_fn(j, p) else: assert_fn(numba_res[0], py_res[0]) @@ -831,7 +838,13 @@ def test_config_options_fastmath(): pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__)) numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] - assert numba_mul_fn.targetoptions["fastmath"] is True + assert numba_mul_fn.targetoptions["fastmath"] == { + "afn", + "arcp", + "contract", + "nsz", + "reassoc", + } def test_config_options_cached(): @@ -882,3 +895,20 @@ def test_cache_warning_suppressed(): x_test = np.random.uniform(size=5) np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) + + +@pytest.mark.parametrize("mode", ("default", "trust_input", "direct")) +def test_function_overhead(mode, benchmark): + x = pt.vector("x") + out = pt.exp(x) + + fn = function([x], out, mode="NUMBA") + if mode == "trust_input": + fn.trust_input = True + elif mode == "direct": + fn = fn.vm.jit_fn + + test_x = np.zeros(1000) + assert np.sum(fn(test_x)) == 1000 + + benchmark(fn, test_x) diff --git a/tests/link/numba/test_blockwise.py b/tests/link/numba/test_blockwise.py new file mode 100644 index 0000000000..ced4185e14 --- /dev/null +++ b/tests/link/numba/test_blockwise.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from pytensor import function +from pytensor.tensor import tensor +from pytensor.tensor.basic import ARange +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.nlinalg import SVD, Det +from pytensor.tensor.slinalg import Cholesky, cholesky +from tests.link.numba.test_basic import compare_numba_and_py, numba_mode + + +# Fails if object mode warning is issued when not expected +pytestmark = pytest.mark.filterwarnings("error") + + +@pytest.mark.parametrize("shape_opt", [True, False], ids=str) +@pytest.mark.parametrize("core_op", [Det(), Cholesky(), SVD(compute_uv=True)], ids=str) +def test_blockwise(core_op, shape_opt): + x = tensor(shape=(5, None, None)) + outs = Blockwise(core_op=core_op)(x, return_list=True) + + mode = ( + numba_mode.including("ShapeOpt") + if shape_opt + else numba_mode.excluding("ShapeOpt") + ) + x_test = np.eye(3) * np.arange(1, 6)[:, None, None] + compare_numba_and_py( + ([x], outs), + [x_test], + numba_mode=mode, + eval_obj_mode=False, + ) + + +def test_non_square_blockwise(): + """Test that Op that cannot always be blockwised at runtime fails gracefully.""" + x = tensor(shape=(3,), dtype="int64") + out = Blockwise(core_op=ARange(dtype="int64"), signature="(),(),()->(a)")(0, x, 1) + + with pytest.warns(UserWarning, match="Numba will use object mode"): + fn = function([x], out, mode="NUMBA") + + np.testing.assert_allclose(fn([5, 5, 5]), np.broadcast_to(np.arange(5), (3, 5))) + + with pytest.raises(ValueError): + fn([3, 4, 5]) + + +def test_blockwise_benchmark(benchmark): + x = tensor(shape=(5, 3, 3)) + out = cholesky(x) + assert isinstance(out.owner.op, Blockwise) + + fn = function([x], out, mode="NUMBA") + x_test = np.eye(3) * np.arange(1, 6)[:, None, None] + fn(x_test) # JIT compile + benchmark(fn, x_test) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 4c13004409..862ea1a2e2 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -15,15 +15,16 @@ from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum +from pytensor.scalar import float64 +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from tests.link.numba.test_basic import ( compare_numba_and_py, scalar_my_multi_out, set_test_value, ) -from tests.tensor.test_elemwise import TestElemwise +from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester rng = np.random.default_rng(42849) @@ -249,24 +250,12 @@ def test_Dimshuffle_non_contiguous(): ( lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), + set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( @@ -313,6 +302,24 @@ def test_Dimshuffle_non_contiguous(): pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) ), ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + (), # Empty axes would normally be rewritten away, but we want to test it still works + set_test_value( + pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + None, + set_test_value( + pt.scalar(), np.array(99.0, dtype=config.floatX) + ), # Scalar input would normally be rewritten away, but we want to test it still works + ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Prod( axis=axis, dtype=dtype, acc_dtype=acc_dtype @@ -379,7 +386,7 @@ def test_CAReduce(careduce_fn, axis, v): g = careduce_fn(v, axis=axis) g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( + fn, _ = compare_numba_and_py( g_fg, [ i.tag.test_value @@ -387,6 +394,10 @@ def test_CAReduce(careduce_fn, axis, v): if not isinstance(i, SharedVariable | Constant) ], ) + # Confirm CAReduce is in the compiled function + fn.dprint() + [node] = fn.maker.fgraph.apply_nodes + assert isinstance(node.op, CAReduce) def test_scalar_Elemwise_Clip(): @@ -631,10 +642,10 @@ def test_logsumexp_benchmark(size, axis, benchmark): X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") # JIT compile first - _ = X_lse_fn(X_val) - res = benchmark(X_lse_fn, X_val) + res = X_lse_fn(X_val) exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) np.testing.assert_array_almost_equal(res, exp_res) + benchmark(X_lse_fn, X_val) def test_fused_elemwise_benchmark(benchmark): @@ -665,3 +676,33 @@ def test_elemwise_out_type(): x_val = np.broadcast_to(np.zeros((3,)), (6, 3)) assert func(x_val).shape == (18,) + + +@pytest.mark.parametrize( + "axis", + (0, 1, 2, (0, 1), (0, 2), (1, 2), None), + ids=lambda x: f"axis={x}", +) +@pytest.mark.parametrize( + "c_contiguous", + (True, False), + ids=lambda x: f"c_contiguous={x}", +) +def test_numba_careduce_benchmark(axis, c_contiguous, benchmark): + return careduce_benchmark_tester( + axis, c_contiguous, mode="NUMBA", benchmark=benchmark + ) + + +def test_scalar_loop(): + a = float64("a") + scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a]) + + x = pt.tensor("x", shape=(3,)) + elemwise_loop = Elemwise(scalar_loop)(3, x) + + with pytest.warns(UserWarning, match="object mode"): + compare_numba_and_py( + ([x], [elemwise_loop]), + (np.array([1, 2, 3], dtype="float64"),), + ) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 437956bdc0..655e507da6 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -9,6 +9,7 @@ from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import Composite +from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise from tests.link.numba.test_basic import compare_numba_and_py, set_test_value @@ -140,3 +141,21 @@ def test_reciprocal(v, dtype): if not isinstance(i, SharedVariable | Constant) ], ) + + +@pytest.mark.parametrize("composite", (False, True)) +def test_isnan(composite): + # Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath + x = tensor(shape=(2,), dtype="float64") + + if composite: + x_scalar = psb.float64() + scalar_out = ~psb.isnan(x_scalar) + out = Elemwise(Composite([x_scalar], [scalar_out]))(x) + else: + out = pt.isnan(x) + + compare_numba_and_py( + ([x], [out]), + [np.array([1, 0], dtype="float64")], + ) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 5db0f24222..5b9436688b 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -488,7 +488,7 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1): ref_fn = pytensor.function(list(test), outs, mode=get_mode("FAST_COMPILE")) ref_res = ref_fn(*test.values()) - for numba_r, ref_r in zip(numba_res, ref_res): + for numba_r, ref_r in zip(numba_res, ref_res, strict=True): np.testing.assert_array_almost_equal(numba_r, ref_r) benchmark(numba_fn, *test.values()) diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index ff335e30dc..d63445bf77 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -85,7 +85,11 @@ def test_AdvancedSubtensor1_out_of_bounds(): (np.array([True, False, False])), False, ), - (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True), + ( + pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + ([1, 2], [2, 3]), + False, + ), # Single multidimensional indexing (supported after specialization rewrites) ( as_tensor(np.arange(3 * 3).reshape((3, 3))), @@ -117,17 +121,23 @@ def test_AdvancedSubtensor1_out_of_bounds(): (slice(2, None), np.eye(3).astype(bool)), False, ), - # Multiple advanced indexing, only supported in obj mode ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (slice(None), [1, 2], [3, 4]), - True, + False, ), + ( + as_tensor(np.arange(3 * 5 * 7).reshape((3, 5, 7))), + ([1, 2], [3, 4], [5, 6]), + False, + ), + # Non-contiguous vector indexing, only supported in obj mode ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], slice(None), [3, 4]), True, ), + # >1d vector indexing, only supported in obj mode ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([[1, 2], [2, 1]], [0, 0]), @@ -135,7 +145,7 @@ def test_AdvancedSubtensor1_out_of_bounds(): ), ], ) -@pytest.mark.filterwarnings("error") +@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed def test_AdvancedSubtensor(x, indices, objmode_needed): """Test NumPy's advanced indexing in more than one dimension.""" x_pt = x.type() @@ -268,94 +278,173 @@ def test_AdvancedIncSubtensor1(x, y, indices): "x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode", [ ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(3 * 5).reshape(3, 5), - (slice(None, None, 2), [1, 2, 3]), + (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector index False, False, False, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - -99, - (slice(None, None, 2), [1, 2, 3], -1), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + np.array(-99), # Broadcasted value + ( + slice(None, None, 2), + [1, 2, 3], + -1, + ), # Mixed basic and broadcasted vector idx False, False, False, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - -99, # Broadcasted value - (slice(None, None, 2), [1, 2, 3]), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + np.array(-99), # Broadcasted value + (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector idx False, False, False, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(4 * 5).reshape(4, 5), - (0, [1, 2, 2, 3]), + (0, [1, 2, 2, 3]), # Broadcasted vector index True, False, True, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - [-99], # Broadcsasted value - (0, [1, 2, 2, 3]), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + np.array([-99]), # Broadcasted value + (0, [1, 2, 2, 3]), # Broadcasted vector index True, False, True, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(1 * 4 * 5).reshape(1, 4, 5), - (np.array([True, False, False])), + (np.array([True, False, False])), # Broadcasted boolean index False, False, False, ), ( - as_tensor(np.arange(3 * 3).reshape((3, 3))), + np.arange(3 * 3).reshape((3, 3)), -np.arange(3), - (np.eye(3).astype(bool)), + (np.eye(3).astype(bool)), # Boolean index + False, + False, + False, + ), + ( + np.arange(3 * 3 * 5).reshape((3, 3, 5)), + rng.poisson(size=(3, 2)), + ( + np.eye(3).astype(bool), + slice(-2, None), + ), # Boolean index, mixed with basic index + False, + False, + False, + ), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(2, 5)), + ([1, 2], [2, 3]), # 2 vector indices + False, + False, + False, + ), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(3, 2)), + (slice(None), [1, 2], [2, 3]), # 2 vector indices + False, + False, + False, + ), + ( + np.arange(3 * 4 * 6).reshape((3, 4, 6)), + rng.poisson(size=(2,)), + ([1, 2], [2, 3], [4, 5]), # 3 vector indices + False, + False, + False, + ), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + np.array(-99), # Broadcasted value + ([1, 2], [2, 3]), # 2 vector indices False, True, True, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - as_tensor(rng.poisson(size=(2, 5))), - ([1, 2], [2, 3]), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(2, 4)), + ([1, 2], slice(None), [3, 4]), # Non-contiguous vector indices False, True, True, ), ( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - as_tensor(rng.poisson(size=(2, 4))), - ([1, 2], slice(None), [3, 4]), + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(2, 2)), + ( + slice(1, None), + [1, 2], + [3, 4], + ), # Mixed double vector index and basic index False, True, True, ), - pytest.param( - as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - as_tensor(rng.poisson(size=(2, 5))), - ([1, 1], [2, 2]), + ( + np.arange(5), + rng.poisson(size=(2, 2)), + ([[1, 2], [2, 3]]), # matrix indices False, + False, # Gets converted to AdvancedIncSubtensor1 + True, # This is actually supported with the default `ignore_duplicates=False` + ), + ( + np.arange(3 * 5).reshape((3, 5)), + rng.poisson(size=(1, 2, 2)), + (slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index + False, + True, True, + ), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(2, 5)), + ([1, 1], [2, 2]), # Repeated indices True, + False, + False, ), ], ) -@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("inplace", (False, True)) +@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed def test_AdvancedIncSubtensor( - x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode + x, + y, + indices, + duplicate_indices, + set_requires_objmode, + inc_requires_objmode, + inplace, ): - out_pt = set_subtensor(x[indices], y) + # Need rewrite to support certain forms of advanced indexing without object mode + mode = numba_mode.including("specialize") + + x_pt = pt.as_tensor(x).type("x") + y_pt = pt.as_tensor(y).type("y") + + out_pt = set_subtensor(x_pt[indices], y_pt, inplace=inplace) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) with ( pytest.warns( @@ -365,11 +454,18 @@ def test_AdvancedIncSubtensor( if set_requires_objmode else contextlib.nullcontext() ): - compare_numba_and_py(out_fg, []) + fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode) + + if inplace: + # Test updates inplace + x_orig = x.copy() + fn(x, y + 1) + assert not np.all(x == x_orig) - out_pt = inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices) + out_pt = inc_subtensor( + x_pt[indices], y_pt, ignore_duplicates=not duplicate_indices, inplace=inplace + ) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) with ( pytest.warns( UserWarning, @@ -378,21 +474,9 @@ def test_AdvancedIncSubtensor( if inc_requires_objmode else contextlib.nullcontext() ): - compare_numba_and_py(out_fg, []) - - x_pt = x.type() - out_pt = set_subtensor(x_pt[indices], y) - # Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just - # hack it on here - out_pt.owner.op.inplace = True - assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - with ( - pytest.warns( - UserWarning, - match="Numba will use object mode to run AdvancedSetSubtensor's perform method", - ) - if set_requires_objmode - else contextlib.nullcontext() - ): - compare_numba_and_py(out_fg, [x.data]) + fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode) + if inplace: + # Test updates inplace + x_orig = x.copy() + fn(x, y) + assert not np.all(x == x_orig) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 93035f52f4..2ac8ee7c3b 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import pytensor.tensor as pt import pytensor.tensor.basic as ptb from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function @@ -17,11 +18,15 @@ from pytensor.ifelse import ifelse from pytensor.link.pytorch.linker import PytorchLinker from pytensor.raise_op import CheckAndRaise -from pytensor.tensor import alloc, arange, as_tensor, empty, eye +from pytensor.scalar import float64, int64 +from pytensor.scalar.loop import ScalarLoop +from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.type import matrices, matrix, scalar, vector torch = pytest.importorskip("torch") +torch_dispatch = pytest.importorskip("pytensor.link.pytorch.dispatch.basic") optimizer = RewriteDatabaseQuery( @@ -52,8 +57,6 @@ def compare_pytorch_and_py( assert_fn: func, opt Assert function used to check for equality between python and pytorch. If not provided uses np.testing.assert_allclose - must_be_device_array: Bool - Checks if torch.device.type is cuda """ @@ -65,20 +68,19 @@ def compare_pytorch_and_py( pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode) pytorch_res = pytensor_torch_fn(*test_inputs) - if must_be_device_array: - if isinstance(pytorch_res, list): - assert all(isinstance(res, torch.Tensor) for res in pytorch_res) - else: - assert pytorch_res.device.type == "cuda" + if isinstance(pytorch_res, list): + assert all(isinstance(res, np.ndarray) for res in pytorch_res) + else: + assert isinstance(pytorch_res, np.ndarray) pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: - for pytorch_res_i, py_res_i in zip(pytorch_res, py_res): - assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i) + for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True): + assert_fn(pytorch_res_i, py_res_i) else: - assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0]) + assert_fn(pytorch_res[0], py_res[0]) return pytensor_torch_fn, pytorch_res @@ -161,23 +163,23 @@ def test_shared(device): pytensor_torch_fn = function([], a, mode="PYTORCH") pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(pytorch_res, np.ndarray) assert isinstance(a.get_value(), np.ndarray) - np.testing.assert_allclose(pytorch_res.cpu(), a.get_value()) + np.testing.assert_allclose(pytorch_res, a.get_value()) pytensor_torch_fn = function([], a * 2, mode="PYTORCH") pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(pytorch_res, np.ndarray) assert isinstance(a.get_value(), np.ndarray) - np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2) + np.testing.assert_allclose(pytorch_res, a.get_value() * 2) new_a_value = np.array([3, 4, 5], dtype=config.floatX) a.set_value(new_a_value) pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2) + assert isinstance(pytorch_res, np.ndarray) + np.testing.assert_allclose(pytorch_res, new_a_value * 2) @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -224,7 +226,7 @@ def test_alloc_and_empty(): fn = function([dim1], out, mode=pytorch_mode) res = fn(7) assert res.shape == (5, 7, 3) - assert res.dtype == torch.float32 + assert res.dtype == np.float32 v = vector("v", shape=(3,), dtype="float64") out = alloc(v, dim0, dim1, 3) @@ -335,7 +337,7 @@ def test_pytorch_OpFromGraph(): ofg_2 = OpFromGraph([x, y], [x * y, x - y]) o1, o2 = ofg_2(y, z) - out = ofg_1(x, o1) + o2 + out = ofg_1(x, o1) / o2 xv = np.ones((2, 2), dtype=config.floatX) yv = np.ones((2, 2), dtype=config.floatX) * 3 @@ -343,3 +345,129 @@ def test_pytorch_OpFromGraph(): f = FunctionGraph([x, y, z], [out]) compare_pytorch_and_py(f, [xv, yv, zv]) + + +def test_pytorch_link_references(): + import pytensor.link.utils as m + + class BasicOp(Op): + def __init__(self): + super().__init__() + + def make_node(self, *x): + return Apply(self, list(x), [xi.type() for xi in x]) + + def perform(self, *_): + raise RuntimeError("In perform") + + @torch_dispatch.pytorch_funcify.register(BasicOp) + def fn(op, node, **kwargs): + def inner_fn(x): + assert "inner_fn" in dir(m), "not available during dispatch" + return x + + return inner_fn + + x = vector("x") + op = BasicOp() + out = op(x) + + f = function([x], out, mode="PYTORCH") + f(torch.ones(3)) + assert "inner_fn" not in dir(m), "function call reference leaked" + + +def test_pytorch_scipy(): + x = vector("a", shape=(3,)) + out = expit(x) + f = FunctionGraph([x], [out]) + compare_pytorch_and_py(f, [np.random.rand(3)]) + + +def test_pytorch_softplus(): + x = vector("a", shape=(3,)) + out = softplus(x) + f = FunctionGraph([x], [out]) + compare_pytorch_and_py(f, [np.random.rand(3)]) + + +def test_ScalarLoop(): + n_steps = int64("n_steps") + x0 = float64("x0") + const = float64("const") + x = x0 + const + + op = ScalarLoop(init=[x0], constant=[const], update=[x]) + x = op(n_steps, x0, const) + + fn = function([n_steps, x0, const], x, mode=pytorch_mode) + np.testing.assert_allclose(fn(5, 0, 1), 5) + np.testing.assert_allclose(fn(5, 0, 2), 10) + np.testing.assert_allclose(fn(4, 3, -1), -1) + + +def test_ScalarLoop_while(): + n_steps = int64("n_steps") + x0 = float64("x0") + x = x0 + 1 + until = x >= 10 + + op = ScalarLoop(init=[x0], update=[x], until=until) + fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode) + for res, expected in zip( + [fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)], + [[10, True], [10, True], [6, False]], + strict=True, + ): + np.testing.assert_allclose(res[0], np.array(expected[0])) + np.testing.assert_allclose(res[1], np.array(expected[1])) + + +def test_ScalarLoop_Elemwise_single_carries(): + n_steps = int64("n_steps") + x0 = float64("x0") + x = x0 * 2 + until = x >= 10 + + scalarop = ScalarLoop(init=[x0], update=[x], until=until) + op = Elemwise(scalarop) + + n_steps = pt.scalar("n_steps", dtype="int32") + x0 = pt.vector("x0", dtype="float32") + state, done = op(n_steps, x0) + + f = FunctionGraph([n_steps, x0], [state, done]) + args = [ + np.array(10).astype("int32"), + np.arange(0, 5).astype("float32"), + ] + compare_pytorch_and_py( + f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) + ) + + +def test_ScalarLoop_Elemwise_multi_carries(): + n_steps = int64("n_steps") + x0 = float64("x0") + x1 = float64("x1") + x = x0 * 2 + x1_n = x1 * 3 + until = x >= 10 + + scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until) + op = Elemwise(scalarop) + + n_steps = pt.scalar("n_steps", dtype="int32") + x0 = pt.vector("x0", dtype="float32") + x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1)) + *states, done = op(n_steps, x0, x1) + + f = FunctionGraph([n_steps, x0, x1], [*states, done]) + args = [ + np.array(10).astype("int32"), + np.arange(0, 5).astype("float32"), + np.random.rand(7, 3, 1).astype("float32"), + ] + compare_pytorch_and_py( + f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) + ) diff --git a/tests/link/pytorch/test_blockwise.py b/tests/link/pytorch/test_blockwise.py index 75f207e544..d0678fd2c4 100644 --- a/tests/link/pytorch/test_blockwise.py +++ b/tests/link/pytorch/test_blockwise.py @@ -12,7 +12,7 @@ basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic") -class TestOp(Op): +class BatchedTestOp(Op): gufunc_signature = "(m,n),(n,p)->(m,p)" def __init__(self, final_shape): @@ -27,9 +27,8 @@ def perform(self, *_): raise RuntimeError("In perform") -@basic.pytorch_funcify.register(TestOp) +@basic.pytorch_funcify.register(BatchedTestOp) def evaluate_test_op(op, **_): - @torch.compiler.disable(recursive=False) def func(a, b): op.call_shapes.extend(map(torch.Tensor.size, [a, b])) return a @ b @@ -43,7 +42,7 @@ def test_blockwise_broadcast(): x = pt.tensor4("x", shape=(5, 1, 2, 3)) y = pt.tensor3("y", shape=(3, 3, 2)) - op = TestOp((2, 2)) + op = BatchedTestOp((2, 2)) z = Blockwise(op)(x, y) f = pytensor.function([x, y], z, mode="PYTORCH") diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 86089cc921..2a9cf39c99 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -1,10 +1,13 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt import pytensor.tensor.math as ptm from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph +from pytensor.scalar.basic import ScalarOp, get_scalar_type +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, tensor3, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -149,4 +152,34 @@ def test_cast(): _, [res] = compare_pytorch_and_py( fgraph, [np.arange(6, dtype="float32").reshape(2, 3)] ) - assert res.dtype == torch.int32 + assert res.dtype == np.int32 + + +def test_vmap_elemwise(): + from pytensor.link.pytorch.dispatch.basic import pytorch_funcify + + class TestOp(ScalarOp): + def __init__(self): + super().__init__( + output_types_preference=lambda *_: [get_scalar_type("float32")] + ) + self.call_shapes = [] + self.nin = 1 + + def perform(self, *_): + raise RuntimeError("In perform") + + @pytorch_funcify.register(TestOp) + def relu(op, node, **kwargs): + def relu(row): + op.call_shapes.append(row.size()) + return torch.max(torch.zeros_like(row), row) + + return relu + + x = matrix("x", shape=(2, 3)) + op = TestOp() + f = pytensor.function([x], Elemwise(op)(x), mode="PYTORCH") + vals = torch.zeros(2, 3).normal_() + np.testing.assert_allclose(f(vals), torch.relu(vals)) + assert op.call_shapes == [torch.Size([])], op.call_shapes diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 7d69ac0500..55e7c447e3 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + import numpy as np import pytest @@ -22,13 +24,13 @@ def matrix_test(): @pytest.mark.parametrize( "func", - (pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det), + (pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det), ) def test_lin_alg_no_params(func, matrix_test): x, test_value = matrix_test out = func(x) - out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) + out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out]) def assert_fn(x, y): np.testing.assert_allclose(x, y, rtol=1e-3) diff --git a/tests/link/test_link.py b/tests/link/test_link.py index a2e264759b..7d84c2a478 100644 --- a/tests/link/test_link.py +++ b/tests/link/test_link.py @@ -44,7 +44,7 @@ def execute(*args): got = len(args) if got != takes: raise TypeError(f"Function call takes exactly {takes} args ({got} given)") - for arg, variable in zip(args, inputs): + for arg, variable in zip(args, inputs, strict=True): variable.data = arg thunk() if unpack_single: diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index e648869d4c..5aab9a95cc 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -43,7 +43,6 @@ log1p, log2, log10, - mean, mul, neg, neq, @@ -58,7 +57,7 @@ true_div, uint8, ) -from pytensor.tensor.type import fscalar, imatrix, iscalar, matrix +from pytensor.tensor.type import fscalar, imatrix, matrix from tests.link.test_link import make_function @@ -521,34 +520,6 @@ def test_constant(): assert c.dtype == "float32" -@pytest.mark.parametrize("mode", [Mode("py"), Mode("cvm")]) -def test_mean(mode): - a = iscalar("a") - b = iscalar("b") - z = mean(a, b) - z_fn = pytensor.function([a, b], z, mode=mode) - res = z_fn(1, 1) - assert np.allclose(res, 1.0) - - a = fscalar("a") - b = fscalar("b") - c = fscalar("c") - - z = mean(a, b, c) - - z_fn = pytensor.function([a, b, c], pytensor.grad(z, [a]), mode=mode) - res = z_fn(3, 4, 5) - assert np.allclose(res, 1 / 3) - - z_fn = pytensor.function([a, b, c], pytensor.grad(z, [b]), mode=mode) - res = z_fn(3, 4, 5) - assert np.allclose(res, 1 / 3) - - z = mean() - z_fn = pytensor.function([], z, mode=mode) - assert z_fn() == 0 - - def test_shape(): a = float32("a") assert isinstance(a.type, ScalarType) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 7bdf490b68..23f7cc5a19 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -173,7 +173,7 @@ def max_err(self, _g_pt): raise ValueError("argument has wrong number of elements", len(g_pt)) errs = [] - for i, (a, b) in enumerate(zip(g_pt, self.gx)): + for i, (a, b) in enumerate(zip(g_pt, self.gx, strict=True)): if a.shape != b.shape: raise ValueError( f"argument element {i} has wrong shape {(a.shape, b.shape)}" @@ -201,7 +201,10 @@ def scan_project_sum(*args, **kwargs): rng.add_default_updates = False factors = [rng.uniform(0.1, 0.9, size=s.shape) for s in scan_outputs] # Random values (?) - return (sum((s * f).sum() for s, f in zip(scan_outputs, factors)), updates) + return ( + sum((s * f).sum() for s, f in zip(scan_outputs, factors, strict=True)), + updates, + ) def asarrayX(value): @@ -280,7 +283,7 @@ def inner_fn(x): assert y.default_update is not None assert z_rng.default_update is not None - out_fn = function([], out, mode=Mode(optimizer=None)) + out_fn = function([], out) res, z_res = out_fn() assert len(set(res)) == 4 assert len(set(z_res)) == 1 @@ -3843,7 +3846,7 @@ def one_step(x_t, h_tm2, h_tm1, W_ih, W_hh, b_h, W_ho, b_o): gparams = grad(cost, params) updates = [ (param, param - gparam * learning_rate) - for param, gparam in zip(params, gparams) + for param, gparam in zip(params, gparams, strict=True) ] learn_rnn_fn = function(inputs=[x, t], outputs=cost, updates=updates, mode=mode) function(inputs=[x], outputs=y, mode=mode) diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 9df0966b78..9bf32af48f 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -5,7 +5,7 @@ import pytensor.tensor as pt from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph -from pytensor.printing import debugprint, pydot_imported, pydotprint +from pytensor.printing import _try_pydot_import, debugprint, pydotprint from pytensor.tensor.type import dvector, iscalar, scalar, vector @@ -62,9 +62,10 @@ def test_debugprint_sitsot(): Scan{scan_fn, while_loop=False, inplace=none} [id C] โ† Mul [id W] (inner_out_sit_sot-0) โ”œโ”€ *0- [id X] -> [id E] (inner_in_sit_sot-0) - โ””โ”€ *1- [id Y] -> [id M] (inner_in_non_seqs-0)""" + โ””โ”€ *1- [id Y] -> [id M] (inner_in_non_seqs-0) + """ - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -120,9 +121,10 @@ def test_debugprint_sitsot_no_extra_info(): Scan{scan_fn, while_loop=False, inplace=none} [id C] โ† Mul [id W] โ”œโ”€ *0- [id X] -> [id E] - โ””โ”€ *1- [id Y] -> [id M]""" + โ””โ”€ *1- [id Y] -> [id M] + """ - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -188,9 +190,10 @@ def test_debugprint_nitsot(): โ”œโ”€ *0- [id Y] -> [id S] (inner_in_seqs-0) โ””โ”€ Pow [id Z] โ”œโ”€ *2- [id BA] -> [id W] (inner_in_non_seqs-0) - โ””โ”€ *1- [id BB] -> [id U] (inner_in_seqs-1)""" + โ””โ”€ *1- [id BB] -> [id U] (inner_in_seqs-1) + """ - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -303,9 +306,10 @@ def compute_A_k(A, k): Scan{scan_fn, while_loop=False, inplace=none} [id BE] โ† Mul [id CA] (inner_out_sit_sot-0) โ”œโ”€ *0- [id CB] -> [id BG] (inner_in_sit_sot-0) - โ””โ”€ *1- [id CC] -> [id BO] (inner_in_non_seqs-0)""" + โ””โ”€ *1- [id CC] -> [id BO] (inner_in_non_seqs-0) + """ - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() fg = FunctionGraph([c, k, A], [final_result]) @@ -402,9 +406,10 @@ def compute_A_k(A, k): โ†’ *1- [id CB] -> [id BA] (inner_in_non_seqs-0) โ† Mul [id CC] (inner_out_sit_sot-0) โ”œโ”€ *0- [id CA] (inner_in_sit_sot-0) - โ””โ”€ *1- [id CB] (inner_in_non_seqs-0)""" + โ””โ”€ *1- [id CB] (inner_in_non_seqs-0) + """ - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -477,9 +482,10 @@ def fn(a_m2, a_m1, b_m2, b_m1): โ””โ”€ *0- [id BD] -> [id E] (inner_in_mit_sot-0-0) โ† Add [id BE] (inner_out_mit_sot-1) โ”œโ”€ *3- [id BF] -> [id O] (inner_in_mit_sot-1-1) - โ””โ”€ *2- [id BG] -> [id O] (inner_in_mit_sot-1-0)""" + โ””โ”€ *2- [id BG] -> [id O] (inner_in_mit_sot-1-0) + """ - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -613,9 +619,10 @@ def test_debugprint_mitmot(): Scan{scan_fn, while_loop=False, inplace=none} [id F] โ† Mul [id CV] (inner_out_sit_sot-0) โ”œโ”€ *0- [id CT] -> [id H] (inner_in_sit_sot-0) - โ””โ”€ *1- [id CW] -> [id P] (inner_in_non_seqs-0)""" + โ””โ”€ *1- [id CW] -> [id P] (inner_in_non_seqs-0) + """ - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -675,10 +682,17 @@ def no_shared_fn(n, x_tm1, M): output_str = debugprint(out, file="str", print_op_info=True) lines = output_str.split("\n") - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() +try: + _try_pydot_import() + pydot_imported = True +except Exception: + pydot_imported = False + + @pytest.mark.skipif(not pydot_imported, reason="pydot not available") def test_pydotprint(): def f_pow2(x_tm1): diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 6f77625f2f..fd9c43b129 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -673,7 +673,7 @@ def test_machine_translation(self): zi = tensor3("zi") zi_value = x_value - init = pt.alloc(np.cast[config.floatX](0), batch_size, dim) + init = pt.alloc(np.asarray(0, dtype=config.floatX), batch_size, dim) def rnn_step1( # sequences diff --git a/tests/scan/test_utils.py b/tests/scan/test_utils.py index a26c2cbd4b..3586101ada 100644 --- a/tests/scan/test_utils.py +++ b/tests/scan/test_utils.py @@ -220,7 +220,7 @@ def test_ScanArgs_remove_inner_input(): test_v = sigmas_t rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=False) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) assert sigmas_t in removed_nodes assert sigmas_t not in scan_args_copy.inner_in_seqs @@ -232,7 +232,7 @@ def test_ScanArgs_remove_inner_input(): # This removal includes dependents rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) # `sigmas[t]` (i.e. inner-graph input) should be gone assert sigmas_t in removed_nodes @@ -288,7 +288,7 @@ def test_ScanArgs_remove_outer_input(): scan_args_copy = copy(scan_args) test_v = sigmas_in rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) # `sigmas_in` (i.e. outer-graph input) should be gone assert scan_args.outer_in_seqs[-1] in removed_nodes @@ -334,7 +334,7 @@ def test_ScanArgs_remove_inner_output(): scan_args_copy = copy(scan_args) test_v = Y_t rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) # `Y_t` (i.e. inner-graph output) should be gone assert Y_t in removed_nodes @@ -371,7 +371,7 @@ def test_ScanArgs_remove_outer_output(): scan_args_copy = copy(scan_args) test_v = Y_rv rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) # `Y_t` (i.e. inner-graph output) should be gone assert Y_t in removed_nodes @@ -409,7 +409,7 @@ def test_ScanArgs_remove_nonseq_outer_input(): scan_args_copy = copy(scan_args) test_v = Gamma_rv rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) assert Gamma_rv in removed_nodes assert Gamma_in in removed_nodes @@ -447,7 +447,7 @@ def test_ScanArgs_remove_nonseq_inner_input(): scan_args_copy = copy(scan_args) test_v = Gamma_in rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) assert Gamma_in in removed_nodes assert Gamma_rv in removed_nodes @@ -482,7 +482,7 @@ def test_ScanArgs_remove_shared_inner_output(): scan_update = scan_args.inner_out_shared[0] scan_args_copy = copy(scan_args) rm_info = scan_args_copy.remove_from_fields(scan_update, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) assert rng_in in removed_nodes assert all(v in removed_nodes for v in scan_args.inner_out_shared) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index afae9b2187..4075ed3ed6 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -334,7 +334,7 @@ def f(spdata): oconv = conv_none def conv_op(*inputs): - ipt = [conv(i) for i, conv in zip(inputs, iconv)] + ipt = [conv(i) for i, conv in zip(inputs, iconv, strict=True)] out = op(*ipt) return oconv(out) @@ -2192,7 +2192,7 @@ def setup_method(self): def test_op(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) f = pytensor.function(variable, self.op(*variable)) @@ -2203,7 +2203,7 @@ def test_op(self): def test_infer_shape(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) self._compile_and_check( variable, [self.op(*variable)], data, self.op_class @@ -2211,7 +2211,7 @@ def test_infer_shape(self): def test_grad(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) verify_grad_sparse(self.op, data, structured=False) @@ -2223,7 +2223,7 @@ def setup_method(self): def test_op(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) data[0][0, 0] = data[0][1, 1] = 0 @@ -2242,7 +2242,7 @@ def test_op(self): def test_grad(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) verify_grad_sparse(self.op, data, structured=False) diff --git a/tests/tensor/conv/test_abstract_conv.py b/tests/tensor/conv/test_abstract_conv.py index 223e3774c2..23ba23e1e9 100644 --- a/tests/tensor/conv/test_abstract_conv.py +++ b/tests/tensor/conv/test_abstract_conv.py @@ -461,7 +461,8 @@ def get_output_shape( self, inputs_shape, filters_shape, subsample, border_mode, filter_dilation ): dil_filters = tuple( - (s - 1) * d + 1 for s, d in zip(filters_shape[2:], filter_dilation) + (s - 1) * d + 1 + for s, d in zip(filters_shape[2:], filter_dilation, strict=True) ) if border_mode == "valid": border_mode = (0,) * (len(inputs_shape) - 2) @@ -484,6 +485,7 @@ def get_output_shape( subsample, border_mode, filter_dilation, + strict=True, ) ), ) @@ -760,7 +762,7 @@ def test_all(self): db = self.default_border_mode dflip = self.default_filter_flip dprovide_shape = self.default_provide_shape - for i, f in zip(self.inputs_shapes, self.filters_shapes): + for i, f in zip(self.inputs_shapes, self.filters_shapes, strict=True): for provide_shape in self.provide_shape: self.run_test_case(i, f, ds, db, dflip, provide_shape) if min(i) > 0 and min(f) > 0: @@ -1743,7 +1745,7 @@ def setup_method(self): self.random_stream = np.random.default_rng(utt.fetch_seed()) self.inputs_shapes = [(8, 1, 12, 12), (1, 1, 5, 5), (1, 1, 5, 6), (1, 1, 6, 6)] - self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] + self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] * 2 self.subsamples = [(1, 1), (2, 2)] self.border_modes = ["valid", "full"] @@ -1761,7 +1763,9 @@ def test_conv2d_grad_wrt_inputs(self): # the outputs of `pytensor.tensor.conv` forward grads to make sure the # results are the same. - for in_shape, fltr_shape in zip(self.inputs_shapes, self.filters_shapes): + for in_shape, fltr_shape in zip( + self.inputs_shapes, self.filters_shapes, strict=True + ): for bm in self.border_modes: for ss in self.subsamples: for ff in self.filter_flip: @@ -1823,7 +1827,9 @@ def test_conv2d_grad_wrt_weights(self): # the outputs of `pytensor.tensor.conv` forward grads to make sure the # results are the same. - for in_shape, fltr_shape in zip(self.inputs_shapes, self.filters_shapes): + for in_shape, fltr_shape in zip( + self.inputs_shapes, self.filters_shapes, strict=True + ): for bm in self.border_modes: for ss in self.subsamples: for ff in self.filter_flip: @@ -1915,7 +1921,7 @@ def test_fwd(self): kern_sym = tensor5("kern") for imshp, kshp, groups in zip( - self.img_shape, self.kern_shape, self.num_groups + self.img_shape, self.kern_shape, self.num_groups, strict=True ): img = np.random.random(imshp).astype(config.floatX) kern = np.random.random(kshp).astype(config.floatX) @@ -1951,7 +1957,7 @@ def test_fwd(self): ) ref_concat_output = [ ref_func(img_arr, kern_arr) - for img_arr, kern_arr in zip(split_imgs, split_kern) + for img_arr, kern_arr in zip(split_imgs, split_kern, strict=True) ] ref_concat_output = np.concatenate(ref_concat_output, axis=1) @@ -1967,7 +1973,11 @@ def test_gradweights(self): img_sym = tensor5("img") top_sym = tensor5("kern") for imshp, kshp, tshp, groups in zip( - self.img_shape, self.kern_shape, self.top_shape, self.num_groups + self.img_shape, + self.kern_shape, + self.top_shape, + self.num_groups, + strict=True, ): img = np.random.random(imshp).astype(config.floatX) top = np.random.random(tshp).astype(config.floatX) @@ -2005,7 +2015,7 @@ def test_gradweights(self): ) ref_concat_output = [ ref_func(img_arr, top_arr) - for img_arr, top_arr in zip(split_imgs, split_top) + for img_arr, top_arr in zip(split_imgs, split_top, strict=True) ] ref_concat_output = np.concatenate(ref_concat_output, axis=0) @@ -2028,7 +2038,11 @@ def test_gradinputs(self): kern_sym = tensor5("kern") top_sym = tensor5("top") for imshp, kshp, tshp, groups in zip( - self.img_shape, self.kern_shape, self.top_shape, self.num_groups + self.img_shape, + self.kern_shape, + self.top_shape, + self.num_groups, + strict=True, ): kern = np.random.random(kshp).astype(config.floatX) top = np.random.random(tshp).astype(config.floatX) @@ -2066,7 +2080,7 @@ def test_gradinputs(self): ) ref_concat_output = [ ref_func(kern_arr, top_arr) - for kern_arr, top_arr in zip(split_kerns, split_top) + for kern_arr, top_arr in zip(split_kerns, split_top, strict=True) ] ref_concat_output = np.concatenate(ref_concat_output, axis=1) @@ -2368,6 +2382,7 @@ def test_fwd(self): self.subsample, self.num_groups, self.verify_flags, + strict=True, ): img = np.random.random(imshp).astype(config.floatX) kern = np.random.random(kshp).astype(config.floatX) @@ -2426,6 +2441,7 @@ def test_gradweight(self): self.subsample, self.num_groups, self.verify_flags, + strict=True, ): img = np.random.random(imshp).astype(config.floatX) top = np.random.random(topshp).astype(config.floatX) @@ -2494,6 +2510,7 @@ def test_gradinput(self): self.subsample, self.num_groups, self.verify_flags, + strict=True, ): single_kshp = kshp[:1] + kshp[3:] @@ -2576,7 +2593,9 @@ def test_fwd(self): img_sym = tensor4("img") kern_sym = tensor4("kern") - for imshp, kshp, pad in zip(self.img_shape, self.kern_shape, self.border_mode): + for imshp, kshp, pad in zip( + self.img_shape, self.kern_shape, self.border_mode, strict=True + ): img = np.random.random(imshp).astype(config.floatX) kern = np.random.random(kshp).astype(config.floatX) @@ -2627,7 +2646,11 @@ def test_gradweight(self): top_sym = tensor4("top") for imshp, kshp, topshp, pad in zip( - self.img_shape, self.kern_shape, self.topgrad_shape, self.border_mode + self.img_shape, + self.kern_shape, + self.topgrad_shape, + self.border_mode, + strict=True, ): img = np.random.random(imshp).astype(config.floatX) top = np.random.random(topshp).astype(config.floatX) @@ -2684,7 +2707,11 @@ def test_gradinput(self): top_sym = tensor4("top") for imshp, kshp, topshp, pad in zip( - self.img_shape, self.kern_shape, self.topgrad_shape, self.border_mode + self.img_shape, + self.kern_shape, + self.topgrad_shape, + self.border_mode, + strict=True, ): kern = np.random.random(kshp).astype(config.floatX) top = np.random.random(topshp).astype(config.floatX) diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index f342d5b81c..acc793156f 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -140,7 +140,7 @@ def test_inplace_rewrites(rv_op): assert new_op._props_dict() == (op._props_dict() | {"inplace": True}) assert all( np.array_equal(a.data, b.data) - for a, b in zip(new_op.dist_params(new_node), op.dist_params(node)) + for a, b in zip(new_op.dist_params(new_node), op.dist_params(node), strict=True) ) assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data) assert check_stack_trace(f) diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 7d24a49228..d10e384339 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1,6 +1,6 @@ import pickle import re -from copy import copy +from copy import deepcopy import numpy as np import pytest @@ -113,7 +113,9 @@ def test_fn(*args, random_state=None, **kwargs): pt_rng = shared(rng, borrow=True) - numpy_res = np.asarray(test_fn(*param_vals, random_state=copy(rng), **kwargs_vals)) + numpy_res = np.asarray( + test_fn(*param_vals, random_state=deepcopy(rng), **kwargs_vals) + ) pytensor_res = rv(*params, rng=pt_rng, **kwargs) diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 8e74b06bd4..edec9a4389 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -74,16 +74,16 @@ def test_RandomVariable_basics(strict_test_value_flags): # `dtype` is respected rv = RandomVariable("normal", signature="(),()->()", dtype="int32") with config.change_flags(compute_test_value="off"): - rv_out = rv() + rv_out = rv(0, 0) assert rv_out.dtype == "int32" - rv_out = rv(dtype="int64") + rv_out = rv(0, 0, dtype="int64") assert rv_out.dtype == "int64" with pytest.raises( ValueError, match="Cannot change the dtype of a normal RV from int32 to float32", ): - assert rv(dtype="float32").dtype == "float32" + assert rv(0, 0, dtype="float32").dtype == "float32" def test_RandomVariable_bcast(strict_test_value_flags): diff --git a/tests/tensor/random/test_type.py b/tests/tensor/random/test_type.py index d289862347..d358f2a93a 100644 --- a/tests/tensor/random/test_type.py +++ b/tests/tensor/random/test_type.py @@ -52,7 +52,7 @@ def test_filter(self): with pytest.raises(TypeError): rng_type.filter(1) - rng_dict = rng.__getstate__() + rng_dict = rng.bit_generator.state assert rng_type.is_valid_value(rng_dict) is False assert rng_type.is_valid_value(rng_dict, strict=False) @@ -88,13 +88,13 @@ def test_values_eq(self): assert rng_type.values_eq(bitgen_g, bitgen_h) assert rng_type.is_valid_value(bitgen_a, strict=True) - assert rng_type.is_valid_value(bitgen_b.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_b.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_c, strict=True) - assert rng_type.is_valid_value(bitgen_d.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_d.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_e, strict=True) - assert rng_type.is_valid_value(bitgen_f.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_f.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_g, strict=True) - assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_h.bit_generator.state, strict=False) def test_may_share_memory(self): bg_a = np.random.PCG64() diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 3616b2fd24..f7d8731c1b 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -165,14 +165,20 @@ def test_seed(self, rng_ctor): state_rng = random.state_updates[0][0].get_value(borrow=True) if hasattr(state_rng, "get_state"): - ref_state = ref_rng.get_state() random_state = state_rng.get_state() + + # hack to try to get something reasonable for ref_rng + try: + ref_state = ref_rng.get_state() + except AttributeError: + ref_state = list(ref_rng.bit_generator.state.values()) + assert np.array_equal(random_state[1], ref_state[1]) assert random_state[0] == ref_state[0] assert random_state[2:] == ref_state[2:] else: - ref_state = ref_rng.__getstate__() - random_state = state_rng.__getstate__() + ref_state = ref_rng.bit_generator.state + random_state = state_rng.bit_generator.state assert random_state["bit_generator"] == ref_state["bit_generator"] assert random_state["state"] == ref_state["state"] @@ -271,7 +277,7 @@ def __init__(self, seed=123): g2 = Graph(seed=987) f2 = function([], g2.y) - for su1, su2 in zip(g1.rng.state_updates, g2.rng.state_updates): + for su1, su2 in zip(g1.rng.state_updates, g2.rng.state_updates, strict=True): su2[0].set_value(su1[0].get_value()) np.testing.assert_array_almost_equal(f1(), f2(), decimal=6) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 4ff773dbb8..8911f56630 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -12,7 +12,8 @@ from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph.basic import equal_computations +from pytensor.graph import Op +from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -29,6 +30,7 @@ TensorFromScalar, as_tensor, cast, + constant, join, tile, ) @@ -65,6 +67,8 @@ local_merge_alloc, local_useless_alloc, local_useless_elemwise, + topo_constant_folding, + topo_unconditional_constant_folding, topological_fill_sink, ) from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot @@ -742,56 +746,92 @@ def test_upcast(self): ) or (len(topo) > 1) -def test_constant_folding(): - # Test that constant folding get registered at fast_compile - # An error removed that registration during the registration. - x = dvector() - mode = get_mode("FAST_COMPILE").excluding("fusion") - f = function([x], [x * 2, x + x], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - - # Test that we do not crash when constant folding elemwise scalar - # as they should not generate c code. +class TestConstantFolding: + def test_constant_folding(self): + # Test that constant folding get registered at fast_compile + # An error removed that registration during the registration. + x = dvector() + mode = get_mode("FAST_COMPILE").excluding("fusion") + f = function([x], [x * 2, x + x], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 2 - x = pt.constant(3) - assert x.ndim == 0 - mode = get_mode("FAST_COMPILE").excluding("fusion") - f = function([], [x * 2, x + x], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - assert all(isinstance(n.op, DeepCopyOp) for n in topo) + # Test that we do not crash when constant folding elemwise scalar + # as they should not generate c code. + x = pt.constant(3) + assert x.ndim == 0 + mode = get_mode("FAST_COMPILE").excluding("fusion") + f = function([], [x * 2, x + x], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 2 + assert all(isinstance(n.op, DeepCopyOp) for n in topo) -@pytest.mark.xfail( - reason="PyTensor rewrites constants before stabilization. " - "This breaks stabilization rewrites in some cases. See #504.", - raises=AssertionError, -) -def test_constant_get_stabilized(): - # Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites. - # This caused some stabilization rewrites to not be activated and that - # caused inf values to appear when they should not. + @pytest.mark.xfail( + reason="PyTensor rewrites constants before stabilization. " + "This breaks stabilization rewrites in some cases. See #504.", + raises=AssertionError, + ) + def test_constant_get_stabilized(self): + # Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites. + # This caused some stabilization rewrites to not be activated and that + # caused inf values to appear when they should not. - # We can't simply move the `constant_folding` rewrite to - # specialize since this will break other rewrites. We will need to - # partially duplicate some canonicalize rewrites to fix this issue. + # We can't simply move the `constant_folding` rewrite to + # specialize since this will break other rewrites. We will need to + # partially duplicate some canonicalize rewrites to fix this issue. - x2 = scalar() - y2 = log(1 + exp(x2)) - mode = get_default_mode() - mode.check_isfinite = False - f2 = function([x2], y2, mode=mode) - - assert len(f2.maker.fgraph.toposort()) == 1 - assert f2.maker.fgraph.toposort()[0].op == softplus - assert f2(800) == 800 - - x = pt.as_tensor_variable(800) - y = log(1 + exp(x)) - f = function([], y, mode=mode) - # When this error is fixed, the following line should be ok. - assert f() == 800, f() + x2 = scalar() + y2 = log(1 + exp(x2)) + mode = get_default_mode() + mode.check_isfinite = False + f2 = function([x2], y2, mode=mode) + + assert len(f2.maker.fgraph.toposort()) == 1 + assert f2.maker.fgraph.toposort()[0].op == softplus + assert f2(800) == 800 + + x = pt.as_tensor_variable(800) + y = log(1 + exp(x)) + f = function([], y, mode=mode) + # When this error is fixed, the following line should be ok. + assert f() == 800, f() + + def test_unconditional(self): + x = pt.alloc(np.e, *(3, 5)) + fg = FunctionGraph(outputs=[x], clone=False) + + # Default constant folding doesn't apply to Alloc used as outputs + topo_constant_folding.apply(fg) + assert not isinstance(fg.outputs[0], Constant) + + # Unconditional constant folding does apply + topo_unconditional_constant_folding.apply(fg) + assert isinstance(fg.outputs[0], Constant) + np.testing.assert_allclose(fg.outputs[0].data, np.full((3, 5), np.e)) + + def test_unconditional_no_perform_method(self): + """Test that errors are caught when the Op does not have a perform method.""" + + class OpNoPerform(Op): + itypes = [scalar(dtype="float64").type] + otypes = [scalar(dtype="float64").type] + + def perform(self, *args, **kwargs): + raise NotImplementedError("This Op cannot be evaluated") + + x = constant(np.array(5.0)) + out = OpNoPerform()(x) + + fg = FunctionGraph(outputs=[out], clone=False) + # Default constant_folding will raise + with pytest.raises(NotImplementedError): + topo_constant_folding.apply(fg) + + # Unconditional constant folding will be silent + topo_unconditional_constant_folding.apply(fg) + assert not isinstance(fg.outputs[0], Constant) + assert isinstance(fg.outputs[0].owner.op, OpNoPerform) class TestLocalSwitchSink: diff --git a/tests/tensor/rewriting/test_blas.py b/tests/tensor/rewriting/test_blas.py index efd18c3831..d939ceedce 100644 --- a/tests/tensor/rewriting/test_blas.py +++ b/tests/tensor/rewriting/test_blas.py @@ -2,11 +2,39 @@ import pytest from pytensor import function +from pytensor import tensor as pt from pytensor.compile import get_default_mode -from pytensor.tensor import matmul, tensor, vectorize +from pytensor.graph import FunctionGraph +from pytensor.tensor import ( + col, + dscalar, + dvector, + matmul, + matrix, + mul, + neg, + row, + scalar, + sqrt, + tensor, + vector, + vectorize, +) from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.rewriting.blas import ( + _as_scalar, + _factor_canonicalized, + _gemm_canonicalize, + _is_real_matrix, + res_is_a, + specialize_matmul_to_batched_dot, +) + + +def XYZab(): + return matrix(), matrix(), matrix(), scalar(), scalar() @pytest.mark.parametrize("valid_case", (True, False)) @@ -46,3 +74,136 @@ def core_np(x, y): vectorize_pt(x_test, y_test), vectorize_np(x_test, y_test), ) + + +def test_gemm_factor(): + X, Y = matrix("X"), matrix("Y") + + assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)]) + assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)]) + + +def test_gemm_canonicalize(): + X, Y, Z, a, b = ( + matrix("X"), + matrix("Y"), + matrix("Z"), + scalar("a"), + scalar("b"), + ) + c, d = scalar("c"), scalar("d") + u = row("u") + v = vector("v") + w = col("w") + + can = [] + fg = FunctionGraph([X, Y, Z], [X + Y + Z], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can == [(1.0, X), (1.0, Y), (1.0, Z)] + + can = [] + fg = FunctionGraph([X, Y, u], [X + Y + u], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can == [(1.0, X), (1.0, Y), (1.0, u)], can + + can = [] + fg = FunctionGraph([X, Y, v], [X + Y + v], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + # [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))] + assert can[:2] == [(1.0, X), (1.0, Y)] + assert isinstance(can[2], tuple) + assert len(can[2]) == 2 + assert can[2][0] == 1.0 + assert can[2][1].owner + assert isinstance(can[2][1].owner.op, DimShuffle) + assert can[2][1].owner.inputs == [v] + + can = [] + fg = FunctionGraph([X, Y, w], [X + Y + w], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can == [(1.0, X), (1.0, Y), (1.0, w)], can + + can = [] + fg = FunctionGraph([a, X, Y, b, Z, c], [a * X + Y - b * Z * c], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can[0] == (a, X) + assert can[1] == (1.0, Y) + assert can[2][0].owner.op == mul + assert can[2][0].owner.inputs[0].owner.op == neg + assert can[2][0].owner.inputs[0].owner.inputs[0] == c + assert can[2][0].owner.inputs[1] == b + + can = [] + fg = FunctionGraph( + [a, X, Y, b, Z, c, d], [(-d) * X - (a * X + Y - b * Z * c)], clone=False + ) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can[0][0].owner.op == neg + assert can[0][0].owner.inputs[0] == d + assert can[0][1] == X + assert can[1][0].owner.op == neg + assert can[1][0].owner.inputs[0] == a + assert can[2] == (-1.0, Y) + assert can[3][0].owner.op == mul + assert can[3][0].owner.inputs == [c, b] + + +def test_res_is_a(): + X, Y, Z, a, b = XYZab() + + assert not res_is_a(None, a, sqrt) + assert not res_is_a(None, a + a, sqrt) + assert res_is_a(None, sqrt(a + a), sqrt) + + sqrt_term = sqrt(a + a) + fg = FunctionGraph([a], [2 * sqrt_term], clone=False) + assert res_is_a(fg, sqrt_term, sqrt, 2) + assert not res_is_a(fg, sqrt_term, sqrt, 0) + + +class TestAsScalar: + def test_basic(self): + # Test that it works on scalar constants + a = pt.constant(2.5) + b = pt.constant(np.asarray([[[0.5]]])) + b2 = b.dimshuffle() + assert b2.ndim == 0 + d_a = DimShuffle(input_ndim=0, new_order=[])(a) + d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b) + d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a) + + assert _as_scalar(a) == a + assert _as_scalar(b) != b + assert _as_scalar(d_a) != d_a + assert _as_scalar(d_b) != d_b + assert _as_scalar(d_a2) != d_a2 + + def test_basic_1(self): + # Test that it fails on nonscalar constants + a = pt.constant(np.ones(5)) + assert _as_scalar(a) is None + assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None + + def test_basic_2(self): + # Test that it works on scalar variables + a = dscalar() + d_a = DimShuffle(input_ndim=0, new_order=[])(a) + d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a) + + assert _as_scalar(a) is a + assert _as_scalar(d_a) is a + assert _as_scalar(d_a2) is a + + def test_basic_3(self): + # Test that it fails on nonscalar variables + a = matrix() + assert _as_scalar(a) is None + assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None + + +class TestRealMatrix: + def test_basic(self): + assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix())) + assert not _is_real_matrix( + DimShuffle(input_ndim=1, new_order=["x", 0])(dvector()) + ) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 9488a9f688..f1b71949d1 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -987,10 +987,12 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): else: out = [ self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out") - for g_, od in zip(g, out_dtype) + for g_, od in zip(g, out_dtype, strict=True) ] - assert all(o.dtype == g_.dtype for o, g_ in zip(out, g)) - f = function(sym_inputs, [], updates=list(zip(out, g)), mode=self.mode) + assert all(o.dtype == g_.dtype for o, g_ in zip(out, g, strict=True)) + f = function( + sym_inputs, [], updates=list(zip(out, g, strict=True)), mode=self.mode + ) for x in range(nb_repeat): f(*val_inputs) out = [o.get_value() for o in out] @@ -1000,7 +1002,7 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): if any(o == "float32" for o in out_dtype): atol = 1e-6 - for o, a in zip(out, answer): + for o, a in zip(out, answer, strict=True): np.testing.assert_allclose(o, a * nb_repeat, atol=atol) topo = f.maker.fgraph.toposort() @@ -1020,7 +1022,7 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): ) assert expected_len_sym_inputs == len(sym_inputs) - for od, o in zip(out_dtype, out): + for od, o in zip(out_dtype, out, strict=True): assert od == o.dtype def test_fusion_35_inputs(self): diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9dd2a247a8..c9b9afff19 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -21,6 +21,7 @@ KroneckerProduct, MatrixInverse, MatrixPinv, + SLogDet, matrix_inverse, svd, ) @@ -719,7 +720,7 @@ def test_det_blockdiag_rewrite(): def test_slogdet_blockdiag_rewrite(): - n_matrices = 100 + n_matrices = 10 matrix_size = (5, 5) sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) @@ -776,11 +777,34 @@ def test_diag_kronecker_rewrite(): ) +def test_det_kronecker_rewrite(): + a, b = pt.dmatrices("a", "b") + kron_prod = pt.linalg.kron(a, b) + det_output = pt.linalg.det(kron_prod) + f_rewritten = function([a, b], [det_output], mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, KroneckerProduct) for node in nodes) + + # Value Test + a_test, b_test = np.random.rand(2, 20, 20) + kron_prod_test = np.kron(a_test, b_test) + det_output_test = np.linalg.det(kron_prod_test) + rewritten_det_val = f_rewritten(a_test, b_test) + assert_allclose( + det_output_test, + rewritten_det_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + def test_slogdet_kronecker_rewrite(): a, b = pt.dmatrices("a", "b") kron_prod = pt.linalg.kron(a, b) sign_output, logdet_output = pt.linalg.slogdet(kron_prod) - f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN") + f_rewritten = function([a, b], [sign_output, logdet_output], mode="FAST_RUN") # Rewrite Test nodes = f_rewritten.maker.fgraph.apply_nodes @@ -790,7 +814,7 @@ def test_slogdet_kronecker_rewrite(): a_test, b_test = np.random.rand(2, 20, 20) kron_prod_test = np.kron(a_test, b_test) sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test) - rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test) + rewritten_sign_val, rewritten_logdet_val = f_rewritten(a_test, b_test) assert_allclose( sign_output_test, rewritten_sign_val, @@ -906,3 +930,69 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): f_rewritten = function([x], z_cholesky, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Cholesky) for node in nodes) + + +def test_slogdet_specialization(): + x, a = pt.dmatrix("x"), np.random.rand(20, 20) + det_x, det_a = pt.linalg.det(x), np.linalg.det(a) + log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a)) + log_det_x, log_det_a = pt.log(det_x), np.log(det_a) + sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a) + exp_det_x = pt.exp(det_x) + + # REWRITE TESTS + # sign(det(x)) + f = function([x], [sign_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_sign_det_a = f(a) + assert_allclose( + sign_det_a, + rw_sign_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + # log(abs(det(x))) + f = function([x], [log_abs_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_log_abs_det_a = f(a) + assert_allclose( + log_abs_det_a, + rw_log_abs_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + # log(det(x)) + f = function([x], [log_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_log_det_a = f(a) + assert_allclose( + log_det_a, + rw_log_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + # More than 1 valid function + f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + + # Other functions (rewrite shouldnt be applied to these) + # Only invalid functions + f = function([x], [exp_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, SLogDet) for node in nodes) + + # Invalid + Valid function + f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, SLogDet) for node in nodes) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index e4a08cdf81..debcf44c64 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -61,6 +61,7 @@ ge, gt, int_div, + kv, le, log, log1mexp, @@ -1382,11 +1383,11 @@ def assert_eqs_const(self, f, val, op=deep_copy_op): if op == deep_copy_op: assert len(elem.inputs) == 1, elem.inputs assert isinstance(elem.inputs[0], TensorConstant), elem - assert pt.extract_constant(elem.inputs[0]) == val, val + assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val else: assert len(elem.inputs) == 2, elem.inputs assert isinstance(elem.inputs[0], TensorConstant), elem - assert pt.extract_constant(elem.inputs[0]) == val, val + assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val def assert_identity(self, f): topo = f.maker.fgraph.toposort() @@ -2163,7 +2164,7 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite): # The zero branch upcasts the output, so we can't ignore its dtype zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch") other_branch = scalar("other_branch", dtype="float32") - outer_var = scalar("mul_var", dtype="bool") + outer_var = scalar("outer_var", dtype="bool") out = op(switch(cond, zero_branch, other_branch), outer_var) fgraph = FunctionGraph(outputs=[out], clone=False) @@ -2173,6 +2174,27 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite): expected_out = switch(cond, zero_branch, op(other_branch, outer_var)) assert equal_computations([new_out], [expected_out]) + @pytest.mark.parametrize( + "op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)] + ) + def test_local_mul_div_switch_sink_branch_order(self, op, rewrite): + cond = scalar("cond", dtype="bool") + zero_branch = constant(np.array(0.0, dtype="float64"), "zero_branch") + other_branch = scalar("other_branch", dtype="float64") + outer_var = scalar("outer_var", dtype="float64") + + left = op(switch(cond, zero_branch, other_branch), outer_var) + right = op(switch(cond, other_branch, zero_branch), outer_var) + fgraph = FunctionGraph(outputs=[left, right], clone=False) + [new_left] = rewrite.transform(fgraph, left.owner) + [new_right] = rewrite.transform(fgraph, right.owner) + + expected_left = switch(cond, zero_branch, op(other_branch, outer_var)) + expected_right = switch(cond, op(other_branch, outer_var), zero_branch) + assert equal_computations( + [new_left, new_right], [expected_left, expected_right] + ) + @pytest.mark.skipif( config.cxx == "", @@ -3784,14 +3806,9 @@ def test_local_expm1(): for n in h.maker.fgraph.toposort() ) - # This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked - expect_rewrite = config.mode != "FAST_COMPILE" - assert ( - any( - isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1) - for n in r.maker.fgraph.toposort() - ) - == expect_rewrite + assert any( + isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1) + for n in r.maker.fgraph.toposort() ) @@ -4418,23 +4435,6 @@ def test_local_add_neg_to_sub(first_negative): assert np.allclose(f(x_test, y_test), exp) -def test_local_add_neg_to_sub_const(): - x = vector("x") - const = 5.0 - - f = function([x], x + (-const), mode=Mode("py")) - - nodes = [ - node.op - for node in f.maker.fgraph.toposort() - if not isinstance(node.op, DimShuffle) - ] - assert nodes == [pt.sub] - - x_test = np.array([3, 4], dtype=config.floatX) - assert np.allclose(f(x_test), x_test + (-const)) - - def test_log1mexp_stabilization(): mode = Mode("py").including("stabilize") @@ -4557,3 +4557,17 @@ def test_local_batched_matmul_to_core_matmul(): x_test = rng.normal(size=(5, 3, 2)) y_test = rng.normal(size=(5, 2, 2)) np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + +def test_log_kv_stabilization(): + x = pt.scalar("x") + out = log(kv(4.5, x)) + + # Expression would underflow to -inf without rewrite + mode = get_default_mode().including("stabilize") + # Reference value from mpmath + # mpmath.log(mpmath.besselk(4.5, 1000.0)) + np.testing.assert_allclose( + out.eval({x: 1000.0}, mode=mode), + -1003.2180912984705, + ) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 72a7a0f235..fcfd72ddf2 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -1389,7 +1389,7 @@ def test_none_slice(self): for x_s in self.x_shapes: x_val = self.rng.uniform(size=x_s).astype(config.floatX) - for i_val in zip(*values): + for i_val in zip(*values, strict=True): f(x_val, *i_val) def test_none_index(self): @@ -1447,7 +1447,7 @@ def test_none_index(self): for x_s in self.x_shapes: x_val = self.rng.uniform(size=x_s).astype(config.floatX) - for i_val in zip(*values): + for i_val in zip(*values, strict=True): # The index could be out of bounds # In that case, an Exception should be raised, # otherwise, we let DebugMode check f @@ -1568,7 +1568,7 @@ def test_stack_trace(self): incs = [set_subtensor(x[idx], y) for y in ys] outs = [inc[idx] for inc in incs] - for y, out in zip(ys, outs): + for y, out in zip(ys, outs, strict=True): f = function([x, y, idx], out, self.mode) assert check_stack_trace(f, ops_to_check=(Assert, ps.Cast)) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 323d401f42..8e3636c814 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -46,7 +46,6 @@ default, diag, expand_dims, - extract_constant, eye, fill, flatnonzero, @@ -420,7 +419,7 @@ def test_make_vector(self, dtype, inputs): # The gradient should be 0 utt.assert_allclose(g_val, 0) else: - for var, grval in zip((b, i, d), g_val): + for var, grval in zip((b, i, d), g_val, strict=True): float_inputs = [] if var.dtype in int_dtypes: pass @@ -777,6 +776,7 @@ def test_alloc_constant_folding(self): # AdvancedIncSubtensor (some_matrix[idx, idx], 1), ], + strict=True, ): derp = pt_sum(dense_dot(subtensor, variables)) @@ -1120,7 +1120,7 @@ def check(m): assert np.allclose(res_matrix, np.vstack(np.nonzero(m))) - for i, j in zip(res_tuple, np.nonzero(m)): + for i, j in zip(res_tuple, np.nonzero(m), strict=True): assert np.allclose(i, j) rand0d = np.empty(()) @@ -2170,7 +2170,7 @@ def test_split_view(self, linker): ) x_test = np.arange(5, dtype=config.floatX) res = f(x_test) - for r, expected in zip(res, ([], [0, 1, 2], [3, 4])): + for r, expected in zip(res, ([], [0, 1, 2], [3, 4]), strict=True): assert np.allclose(r, expected) if linker == "py": assert r.base is x_test @@ -2951,8 +2951,8 @@ def test_mgrid_numpy_equiv(self): mgrid[0:1:0.1, 1:10:1.0, 10:100:10.0], mgrid[0:2:1, 1:10:1, 10:100:10], ) - for n, t in zip(nmgrid, tmgrid): - for ng, tg in zip(n, t): + for n, t in zip(nmgrid, tmgrid, strict=True): + for ng, tg in zip(n, t, strict=True): utt.assert_allclose(ng, tg.eval()) def test_ogrid_numpy_equiv(self): @@ -2966,8 +2966,8 @@ def test_ogrid_numpy_equiv(self): ogrid[0:1:0.1, 1:10:1.0, 10:100:10.0], ogrid[0:2:1, 1:10:1, 10:100:10], ) - for n, t in zip(nogrid, togrid): - for ng, tg in zip(n, t): + for n, t in zip(nogrid, togrid, strict=True): + for ng, tg in zip(n, t, strict=True): utt.assert_allclose(ng, tg.eval()) def test_mgrid_pytensor_variable_numpy_equiv(self): @@ -2979,8 +2979,10 @@ def test_mgrid_pytensor_variable_numpy_equiv(self): timgrid = mgrid[l:2:1, 1:m:1, 10:100:n] ff = pytensor.function([i, j, k], tfmgrid) fi = pytensor.function([l, m, n], timgrid) - for n, t in zip((nfmgrid, nimgrid), (ff(0, 10, 10.0), fi(0, 10, 10))): - for ng, tg in zip(n, t): + for n, t in zip( + (nfmgrid, nimgrid), (ff(0, 10, 10.0), fi(0, 10, 10)), strict=True + ): + for ng, tg in zip(n, t, strict=True): utt.assert_allclose(ng, tg) def test_ogrid_pytensor_variable_numpy_equiv(self): @@ -2992,8 +2994,10 @@ def test_ogrid_pytensor_variable_numpy_equiv(self): tiogrid = ogrid[l:2:1, 1:m:1, 10:100:n] ff = pytensor.function([i, j, k], tfogrid) fi = pytensor.function([l, m, n], tiogrid) - for n, t in zip((nfogrid, niogrid), (ff(0, 10, 10.0), fi(0, 10, 10))): - for ng, tg in zip(n, t): + for n, t in zip( + (nfogrid, niogrid), (ff(0, 10, 10.0), fi(0, 10, 10)), strict=True + ): + for ng, tg in zip(n, t, strict=True): utt.assert_allclose(ng, tg) @@ -3038,7 +3042,7 @@ def test_dim2(self): assert np.all(f_inverse(inv_val) == p_val) # Check that, for each permutation, # permutation(inverse) == inverse(permutation) = identity - for p_row, i_row in zip(p_val, inv_val): + for p_row, i_row in zip(p_val, inv_val, strict=True): assert np.all(p_row[i_row] == np.arange(10)) assert np.all(i_row[p_row] == np.arange(10)) @@ -3104,7 +3108,9 @@ def test_2_2(self): # Each row of p contains a permutation to apply to the corresponding # row of input - out_bis = np.asarray([i_row[p_row] for i_row, p_row in zip(input_val, p_val)]) + out_bis = np.asarray( + [i_row[p_row] for i_row, p_row in zip(input_val, p_val, strict=True)] + ) assert np.all(out_val == out_bis) # Verify gradient @@ -3266,7 +3272,6 @@ def test_autocast_custom(): assert (dvector() + 1.1).dtype == "float64" assert (fvector() + np.float32(1.1)).dtype == "float32" assert (fvector() + np.float64(1.1)).dtype == "float64" - assert (fvector() + 1.1).dtype == config.floatX assert (lvector() + np.int64(1)).dtype == "int64" assert (lvector() + np.int32(1)).dtype == "int64" assert (lvector() + np.int16(1)).dtype == "int64" @@ -3564,12 +3569,13 @@ def test_second(self): assert get_underlying_scalar_constant_value(s) == c.data def test_copy(self): - # Make sure we do not return the internal storage of a constant, + # Make sure we do not return a writeable internal storage of a constant, # so we cannot change the value of a constant by mistake. c = constant(3) - d = extract_constant(c) - d += 1 - e = extract_constant(c) + d = get_scalar_constant_value(c) + with pytest.raises(ValueError, match="output array is read-only"): + d += 1 + e = get_scalar_constant_value(c) assert e == 3, (c, d, e) @pytest.mark.parametrize("only_process_constants", (True, False)) @@ -4674,7 +4680,7 @@ def test_where(): np.testing.assert_allclose(np.where(cond, ift, iff), where(cond, ift, iff).eval()) # Test for only condition input - for np_output, pt_output in zip(np.where(cond), where(cond)): + for np_output, pt_output in zip(np.where(cond), where(cond), strict=True): np.testing.assert_allclose(np_output, pt_output.eval()) # Test for error diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 3b6115a107..1e4afb8928 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -16,7 +16,6 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import grad -from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import InconsistencyError from pytensor.tensor import inplace @@ -28,12 +27,8 @@ Gemm, Gemv, Ger, - _as_scalar, _dot22, _dot22scalar, - _factor_canonicalized, - _gemm_canonicalize, - _is_real_matrix, batched_dot, batched_tensordot, gemm, @@ -44,19 +39,15 @@ gemv_no_inplace, ger, ger_destructive, - res_is_a, ) -from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt +from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger from pytensor.tensor.type import ( cmatrix, - col, cscalar, dmatrix, drow, dscalar, - dvector, fmatrix, fscalar, imatrix, @@ -65,7 +56,6 @@ ivector, matrices, matrix, - row, scalar, scalars, tensor, @@ -572,67 +562,6 @@ def test_gemm(self): self.run_gemm(dtype, alpha, beta, tA, tB, tC, sA, sB, sC, rng) -def test_res_is_a(): - X, Y, Z, a, b = XYZab() - - assert not res_is_a(None, a, sqrt) - assert not res_is_a(None, a + a, sqrt) - assert res_is_a(None, sqrt(a + a), sqrt) - - sqrt_term = sqrt(a + a) - fg = FunctionGraph([a], [2 * sqrt_term], clone=False) - assert res_is_a(fg, sqrt_term, sqrt, 2) - assert not res_is_a(fg, sqrt_term, sqrt, 0) - - -class TestAsScalar: - def test_basic(self): - # Test that it works on scalar constants - a = pt.constant(2.5) - b = pt.constant(np.asarray([[[0.5]]])) - b2 = b.dimshuffle() - assert b2.ndim == 0 - d_a = DimShuffle(input_ndim=0, new_order=[])(a) - d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b) - d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a) - - assert _as_scalar(a) == a - assert _as_scalar(b) != b - assert _as_scalar(d_a) != d_a - assert _as_scalar(d_b) != d_b - assert _as_scalar(d_a2) != d_a2 - - def test_basic_1(self): - # Test that it fails on nonscalar constants - a = pt.constant(np.ones(5)) - assert _as_scalar(a) is None - assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None - - def test_basic_2(self): - # Test that it works on scalar variables - a = dscalar() - d_a = DimShuffle(input_ndim=0, new_order=[])(a) - d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a) - - assert _as_scalar(a) is a - assert _as_scalar(d_a) is a - assert _as_scalar(d_a2) is a - - def test_basic_3(self): - # Test that it fails on nonscalar variables - a = matrix() - assert _as_scalar(a) is None - assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None - - -class TestRealMatrix: - def test_basic(self): - assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix())) - assert not _is_real_matrix( - DimShuffle(input_ndim=1, new_order=["x", 0])(dvector()) - ) - - """ This test suite ensures that Gemm is inserted where it belongs, and that the resulting functions compute the same things as the originals. @@ -774,78 +703,6 @@ def test_gemm_opt_double_gemm(): assert max_abs_err <= eps, "GEMM is computing the wrong output. max_rel_err =" -def test_gemm_canonicalize(): - X, Y, Z, a, b = ( - matrix("X"), - matrix("Y"), - matrix("Z"), - scalar("a"), - scalar("b"), - ) - c, d = scalar("c"), scalar("d") - u = row("u") - v = vector("v") - w = col("w") - - can = [] - fg = FunctionGraph([X, Y, Z], [X + Y + Z], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can == [(1.0, X), (1.0, Y), (1.0, Z)] - - can = [] - fg = FunctionGraph([X, Y, u], [X + Y + u], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can == [(1.0, X), (1.0, Y), (1.0, u)], can - - can = [] - fg = FunctionGraph([X, Y, v], [X + Y + v], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - # [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))] - assert can[:2] == [(1.0, X), (1.0, Y)] - assert isinstance(can[2], tuple) - assert len(can[2]) == 2 - assert can[2][0] == 1.0 - assert can[2][1].owner - assert isinstance(can[2][1].owner.op, DimShuffle) - assert can[2][1].owner.inputs == [v] - - can = [] - fg = FunctionGraph([X, Y, w], [X + Y + w], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can == [(1.0, X), (1.0, Y), (1.0, w)], can - - can = [] - fg = FunctionGraph([a, X, Y, b, Z, c], [a * X + Y - b * Z * c], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can[0] == (a, X) - assert can[1] == (1.0, Y) - assert can[2][0].owner.op == mul - assert can[2][0].owner.inputs[0].owner.op == neg - assert can[2][0].owner.inputs[0].owner.inputs[0] == c - assert can[2][0].owner.inputs[1] == b - - can = [] - fg = FunctionGraph( - [a, X, Y, b, Z, c, d], [(-d) * X - (a * X + Y - b * Z * c)], clone=False - ) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can[0][0].owner.op == neg - assert can[0][0].owner.inputs[0] == d - assert can[0][1] == X - assert can[1][0].owner.op == neg - assert can[1][0].owner.inputs[0] == a - assert can[2] == (-1.0, Y) - assert can[3][0].owner.op == mul - assert can[3][0].owner.inputs == [c, b] - - -def test_gemm_factor(): - X, Y = matrix("X"), matrix("Y") - - assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)]) - assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)]) - - def test_upcasting_scalar_nogemm(): # Test that the optimization does not crash when the scale has an incorrect # dtype, and forces upcasting of the result @@ -2594,7 +2451,7 @@ def test_ger(self): lambda xs, ys: np.asarray( [ x * y if x.ndim == 0 or y.ndim == 0 else np.dot(x, y) - for x, y in zip(xs, ys) + for x, y in zip(xs, ys, strict=True) ], dtype=ps.upcast(xs.dtype, ys.dtype), ) @@ -2697,7 +2554,7 @@ def check_first_dim(inverted): assert x.strides[0] == direction * np.dtype(config.floatX).itemsize assert not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]) result = f(x, w) - ref_result = np.asarray([np.dot(u, v) for u, v in zip(x, w)]) + ref_result = np.asarray([np.dot(u, v) for u, v in zip(x, w, strict=True)]) utt.assert_allclose(ref_result, result) for inverted in (0, 1): diff --git a/tests/tensor/test_blas_scipy.py b/tests/tensor/test_blas_scipy.py index 7cdfaadc34..716eab7bbe 100644 --- a/tests/tensor/test_blas_scipy.py +++ b/tests/tensor/test_blas_scipy.py @@ -1,7 +1,6 @@ import pickle import numpy as np -import pytest import pytensor from pytensor import tensor as pt @@ -12,7 +11,6 @@ from tests.unittest_tools import OptimizationTestMixin -@pytest.mark.skipif(not pytensor.tensor.blas_scipy.have_fblas, reason="fblas needed") class TestScipyGer(OptimizationTestMixin): def setup_method(self): self.mode = pytensor.compile.get_default_mode() diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 51b381861a..51862562ac 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -259,6 +259,58 @@ def test_blockwise_shape(): assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4) +def test_blockwise_infer_core_shape(): + class TestOpWithInferShape(Op): + def make_node(self, a, b): + assert a.type.ndim == 1 + assert b.type.ndim == 1 + c = tensor(shape=(None,)) + d = tensor(shape=(None,)) + return Apply(self, [a, b], [c, d]) + + def perform(self, node, inputs, outputs): + a, b = inputs + c, d = outputs + c[0] = np.arange(a.size + b.size) + d[0] = np.arange(a.sum() + b.sum()) + + def infer_shape(self, fgraph, node, input_shapes): + # First output shape depends only on input_shapes + # Second output shape depends on input values + x, y = node.inputs + [(x_shape,), (y_shape,)] = input_shapes + return (x_shape + y_shape,), (x.sum() + y.sum(),) + + blockwise_op = Blockwise( + core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)" + ) + + a = tensor("a", shape=(5, 3)) + b = tensor("b", shape=(1, 4)) + c, d = blockwise_op(a, b) + assert c.type.shape == (5, None) + assert d.type.shape == (5, None) + + c_shape_fn = pytensor.function([a, b], c.shape) + # c_shape can be computed from the input shapes alone + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape) + for n in c_shape_fn.maker.fgraph.apply_nodes + ) + + d_shape_fn = pytensor.function([a, b], d.shape) + # d_shape cannot be computed from the input shapes alone + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape) + for n in d_shape_fn.maker.fgraph.apply_nodes + ) + + a_test = np.zeros(a.type.shape, dtype=a.type.dtype) + b_test = np.zeros(b.type.shape, dtype=b.type.dtype) + assert tuple(c_shape_fn(a_test, b_test)) == (5, 7) + assert tuple(d_shape_fn(a_test, b_test)) == (5, 0) + + class BlockwiseOpTester: """Base class to test Blockwise works for specific Ops""" @@ -295,7 +347,7 @@ def create_batched_inputs(self, batch_idx: int | None = None): vec_inputs = [] vec_inputs_testvals = [] for idx, (batch_shape, param_sig) in enumerate( - zip(batch_shapes, self.params_sig) + zip(batch_shapes, self.params_sig, strict=True) ): if batch_idx is not None and idx != batch_idx: # Skip out combinations in which other inputs are batched @@ -538,7 +590,7 @@ def core_scipy_fn(A, b): A_val_copy, b_val_copy ) np.testing.assert_allclose( - out, expected_out, atol=1e-5 if config.floatX == "float32" else 0 + out, expected_out, atol=1e-4 if config.floatX == "float32" else 0 ) # Confirm input was destroyed diff --git a/tests/tensor/test_casting.py b/tests/tensor/test_casting.py index 6907988369..7194153a37 100644 --- a/tests/tensor/test_casting.py +++ b/tests/tensor/test_casting.py @@ -71,6 +71,7 @@ def test_illegal(self): _convert_to_float32, _convert_to_float64, ], + strict=True, ), ) def test_basic(self, type1, type2, converter): diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index 9131cda056..426ed13dcd 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -136,7 +136,7 @@ def test_einsum_signatures(static_shape_known, signature): operands = [ pt.tensor(name, shape=static_shape) - for name, static_shape in zip(ascii_lowercase, static_shapes) + for name, static_shape in zip(ascii_lowercase, static_shapes, strict=False) ] out = pt.einsum(signature, *operands) assert out.owner.op.optimized == static_shape_known or len(operands) <= 2 @@ -156,11 +156,11 @@ def test_einsum_signatures(static_shape_known, signature): def test_batch_dim(): - shapes = ( - (7, 3, 5), - (5, 2), - ) - x, y = (pt.tensor(name, shape=shape) for name, shape in zip("xy", shapes)) + shapes = { + "x": (7, 3, 5), + "y": (5, 2), + } + x, y = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) out = pt.einsum("mij,jk->mik", x, y) assert out.type.shape == (7, 3, 2) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 76906232af..df24609dff 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -19,7 +19,7 @@ from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.tensor import as_tensor_variable -from pytensor.tensor.basic import second +from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Any, Sum, exp from pytensor.tensor.math import all as pt_all @@ -121,7 +121,11 @@ def test_infer_shape(self): def test_too_big_rank(self): x = self.type(self.dtype, shape=())() - y = x.dimshuffle(("x",) * (np.MAXDIMS + 1)) + if np.__version__ >= "2.0": + # np.MAXDIMS removed, max number of dims increased to 64 from 32 + y = x.dimshuffle(("x",) * (64 + 1)) + else: + y = x.dimshuffle(("x",) * (32 + 1)) with pytest.raises(ValueError): y.eval({x: 0}) @@ -330,6 +334,7 @@ def test_fill(self): [self.op, self.cop], [self.type, self.ctype], [self.rand_val, self.rand_cval], + strict=True, ): x = t(pytensor.config.floatX, shape=(None, None))("x") y = t(pytensor.config.floatX, shape=(1, 1))("y") @@ -361,6 +366,7 @@ def test_weird_strides(self): [self.op, self.cop], [self.type, self.ctype], [self.rand_val, self.rand_cval], + strict=True, ): x = t(pytensor.config.floatX, shape=(None,) * 5)("x") y = t(pytensor.config.floatX, shape=(None,) * 5)("y") @@ -381,6 +387,7 @@ def test_same_inputs(self): [self.op, self.cop], [self.type, self.ctype], [self.rand_val, self.rand_cval], + strict=True, ): x = t(pytensor.config.floatX, shape=(None,) * 2)("x") e = op(ps.add)(x, x) @@ -669,7 +676,7 @@ def test_scalar_input(self): assert self.op(ps.add, axis=(-1,))(x).eval({x: 5}) == 5 with pytest.raises( - np.AxisError, + np.exceptions.AxisError, match=re.escape("axis (-2,) is out of bounds for array of dimension 0"), ): self.op(ps.add, axis=(-2,))(x) @@ -804,8 +811,8 @@ def test_partial_static_shape_info(self): assert len(res_shape) == 1 assert len(res_shape[0]) == 2 - assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1 - assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1 + assert get_scalar_constant_value(res_shape[0][0]) == 1 + assert get_scalar_constant_value(res_shape[0][1]) == 1 def test_infer_shape_multi_output(self): class CustomElemwise(Elemwise): @@ -980,27 +987,33 @@ def test_CAReduce(self): assert vect_node.inputs[0] is bool_tns -@pytest.mark.parametrize( - "axis", - (0, 1, 2, (0, 1), (0, 2), (1, 2), None), - ids=lambda x: f"axis={x}", -) -@pytest.mark.parametrize( - "c_contiguous", - (True, False), - ids=lambda x: f"c_contiguous={x}", -) -def test_careduce_benchmark(axis, c_contiguous, benchmark): +def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark): N = 256 x_test = np.random.uniform(size=(N, N, N)) transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1) x = pytensor.shared(x_test, name="x", shape=x_test.shape) out = x.transpose(transpose_axis).sum(axis=axis) - fn = pytensor.function([], out) + fn = pytensor.function([], out, mode=mode) np.testing.assert_allclose( fn(), x_test.transpose(transpose_axis).sum(axis=axis), ) benchmark(fn) + + +@pytest.mark.parametrize( + "axis", + (0, 1, 2, (0, 1), (0, 2), (1, 2), None), + ids=lambda x: f"axis={x}", +) +@pytest.mark.parametrize( + "c_contiguous", + (True, False), + ids=lambda x: f"c_contiguous={x}", +) +def test_c_careduce_benchmark(axis, c_contiguous, benchmark): + return careduce_benchmark_tester( + axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark + ) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 0da714c3bf..19d78904d8 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -366,6 +366,7 @@ def setup_method(self): [1, None, None], [1, None, 1, 1, None], ], + strict=True, ), ) def test_op(self, shape, var_shape): @@ -389,6 +390,7 @@ def test_op(self, shape, var_shape): [1, None, None], [1, None, 1, 1, None], ], + strict=True, ), ) def test_infer_shape(self, shape, var_shape): @@ -408,6 +410,7 @@ def test_infer_shape(self, shape, var_shape): [True, False, False], [True, False, True, True, False], ], + strict=True, ), ) def test_grad(self, shape, broadcast): @@ -423,6 +426,7 @@ def test_grad(self, shape, broadcast): [1, None, None], [1, None, 1, 1, None], ], + strict=True, ), ) def test_var_interface(self, shape, var_shape): @@ -465,7 +469,7 @@ def test_scalar_input(self): assert squeeze(x, axis=(0,)).eval({x: 5}) == 5 with pytest.raises( - np.AxisError, + np.exceptions.AxisError, match=re.escape("axis (1,) is out of bounds for array of dimension 0"), ): squeeze(x, axis=1) @@ -505,6 +509,7 @@ def setup_method(self): [1, 1, 0, 1, 0], ], [(2, 3), (4, 3), (4, 3), (4, 3), (4, 3), (3, 5)], + strict=True, ), ) def test_op(self, axis, cond, shape): @@ -689,7 +694,7 @@ def test_perform(self, shp): y = scalar() f = function([x, y], fill_diagonal(x, y)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out), val) @@ -701,7 +706,7 @@ def test_perform_3d(self): x = tensor3() y = scalar() f = function([x, y], fill_diagonal(x, y)) - val = np.cast[config.floatX](rng.random() + 10) + val = rng.random(dtype=config.floatX) + 10 out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert out[0, 0, 0] == val @@ -763,7 +768,7 @@ def test_perform(self, test_offset, shp): f = function([x, y, z], fill_diagonal_offset(x, y, z)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val, test_offset) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out, test_offset), val) @@ -889,11 +894,13 @@ def test_basic_vector(self, x, inp, axis): np.unique(inp, False, True, True, axis=axis), np.unique(inp, True, True, True, axis=axis), ] - for params, outs_expected in zip(self.op_params, list_outs_expected): + for params, outs_expected in zip( + self.op_params, list_outs_expected, strict=True + ): out = pt.unique(x, *params, axis=axis) f = pytensor.function(inputs=[x], outputs=out) outs = f(inp) - for out, out_exp in zip(outs, outs_expected): + for out, out_exp in zip(outs, outs_expected, strict=True): utt.assert_allclose(out, out_exp) @pytest.mark.parametrize( @@ -1062,7 +1069,7 @@ def shape_tuple(x, use_bcast=True): if use_bcast: return tuple( s if not bcast else 1 - for s, bcast in zip(tuple(x.shape), x.broadcastable) + for s, bcast in zip(tuple(x.shape), x.broadcastable, strict=True) ) else: return tuple(s for s in tuple(x.shape)) @@ -1202,12 +1209,12 @@ def test_broadcast_shape_constants(): def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res): s1s = pt.lscalars(len(s1_vals)) eval_point = {} - for s, s_val in zip(s1s, s1_vals): + for s, s_val in zip(s1s, s1_vals, strict=True): eval_point[s] = s_val s.tag.test_value = s_val s2s = pt.lscalars(len(s2_vals)) - for s, s_val in zip(s2s, s2_vals): + for s, s_val in zip(s2s, s2_vals, strict=True): eval_point[s] = s_val s.tag.test_value = s_val diff --git a/tests/tensor/test_interpolate.py b/tests/tensor/test_interpolate.py new file mode 100644 index 0000000000..95ebae10e2 --- /dev/null +++ b/tests/tensor/test_interpolate.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest +from numpy.testing import assert_allclose + +import pytensor +import pytensor.tensor as pt +from pytensor.tensor.interpolate import ( + InterpolationMethod, + interp, + interpolate1d, + valid_methods, +) + + +floatX = pytensor.config.floatX + + +def test_interp(): + xp = [1.0, 2.0, 3.0] + fp = [3.0, 2.0, 0.0] + + x = [0, 1, 1.5, 2.72, 3.14] + + out = interp(x, xp, fp).eval() + np_out = np.interp(x, xp, fp) + + assert_allclose(out, np_out) + + +def test_interp_padded(): + xp = [1.0, 2.0, 3.0] + fp = [3.0, 2.0, 0.0] + + assert interp(3.14, xp, fp, right=-99.0).eval() == -99.0 + assert_allclose( + interp([-1.0, -2.0, -3.0], xp, fp, left=1000.0).eval(), [1000.0, 1000.0, 1000.0] + ) + assert_allclose( + interp([-1.0, 10.0], xp, fp, left=-10, right=10).eval(), [-10, 10.0] + ) + + +@pytest.mark.parametrize("method", valid_methods, ids=str) +@pytest.mark.parametrize( + "left_pad, right_pad", [(None, None), (None, 100), (-100, None), (-100, 100)] +) +def test_interpolate_scalar_no_extrapolate( + method: InterpolationMethod, left_pad, right_pad +): + x = np.linspace(-2, 6, 10) + y = np.sin(x) + + f_op = interpolate1d( + x, y, method, extrapolate=False, left_pad=left_pad, right_pad=right_pad + ) + x_hat_pt = pt.dscalar("x_hat") + f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN") + + # Data points should be returned exactly, except when method == mean + if method not in ["mean", "first"]: + assert f(x[3]) == y[3] + elif method == "first": + assert f(x[3]) == y[2] + else: + # method == 'mean + assert f(x[3]) == (y[2] + y[3]) / 2 + + # When extrapolate=False, points beyond the data envelope should be constant + left_pad = y[0] if left_pad is None else left_pad + right_pad = y[-1] if right_pad is None else right_pad + + assert f(-10) == left_pad + assert f(100) == right_pad + + +@pytest.mark.parametrize("method", valid_methods, ids=str) +def test_interpolate_scalar_extrapolate(method: InterpolationMethod): + x = np.linspace(-2, 6, 10) + y = np.sin(x) + + f_op = interpolate1d(x, y, method) + x_hat_pt = pt.dscalar("x_hat") + f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN") + + left_test_point = -5 + right_test_point = 100 + if method == "linear": + # Linear will compute a slope from the endpoints and continue it + left_slope = (left_test_point - x[0]) / (x[1] - x[0]) + right_slope = (right_test_point - x[-2]) / (x[-1] - x[-2]) + assert f(left_test_point) == y[0] + left_slope * (y[1] - y[0]) + assert f(right_test_point) == y[-2] + right_slope * (y[-1] - y[-2]) + + elif method == "mean": + left_expected = (y[0] + y[1]) / 2 + right_expected = (y[-1] + y[-2]) / 2 + assert f(left_test_point) == left_expected + assert f(right_test_point) == right_expected + + else: + assert f(left_test_point) == y[0] + assert f(right_test_point) == y[-1] + + # For interior points, "first" and "last" should disagree. First should take the left side of the interval, + # and last should take the right. + interior_point = x[3] + 0.1 + assert f(interior_point) == (y[4] if method == "last" else y[3]) diff --git a/tests/tensor/test_io.py b/tests/tensor/test_io.py index cece2af277..4c5e5655fe 100644 --- a/tests/tensor/test_io.py +++ b/tests/tensor/test_io.py @@ -49,7 +49,7 @@ def test_memmap(self): path = Variable(Generic(), None) x = load(path, "int32", (None,), mmap_mode="c") fn = function([path], x) - assert isinstance(fn(self.filename), np.core.memmap) + assert isinstance(fn(self.filename), np.memmap) def teardown_method(self): (pytensor.config.compiledir / "_test.npy").unlink() diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 14bc2614e3..af7242a53f 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -40,7 +40,6 @@ Argmax, Dot, Max, - Mean, Prod, ProdWithoutZeros, Sum, @@ -392,11 +391,15 @@ def test_maximum_minimum_grad(): grad=_grad_broadcast_unary_normal, ) + +neg_good = _good_broadcast_unary_normal.copy() +neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")} TestNegBroadcast = makeBroadcastTester( op=neg, expected=lambda x: -x, - good=_good_broadcast_unary_normal, + good=neg_good, grad=_grad_broadcast_unary_normal, + bad_compile=neg_bad, ) TestSgnBroadcast = makeBroadcastTester( @@ -2458,11 +2461,22 @@ def pytensor_i_scalar(dtype): def numpy_i_scalar(dtype): return numpy_scalar(dtype) + pytensor_funcs = { + "scalar": pytensor_scalar, + "array": pytensor_array, + "i_scalar": pytensor_i_scalar, + } + numpy_funcs = { + "scalar": numpy_scalar, + "array": numpy_array, + "i_scalar": numpy_i_scalar, + } + with config.change_flags(cast_policy="numpy+floatX"): # We will test all meaningful combinations of # scalar and array operations. - pytensor_args = [eval(f"pytensor_{c}") for c in combo] - numpy_args = [eval(f"numpy_{c}") for c in combo] + pytensor_args = [pytensor_funcs[c] for c in combo] + numpy_args = [numpy_funcs[c] for c in combo] pytensor_arg_1 = pytensor_args[0](a_type) pytensor_arg_2 = pytensor_args[1](b_type) pytensor_dtype = op( @@ -2587,15 +2601,6 @@ def test_mod_compile(): class TestInferShape(utt.InferShapeTester): - def test_Mean(self): - adtens3 = dtensor3() - adtens3_val = random(3, 4, 5) - aiscal_val = 2 - self._compile_and_check([adtens3], [Mean(None)(adtens3)], [adtens3_val], Mean) - self._compile_and_check( - [adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean - ) - def test_Max(self): adtens3 = dtensor3() adtens3_val = random(4, 5, 3) @@ -3420,7 +3425,11 @@ def test_var_axes(self): def reduce_bitwise_and(x, axis=-1, dtype="int8"): - identity = np.array((-1,), dtype=dtype)[0] + if dtype == "uint8": + # in numpy version >= 2.0, out of bounds uint8 values are not converted + identity = np.array((255,), dtype=dtype)[0] + else: + identity = np.array((-1,), dtype=dtype)[0] shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis) if 0 in shape_without_axis: diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 6ca9279bca..921aae826b 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from pytensor.gradient import verify_grad +from pytensor.gradient import NullTypeGradError, verify_grad from pytensor.scalar import ScalarLoop from pytensor.tensor.elemwise import Elemwise @@ -18,7 +18,7 @@ from pytensor import tensor as pt from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config -from pytensor.tensor import gammaincc, inplace, vector +from pytensor.tensor import gammaincc, inplace, kv, kve, vector from tests import unittest_tools as utt from tests.tensor.utils import ( _good_broadcast_unary_chi2sf, @@ -1196,3 +1196,37 @@ def test_unused_grad_loop_opt(self, wrt): [dd for i, dd in enumerate(expected_dds) if i in wrt], rtol=rtol, ) + + +def test_kve(): + rng = np.random.default_rng(3772) + v = vector("v") + x = vector("x") + + out = kve(v[:, None], x[None, :]) + test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype) + test_x = np.linspace(0, 1005, 10, dtype=x.type.dtype) + + np.testing.assert_allclose( + out.eval({v: test_v, x: test_x}), + scipy.special.kve(test_v[:, None], test_x[None, :]), + ) + + with pytest.raises(NullTypeGradError): + grad(out.sum(), v) + + verify_grad(lambda x: kv(4.5, x), [test_x + 0.5], rng=rng) + + +def test_kv(): + v = vector("v") + x = vector("x") + + out = kv(v[:, None], x[None, :]) + test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype) + test_x = np.linspace(0, 512, 10, dtype=x.type.dtype) + + np.testing.assert_allclose( + out.eval({v: test_v, x: test_x}), + scipy.special.kv(test_v[:, None], test_x[None, :]), + ) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 1a13992011..4b83446c5f 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -198,7 +198,7 @@ def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag): np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs] - for np_val, pt_val in zip(np_outputs, pt_outputs): + for np_val, pt_val in zip(np_outputs, pt_outputs, strict=True): assert _allclose(np_val, pt_val) def test_svd_infer_shape(self): diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 3d4b6697b8..f46d771938 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -169,7 +169,12 @@ def test_eigvalsh_grad(): ) -class TestSolveBase(utt.InferShapeTester): +class TestSolveBase: + class SolveTest(SolveBase): + def perform(self, node, inputs, outputs): + A, b = inputs + outputs[0][0] = scipy.linalg.solve(A, b) + @pytest.mark.parametrize( "A_func, b_func, error_message", [ @@ -191,16 +196,16 @@ def test_make_node(self, A_func, b_func, error_message): with pytest.raises(ValueError, match=error_message): A = A_func() b = b_func() - SolveBase(b_ndim=2)(A, b) + self.SolveTest(b_ndim=2)(A, b) def test__repr__(self): np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - y = SolveBase(b_ndim=2)(A, b) + y = self.SolveTest(b_ndim=2)(A, b) assert ( y.__repr__() - == "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" + == "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" ) @@ -239,8 +244,9 @@ def test_correctness(self): A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) A_val = np.dot(A_val.transpose(), A_val) - assert np.allclose( - scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val) + np.testing.assert_allclose( + scipy.linalg.solve(A_val, b_val, assume_a="gen"), + gen_solve_func(A_val, b_val), ) A_undef = np.array( @@ -253,7 +259,7 @@ def test_correctness(self): ], dtype=config.floatX, ) - assert np.allclose( + np.testing.assert_allclose( scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) ) @@ -450,7 +456,7 @@ def test_solve_dtype(self): fn = function([A, b], x) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) - assert x.dtype == x_result.dtype + assert x.dtype == x_result.dtype, (A_dtype, b_dtype) def test_cho_solve(): diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 7b3f9af617..3886a08f48 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -1056,7 +1056,7 @@ def test_shape_i_const(self): shapes += [data.get_value(borrow=True)[start:stop:step].shape] f = self.function([], outs, mode=mode_opt, op=subtensor_ops, N=0) t_shapes = f() - for t_shape, shape in zip(t_shapes, shapes): + for t_shape, shape in zip(t_shapes, shapes, strict=True): assert np.all(t_shape == shape) assert Subtensor not in [x.op for x in f.maker.fgraph.toposort()] @@ -1320,7 +1320,9 @@ def test_advanced1_inc_and_set(self): f_outs = f(*all_inputs_num) assert len(f_outs) == len(all_outputs_num) - for params, f_out, output_num in zip(all_params, f_outs, all_outputs_num): + for params, f_out, output_num in zip( + all_params, f_outs, all_outputs_num, strict=True + ): # NB: if this assert fails, it will probably be easier to debug if # you enable the debug code above. assert np.allclose(f_out, output_num), (params, f_out, output_num) @@ -1397,7 +1399,7 @@ def test_adv1_inc_sub_notlastdim_1_2dval_broadcast(self): shape_i = ((4,), (4, 2)) shape_val = ((3, 1), (3, 1, 1)) - for i, shp_i, shp_v in zip(sym_i, shape_i, shape_val): + for i, shp_i, shp_v in zip(sym_i, shape_i, shape_val, strict=True): sub_m = m[:, i] m1 = set_subtensor(sub_m, np.zeros(shp_v)) m2 = inc_subtensor(sub_m, np.ones(shp_v)) @@ -1427,7 +1429,7 @@ def test_adv1_inc_sub_notlastdim_1_2dval_no_broadcast(self): shape_i = ((4,), (4, 2)) shape_val = ((3, 4), (3, 4, 2)) - for i, shp_i, shp_v in zip(sym_i, shape_i, shape_val): + for i, shp_i, shp_v in zip(sym_i, shape_i, shape_val, strict=True): sub_m = m[:, i] m1 = set_subtensor(sub_m, np.zeros(shp_v)) m2 = inc_subtensor(sub_m, np.ones(shp_v)) @@ -1863,7 +1865,7 @@ def test_index_into_vec_w_matrix(self): assert a.type.ndim == self.ix2.type.ndim assert all( s1 == s2 - for s1, s2 in zip(a.type.shape, self.ix2.type.shape) + for s1, s2 in zip(a.type.shape, self.ix2.type.shape, strict=True) if s1 == 1 or s2 == 1 ) @@ -2628,7 +2630,9 @@ def idx_as_tensor(x): def bcast_shape_tuple(x): if not hasattr(x, "shape"): return x - return tuple(s if ss != 1 else 1 for s, ss in zip(tuple(x.shape), x.type.shape)) + return tuple( + s if ss != 1 else 1 for s, ss in zip(tuple(x.shape), x.type.shape, strict=True) + ) test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True])) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 1ed3b55a89..1a8b2455ec 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -152,7 +152,7 @@ def upcast_float16_ufunc(fn): """ def ret(*args, **kwargs): - out_dtype = np.find_common_type([a.dtype for a in args], [np.float16]) + out_dtype = np.result_type(np.float16, *args) if out_dtype == "float16": # Force everything to float32 sig = "f" * fn.nin + "->" + "f" * fn.nout @@ -339,6 +339,7 @@ def makeTester( good=None, bad_build=None, bad_runtime=None, + bad_compile=None, grad=None, mode=None, grad_rtol=None, @@ -373,6 +374,7 @@ def makeTester( _test_memmap = test_memmap _check_name = check_name _grad_eps = grad_eps + _bad_compile = bad_compile or {} class Checker: op = staticmethod(_op) @@ -382,6 +384,7 @@ class Checker: good = _good bad_build = _bad_build bad_runtime = _bad_runtime + bad_compile = _bad_compile grad = _grad mode = _mode skip = skip_ @@ -508,7 +511,7 @@ def test_good(self): expecteds = (expecteds,) for i, (variable, expected, out_symbol) in enumerate( - zip(variables, expecteds, node.outputs) + zip(variables, expecteds, node.outputs, strict=True) ): condition = ( variable.dtype != out_symbol.type.dtype @@ -539,6 +542,24 @@ def test_bad_build(self): # instantiated on the following bad inputs: %s" # % (self.op, testname, node, inputs)) + @config.change_flags(compute_test_value="off") + @pytest.mark.skipif(skip, reason="Skipped") + def test_bad_compile(self): + for testname, inputs in self.bad_compile.items(): + inputrs = [shared(input) for input in inputs] + try: + node = safe_make_node(self.op, *inputrs) + except Exception as exc: + err_msg = ( + f"Test {self.op}::{testname}: Error occurred while trying" + f" to make a node with inputs {inputs}" + ) + exc.args += (err_msg,) + raise + + with pytest.raises(Exception): + inplace_func([], node.outputs, mode=mode, name="test_bad_runtime") + @config.change_flags(compute_test_value="off") @pytest.mark.skipif(skip, reason="Skipped") def test_bad_runtime(self): diff --git a/tests/test_gradient.py b/tests/test_gradient.py index c45d07662d..3f23e56c4f 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -68,6 +68,7 @@ def grad_sources_inputs(sources, inputs): wrt=inputs, consider_constant=inputs, ), + strict=True, ) ) @@ -480,12 +481,12 @@ def make_grad_func(X): int_type = imatrix().dtype float_type = "float64" - X = np.cast[int_type](rng.standard_normal((m, d)) * 127.0) - W = np.cast[W.dtype](rng.standard_normal((d, n))) - b = np.cast[b.dtype](rng.standard_normal(n)) + X = (rng.standard_normal((m, d)) * 127.0).astype(int_type) + W = rng.standard_normal((d, n), dtype=W.dtype) + b = rng.standard_normal(n, dtype=b.dtype) int_result = int_func(X, W, b) - float_result = float_func(np.cast[float_type](X), W, b) + float_result = float_func(np.asarray(X, dtype=float_type), W, b) assert np.allclose(int_result, float_result), (int_result, float_result) @@ -507,7 +508,7 @@ def test_grad_disconnected(self): # the output f = pytensor.function([x], g) rng = np.random.default_rng([2012, 9, 5]) - x = np.cast[x.dtype](rng.standard_normal(3)) + x = rng.standard_normal(3, dtype=x.dtype) g = f(x) assert np.allclose(g, np.ones(x.shape, dtype=x.dtype)) @@ -629,7 +630,10 @@ def test_known_grads(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(10), rng.integers(10), rng.standard_normal()] - values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] + values = [ + np.asarray(value, dtype=ipt.dtype) + for ipt, value in zip(inputs, values, strict=True) + ] true_grads = grad(cost, inputs, disconnected_inputs="ignore") true_grads = pytensor.function(inputs, true_grads) @@ -637,14 +641,14 @@ def test_known_grads(): for layer in layers: first = grad(cost, layer, disconnected_inputs="ignore") - known = dict(zip(layer, first)) + known = dict(zip(layer, first, strict=True)) full = grad( cost=None, known_grads=known, wrt=inputs, disconnected_inputs="ignore" ) full = pytensor.function(inputs, full) full = full(*values) assert len(true_grads) == len(full) - for a, b, var in zip(true_grads, full, inputs): + for a, b, var in zip(true_grads, full, inputs, strict=True): assert np.allclose(a, b) @@ -676,7 +680,7 @@ def test_known_grads_integers(): f = pytensor.function([g_expected], g_grad) x = -3 - gv = np.cast[config.floatX](0.6) + gv = np.asarray(0.6, dtype=config.floatX) g_actual = f(gv) @@ -742,7 +746,10 @@ def test_subgraph_grad(): inputs = [t, x] rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(2), rng.standard_normal(3)] - values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] + values = [ + np.asarray(value, dtype=ipt.dtype) + for ipt, value in zip(inputs, values, strict=True) + ] wrt = [w2, w1] cost = cost2 + cost1 @@ -755,13 +762,13 @@ def test_subgraph_grad(): param_grad, next_grad = subgraph_grad( wrt=params[i], end=grad_ends[i], start=next_grad, cost=costs[i] ) - next_grad = dict(zip(grad_ends[i], next_grad)) + next_grad = dict(zip(grad_ends[i], next_grad, strict=True)) param_grads.extend(param_grad) pgrads = pytensor.function(inputs, param_grads) pgrads = pgrads(*values) - for true_grad, pgrad in zip(true_grads, pgrads): + for true_grad, pgrad in zip(true_grads, pgrads, strict=True): assert np.sum(np.abs(true_grad - pgrad)) < 0.00001 @@ -1026,21 +1033,21 @@ def test_jacobian_scalar(): # test when the jacobian is called with a tensor as wrt Jx = jacobian(y, x) f = pytensor.function([x], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a tuple as wrt Jx = jacobian(y, (x,)) assert isinstance(Jx, tuple) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list as wrt Jx = jacobian(y, [x]) assert isinstance(Jx, list) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list of two elements @@ -1048,8 +1055,8 @@ def test_jacobian_scalar(): y = x * z Jx = jacobian(y, [x, z]) f = pytensor.function([x, z], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) - vz = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) vJx = f(vx, vz) assert np.allclose(vJx[0], vz) diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index d506d96df6..5ca7de6e63 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -234,14 +234,14 @@ def test_multiple_out_grad(self): np.asarray(rng.uniform(size=(l,)), pytensor.config.floatX) for l in lens ] outs_1 = f(1, *values) - assert all(x.shape[0] == y for x, y in zip(outs_1, lens)) + assert all(x.shape[0] == y for x, y in zip(outs_1, lens, strict=True)) assert np.all(outs_1[0] == 1.0) assert np.all(outs_1[1] == 1.0) assert np.all(outs_1[2] == 0.0) assert np.all(outs_1[3] == 0.0) outs_0 = f(0, *values) - assert all(x.shape[0] == y for x, y in zip(outs_1, lens)) + assert all(x.shape[0] == y for x, y in zip(outs_1, lens, strict=True)) assert np.all(outs_0[0] == 0.0) assert np.all(outs_0[1] == 0.0) assert np.all(outs_0[2] == 1.0) diff --git a/tests/test_printing.py b/tests/test_printing.py index d5b0707442..be5dbbc5a1 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -17,13 +17,13 @@ PatternPrinter, PPrinter, Print, + _try_pydot_import, char_from_number, debugprint, default_printer, get_node_by_id, min_informative_str, pp, - pydot_imported, pydotprint, ) from pytensor.tensor import as_tensor_variable @@ -31,6 +31,13 @@ from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable +try: + _try_pydot_import() + pydot_imported = True +except Exception: + pydot_imported = False + + @pytest.mark.parametrize( "number,s", [ @@ -385,7 +392,7 @@ def test_debugprint_inner_graph(): โ””โ”€ *1- [id F] """ - for exp_line, res_line in zip(exp_res.split("\n"), lines): + for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): assert exp_line.strip() == res_line.strip() # Test nested inner-graph `Op`s @@ -413,7 +420,7 @@ def test_debugprint_inner_graph(): โ””โ”€ *1- [id E] """ - for exp_line, res_line in zip(exp_res.split("\n"), lines): + for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): assert exp_line.strip() == res_line.strip() diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 4b309c2324..19598bfb21 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -577,17 +577,17 @@ def test_correct_answer(self): x = tensor3() y = tensor3() - A = np.cast[pytensor.config.floatX](np.random.random((5, 3))) - B = np.cast[pytensor.config.floatX](np.random.random((7, 2))) - X = np.cast[pytensor.config.floatX](np.random.random((5, 6, 1))) - Y = np.cast[pytensor.config.floatX](np.random.random((1, 9, 3))) + A = np.random.random((5, 3)).astype(pytensor.config.floatX) + B = np.random.random((7, 2)).astype(pytensor.config.floatX) + X = np.random.random((5, 6, 1)).astype(pytensor.config.floatX) + Y = np.random.random((1, 9, 3)).astype(pytensor.config.floatX) make_list((3.0, 4.0)) c = make_list((a, b)) z = make_list((x, y)) fc = pytensor.function([a, b], c) fz = pytensor.function([x, y], z) - for m, n in zip(fc(A, B), [A, B]): + for m, n in zip(fc(A, B), [A, B], strict=True): assert (m == n).all() - for m, n in zip(fz(X, Y), [X, Y]): + for m, n in zip(fz(X, Y), [X, Y], strict=True): assert (m == n).all() diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index 9134b29b65..a5b0a21a49 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -216,7 +216,7 @@ def _compile_and_check( if excluding: mode = mode.excluding(*excluding) if warn: - for var, inp in zip(inputs, numeric_inputs): + for var, inp in zip(inputs, numeric_inputs, strict=True): if isinstance(inp, int | float | list | tuple): inp = var.type.filter(inp) if not hasattr(inp, "shape"): @@ -261,7 +261,7 @@ def _compile_and_check( # Check that the shape produced agrees with the actual shape. numeric_outputs = outputs_function(*numeric_inputs) numeric_shapes = shapes_function(*numeric_inputs) - for out, shape in zip(numeric_outputs, numeric_shapes): + for out, shape in zip(numeric_outputs, numeric_shapes, strict=True): assert np.all(out.shape == shape), (out.shape, shape)