diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 3462dd00ff..e588a5eaeb 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -189,5 +189,5 @@ jobs: name: universal_wheel path: dist - - uses: pypa/gh-action-pypi-publish@v1.12.2 + - uses: pypa/gh-action-pypi-publish@v1.12.4 # Implicitly attests that the packages were uploaded in the context of this workflow. diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 53f1e16606..5bb416f893 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -65,7 +65,7 @@ jobs: - uses: pre-commit/action@v3.0.1 test: - name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}" + name: "${{ matrix.os }} test py${{ matrix.python-version }} numpy${{ matrix.numpy-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}" needs: - changes - style @@ -76,6 +76,7 @@ jobs: matrix: os: ["ubuntu-latest"] python-version: ["3.10", "3.12"] + numpy-version: ["~=1.26.0", ">=2.0"] fast-compile: [0, 1] float32: [0, 1] install-numba: [0] @@ -105,45 +106,68 @@ jobs: float32: 1 - part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" fast-compile: 1 + - numpy-version: "~=1.26.0" + fast-compile: 1 + - numpy-version: "~=1.26.0" + float32: 1 + - numpy-version: "~=1.26.0" + python-version: "3.12" + - numpy-version: "~=1.26.0" + part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" include: - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: "~=2.1.0" fast-compile: 0 float32: 0 part: "tests/link/numba" - install-numba: 1 os: "ubuntu-latest" python-version: "3.12" + numpy-version: "~=2.1.0" fast-compile: 0 float32: 0 part: "tests/link/numba" - install-jax: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/jax" - install-jax: 1 os: "ubuntu-latest" python-version: "3.12" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/jax" - install-torch: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/pytorch" - os: macos-15 python-version: "3.12" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 install-numba: 0 install-jax: 0 install-torch: 0 part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" + - os: "ubuntu-latest" + python-version: "3.10" + numpy-version: "~=1.26.0" + fast-compile: 0 + float32: 0 + install-numba: 0 + install-jax: 0 + install-torch: 0 + part: "tests/tensor/test_math.py" steps: - uses: actions/checkout@v4 @@ -174,9 +198,9 @@ jobs: run: | if [[ $OS == "macos-15" ]]; then - micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" numpy scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate; + micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate; else - micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; + micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi @@ -193,6 +217,7 @@ jobs: fi env: PYTHON_VERSION: ${{ matrix.python-version }} + NUMPY_VERSION: ${{ matrix.numpy-version }} INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}} diff --git a/doc/extending/op.rst b/doc/extending/op.rst index ddd397dee9..b1585c4ecd 100644 --- a/doc/extending/op.rst +++ b/doc/extending/op.rst @@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`. the outputs) back to their corresponding shapes and return them as the output of the :meth:`Op.R_op` method. - :ref:`List of op with r op support `. diff --git a/doc/library/gradient.rst b/doc/library/gradient.rst deleted file mode 100644 index f823a1c381..0000000000 --- a/doc/library/gradient.rst +++ /dev/null @@ -1,76 +0,0 @@ -.. _libdoc_gradient: - -=========================================== -:mod:`gradient` -- Symbolic Differentiation -=========================================== - -.. module:: gradient - :platform: Unix, Windows - :synopsis: low-level automatic differentiation -.. moduleauthor:: LISA - -.. testsetup:: * - - from pytensor.gradient import * - -Symbolic gradient is usually computed from :func:`gradient.grad`, which offers a -more convenient syntax for the common case of wanting the gradient of some -scalar cost with respect to some input expressions. The :func:`grad_sources_inputs` -function does the underlying work, and is more flexible, but is also more -awkward to use when :func:`gradient.grad` can do the job. - - -Gradient related functions -========================== - -.. automodule:: pytensor.gradient - :members: - -.. _R_op_list: - - -List of Implemented R op -======================== - - -See the :ref:`gradient tutorial ` for the R op documentation. - -list of ops that support R-op: - * with test - * SpecifyShape - * MaxAndArgmax - * Subtensor - * IncSubtensor set_subtensor too - * Alloc - * Dot - * Elemwise - * Sum - * Softmax - * Shape - * Join - * Rebroadcast - * Reshape - * DimShuffle - * Scan [In tests/scan/test_basic.test_rop] - - * without test - * Split - * ARange - * ScalarFromTensor - * AdvancedSubtensor1 - * AdvancedIncSubtensor1 - * AdvancedIncSubtensor - -Partial list of ops without support for R-op: - - * All sparse ops - * All linear algebra ops. - * PermuteRowElements - * AdvancedSubtensor - * TensorDot - * Outer - * Prod - * MulwithoutZeros - * ProdWithoutZeros - * CAReduce(for max,... done for MaxAndArgmax op) - * MaxAndArgmax(only for matrix on axis 0 or 1) diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst index 8d22c1e577..4f087b6788 100644 --- a/doc/library/tensor/basic.rst +++ b/doc/library/tensor/basic.rst @@ -1791,5 +1791,3 @@ Gradient / Differentiation :members: grad :noindex: -See the :ref:`gradient ` page for complete documentation -of the gradient module. diff --git a/doc/tutorial/gradients.rst b/doc/tutorial/gradients.rst index edb38bb018..f8b7f7ff98 100644 --- a/doc/tutorial/gradients.rst +++ b/doc/tutorial/gradients.rst @@ -86,9 +86,7 @@ of symbolic differentiation). ``i`` of the output list is the gradient of the first argument of `pt.grad` with respect to the ``i``-th element of the list given as second argument. The first argument of `pt.grad` has to be a scalar (a tensor - of size 1). For more information on the semantics of the arguments of - `pt.grad` and details about the implementation, see - :ref:`this` section of the library. + of size 1). Additional information on the inner workings of differentiation may also be found in the more advanced tutorial :ref:`Extending PyTensor`. @@ -204,7 +202,21 @@ you need to do something similar to this: >>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1]) array([ 2., 2.]) -:ref:`List ` of Op that implement Rop. +By default, the R-operator is implemented as a double application of the L_operator +(see `reference `_). +In most cases this should be as performant as a specialized implementation of the R-operator. +However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators, +such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator. + +When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing +`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method. + + +>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True) +>>> f = pytensor.function([W, V, x], JV) +>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1]) +array([ 2., 2.]) + L-operator ---------- @@ -234,7 +246,6 @@ array([[ 0., 0.], as the input parameter, while the result of the R-operator has a shape similar to that of the output. - :ref:`List of op with r op support `. Hessian times a Vector ====================== diff --git a/environment-osx-arm64.yml b/environment-osx-arm64.yml index 13a68faaaa..c9dc703dcc 100644 --- a/environment-osx-arm64.yml +++ b/environment-osx-arm64.yml @@ -9,7 +9,7 @@ channels: dependencies: - python=>3.10 - compilers - - numpy>=1.17.0,<2 + - numpy>=1.17.0 - scipy>=1,<2 - filelock>=3.15 - etuples diff --git a/environment.yml b/environment.yml index 1571ae0d11..9bdddfb6f6 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ channels: dependencies: - python>=3.10 - compilers - - numpy>=1.17.0,<2 + - numpy>=1.17.0 - scipy>=1,<2 - filelock>=3.15 - etuples diff --git a/pyproject.toml b/pyproject.toml index 4e2a1fdb05..e796e35a10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ keywords = [ dependencies = [ "setuptools>=59.0.0", "scipy>=1,<2", - "numpy>=1.17.0,<2", + "numpy>=1.17.0", "filelock>=3.15", "etuples", "logical-unification", @@ -129,7 +129,7 @@ exclude = ["doc/", "pytensor/_version.py"] docstring-code-format = true [tool.ruff.lint] -select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"] +select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] unfixable = [ # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead @@ -144,7 +144,12 @@ lines-after-imports = 2 # TODO: Get rid of these: "**/__init__.py" = ["F401", "E402", "F403"] "pytensor/tensor/linalg.py" = ["F403"] -"pytensor/link/c/cmodule.py" = ["PTH"] +"pytensor/link/c/cmodule.py" = ["PTH", "T201"] +"pytensor/misc/elemwise_time_test.py" = ["T201"] +"pytensor/misc/elemwise_openmp_speedup.py" = ["T201"] +"pytensor/misc/check_duplicate_key.py" = ["T201"] +"pytensor/misc/check_blas.py" = ["T201"] +"pytensor/bin/pytensor_cache.py" = ["T201"] # For the tests we skip because `pytest.importorskip` is used: "tests/link/jax/test_scalar.py" = ["E402"] "tests/link/jax/test_tensor_basic.py" = ["E402"] @@ -158,6 +163,8 @@ lines-after-imports = 2 "tests/sparse/test_sp2.py" = ["E402"] "tests/sparse/test_utils.py" = ["E402"] "tests/sparse/sandbox/test_sp.py" = ["E402", "F401"] +"tests/compile/test_monitormode.py" = ["T201"] +"scripts/run_mypy.py" = ["T201"] [tool.mypy] diff --git a/pytensor/breakpoint.py b/pytensor/breakpoint.py index 314f2a7325..3d59b5c24c 100644 --- a/pytensor/breakpoint.py +++ b/pytensor/breakpoint.py @@ -108,14 +108,14 @@ def perform(self, node, inputs, output_storage): f"'{self.name}' could not be casted to NumPy arrays" ) - print("\n") - print("-------------------------------------------------") - print(f"Conditional breakpoint '{self.name}' activated\n") - print("The monitored variables are stored, in order,") - print("in the list variable 'monitored' as NumPy arrays.\n") - print("Their contents can be altered and, when execution") - print("resumes, the updated values will be used.") - print("-------------------------------------------------") + print("\n") # noqa: T201 + print("-------------------------------------------------") # noqa: T201 + print(f"Conditional breakpoint '{self.name}' activated\n") # noqa: T201 + print("The monitored variables are stored, in order,") # noqa: T201 + print("in the list variable 'monitored' as NumPy arrays.\n") # noqa: T201 + print("Their contents can be altered and, when execution") # noqa: T201 + print("resumes, the updated values will be used.") # noqa: T201 + print("-------------------------------------------------") # noqa: T201 try: import pudb diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 49baa3bb26..a4a3d1840a 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -340,6 +340,12 @@ def __init__( ``None``, this will be used as the connection_pattern for this :class:`Op`. + .. warning:: + + rop overrides is ignored when `pytensor.gradient.Rop` is called with + `use_op_rop_implementation=False` (default). In this case the Lop + is used twice to obtain a mathematically equivalent Rop. + strict: bool, default False If true, it raises when any variables needed to compute the inner graph are not provided as explici inputs. This can only happen for graphs with @@ -641,7 +647,12 @@ def _build_and_cache_rop_op(self): return rop_overrides eval_points = [inp_t() for inp_t in self.input_types] - fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points) + fn_rop = partial( + Rop, + wrt=inner_inputs, + eval_points=eval_points, + use_op_rop_implementation=True, + ) callable_args = (inner_inputs, eval_points) if rop_overrides is None: diff --git a/pytensor/compile/compiledir.py b/pytensor/compile/compiledir.py index 0482ed6cd8..127b971b2e 100644 --- a/pytensor/compile/compiledir.py +++ b/pytensor/compile/compiledir.py @@ -95,10 +95,10 @@ def cleanup(): def print_title(title, overline="", underline=""): len_title = len(title) if overline: - print(str(overline) * len_title) - print(title) + print(str(overline) * len_title) # noqa: T201 + print(title) # noqa: T201 if underline: - print(str(underline) * len_title) + print(str(underline) * len_title) # noqa: T201 def print_compiledir_content(): @@ -159,7 +159,7 @@ def print_compiledir_content(): _logger.error(f"Could not read key file '{filename}'.") print_title(f"PyTensor cache: {compiledir}", overline="=", underline="=") - print() + print() # noqa: T201 print_title(f"List of {len(table)} compiled individual ops", underline="+") print_title( @@ -168,9 +168,9 @@ def print_compiledir_content(): ) table = sorted(table, key=lambda t: str(t[1])) for dir, op, types, compile_time in table: - print(dir, f"{compile_time:.3f}s", op, types) + print(dir, f"{compile_time:.3f}s", op, types) # noqa: T201 - print() + print() # noqa: T201 print_title( f"List of {len(table_multiple_ops)} compiled sets of ops", underline="+" ) @@ -180,9 +180,9 @@ def print_compiledir_content(): ) table_multiple_ops = sorted(table_multiple_ops, key=lambda t: (t[1], t[2])) for dir, ops_to_str, types_to_str, compile_time in table_multiple_ops: - print(dir, f"{compile_time:.3f}s", ops_to_str, types_to_str) + print(dir, f"{compile_time:.3f}s", ops_to_str, types_to_str) # noqa: T201 - print() + print() # noqa: T201 print_title( ( f"List of {len(table_op_class)} compiled Op classes and " @@ -191,33 +191,33 @@ def print_compiledir_content(): underline="+", ) for op_class, nb in reversed(table_op_class.most_common()): - print(op_class, nb) + print(op_class, nb) # noqa: T201 if big_key_files: big_key_files = sorted(big_key_files, key=lambda t: str(t[1])) big_total_size = sum(sz for _, sz, _ in big_key_files) - print( + print( # noqa: T201 f"There are directories with key files bigger than {int(max_key_file_size)} bytes " "(they probably contain big tensor constants)" ) - print( + print( # noqa: T201 f"They use {int(big_total_size)} bytes out of {int(total_key_sizes)} (total size " "used by all key files)" ) for dir, size, ops in big_key_files: - print(dir, size, ops) + print(dir, size, ops) # noqa: T201 nb_keys = sorted(nb_keys.items()) - print() + print() # noqa: T201 print_title("Number of keys for a compiled module", underline="+") print_title( "number of keys/number of modules with that number of keys", underline="-" ) for n_k, n_m in nb_keys: - print(n_k, n_m) - print() - print( + print(n_k, n_m) # noqa: T201 + print() # noqa: T201 + print( # noqa: T201 f"Skipped {int(zeros_op)} files that contained 0 op " "(are they always pytensor.scalar ops?)" ) @@ -242,18 +242,18 @@ def basecompiledir_ls(): subdirs = sorted(subdirs) others = sorted(others) - print(f"Base compile dir is {config.base_compiledir}") - print("Sub-directories (possible compile caches):") + print(f"Base compile dir is {config.base_compiledir}") # noqa: T201 + print("Sub-directories (possible compile caches):") # noqa: T201 for d in subdirs: - print(f" {d}") + print(f" {d}") # noqa: T201 if not subdirs: - print(" (None)") + print(" (None)") # noqa: T201 if others: - print() - print("Other files in base_compiledir:") + print() # noqa: T201 + print("Other files in base_compiledir:") # noqa: T201 for f in others: - print(f" {f}") + print(f" {f}") # noqa: T201 def basecompiledir_purge(): diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index cc1a5b225a..5c51222a1b 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -1315,9 +1315,9 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def printstuff(self): for key in self.equiv: - print(key) + print(key) # noqa: T201 for e in self.equiv[key]: - print(" ", e) + print(" ", e) # noqa: T201 # List of default version of make thunk. @@ -1569,7 +1569,7 @@ def f(): ##### for r, s in storage_map.items(): if s[0] is not None: - print(r, s) + print(r, s) # noqa: T201 assert s[0] is None # try: @@ -2079,7 +2079,7 @@ def __init__( raise StochasticOrder(infolog.getvalue()) else: if self.verbose: - print( + print( # noqa: T201 "OPTCHECK: optimization", i, "of", diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ae905089b5..43a5e131cb 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -178,7 +178,7 @@ def __init__(self, header): def apply(self, fgraph): import pytensor.printing - print("PrintCurrentFunctionGraph:", self.header) + print("PrintCurrentFunctionGraph:", self.header) # noqa: T201 pytensor.printing.debugprint(fgraph.outputs) diff --git a/pytensor/compile/monitormode.py b/pytensor/compile/monitormode.py index 770d4e2f7e..8663bc8832 100644 --- a/pytensor/compile/monitormode.py +++ b/pytensor/compile/monitormode.py @@ -108,8 +108,8 @@ def detect_nan(fgraph, i, node, fn): not isinstance(output[0], np.random.RandomState | np.random.Generator) and np.isnan(output[0]).any() ): - print("*** NaN detected ***") + print("*** NaN detected ***") # noqa: T201 debugprint(node) - print(f"Inputs : {[input[0] for input in fn.inputs]}") - print(f"Outputs: {[output[0] for output in fn.outputs]}") + print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201 + print(f"Outputs: {[output[0] for output in fn.outputs]}") # noqa: T201 break diff --git a/pytensor/compile/nanguardmode.py b/pytensor/compile/nanguardmode.py index 7f90825953..e2fd44cda3 100644 --- a/pytensor/compile/nanguardmode.py +++ b/pytensor/compile/nanguardmode.py @@ -236,7 +236,7 @@ def do_check_on(value, nd, var=None): if config.NanGuardMode__action == "raise": raise AssertionError(msg) elif config.NanGuardMode__action == "pdb": - print(msg) + print(msg) # noqa: T201 import pdb pdb.set_trace() diff --git a/pytensor/compile/profiling.py b/pytensor/compile/profiling.py index 3dfe5283bb..a68365527f 100644 --- a/pytensor/compile/profiling.py +++ b/pytensor/compile/profiling.py @@ -82,7 +82,7 @@ def _atexit_print_fn(): to_sum.append(ps) else: # TODO print the name if there is one! - print("Skipping empty Profile") + print("Skipping empty Profile") # noqa: T201 if len(to_sum) > 1: # Make a global profile cum = copy.copy(to_sum[0]) @@ -125,7 +125,7 @@ def _atexit_print_fn(): assert len(merge) == len(cum.rewriter_profile[1]) cum.rewriter_profile = (cum.rewriter_profile[0], merge) except Exception as e: - print(e) + print(e) # noqa: T201 cum.rewriter_profile = None else: cum.rewriter_profile = None diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 13ca943383..04572b29d0 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -142,13 +142,50 @@ def __str__(self): disconnected_type = DisconnectedType() -def Rop( - f: Variable | Sequence[Variable], - wrt: Variable | Sequence[Variable], - eval_points: Variable | Sequence[Variable], +def pushforward_through_pullback( + outputs: Sequence[Variable], + inputs: Sequence[Variable], + tangents: Sequence[Variable], disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", return_disconnected: Literal["none", "zero", "disconnected"] = "zero", -) -> Variable | None | Sequence[Variable | None]: +) -> Sequence[Variable | None]: + """Compute the pushforward (Rop) through two applications of a pullback (Lop) operation. + + References + ---------- + .. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017. + Available: https://j-towns.github.io/2017/06/12/A-new-trick.html + + """ + # Cotangents are just auxiliary variables that should be pruned from the final graph, + # but that would require a graph rewrite before the user tries to compile a pytensor function. + # To avoid trouble we use .zeros_like() instead of .type(), which does not create a new root variable. + cotangents = [out.zeros_like(dtype=config.floatX) for out in outputs] # type: ignore + + input_cotangents = Lop( + f=outputs, + wrt=inputs, + eval_points=cotangents, + disconnected_inputs=disconnected_outputs, + return_disconnected="zero", + ) + + return Lop( + f=input_cotangents, # type: ignore + wrt=cotangents, + eval_points=tangents, + disconnected_inputs="ignore", + return_disconnected=return_disconnected, + ) + + +def _rop_legacy( + f: Sequence[Variable], + wrt: Sequence[Variable], + eval_points: Sequence[Variable], + disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", +) -> Sequence[Variable | None]: """Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`. Mathematically this stands for the Jacobian of `f` right multiplied by the @@ -190,38 +227,6 @@ def Rop( If `f` is a list/tuple, then return a list/tuple with the results. """ - if not isinstance(wrt, list | tuple): - _wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)] - else: - _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] - - if not isinstance(eval_points, list | tuple): - _eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)] - else: - _eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points] - - if not isinstance(f, list | tuple): - _f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)] - else: - _f = [pytensor.tensor.as_tensor_variable(x) for x in f] - - if len(_wrt) != len(_eval_points): - raise ValueError("`wrt` must be the same length as `eval_points`.") - - # Check that each element of wrt corresponds to an element - # of eval_points with the same dimensionality. - for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)): - try: - if wrt_elem.type.ndim != eval_point.type.ndim: - raise ValueError( - f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: " - f"{wrt_elem.type.ndim} and {eval_point.type.ndim}" - ) - except AttributeError: - # wrt_elem and eval_point don't always have ndim like random type - # Tensor, Sparse have the ndim attribute - pass - seen_nodes: dict[Apply, Sequence[Variable]] = {} def _traverse(node): @@ -237,8 +242,8 @@ def _traverse(node): # inputs of the node local_eval_points = [] for inp in inputs: - if inp in _wrt: - local_eval_points.append(_eval_points[_wrt.index(inp)]) + if inp in wrt: + local_eval_points.append(eval_points[wrt.index(inp)]) elif inp.owner is None: try: local_eval_points.append(inp.zeros_like()) @@ -292,13 +297,13 @@ def _traverse(node): # end _traverse # Populate the dictionary - for out in _f: + for out in f: _traverse(out.owner) rval: list[Variable | None] = [] - for out in _f: - if out in _wrt: - rval.append(_eval_points[_wrt.index(out)]) + for out in f: + if out in wrt: + rval.append(eval_points[wrt.index(out)]) elif ( seen_nodes.get(out.owner, None) is None or seen_nodes[out.owner][out.owner.outputs.index(out)] is None @@ -337,6 +342,116 @@ def _traverse(node): else: rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)]) + return rval + + +def Rop( + f: Variable | Sequence[Variable], + wrt: Variable | Sequence[Variable], + eval_points: Variable | Sequence[Variable], + disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", + use_op_rop_implementation: bool = False, +) -> Variable | None | Sequence[Variable | None]: + """Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`. + + Mathematically this stands for the Jacobian of `f` right multiplied by the + `eval_points`. + + By default, the R-operator is implemented as a double application of the L_operator [1]_. + In most cases this should be as performant as a specialized implementation of the R-operator. + However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators, + such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator. + + When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing + `use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method. + + Parameters + ---------- + f + The outputs of the computational graph to which the R-operator is + applied. + wrt + Variables for which the R-operator of `f` is computed. + eval_points + Points at which to evaluate each of the variables in `wrt`. + disconnected_outputs + Defines the behaviour if some of the variables in `f` + have no dependency on any of the variable in `wrt` (or if + all links are non-differentiable). The possible values are: + + - ``'ignore'``: considers that the gradient on these parameters is zero. + - ``'warn'``: consider the gradient zero, and print a warning. + - ``'raise'``: raise `DisconnectedInputError`. + + return_disconnected + - ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``wrt[i].zeros_like()``. + - ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``None`` + - ``'disconnected'`` : returns variables of type `DisconnectedType` + use_op_lop_implementation: bool, default=True + If `True`, we obtain Rop via double application of Lop. + If `False`, the legacy Rop implementation is used. The number of graphs that support this form + is much more restricted, and the generated graphs may be less optimized. + + Returns + ------- + :class:`~pytensor.graph.basic.Variable` or list/tuple of Variables + A symbolic expression such obeying + ``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``, + where the indices in that expression are magic multidimensional + indices that specify both the position within a list and all + coordinates of the tensor elements. + If `f` is a list/tuple, then return a list/tuple with the results. + + References + ---------- + .. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017. + Available: https://j-towns.github.io/2017/06/12/A-new-trick.html + """ + + if not isinstance(wrt, list | tuple): + _wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)] + else: + _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] + + if not isinstance(eval_points, list | tuple): + _eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)] + else: + _eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points] + + if not isinstance(f, list | tuple): + _f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)] + else: + _f = [pytensor.tensor.as_tensor_variable(x) for x in f] + + if len(_wrt) != len(_eval_points): + raise ValueError("`wrt` must be the same length as `eval_points`.") + + # Check that each element of wrt corresponds to an element + # of eval_points with the same dimensionality. + for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)): + try: + if wrt_elem.type.ndim != eval_point.type.ndim: + raise ValueError( + f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: " + f"{wrt_elem.type.ndim} and {eval_point.type.ndim}" + ) + except AttributeError: + # wrt_elem and eval_point don't always have ndim like random type + # Tensor, Sparse have the ndim attribute + pass + + if use_op_rop_implementation: + rval = _rop_legacy( + _f, _wrt, _eval_points, disconnected_outputs, return_disconnected + ) + else: + rval = pushforward_through_pullback( + _f, _wrt, _eval_points, disconnected_outputs, return_disconnected + ) + using_list = isinstance(f, list) using_tuple = isinstance(f, tuple) return as_list_or_tuple(using_list, using_tuple, rval) @@ -348,6 +463,7 @@ def Lop( eval_points: Variable | Sequence[Variable], consider_constant: Sequence[Variable] | None = None, disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", ) -> Variable | None | Sequence[Variable | None]: """Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`. @@ -404,6 +520,7 @@ def Lop( consider_constant=consider_constant, wrt=_wrt, disconnected_inputs=disconnected_inputs, + return_disconnected=return_disconnected, ) using_list = isinstance(wrt, list) diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 93321fa61f..06be6d013a 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -491,7 +491,7 @@ def validate_(self, fgraph): if verbose: r = uf.f_locals.get("r", "") reason = uf_info.function - print(f"validate failed on node {r}.\n Reason: {reason}, {e}") + print(f"validate failed on node {r}.\n Reason: {reason}, {e}") # noqa: T201 raise t1 = time.perf_counter() if fgraph.profile: @@ -603,13 +603,13 @@ def replace_all_validate( except Exception as e: fgraph.revert(chk) if verbose: - print( + print( # noqa: T201 f"rewriting: validate failed on node {r}.\n Reason: {reason}, {e}" ) raise if verbose: - print( + print( # noqa: T201 f"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}" ) @@ -692,11 +692,11 @@ def on_import(self, fgraph, node, reason): except TypeError: # node.op is unhashable return except Exception as e: - print("OFFENDING node", type(node), type(node.op), file=sys.stderr) + print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201 try: - print("OFFENDING node hash", hash(node.op), file=sys.stderr) + print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201 except Exception: - print("OFFENDING node not hashable", file=sys.stderr) + print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201 raise e def on_prune(self, fgraph, node, reason): @@ -725,7 +725,7 @@ def __init__(self, active=True): def on_attach(self, fgraph): if self.active: - print("-- attaching to: ", fgraph) + print("-- attaching to: ", fgraph) # noqa: T201 def on_detach(self, fgraph): """ @@ -733,19 +733,19 @@ def on_detach(self, fgraph): that it installed into the function_graph """ if self.active: - print("-- detaching from: ", fgraph) + print("-- detaching from: ", fgraph) # noqa: T201 def on_import(self, fgraph, node, reason): if self.active: - print(f"-- importing: {node}, reason: {reason}") + print(f"-- importing: {node}, reason: {reason}") # noqa: T201 def on_prune(self, fgraph, node, reason): if self.active: - print(f"-- pruning: {node}, reason: {reason}") + print(f"-- pruning: {node}, reason: {reason}") # noqa: T201 def on_change_input(self, fgraph, node, i, r, new_r, reason=None): if self.active: - print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") + print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") # noqa: T201 class PreserveVariableAttributes(Feature): diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 1d845e2eb3..e9b676f51a 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -491,7 +491,7 @@ def replace( if verbose is None: verbose = config.optimizer_verbose if verbose: - print( + print( # noqa: T201 f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}" ) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 344d6a1940..b91e743bb6 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -1002,7 +1002,7 @@ def transform(self, fgraph, node, *args, **kwargs): # ensure we have data for all input variables that need it if missing: if self.verbose > 0: - print( + print( # noqa: T201 f"{self.__class__.__name__} cannot meta-rewrite {node}, " f"{len(missing)} of {int(node.nin)} input shapes unknown" ) @@ -1010,7 +1010,7 @@ def transform(self, fgraph, node, *args, **kwargs): # now we can apply the different rewrites in turn, # compile the resulting subgraphs and time their execution if self.verbose > 1: - print( + print( # noqa: T201 f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):" ) timings = [] @@ -1027,20 +1027,20 @@ def transform(self, fgraph, node, *args, **kwargs): continue except Exception as e: if self.verbose > 0: - print(f"* {node_rewriter}: exception", e) + print(f"* {node_rewriter}: exception", e) # noqa: T201 continue else: if self.verbose > 1: - print(f"* {node_rewriter}: {timing:.5g} sec") + print(f"* {node_rewriter}: {timing:.5g} sec") # noqa: T201 timings.append((timing, outputs, node_rewriter)) else: if self.verbose > 0: - print(f"* {node_rewriter}: not applicable") + print(f"* {node_rewriter}: not applicable") # noqa: T201 # finally, we choose the fastest one if timings: timings.sort() if self.verbose > 1: - print(f"= {timings[0][2]}") + print(f"= {timings[0][2]}") # noqa: T201 return timings[0][1] return @@ -1305,7 +1305,7 @@ def transform(self, fgraph, node): new_vars = list(new_repl.values()) if config.optimizer_verbose: - print( + print( # noqa: T201 f"rewriting: rewrite {rewrite} replaces node {node} with {new_repl}" ) @@ -2641,21 +2641,21 @@ def print_profile(cls, stream, prof, level=0): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: - print(blanc, "merge not implemented for ", o) + print(blanc, "merge not implemented for ", o) # noqa: T201 for o, prof in zip( rewrite.final_rewriters, final_sub_profs[i], strict=True ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: - print(blanc, "merge not implemented for ", o) + print(blanc, "merge not implemented for ", o) # noqa: T201 for o, prof in zip( rewrite.cleanup_rewriters, cleanup_sub_profs[i], strict=True ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: - print(blanc, "merge not implemented for ", o) + print(blanc, "merge not implemented for ", o) # noqa: T201 @staticmethod def merge_profile(prof1, prof2): @@ -2800,16 +2800,6 @@ def _check_chain(r, chain): return r is not None -def check_chain(r, *chain): - """ - WRITEME - - """ - if isinstance(r, Apply): - r = r.outputs[0] - return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) - - def pre_greedy_node_rewriter( fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable ) -> Variable: diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py index 9c2eef5049..42ebbcd216 100644 --- a/pytensor/graph/utils.py +++ b/pytensor/graph/utils.py @@ -274,9 +274,9 @@ def __repr__(self): return "scratchpad" + str(self.__dict__) def info(self): - print(f"") + print(f"") # noqa: T201 for k, v in self.__dict__.items(): - print(f" {k}: {v}") + print(f" {k}: {v}") # noqa: T201 # These two methods have been added to help Mypy def __getattribute__(self, name): diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index 0b717c74a6..d509bd1d76 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -10,8 +10,6 @@ from io import StringIO from typing import TYPE_CHECKING, Any, Optional -import numpy as np - from pytensor.compile.compilelock import lock_ctx from pytensor.configdefaults import config from pytensor.graph.basic import ( @@ -33,6 +31,7 @@ from pytensor.link.c.cmodule import get_module_cache as _get_module_cache from pytensor.link.c.interface import CLinkerObject, CLinkerOp, CLinkerType from pytensor.link.utils import gc_helper, map_storage, raise_with_op, streamline +from pytensor.npy_2_compat import ndarray_c_version from pytensor.utils import difference, uniq @@ -875,10 +874,10 @@ def code_gen(self): self.c_init_code_apply = c_init_code_apply if (self.init_tasks, self.tasks) != self.get_init_tasks(): - print("init_tasks\n", self.init_tasks, file=sys.stderr) - print(self.get_init_tasks()[0], file=sys.stderr) - print("tasks\n", self.tasks, file=sys.stderr) - print(self.get_init_tasks()[1], file=sys.stderr) + print("init_tasks\n", self.init_tasks, file=sys.stderr) # noqa: T201 + print(self.get_init_tasks()[0], file=sys.stderr) # noqa: T201 + print("tasks\n", self.tasks, file=sys.stderr) # noqa: T201 + print(self.get_init_tasks()[1], file=sys.stderr) # noqa: T201 assert (self.init_tasks, self.tasks) == self.get_init_tasks() # List of indices that should be ignored when passing the arguments @@ -1367,10 +1366,6 @@ def cmodule_key_( # We must always add the numpy ABI version here as # DynamicModule always add the include - if np.lib.NumpyVersion(np.__version__) < "1.16.0a": - ndarray_c_version = np.core.multiarray._get_ndarray_c_version() - else: - ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}") if c_compiler: sig.append("c_compiler_str=" + c_compiler.version_str()) @@ -1756,7 +1751,7 @@ def __call__(self): exc_value = exc_type(_exc_value) exc_value.__thunk_trace__ = trace except Exception: - print( + print( # noqa: T201 ( "ERROR retrieving error_storage." "Was the error set in the c code?" @@ -1764,7 +1759,7 @@ def __call__(self): end=" ", file=sys.stderr, ) - print(self.error_storage, file=sys.stderr) + print(self.error_storage, file=sys.stderr) # noqa: T201 raise raise exc_value.with_traceback(exc_trace) diff --git a/pytensor/link/c/c_code/lazylinker_c.c b/pytensor/link/c/c_code/lazylinker_c.c index a64614a908..08f3e4d0fb 100644 --- a/pytensor/link/c/c_code/lazylinker_c.c +++ b/pytensor/link/c/c_code/lazylinker_c.c @@ -5,9 +5,6 @@ #if PY_VERSION_HEX >= 0x03000000 #include "numpy/npy_3kcompat.h" -#define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr -#define PyCObject_GetDesc NpyCapsule_GetDesc -#define PyCObject_Check NpyCapsule_Check #endif #ifndef Py_TYPE @@ -323,9 +320,9 @@ static int CLazyLinker_init(CLazyLinker *self, PyObject *args, PyObject *kwds) { if (PyObject_HasAttrString(thunk, "cthunk")) { PyObject *cthunk = PyObject_GetAttrString(thunk, "cthunk"); // new reference - assert(cthunk && PyCObject_Check(cthunk)); - self->thunk_cptr_fn[i] = PyCObject_AsVoidPtr(cthunk); - self->thunk_cptr_data[i] = PyCObject_GetDesc(cthunk); + assert(cthunk && NpyCapsule_Check(cthunk)); + self->thunk_cptr_fn[i] = NpyCapsule_AsVoidPtr(cthunk); + self->thunk_cptr_data[i] = NpyCapsule_GetDesc(cthunk); Py_DECREF(cthunk); // cthunk is kept alive by membership in self->thunks } @@ -487,8 +484,8 @@ static PyObject *pycall(CLazyLinker *self, Py_ssize_t node_idx, int verbose) { PyList_SetItem(self->call_times, node_idx, PyFloat_FromDouble(t1 - t0 + ti)); PyObject *count = PyList_GetItem(self->call_counts, node_idx); - long icount = PyInt_AsLong(count); - PyList_SetItem(self->call_counts, node_idx, PyInt_FromLong(icount + 1)); + long icount = PyLong_AsLong(count); + PyList_SetItem(self->call_counts, node_idx, PyLong_FromLong(icount + 1)); } } else { if (verbose) { @@ -512,8 +509,8 @@ static int c_call(CLazyLinker *self, Py_ssize_t node_idx, int verbose) { PyList_SetItem(self->call_times, node_idx, PyFloat_FromDouble(t1 - t0 + ti)); PyObject *count = PyList_GetItem(self->call_counts, node_idx); - long icount = PyInt_AsLong(count); - PyList_SetItem(self->call_counts, node_idx, PyInt_FromLong(icount + 1)); + long icount = PyLong_AsLong(count); + PyList_SetItem(self->call_counts, node_idx, PyLong_FromLong(icount + 1)); } else { err = fn(self->thunk_cptr_data[node_idx]); } @@ -774,20 +771,20 @@ static PyObject *CLazyLinker_call(PyObject *_self, PyObject *args, output_subset = (char *)calloc(self->n_output_vars, sizeof(char)); for (int it = 0; it < output_subset_size; ++it) { PyObject *elem = PyList_GetItem(output_subset_ptr, it); - if (!PyInt_Check(elem)) { + if (!PyLong_Check(elem)) { err = 1; PyErr_SetString(PyExc_RuntimeError, "Some elements of output_subset list are not int"); } - output_subset[PyInt_AsLong(elem)] = 1; + output_subset[PyLong_AsLong(elem)] = 1; } } } self->position_of_error = -1; // create constants used to fill the var_compute_cells - PyObject *one = PyInt_FromLong(1); - PyObject *zero = PyInt_FromLong(0); + PyObject *one = PyLong_FromLong(1); + PyObject *zero = PyLong_FromLong(0); // pre-allocate our return value Py_INCREF(Py_None); @@ -942,11 +939,8 @@ static PyMemberDef CLazyLinker_members[] = { }; static PyTypeObject lazylinker_ext_CLazyLinkerType = { -#if defined(NPY_PY3K) PyVarObject_HEAD_INIT(NULL, 0) -#else - PyObject_HEAD_INIT(NULL) 0, /*ob_size*/ -#endif + "lazylinker_ext.CLazyLinker", /*tp_name*/ sizeof(CLazyLinker), /*tp_basicsize*/ 0, /*tp_itemsize*/ @@ -987,7 +981,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { }; static PyObject *get_version(PyObject *dummy, PyObject *args) { - PyObject *result = PyFloat_FromDouble(0.212); + PyObject *result = PyFloat_FromDouble(0.3); return result; } @@ -996,7 +990,7 @@ static PyMethodDef lazylinker_ext_methods[] = { {NULL, NULL, 0, NULL} /* Sentinel */ }; -#if defined(NPY_PY3K) + static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT, "lazylinker_ext", NULL, @@ -1006,28 +1000,19 @@ static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT, NULL, NULL, NULL}; -#endif -#if defined(NPY_PY3K) -#define RETVAL m + PyMODINIT_FUNC PyInit_lazylinker_ext(void) { -#else -#define RETVAL -PyMODINIT_FUNC initlazylinker_ext(void) { -#endif + PyObject *m; lazylinker_ext_CLazyLinkerType.tp_new = PyType_GenericNew; if (PyType_Ready(&lazylinker_ext_CLazyLinkerType) < 0) - return RETVAL; -#if defined(NPY_PY3K) + return NULL; + m = PyModule_Create(&moduledef); -#else - m = Py_InitModule3("lazylinker_ext", lazylinker_ext_methods, - "Example module that creates an extension type."); -#endif Py_INCREF(&lazylinker_ext_CLazyLinkerType); PyModule_AddObject(m, "CLazyLinker", (PyObject *)&lazylinker_ext_CLazyLinkerType); - return RETVAL; + return m; } diff --git a/pytensor/link/c/c_code/pytensor_mod_helper.h b/pytensor/link/c/c_code/pytensor_mod_helper.h index d3e4b29a2b..2f857e6775 100644 --- a/pytensor/link/c/c_code/pytensor_mod_helper.h +++ b/pytensor/link/c/c_code/pytensor_mod_helper.h @@ -18,14 +18,8 @@ #define PYTENSOR_EXTERN #endif -#if PY_MAJOR_VERSION < 3 -#define PYTENSOR_RTYPE void -#else -#define PYTENSOR_RTYPE PyObject * -#endif - /* We need to redefine PyMODINIT_FUNC to add MOD_PUBLIC in the middle */ #undef PyMODINIT_FUNC -#define PyMODINIT_FUNC PYTENSOR_EXTERN MOD_PUBLIC PYTENSOR_RTYPE +#define PyMODINIT_FUNC PYTENSOR_EXTERN MOD_PUBLIC PyObject * #endif diff --git a/pytensor/link/c/interface.py b/pytensor/link/c/interface.py index 7e281af947..e9375d2511 100644 --- a/pytensor/link/c/interface.py +++ b/pytensor/link/c/interface.py @@ -1,7 +1,7 @@ import typing import warnings from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Hashable from typing import Optional from pytensor.graph.basic import Apply, Constant @@ -155,7 +155,7 @@ def c_init_code(self, **kwargs) -> list[str]: """Return a list of code snippets to be inserted in module initialization.""" return [] - def c_code_cache_version(self) -> tuple[int, ...]: + def c_code_cache_version(self) -> tuple[Hashable, ...]: """Return a tuple of integers indicating the version of this `Op`. An empty tuple indicates an "unversioned" `Op` that will not be cached @@ -223,7 +223,7 @@ def c_code( """ raise NotImplementedError() - def c_code_cache_version_apply(self, node: Apply) -> tuple[int, ...]: + def c_code_cache_version_apply(self, node: Apply) -> tuple[Hashable, ...]: """Return a tuple of integers indicating the version of this `Op`. An empty tuple indicates an "unversioned" `Op` that will not be diff --git a/pytensor/link/c/lazylinker_c.py b/pytensor/link/c/lazylinker_c.py index 679cb4e290..ce67190342 100644 --- a/pytensor/link/c/lazylinker_c.py +++ b/pytensor/link/c/lazylinker_c.py @@ -14,7 +14,7 @@ _logger = logging.getLogger(__file__) force_compile = False -version = 0.212 # must match constant returned in function get_version() +version = 0.3 # must match constant returned in function get_version() lazylinker_ext: ModuleType | None = None diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 74905d686f..b668f242e1 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -79,7 +79,7 @@ def is_f16(t): # that don't implement c code. In those cases, we # don't want to print a warning. cl.get_dynamic_module() - print(f"Disabling C code for {self} due to unsupported float16") + warnings.warn(f"Disabling C code for {self} due to unsupported float16") raise NotImplementedError("float16") outputs = cl.make_thunk( input_storage=node_input_storage, output_storage=node_output_storage diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 9a89bf1406..8a33dfac13 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -56,7 +56,7 @@ def assert_size_argument_jax_compatible(node): @jax_typify.register(Generator) def jax_typify_Generator(rng, **kwargs): - state = rng.__getstate__() + state = rng.bit_generator.state state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] # XXX: Is this a reasonable approach? @@ -128,7 +128,6 @@ def jax_sample_fn(op, node): @jax_sample_fn.register(ptr.BetaRV) @jax_sample_fn.register(ptr.DirichletRV) @jax_sample_fn.register(ptr.PoissonRV) -@jax_sample_fn.register(ptr.MvNormalRV) def jax_sample_fn_generic(op, node): """Generic JAX implementation of random variables.""" name = op.name @@ -173,6 +172,20 @@ def sample_fn(rng, size, dtype, *parameters): return sample_fn +@jax_sample_fn.register(ptr.MvNormalRV) +def jax_sample_mvnormal(op, node): + def sample_fn(rng, size, dtype, mean, cov): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + sample = jax.random.multivariate_normal( + sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method + ) + rng["jax_state"] = rng_key + return (rng, sample) + + return sample_fn + + @jax_sample_fn.register(ptr.BernoulliRV) def jax_sample_fn_bernoulli(op, node): """JAX implementation of `BernoulliRV`.""" diff --git a/pytensor/link/numba/dispatch/_LAPACK.py b/pytensor/link/numba/dispatch/_LAPACK.py new file mode 100644 index 0000000000..ab5561650c --- /dev/null +++ b/pytensor/link/numba/dispatch/_LAPACK.py @@ -0,0 +1,392 @@ +import ctypes + +import numpy as np +from numba.core import cgutils, types +from numba.core.extending import get_cython_function_address, intrinsic +from numba.np.linalg import ensure_lapack, get_blas_kind + + +_PTR = ctypes.POINTER + +_dbl = ctypes.c_double +_float = ctypes.c_float +_char = ctypes.c_char +_int = ctypes.c_int + +_ptr_float = _PTR(_float) +_ptr_dbl = _PTR(_dbl) +_ptr_char = _PTR(_char) +_ptr_int = _PTR(_int) + + +def _get_lapack_ptr_and_ptr_type(dtype, name): + d = get_blas_kind(dtype) + func_name = f"{d}{name}" + float_pointer = _get_float_pointer_for_dtype(d) + lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) + + return lapack_ptr, float_pointer + + +def _get_underlying_float(dtype): + s_dtype = str(dtype) + out_type = s_dtype + if s_dtype == "complex64": + out_type = "float32" + elif s_dtype == "complex128": + out_type = "float64" + + return np.dtype(out_type) + + +def _get_float_pointer_for_dtype(blas_dtype): + if blas_dtype in ["s", "c"]: + return _ptr_float + elif blas_dtype in ["d", "z"]: + return _ptr_dbl + + +def _get_output_ctype(dtype): + s_dtype = str(dtype) + if s_dtype in ["float32", "complex64"]: + return _float + elif s_dtype in ["float64", "complex128"]: + return _dbl + + +@intrinsic +def sptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float32(types.CPointer(types.float32)) + return sig, impl + + +@intrinsic +def dptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float64(types.CPointer(types.float64)) + return sig, impl + + +@intrinsic +def int_ptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.int32(types.CPointer(types.int32)) + return sig, impl + + +@intrinsic +def val_to_int_ptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.int32)(types.int32) + return sig, impl + + +@intrinsic +def val_to_sptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float32)(types.float32) + return sig, impl + + +@intrinsic +def val_to_zptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.complex128)(types.complex128) + return sig, impl + + +@intrinsic +def val_to_dptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float64)(types.float64) + return sig, impl + + +class _LAPACK: + """ + Functions to return type signatures for wrapped LAPACK functions. + + Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 + """ + + def __init__(self): + ensure_lapack() + + @classmethod + def numba_xtrtrs(cls, dtype): + """ + Solve a triangular system of equations of the form A @ X = B or A.T @ X = B. + + Called by scipy.linalg.solve_triangular + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # TRANS + _ptr_int, # DIAG + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + + return functype(lapack_ptr) + + @classmethod + def numba_xpotrf(cls, dtype): + """ + Compute the Cholesky factorization of a real symmetric positive definite matrix. + + Called by scipy.linalg.cholesky + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO, + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xpotrs(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky + factorization computed by numba_potrf. + + Called by scipy.linalg.cho_solve + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xlange(cls, dtype): + """ + Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of + a general M-by-N matrix A. + + Called by scipy.linalg.solve + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange") + output_ctype = _get_output_ctype(dtype) + functype = ctypes.CFUNCTYPE( + output_ctype, # Output + _ptr_int, # NORM + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # WORK + ) + return functype(lapack_ptr) + + @classmethod + def numba_xlamch(cls, dtype): + """ + Determine machine precision for floating point arithmetic. + """ + + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch") + output_dtype = _get_output_ctype(dtype) + functype = ctypes.CFUNCTYPE( + output_dtype, # Output + _ptr_int, # CMACH + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgecon(cls, dtype): + """ + Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf. + + Called by scipy.linalg.solve when assume_a == "gen" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # NORM + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgetrf(cls, dtype): + """ + Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges. + + Called by scipy.linalg.lu_factor + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgetrs(cls, dtype): + """ + Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU + factorization computed by GETRF. + + Called by scipy.linalg.lu_solve + """ + ... + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # TRANS + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xsysv(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method, + factorizing A into LDL^T or UDU^T form, depending on the value of UPLO + + Called by scipy.linalg.solve when assume_a == "sym" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # B + _ptr_int, # LDB + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xsycon(cls, dtype): + """ + Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization + computed by xSYTRF. + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xpocon(cls, dtype): + """ + Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization + computed by potrf. + + Called by scipy.linalg.solve when assume_a == "pos" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xposv(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky + factorization computed by potrf. + """ + + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 0b2b58904a..c66a237f06 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs): def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): - """Create a Numba compatible function from an Aesara `Op`.""" + """Create a Numba compatible function from a Pytensor `Op`.""" warnings.warn( f"Numba will use object mode to run {op}'s perform method", diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2a98985efe..03c7084a8f 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -4,7 +4,6 @@ import numba import numpy as np from numba.core.extending import overload -from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic @@ -19,6 +18,7 @@ store_core_outputs, ) from pytensor.link.utils import compile_function_src +from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.scalar.basic import ( AND, OR, diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 04181e8335..e20d99c605 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from copy import copy +from copy import copy, deepcopy from functools import singledispatch from textwrap import dedent @@ -34,7 +34,7 @@ def copy_NumPyRandomGenerator(rng): def impl(rng): # TODO: Open issue on Numba? with numba.objmode(new_rng=types.npy_rng): - new_rng = copy(rng) + new_rng = deepcopy(rng) return new_rng @@ -144,11 +144,24 @@ def random_fn(rng, p): @numba_core_rv_funcify.register(ptr.MvNormalRV) def core_MvNormalRV(op, node): + method = op.method + @numba_basic.numba_njit def random_fn(rng, mean, cov): - chol = np.linalg.cholesky(cov) - stdnorm = rng.normal(size=cov.shape[-1]) - return np.dot(chol, stdnorm) + mean + if method == "cholesky": + A = np.linalg.cholesky(cov) + elif method == "svd": + A, s, _ = np.linalg.svd(cov) + A *= np.sqrt(s)[None, :] + else: + w, A = np.linalg.eigh(cov) + A *= np.sqrt(w)[None, :] + + out = rng.normal(size=cov.shape[-1]) + # out argument not working correctly: https://github.com/numba/numba/issues/9924 + out[:] = np.dot(A, out) + out += mean + return out random_fn.handles_out = True return random_fn diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 96a8da282e..a3f5ea9491 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -1,136 +1,52 @@ -import ctypes +from collections.abc import Callable import numba import numpy as np -from numba.core import cgutils, types -from numba.extending import get_cython_function_address, intrinsic, overload -from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind +from numba.core import types +from numba.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from numpy.linalg import LinAlgError from scipy import linalg from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular +from pytensor.tensor.slinalg import ( + BlockDiagonal, + Cholesky, + CholeskySolve, + Solve, + SolveTriangular, +) -_PTR = ctypes.POINTER - -_dbl = ctypes.c_double -_float = ctypes.c_float -_char = ctypes.c_char -_int = ctypes.c_int - -_ptr_float = _PTR(_float) -_ptr_dbl = _PTR(_dbl) -_ptr_char = _PTR(_char) -_ptr_int = _PTR(_int) - - -@numba.core.extending.register_jitable -def _check_finite_matrix(a, func_name): - for v in np.nditer(a): - if not np.isfinite(v.item()): - raise np.linalg.LinAlgError( - "Non-numeric values (nan or inf) in input to " + func_name +@numba_basic.numba_njit(inline="always") +def _solve_check(n, info, lamch=False, rcond=None): + """ + Check arguments during the different steps of the solution phase + Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38 + """ + if info < 0: + # TODO: figure out how to do an fstring here + msg = "LAPACK reported an illegal value in input" + raise ValueError(msg) + elif 0 < info: + raise LinAlgError("Matrix is singular.") + + if lamch: + E = _xlamch("E") + if rcond < E: + # TODO: This should be a warning, but we can't raise warnings in numba mode + print( # noqa: T201 + "Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate." ) -@intrinsic -def val_to_dptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.float64)(types.float64) - return sig, impl - - -@intrinsic -def val_to_zptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.complex128)(types.complex128) - return sig, impl - - -@intrinsic -def val_to_sptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.float32)(types.float32) - return sig, impl - - -@intrinsic -def val_to_int_ptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.int32)(types.int32) - return sig, impl - - -@intrinsic -def int_ptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.int32(types.CPointer(types.int32)) - return sig, impl - - -@intrinsic -def dptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.float64(types.CPointer(types.float64)) - return sig, impl - - -@intrinsic -def sptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.float32(types.CPointer(types.float32)) - return sig, impl - - -def _get_float_pointer_for_dtype(blas_dtype): - if blas_dtype in ["s", "c"]: - return _ptr_float - elif blas_dtype in ["d", "z"]: - return _ptr_dbl - - -def _get_underlying_float(dtype): - s_dtype = str(dtype) - out_type = s_dtype - if s_dtype == "complex64": - out_type = "float32" - elif s_dtype == "complex128": - out_type = "float64" - - return np.dtype(out_type) - - -def _get_lapack_ptr_and_ptr_type(dtype, name): - d = get_blas_kind(dtype) - func_name = f"{d}{name}" - float_pointer = _get_float_pointer_for_dtype(d) - lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) - - return lapack_ptr, float_pointer - - def _check_scipy_linalg_matrix(a, func_name): """ Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 @@ -152,64 +68,50 @@ def _check_scipy_linalg_matrix(a, func_name): raise numba.TypingError(msg, highlighting=False) -class _LAPACK: +def _solve_triangular( + A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False +): """ - Functions to return type signatures for wrapped LAPACK functions. + Thin wrapper around scipy.linalg.solve_triangular. - Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 - """ - - def __init__(self): - ensure_lapack() + This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who + import pytensor. - @classmethod - def numba_xtrtrs(cls, dtype): - """ - Called by scipy.linalg.solve_triangular - """ - lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") + The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not + used by scipy.linalg.solve_triangular. + """ + return linalg.solve_triangular( + A, + B, + trans=trans, + lower=lower, + unit_diagonal=unit_diagonal, + overwrite_b=overwrite_b, + ) - functype = ctypes.CFUNCTYPE( - None, - _ptr_int, # UPLO - _ptr_int, # TRANS - _ptr_int, # DIAG - _ptr_int, # N - _ptr_int, # NRHS - float_pointer, # A - _ptr_int, # LDA - float_pointer, # B - _ptr_int, # LDB - _ptr_int, # INFO - ) - return functype(lapack_ptr) - - @classmethod - def numba_xpotrf(cls, dtype): - """ - Called by scipy.linalg.cholesky - """ - lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") - functype = ctypes.CFUNCTYPE( - None, - _ptr_int, # UPLO, - _ptr_int, # N - float_pointer, # A - _ptr_int, # LDA - _ptr_int, # INFO - ) - return functype(lapack_ptr) +@numba_basic.numba_njit(inline="always") +def _trans_char_to_int(trans): + if trans not in [0, 1, 2]: + raise ValueError('Parameter "trans" should be one of 0, 1, 2') + if trans == 0: + return ord("N") + elif trans == 1: + return ord("T") + else: + return ord("C") -def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): - return linalg.solve_triangular( - A, B, trans=trans, lower=lower, unit_diagonal=unit_diagonal - ) +@numba_basic.numba_njit(inline="always") +def _solve_check_input_shapes(A, B): + if A.shape[0] != B.shape[0]: + raise linalg.LinAlgError("Dimensions of A and B do not conform") + if A.shape[-2] != A.shape[-1]: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") @overload(_solve_triangular) -def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): +def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): ensure_lapack() _check_scipy_linalg_matrix(A, "solve_triangular") @@ -218,37 +120,27 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): w_type = _get_underlying_float(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype) - def impl(A, B, trans=0, lower=False, unit_diagonal=False): - B_is_1d = B.ndim == 1 - + def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): _N = np.int32(A.shape[-1]) - if A.shape[-2] != _N: - raise linalg.LinAlgError("Last 2 dimensions of A must be square") + _solve_check_input_shapes(A, B) - if A.shape[0] != B.shape[0]: - raise linalg.LinAlgError("Dimensions of A and B do not conform") + B_is_1d = B.ndim == 1 - if B_is_1d: - B_copy = np.asfortranarray(np.expand_dims(B, -1)) - else: + if not overwrite_b: B_copy = _copy_to_fortran_order(B) - - if trans not in [0, 1, 2]: - raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') - if trans == 0: - transval = ord("N") - elif trans == 1: - transval = ord("T") else: - transval = ord("C") + B_copy = B - B_NDIM = 1 if B_is_1d else int(B.shape[1]) + if B_is_1d: + B_copy = np.expand_dims(B, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) - TRANS = val_to_int_ptr(transval) + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(B_NDIM) + NRHS = val_to_int_ptr(NRHS) LDA = val_to_int_ptr(_N) LDB = val_to_int_ptr(_N) INFO = val_to_int_ptr(0) @@ -266,19 +158,24 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): INFO, ) + _solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO)) + if B_is_1d: - return B_copy[..., 0], int_ptr_to_val(INFO) - return B_copy, int_ptr_to_val(INFO) + return B_copy[..., 0] + + return B_copy return impl @numba_funcify.register(SolveTriangular) def numba_funcify_SolveTriangular(op, node, **kwargs): - trans = op.trans + trans = bool(op.trans) lower = op.lower unit_diagonal = op.unit_diagonal check_finite = op.check_finite + overwrite_b = op.overwrite_b + b_ndim = op.b_ndim dtype = node.inputs[0].dtype if str(dtype).startswith("complex"): @@ -298,11 +195,16 @@ def solve_triangular(a, b): "Non-numeric values (nan or inf) in input b to solve_triangular" ) - res, info = _solve_triangular(a, b, trans, lower, unit_diagonal) - if info != 0: - raise np.linalg.LinAlgError( - "Singular matrix in input A to solve_triangular" - ) + res = _solve_triangular( + a, + b, + trans=trans, + lower=lower, + unit_diagonal=unit_diagonal, + overwrite_b=overwrite_b, + b_ndim=b_ndim, + ) + return res return solve_triangular @@ -429,3 +331,853 @@ def block_diag(*arrs): return out return block_diag + + +def _xlamch(kind: str = "E"): + """ + Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs. + """ + pass + + +@overload(_xlamch) +def xlamch_impl(kind: str = "E") -> Callable[[str], float]: + """ + Compute the machine precision for a given floating point type. + """ + from pytensor import config + + ensure_lapack() + w_type = _get_underlying_float(config.floatX) + + if w_type == "float32": + dtype = types.float32 + elif w_type == "float64": + dtype = types.float64 + else: + raise NotImplementedError("Unsupported dtype") + + numba_lamch = _LAPACK().numba_xlamch(dtype) + + def impl(kind: str = "E") -> float: + KIND = val_to_int_ptr(ord(kind)) + return numba_lamch(KIND) # type: ignore + + return impl + + +def _xlange(A: np.ndarray, order: str | None = None) -> float: + """ + Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode. + """ + return # type: ignore + + +@overload(_xlange) +def xlange_impl( + A: np.ndarray, order: str | None = None +) -> Callable[[np.ndarray, str], float]: + """ + xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of + largest absolute value of a matrix A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(A, "norm") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_lange = _LAPACK().numba_xlange(dtype) + + def impl(A: np.ndarray, order: str | None = None): + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + A_copy = _copy_to_fortran_order(A) + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + + NORM = ( + val_to_int_ptr(ord(order)) + if order is not None + else val_to_int_ptr(ord("1")) + ) + WORK = np.empty(_M, dtype=dtype) # type: ignore + + result = numba_lange( + NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes + ) + + return result + + return impl + + +def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify + graphs. + """ + return # type: ignore + + +@overload(_xgecon) +def xgecon_impl( + A: np.ndarray, A_norm: float, norm: str +) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]: + """ + Compute the condition number of a matrix A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(A, "gecon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_gecon = _LAPACK().numba_xgecon(dtype) + + def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + A_NORM = np.array(A_norm, dtype=dtype) + NORM = val_to_int_ptr(ord(norm)) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(4 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(1) + + numba_gecon( + NORM, + N, + A_copy.view(w_type).ctypes, + LDA, + A_NORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: + """ + Placeholder for LU factorization; used by linalg.solve. + + # TODO: Implement an LU_factor Op, then dispatch to this function in numba mode. + """ + return # type: ignore + + +@overload(_getrf) +def getrf_impl( + A: np.ndarray, overwrite_a: bool = False +) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "getrf") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_getrf = _LAPACK().numba_xgetrf(dtype) + + def impl( + A: np.ndarray, overwrite_a: bool = False + ) -> tuple[np.ndarray, np.ndarray, int]: + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + INFO = val_to_int_ptr(0) + + numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO) + + return A_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _getrs( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> tuple[np.ndarray, int]: + """ + Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve. + + # TODO: Implement an LU_solve Op, then dispatch to this function in numba mode. + """ + return # type: ignore + + +@overload(_getrs) +def getrs_impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(LU, "getrs") + _check_scipy_linalg_matrix(B, "getrs") + dtype = LU.dtype + w_type = _get_underlying_float(dtype) + numba_getrs = _LAPACK().numba_xgetrs(dtype) + + def impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool + ) -> tuple[np.ndarray, int]: + _N = np.int32(LU.shape[-1]) + _solve_check_input_shapes(LU, B) + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) + + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + IPIV = _copy_to_fortran_order(IPIV) + INFO = val_to_int_ptr(0) + + numba_getrs( + TRANS, + N, + NRHS, + LU.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + + return B_copy, int_ptr_to_val(INFO) + + return impl + + +def _solve_gen( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects + for users who import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + assume_a="gen", + transposed=transposed, + ) + + +@overload(_solve_gen) +def solve_gen_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _N = np.int32(A.shape[-1]) + _solve_check_input_shapes(A, B) + + order = "I" if transposed else "1" + norm = _xlange(A, order=order) + + N = A.shape[1] + LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) + _solve_check(N, INFO) + + X, INFO = _getrs( + LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b + ) + _solve_check(N, INFO) + + RCOND, INFO = _xgecon(LU, norm, "1") + _solve_check(N, INFO, True, RCOND) + + return X + + return impl + + +def _sysv( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool +) -> tuple[np.ndarray, np.ndarray, int]: + """ + Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve. + """ + return # type: ignore + + +@overload(_sysv) +def sysv_impl( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool +) -> Callable[ + [np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int] +]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "sysv") + _check_scipy_linalg_matrix(B, "sysv") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_sysv = _LAPACK().numba_xsysv(dtype) + + def impl( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool + ): + _LDA, _N = np.int32(A.shape[-2:]) # type: ignore + _solve_check_input_shapes(A, B) + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + if B_is_1d: + B_copy = np.asfortranarray(np.expand_dims(B_copy, -1)) + + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) # type: ignore + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_LDA) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + LDB = val_to_int_ptr(_N) # type: ignore + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + INFO = val_to_int_ptr(0) + + # Workspace query + numba_sysv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + WS_SIZE = np.int32(WORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + + # Actual solve + numba_sysv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], IPIV, int_ptr_to_val(INFO) + return B_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in + python mode. + """ + return # type: ignore + + +@overload(_sycon) +def sycon_impl( + A: np.ndarray, ipiv: np.ndarray, anorm: float +) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "sycon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_sycon = _LAPACK().numba_xsycon(dtype) + + def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + UPLO = val_to_int_ptr(ord("L")) + ANORM = np.array(anorm, dtype=dtype) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(2 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(0) + + numba_sycon( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + ipiv.ctypes, + ANORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _solve_symmetric( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid + unexpected side-effects when users import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + assume_a="sym", + transposed=transposed, + ) + + +@overload(_solve_symmetric) +def solve_symmetric_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _solve_check_input_shapes(A, B) + + x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) + _solve_check(A.shape[-1], info) + + rcond, info = _sycon(A, ipiv, _xlange(A, order="I")) + _solve_check(A.shape[-1], info, True, rcond) + + return x + + return impl + + +def _posv( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> tuple[np.ndarray, int]: + """ + Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. + """ + return # type: ignore + + +@overload(_posv) +def posv_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[ + [np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int] +]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_posv = _LAPACK().numba_xposv(dtype) + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> tuple[np.ndarray, int]: + _solve_check_input_shapes(A, B) + + _N = np.int32(A.shape[-1]) + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_posv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + return B_copy, int_ptr_to_val(INFO) + + return impl + + +def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by + linalg.solve when assume_a = "pos". + """ + return # type: ignore + + +@overload(_pocon) +def pocon_impl( + A: np.ndarray, anorm: float +) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "pocon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_pocon = _LAPACK().numba_xpocon(dtype) + + def impl(A: np.ndarray, anorm: float): + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + UPLO = val_to_int_ptr(ord("L")) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + ANORM = np.array(anorm, dtype=dtype) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(3 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(0) + + numba_pocon( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + ANORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _solve_psd( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to + avoid unexpected side-effects when users import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + transposed=transposed, + assume_a="pos", + ) + + +@overload(_solve_psd) +def solve_psd_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _solve_check_input_shapes(A, B) + + x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed) + _solve_check(A.shape[-1], info) + + rcond, info = _pocon(x, _xlange(A)) + _solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) + + return x + + return impl + + +@numba_funcify.register(Solve) +def numba_funcify_Solve(op, node, **kwargs): + assume_a = op.assume_a + lower = op.lower + check_finite = op.check_finite + overwrite_a = op.overwrite_a + overwrite_b = op.overwrite_b + transposed = False # TODO: Solve doesnt currently allow the transposed argument + + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by solve in Numba mode" + ) + + if assume_a == "gen": + solve_fn = _solve_gen + elif assume_a == "sym": + solve_fn = _solve_symmetric + elif assume_a == "her": + raise NotImplementedError( + 'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, ' + "please open an issue on github." + ) + elif assume_a == "pos": + solve_fn = _solve_psd + else: + raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode") + + @numba_basic.numba_njit(inline="always") + def solve(a, b): + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to solve" + ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input b to solve" + ) + + res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed) + return res + + return solve + + +def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True): + """ + Solve a positive-definite linear system using the Cholesky decomposition. + """ + A, lower = A_and_lower + return linalg.cho_solve((A, lower), B) + + +@overload(_cho_solve) +def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): + ensure_lapack() + _check_scipy_linalg_matrix(C, "cho_solve") + _check_scipy_linalg_matrix(B, "cho_solve") + dtype = C.dtype + w_type = _get_underlying_float(dtype) + numba_potrs = _LAPACK().numba_xpotrs(dtype) + + def impl(C, B, lower=False, overwrite_b=False, check_finite=True): + _solve_check_input_shapes(C, B) + + _N = np.int32(C.shape[-1]) + C_copy = _copy_to_fortran_order(C) + + B_is_1d = B.ndim == 1 + if B_is_1d: + B_copy = np.asfortranarray(np.expand_dims(B, -1)) + else: + B_copy = _copy_to_fortran_order(B) + + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_potrs( + UPLO, + N, + NRHS, + C_copy.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + return B_copy, int_ptr_to_val(INFO) + + return impl + + +@numba_funcify.register(CholeskySolve) +def numba_funcify_CholeskySolve(op, node, **kwargs): + lower = op.lower + overwrite_b = op.overwrite_b + check_finite = op.check_finite + + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by cho_solve in Numba mode" + ) + + @numba_basic.numba_njit(inline="always") + def cho_solve(c, b): + if check_finite: + if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to cho_solve" + ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input b to cho_solve" + ) + + res, info = _cho_solve( + c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite + ) + + if info < 0: + raise np.linalg.LinAlgError("Illegal values found in input to cho_solve") + elif info > 0: + raise np.linalg.LinAlgError( + "Matrix is not positive definite in input to cho_solve" + ) + return res + + return cho_solve diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..ef4bf10637 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -8,6 +8,7 @@ from pytensor.compile import PYTORCH from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python @@ -19,6 +20,7 @@ Eye, Join, MakeVector, + Split, TensorFromScalar, ) @@ -120,14 +122,23 @@ def arange(start, stop, step): @pytorch_funcify.register(Join) -def pytorch_funcify_Join(op, **kwargs): - def join(axis, *tensors): - # tensors could also be tuples, and in this case they don't have a ndim - tensors = [torch.tensor(tensor) for tensor in tensors] +def pytorch_funcify_Join(op, node, **kwargs): + axis = node.inputs[0] - return torch.cat(tensors, dim=axis) + if isinstance(axis, Constant): + axis = int(axis.data) - return join + def join_constant_axis(_, *tensors): + return torch.cat(tensors, dim=axis) + + return join_constant_axis + + else: + + def join(axis, *tensors): + return torch.cat(tensors, dim=axis) + + return join @pytorch_funcify.register(Eye) @@ -172,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs): @pytorch_funcify.register(OpFromGraph) def pytorch_funcify_OpFromGraph(op, node, **kwargs): kwargs.pop("storage_map", None) - # Apply inner rewrites PYTORCH.optimizer(op.fgraph) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) @@ -185,3 +195,23 @@ def tensorfromscalar(x): return torch.as_tensor(x) return tensorfromscalar + + +@pytorch_funcify.register(Split) +def pytorch_funcify_Split(op, node, **kwargs): + x, dim, split_sizes = node.inputs + if isinstance(dim, Constant) and isinstance(split_sizes, Constant): + dim = int(dim.data) + split_sizes = tuple(int(size) for size in split_sizes.data) + + def split_constant_axis_and_sizes(x, *_): + return x.split(split_sizes, dim=dim) + + return split_constant_axis_and_sizes + + else: + + def inner_fn(x, dim, split_amounts): + return x.split(split_amounts.tolist(), dim=dim.item()) + + return inner_fn diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 65170b1f53..6a1c6b235e 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -5,12 +5,18 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( Cast, + Invert, ScalarOp, ) from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.math import Softplus +@pytorch_funcify.register(Invert) +def pytorch_funcify_invert(op, node, **kwargs): + return torch.bitwise_not + + @pytorch_funcify.register(ScalarOp) def pytorch_funcify_ScalarOp(op, node, **kwargs): """Return pytorch function that implements the same computation as the Scalar Op. diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index f771ac7211..c15b3a3779 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -1,15 +1,28 @@ import torch +from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast @pytorch_funcify.register(Reshape) def pytorch_funcify_Reshape(op, node, **kwargs): - def reshape(x, shape): - return torch.reshape(x, tuple(shape)) + _, shape = node.inputs - return reshape + if isinstance(shape, Constant): + constant_shape = tuple(int(dim) for dim in shape.data) + + def reshape_constant_shape(x, *_): + return torch.reshape(x, constant_shape) + + return reshape_constant_shape + + else: + + def reshape(x, shape): + return torch.reshape(x, tuple(shape)) + + return reshape @pytorch_funcify.register(Shape) diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 75e7ec0776..34358797fb 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -1,3 +1,4 @@ +from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -23,7 +24,21 @@ def check_negative_steps(indices): @pytorch_funcify.register(Subtensor) def pytorch_funcify_Subtensor(op, node, **kwargs): idx_list = op.idx_list + x, *idxs = node.inputs + if all(isinstance(idx, Constant) for idx in idxs): + # Use constant indices to avoid graph break + constant_indices = indices_from_subtensor( + [int(idx.data) for idx in idxs], idx_list + ) + check_negative_steps(constant_indices) + + def constant_index_subtensor(x, *_): + return x[constant_indices] + + return constant_index_subtensor + + # Fallback that will introduce a graph break def subtensor(x, *flattened_indices): indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..b8475e3157 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs): def jit_compile(self, fn): import torch + # flag that tend to help our graphs + torch._dynamo.config.capture_dynamic_output_shape_ops = True + from pytensor.link.pytorch.dispatch import pytorch_typify class wrapper: diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index af44af3254..c6e1283806 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -118,7 +118,7 @@ def calculate_reallocate_info( # where gc for i in range(idx + 1, len(order)): if reuse_out is not None: - break # type: ignore + break for out in order[i].outputs: if ( getattr(out.type, "ndim", None) == 0 diff --git a/pytensor/npy_2_compat.py b/pytensor/npy_2_compat.py new file mode 100644 index 0000000000..667a5c074e --- /dev/null +++ b/pytensor/npy_2_compat.py @@ -0,0 +1,308 @@ +from textwrap import dedent + +import numpy as np + + +# Conditional numpy imports for numpy 1.26 and 2.x compatibility +try: + from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple +except ModuleNotFoundError: + # numpy < 2.0 + from numpy.core.multiarray import normalize_axis_index # type: ignore[no-redef] + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] + + +try: + from numpy._core.einsumfunc import ( # type: ignore[attr-defined] + _find_contraction, + _parse_einsum_input, + ) +except ModuleNotFoundError: + from numpy.core.einsumfunc import ( # type: ignore[no-redef] + _find_contraction, + _parse_einsum_input, + ) + + +# suppress linting warning by "using" the imports here: +__all__ = [ + "_find_contraction", + "_parse_einsum_input", + "normalize_axis_index", + "normalize_axis_tuple", +] + + +numpy_version_tuple = tuple(int(n) for n in np.__version__.split(".")[:2]) +numpy_version = np.lib.NumpyVersion( + np.__version__ +) # used to compare with version strings, e.g. numpy_version < "1.16.0" +using_numpy_2 = numpy_version >= "2.0.0rc1" + + +if using_numpy_2: + ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() +else: + ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined] + + +# used in tests: the type of error thrown if a value is too large for the specified +# numpy data type is different in numpy 2.x +UintOverflowError = OverflowError if using_numpy_2 else TypeError + + +# to patch up some of the C code, we need to use these special values... +if using_numpy_2: + numpy_axis_is_none_flag = np.iinfo(np.int32).min # the value of "NPY_RAVEL_AXIS" +else: + # 32 is the value used to mark axis = None in Numpy C-API prior to version 2.0 + numpy_axis_is_none_flag = 32 + + +# max number of dims is 64 in numpy 2.x; 32 in older versions +numpy_maxdims = 64 if using_numpy_2 else 32 + + +# function that replicates np.unique from numpy < 2.0 +def old_np_unique( + arr, return_index=False, return_inverse=False, return_counts=False, axis=None +): + """Replicate np.unique from numpy versions < 2.0""" + if not return_inverse or not using_numpy_2: + return np.unique(arr, return_index, return_inverse, return_counts, axis) + + outs = list(np.unique(arr, return_index, return_inverse, return_counts, axis)) + + inv_idx = 2 if return_index else 1 + + if axis is None: + outs[inv_idx] = np.ravel(outs[inv_idx]) + else: + inv_shape = (arr.shape[axis],) + outs[inv_idx] = outs[inv_idx].reshape(inv_shape) + + return tuple(outs) + + +# compatibility header for C code +def npy_2_compat_header() -> str: + """Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x""" + return dedent(""" + #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ + #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ + + + /* + * This header is meant to be included by downstream directly for 1.x compat. + * In that case we need to ensure that users first included the full headers + * and not just `ndarraytypes.h`. + */ + + #ifndef NPY_FEATURE_VERSION + #error "The NumPy 2 compat header requires `import_array()` for which " \\ + "the `ndarraytypes.h` header include is not sufficient. Please " \\ + "include it after `numpy/ndarrayobject.h` or similar." \\ + "" \\ + "To simplify inclusion, you may use `PyArray_ImportNumPy()` " \\ + "which is defined in the compat header and is lightweight (can be)." + #endif + + #if NPY_ABI_VERSION < 0x02000000 + /* + * Define 2.0 feature version as it is needed below to decide whether we + * compile for both 1.x and 2.x (defining it gaurantees 1.x only). + */ + #define NPY_2_0_API_VERSION 0x00000012 + /* + * If we are compiling with NumPy 1.x, PyArray_RUNTIME_VERSION so we + * pretend the `PyArray_RUNTIME_VERSION` is `NPY_FEATURE_VERSION`. + * This allows downstream to use `PyArray_RUNTIME_VERSION` if they need to. + */ + #define PyArray_RUNTIME_VERSION NPY_FEATURE_VERSION + /* Compiling on NumPy 1.x where these are the same: */ + #define PyArray_DescrProto PyArray_Descr + #endif + + + /* + * Define a better way to call `_import_array()` to simplify backporting as + * we now require imports more often (necessary to make ABI flexible). + */ + #ifdef import_array1 + + static inline int + PyArray_ImportNumPyAPI() + { + if (NPY_UNLIKELY(PyArray_API == NULL)) { + import_array1(-1); + } + return 0; + } + + #endif /* import_array1 */ + + + /* + * NPY_DEFAULT_INT + * + * The default integer has changed, `NPY_DEFAULT_INT` is available at runtime + * for use as type number, e.g. `PyArray_DescrFromType(NPY_DEFAULT_INT)`. + * + * NPY_RAVEL_AXIS + * + * This was introduced in NumPy 2.0 to allow indicating that an axis should be + * raveled in an operation. Before NumPy 2.0, NPY_MAXDIMS was used for this purpose. + * + * NPY_MAXDIMS + * + * A constant indicating the maximum number dimensions allowed when creating + * an ndarray. + * + * NPY_NTYPES_LEGACY + * + * The number of built-in NumPy dtypes. + */ + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + #define NPY_DEFAULT_INT NPY_INTP + #define NPY_RAVEL_AXIS NPY_MIN_INT + #define NPY_MAXARGS 64 + + #elif NPY_ABI_VERSION < 0x02000000 + #define NPY_DEFAULT_INT NPY_LONG + #define NPY_RAVEL_AXIS 32 + #define NPY_MAXARGS 32 + + /* Aliases of 2.x names to 1.x only equivalent names */ + #define NPY_NTYPES NPY_NTYPES_LEGACY + #define PyArray_DescrProto PyArray_Descr + #define _PyArray_LegacyDescr PyArray_Descr + /* NumPy 2 definition always works, but add it for 1.x only */ + #define PyDataType_ISLEGACY(dtype) (1) + #else + #define NPY_DEFAULT_INT \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? NPY_INTP : NPY_LONG) + #define NPY_RAVEL_AXIS \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? -1 : 32) + #define NPY_MAXARGS \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? 64 : 32) + #endif + + + /* + * Access inline functions for descriptor fields. Except for the first + * few fields, these needed to be moved (elsize, alignment) for + * additional space. Or they are descriptor specific and are not generally + * available anymore (metadata, c_metadata, subarray, names, fields). + * + * Most of these are defined via the `DESCR_ACCESSOR` macro helper. + */ + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION || NPY_ABI_VERSION < 0x02000000 + /* Compiling for 1.x or 2.x only, direct field access is OK: */ + + static inline void + PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size) + { + dtype->elsize = size; + } + + static inline npy_uint64 + PyDataType_FLAGS(const PyArray_Descr *dtype) + { + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + return dtype->flags; + #else + return (unsigned char)dtype->flags; /* Need unsigned cast on 1.x */ + #endif + } + + #define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\ + static inline type \\ + PyDataType_##FIELD(const PyArray_Descr *dtype) { \\ + if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\ + return (type)0; \\ + } \\ + return ((_PyArray_LegacyDescr *)dtype)->field; \\ + } + #else /* compiling for both 1.x and 2.x */ + + static inline void + PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + ((_PyArray_DescrNumPy2 *)dtype)->elsize = size; + } + else { + ((PyArray_DescrProto *)dtype)->elsize = (int)size; + } + } + + static inline npy_uint64 + PyDataType_FLAGS(const PyArray_Descr *dtype) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + return ((_PyArray_DescrNumPy2 *)dtype)->flags; + } + else { + return (unsigned char)((PyArray_DescrProto *)dtype)->flags; + } + } + + /* Cast to LegacyDescr always fine but needed when `legacy_only` */ + #define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\ + static inline type \\ + PyDataType_##FIELD(const PyArray_Descr *dtype) { \\ + if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\ + return (type)0; \\ + } \\ + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { \\ + return ((_PyArray_LegacyDescr *)dtype)->field; \\ + } \\ + else { \\ + return ((PyArray_DescrProto *)dtype)->field; \\ + } \\ + } + #endif + + DESCR_ACCESSOR(ELSIZE, elsize, npy_intp, 0) + DESCR_ACCESSOR(ALIGNMENT, alignment, npy_intp, 0) + DESCR_ACCESSOR(METADATA, metadata, PyObject *, 1) + DESCR_ACCESSOR(SUBARRAY, subarray, PyArray_ArrayDescr *, 1) + DESCR_ACCESSOR(NAMES, names, PyObject *, 1) + DESCR_ACCESSOR(FIELDS, fields, PyObject *, 1) + DESCR_ACCESSOR(C_METADATA, c_metadata, NpyAuxData *, 1) + + #undef DESCR_ACCESSOR + + + #if !(defined(NPY_INTERNAL_BUILD) && NPY_INTERNAL_BUILD) + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + return _PyDataType_GetArrFuncs(descr); + } + #elif NPY_ABI_VERSION < 0x02000000 + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + return descr->f; + } + #else + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + return _PyDataType_GetArrFuncs(descr); + } + else { + return ((PyArray_DescrProto *)descr)->f; + } + } + #endif + + + #endif /* not internal build */ + + #endif /* NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ */ + + """) diff --git a/pytensor/printing.py b/pytensor/printing.py index 6a18f6e8e5..bc42029c11 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -726,7 +726,7 @@ def _print_fn(op, xin): pmsg = temp() else: pmsg = temp - print(op.message, attr, "=", pmsg) + print(op.message, attr, "=", pmsg) # noqa: T201 class Print(Op): @@ -1657,7 +1657,7 @@ def apply_name(node): raise if print_output_file: - print("The output file is available at", outfile) + print("The output file is available at", outfile) # noqa: T201 class _TagGenerator: @@ -1824,8 +1824,7 @@ def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> s # The __str__ method is encoding the object's id in its str name = position_independent_str(obj) if " at 0x" in name: - print(name) - raise AssertionError() + raise AssertionError(name) prefix = cur_tag + "=" diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 3c33434e56..f8ecabd7b2 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -36,7 +36,6 @@ from pytensor.utils import ( apply_across_args, difference, - from_return_values, to_return_values, ) @@ -184,7 +183,9 @@ def __call__(self, x): for dtype in try_dtypes: x_ = np.asarray(x).astype(dtype=dtype) - if np.all(x == x_): + if np.all( + np.asarray(x) == x_ + ): # use np.asarray(x) to match TensorType.filter break # returns either an exact x_==x, or the last cast x_ return x_ @@ -350,6 +351,8 @@ def c_headers(self, c_compiler=None, **kwargs): # we declare them here and they will be re-used by TensorType l.append("") l.append("") + l.append("") + if config.lib__amdlibm and c_compiler.supports_amdlibm: l += [""] return l @@ -518,73 +521,167 @@ def c_support_code(self, **kwargs): # In that case we add the 'int' type to the real types. real_types.append("int") + # Macros for backwards compatibility with numpy < 2.0 + # + # In numpy 2.0+, these are defined in npy_math.h, but + # for early versions, they must be vendored by users (e.g. PyTensor) + backwards_compat_macros = """ + #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_ + #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_ + + #include + + #ifndef NPY_CSETREALF + #define NPY_CSETREALF(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAGF + #define NPY_CSETIMAGF(c, i) (c)->imag = (i) + #endif + #ifndef NPY_CSETREAL + #define NPY_CSETREAL(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAG + #define NPY_CSETIMAG(c, i) (c)->imag = (i) + #endif + #ifndef NPY_CSETREALL + #define NPY_CSETREALL(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAGL + #define NPY_CSETIMAGL(c, i) (c)->imag = (i) + #endif + + #endif + """ + + def _make_get_set_real_imag(scalar_type: str) -> str: + """Make overloaded getter/setter functions for real/imag parts of numpy complex types. + + The functions called by these getter/setter functions are defining in npy_math.h, or + in the `backward_compat_macros` defined above. + + Args: + scalar_type: float, double, or longdouble + + Returns: + C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the + given type. + """ + complex_type = "npy_c" + scalar_type + suffix = "" if scalar_type == "double" else scalar_type[0] + + if scalar_type == "longdouble": + scalar_type = "npy_" + scalar_type + + return_type = scalar_type + + template = f""" + static inline {return_type} get_real(const {complex_type} z) + {{ + return npy_creal{suffix}(z); + }} + + static inline void set_real({complex_type} *z, const {scalar_type} r) + {{ + NPY_CSETREAL{suffix.upper()}(z, r); + }} + + static inline {return_type} get_imag(const {complex_type} z) + {{ + return npy_cimag{suffix}(z); + }} + + static inline void set_imag({complex_type} *z, const {scalar_type} i) + {{ + NPY_CSETIMAG{suffix.upper()}(z, i); + }} + """ + return template + + get_set_aliases = "\n".join( + _make_get_set_real_imag(stype) + for stype in ["float", "double", "longdouble"] + ) + + get_set_aliases = backwards_compat_macros + "\n" + get_set_aliases + + # Template for defining pytensor_complex64 and pytensor_complex128 structs/classes + # + # The npy_complex64, npy_complex128 types are aliases defined at run time based on + # the size of floats and doubles on the machine. This means that both types are + # not necessarily defined on every machine, but a machine with 32-bit floats and + # 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128 + # as an alias of npy_complex128. + # + # In any case, the get/set real/imag functions defined above will always work for + # npy_complex64 and npy_complex128. template = """ - struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s - { - typedef pytensor_complex%(nbits)s complex_type; - typedef npy_float%(half_nbits)s scalar_type; - - complex_type operator +(const complex_type &y) const { - complex_type ret; - ret.real = this->real + y.real; - ret.imag = this->imag + y.imag; - return ret; - } - - complex_type operator -() const { - complex_type ret; - ret.real = -this->real; - ret.imag = -this->imag; - return ret; - } - bool operator ==(const complex_type &y) const { - return (this->real == y.real) && (this->imag == y.imag); - } - bool operator ==(const scalar_type &y) const { - return (this->real == y) && (this->imag == 0); - } - complex_type operator -(const complex_type &y) const { - complex_type ret; - ret.real = this->real - y.real; - ret.imag = this->imag - y.imag; - return ret; - } - complex_type operator *(const complex_type &y) const { - complex_type ret; - ret.real = this->real * y.real - this->imag * y.imag; - ret.imag = this->real * y.imag + this->imag * y.real; - return ret; - } - complex_type operator /(const complex_type &y) const { - complex_type ret; - scalar_type y_norm_square = y.real * y.real + y.imag * y.imag; - ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square; - ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; - return ret; - } - template - complex_type& operator =(const T& y); - - pytensor_complex%(nbits)s() {} - - template - pytensor_complex%(nbits)s(const T& y) { *this = y; } - - template - pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; } + struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s { + typedef pytensor_complex%(nbits)s complex_type; + typedef npy_float%(half_nbits)s scalar_type; + + complex_type operator+(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) + get_real(y)); + set_imag(&ret, get_imag(*this) + get_imag(y)); + return ret; + } + + complex_type operator-() const { + complex_type ret; + set_real(&ret, -get_real(*this)); + set_imag(&ret, -get_imag(*this)); + return ret; + } + bool operator==(const complex_type &y) const { + return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y)); + } + bool operator==(const scalar_type &y) const { + return (get_real(*this) == y) && (get_real(*this) == 0); + } + complex_type operator-(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) - get_real(y)); + set_imag(&ret, get_imag(*this) - get_imag(y)); + return ret; + } + complex_type operator*(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y)); + set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y)); + return ret; + } + complex_type operator/(const complex_type &y) const { + complex_type ret; + scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y); + set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square); + set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square); + return ret; + } + template complex_type &operator=(const T &y); + + + pytensor_complex%(nbits)s() {} + + template pytensor_complex%(nbits)s(const T &y) { *this = y; } + + template + pytensor_complex%(nbits)s(const TR &r, const TI &i) { + set_real(this, r); + set_imag(this, i); + } }; """ def operator_eq_real(mytype, othertype): return f""" template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y; this->imag=0; return *this; }} + {{ set_real(this, y); set_imag(this, 0); return *this; }} """ def operator_eq_cplx(mytype, othertype): return f""" template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y.real; this->imag=y.imag; return *this; }} + {{ set_real(this, get_real(y)); set_imag(this, get_imag(y)); return *this; }} """ operator_eq = "".join( @@ -606,10 +703,10 @@ def operator_eq_cplx(mytype, othertype): def operator_plus_real(mytype, othertype): return f""" const {mytype} operator+(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real+y, x.imag); }} + {{ return {mytype}(get_real(x) + y, get_imag(x)); }} const {mytype} operator+(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real+y, x.imag); }} + {{ return {mytype}(get_real(x) + y, get_imag(x)); }} """ operator_plus = "".join( @@ -621,10 +718,10 @@ def operator_plus_real(mytype, othertype): def operator_minus_real(mytype, othertype): return f""" const {mytype} operator-(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real-y, x.imag); }} + {{ return {mytype}(get_real(x) - y, get_imag(x)); }} const {mytype} operator-(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(y-x.real, -x.imag); }} + {{ return {mytype}(y - get_real(x), -get_imag(x)); }} """ operator_minus = "".join( @@ -636,10 +733,10 @@ def operator_minus_real(mytype, othertype): def operator_mul_real(mytype, othertype): return f""" const {mytype} operator*(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real*y, x.imag*y); }} + {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }} const {mytype} operator*(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real*y, x.imag*y); }} + {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }} """ operator_mul = "".join( @@ -649,7 +746,8 @@ def operator_mul_real(mytype, othertype): ) return ( - template % dict(nbits=64, half_nbits=32) + get_set_aliases + + template % dict(nbits=64, half_nbits=32) + template % dict(nbits=128, half_nbits=64) + operator_eq + operator_plus @@ -664,7 +762,7 @@ def c_init_code(self, **kwargs): return ["import_array();"] def c_code_cache_version(self): - return (13, np.__version__) + return (14, np.__version__) def get_shape_info(self, obj): return obj.itemsize @@ -1081,6 +1179,16 @@ def real_out(type): return (type,) +def _cast_to_promised_scalar_dtype(x, dtype): + try: + return x.astype(dtype) + except AttributeError: + if dtype == "bool": + return np.bool_(x) + else: + return getattr(np, dtype)(x) + + class ScalarOp(COp): nin = -1 nout = 1 @@ -1134,28 +1242,18 @@ def output_types(self, types): else: raise NotImplementedError(f"Cannot calculate the output types for {self}") - @staticmethod - def _cast_scalar(x, dtype): - if hasattr(x, "astype"): - return x.astype(dtype) - elif dtype == "bool": - return np.bool_(x) - else: - return getattr(np, dtype)(x) - def perform(self, node, inputs, output_storage): if self.nout == 1: - dtype = node.outputs[0].dtype - output_storage[0][0] = self._cast_scalar(self.impl(*inputs), dtype) + output_storage[0][0] = _cast_to_promised_scalar_dtype( + self.impl(*inputs), + node.outputs[0].dtype, + ) else: - variables = from_return_values(self.impl(*inputs)) - assert len(variables) == len(output_storage) # strict=False because we are in a hot loop for out, storage, variable in zip( - node.outputs, output_storage, variables, strict=False + node.outputs, output_storage, self.impl(*inputs), strict=False ): - dtype = out.dtype - storage[0] = self._cast_scalar(variable, dtype) + storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype) def impl(self, *inputs): raise MethodNotDefined("impl", type(self), self.__class__.__name__) @@ -2568,7 +2666,7 @@ def c_code(self, node, name, inputs, outputs, sub): if type in float_types: return f"{z} = fabs({x});" if type in complex_types: - return f"{z} = sqrt({x}.real*{x}.real + {x}.imag*{x}.imag);" + return f"{z} = sqrt(get_real({x}) * get_real({x}) + get_imag({x}) * get_imag({x}));" if node.outputs[0].type == bool: return f"{z} = ({x}) ? 1 : 0;" if type in uint_types: @@ -2967,7 +3065,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (x * np.asarray(math.log(2.0)).astype(x.dtype)),) + return (gz / (x * np.array(math.log(2.0), dtype=x.dtype)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3010,7 +3108,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (x * np.asarray(math.log(10.0)).astype(x.dtype)),) + return (gz / (x * np.array(math.log(10.0), dtype=x.dtype)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3125,7 +3223,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * exp2(x) * log(np.cast[x.type](2)),) + return (gz * exp2(x) * log(np.array(2, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3264,7 +3362,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * np.asarray(np.pi / 180, gz.type),) + return (gz * np.array(np.pi / 180, dtype=gz.type),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3299,7 +3397,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * np.asarray(180.0 / np.pi, gz.type),) + return (gz * np.array(180.0 / np.pi, dtype=gz.type),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3372,7 +3470,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (-gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (-gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3446,7 +3544,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3518,7 +3616,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) + sqr(x)),) + return (gz / (np.array(1, dtype=x.type) + sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3641,7 +3739,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) - np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) - np.array(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3718,7 +3816,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) + np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) + np.array(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3796,7 +3894,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) - sqr(x)),) + return (gz / (np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index a5512c6564..ec7eca76b9 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -9,8 +9,7 @@ from textwrap import dedent import numpy as np -import scipy.special -import scipy.stats +from scipy import special from pytensor.configdefaults import config from pytensor.gradient import grad_not_implemented, grad_undefined @@ -40,7 +39,6 @@ true_div, upcast, upgrade_to_float, - upgrade_to_float64, upgrade_to_float_no_complex, ) from pytensor.scalar.basic import abs as scalar_abs @@ -54,7 +52,7 @@ class Erf(UnaryScalarOp): nfunc_spec = ("scipy.special.erf", 1, 1) def impl(self, x): - return scipy.special.erf(x) + return special.erf(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -88,7 +86,7 @@ class Erfc(UnaryScalarOp): nfunc_spec = ("scipy.special.erfc", 1, 1) def impl(self, x): - return scipy.special.erfc(x) + return special.erfc(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -115,7 +113,7 @@ def c_code(self, node, name, inp, out, sub): return f"{z} = erfc(({cast}){x});" -# scipy.special.erfc don't support complex. Why? +# special.erfc don't support complex. Why? erfc = Erfc(upgrade_to_float_no_complex, name="erfc") @@ -137,7 +135,7 @@ class Erfcx(UnaryScalarOp): nfunc_spec = ("scipy.special.erfcx", 1, 1) def impl(self, x): - return scipy.special.erfcx(x) + return special.erfcx(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -193,7 +191,7 @@ class Erfinv(UnaryScalarOp): nfunc_spec = ("scipy.special.erfinv", 1, 1) def impl(self, x): - return scipy.special.erfinv(x) + return special.erfinv(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -228,7 +226,7 @@ class Erfcinv(UnaryScalarOp): nfunc_spec = ("scipy.special.erfcinv", 1, 1) def impl(self, x): - return scipy.special.erfcinv(x) + return special.erfcinv(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -262,12 +260,8 @@ def c_code(self, node, name, inp, out, sub): class Owens_t(BinaryScalarOp): nfunc_spec = ("scipy.special.owens_t", 2, 1) - @staticmethod - def st_impl(h, a): - return scipy.special.owens_t(h, a) - def impl(self, h, a): - return Owens_t.st_impl(h, a) + return special.owens_t(h, a) def grad(self, inputs, grads): (h, a) = inputs @@ -291,12 +285,8 @@ def c_code(self, *args, **kwargs): class Gamma(UnaryScalarOp): nfunc_spec = ("scipy.special.gamma", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.gamma(x) - def impl(self, x): - return Gamma.st_impl(x) + return special.gamma(x) def L_op(self, inputs, outputs, gout): (x,) = inputs @@ -330,12 +320,8 @@ class GammaLn(UnaryScalarOp): nfunc_spec = ("scipy.special.gammaln", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.gammaln(x) - def impl(self, x): - return GammaLn.st_impl(x) + return special.gammaln(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -374,12 +360,8 @@ class Psi(UnaryScalarOp): nfunc_spec = ("scipy.special.psi", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.psi(x) - def impl(self, x): - return Psi.st_impl(x) + return special.psi(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -465,12 +447,8 @@ class TriGamma(UnaryScalarOp): """ - @staticmethod - def st_impl(x): - return scipy.special.polygamma(1, x) - def impl(self, x): - return TriGamma.st_impl(x) + return special.polygamma(1, x) def L_op(self, inputs, outputs, outputs_gradients): (x,) = inputs @@ -568,12 +546,8 @@ def output_types_preference(n_type, x_type): # Scipy doesn't support it return upgrade_to_float_no_complex(x_type) - @staticmethod - def st_impl(n, x): - return scipy.special.polygamma(n, x) - def impl(self, n, x): - return PolyGamma.st_impl(n, x) + return special.polygamma(n, x) def L_op(self, inputs, outputs, output_gradients): (n, x) = inputs @@ -592,50 +566,6 @@ def c_code(self, *args, **kwargs): polygamma = PolyGamma(name="polygamma") -class Chi2SF(BinaryScalarOp): - """ - Compute (1 - chi2_cdf(x)) - ie. chi2 pvalue (chi2 'survival function') - """ - - nfunc_spec = ("scipy.stats.chi2.sf", 2, 1) - - @staticmethod - def st_impl(x, k): - return scipy.stats.chi2.sf(x, k) - - def impl(self, x, k): - return Chi2SF.st_impl(x, k) - - def c_support_code(self, **kwargs): - return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") - - def c_code(self, node, name, inp, out, sub): - x, k = inp - (z,) = out - if node.inputs[0].type in float_types: - dtype = "npy_" + node.outputs[0].dtype - return f"""{z} = - ({dtype}) 1 - GammaP({k}/2., {x}/2.);""" - raise NotImplementedError("only floatingpoint is implemented") - - def __eq__(self, other): - return type(self) is type(other) - - def __hash__(self): - return hash(type(self)) - - def c_code_cache_version(self): - v = super().c_code_cache_version() - if v: - return (2, *v) - else: - return v - - -chi2sf = Chi2SF(upgrade_to_float64, name="chi2sf") - - class GammaInc(BinaryScalarOp): """ Compute the regularized lower gamma function (P). @@ -643,12 +573,8 @@ class GammaInc(BinaryScalarOp): nfunc_spec = ("scipy.special.gammainc", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammainc(k, x) - def impl(self, k, x): - return GammaInc.st_impl(k, x) + return special.gammainc(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -694,12 +620,8 @@ class GammaIncC(BinaryScalarOp): nfunc_spec = ("scipy.special.gammaincc", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincc(k, x) - def impl(self, k, x): - return GammaIncC.st_impl(k, x) + return special.gammaincc(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -745,12 +667,8 @@ class GammaIncInv(BinaryScalarOp): nfunc_spec = ("scipy.special.gammaincinv", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincinv(k, x) - def impl(self, k, x): - return GammaIncInv.st_impl(k, x) + return special.gammaincinv(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -774,12 +692,8 @@ class GammaIncCInv(BinaryScalarOp): nfunc_spec = ("scipy.special.gammainccinv", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammainccinv(k, x) - def impl(self, k, x): - return GammaIncCInv.st_impl(k, x) + return special.gammainccinv(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -1013,12 +927,8 @@ class GammaU(BinaryScalarOp): # Note there is no basic SciPy version so no nfunc_spec. - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincc(k, x) * scipy.special.gamma(k) - def impl(self, k, x): - return GammaU.st_impl(k, x) + return special.gammaincc(k, x) * special.gamma(k) def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") @@ -1049,12 +959,8 @@ class GammaL(BinaryScalarOp): # Note there is no basic SciPy version so no nfunc_spec. - @staticmethod - def st_impl(k, x): - return scipy.special.gammainc(k, x) * scipy.special.gamma(k) - def impl(self, k, x): - return GammaL.st_impl(k, x) + return special.gammainc(k, x) * special.gamma(k) def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") @@ -1085,12 +991,8 @@ class Jv(BinaryScalarOp): nfunc_spec = ("scipy.special.jv", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.jv(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return special.jv(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1114,12 +1016,8 @@ class J1(UnaryScalarOp): nfunc_spec = ("scipy.special.j1", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.j1(x) - def impl(self, x): - return self.st_impl(x) + return special.j1(x) def grad(self, inputs, grads): (x,) = inputs @@ -1145,12 +1043,8 @@ class J0(UnaryScalarOp): nfunc_spec = ("scipy.special.j0", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.j0(x) - def impl(self, x): - return self.st_impl(x) + return special.j0(x) def grad(self, inp, grads): (x,) = inp @@ -1176,12 +1070,8 @@ class Iv(BinaryScalarOp): nfunc_spec = ("scipy.special.iv", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.iv(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return special.iv(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1205,12 +1095,8 @@ class I1(UnaryScalarOp): nfunc_spec = ("scipy.special.i1", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.i1(x) - def impl(self, x): - return self.st_impl(x) + return special.i1(x) def grad(self, inputs, grads): (x,) = inputs @@ -1231,12 +1117,8 @@ class I0(UnaryScalarOp): nfunc_spec = ("scipy.special.i0", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.i0(x) - def impl(self, x): - return self.st_impl(x) + return special.i0(x) def grad(self, inp, grads): (x,) = inp @@ -1257,12 +1139,8 @@ class Ive(BinaryScalarOp): nfunc_spec = ("scipy.special.ive", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.ive(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return special.ive(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1286,12 +1164,8 @@ class Kve(BinaryScalarOp): nfunc_spec = ("scipy.special.kve", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.kve(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return special.kve(v, x) def L_op(self, inputs, outputs, output_grads): v, x = inputs @@ -1321,7 +1195,7 @@ class Sigmoid(UnaryScalarOp): nfunc_spec = ("scipy.special.expit", 1, 1) def impl(self, x): - return scipy.special.expit(x) + return special.expit(x) def grad(self, inp, grads): (x,) = inp @@ -1372,8 +1246,7 @@ class Softplus(UnaryScalarOp): "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" """ - @staticmethod - def static_impl(x): + def impl(self, x): # If x is an int8 or uint8, numpy.exp will compute the result in # half-precision (float16), where we want float32. not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8") @@ -1388,9 +1261,6 @@ def static_impl(x): else: return x - def impl(self, x): - return Softplus.static_impl(x) - def grad(self, inp, grads): (x,) = inp (gz,) = grads @@ -1453,16 +1323,12 @@ class Log1mexp(UnaryScalarOp): "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" """ - @staticmethod - def static_impl(x): + def impl(self, x): if x < np.log(0.5): return np.log1p(-np.exp(x)) else: return np.log(-np.expm1(x)) - def impl(self, x): - return Log1mexp.static_impl(x) - def grad(self, inp, grads): (x,) = inp (gz,) = grads @@ -1496,7 +1362,7 @@ class BetaInc(ScalarOp): nfunc_spec = ("scipy.special.betainc", 3, 1) def impl(self, a, b, x): - return scipy.special.betainc(a, b, x) + return special.betainc(a, b, x) def grad(self, inp, grads): a, b, x = inp @@ -1756,7 +1622,7 @@ class BetaIncInv(ScalarOp): nfunc_spec = ("scipy.special.betaincinv", 3, 1) def impl(self, a, b, x): - return scipy.special.betaincinv(a, b, x) + return special.betaincinv(a, b, x) def grad(self, inputs, grads): (a, b, x) = inputs @@ -1794,12 +1660,8 @@ class Hyp2F1(ScalarOp): nin = 4 nfunc_spec = ("scipy.special.hyp2f1", 4, 1) - @staticmethod - def st_impl(a, b, c, z): - return scipy.special.hyp2f1(a, b, c, z) - def impl(self, a, b, c, z): - return Hyp2F1.st_impl(a, b, c, z) + return special.hyp2f1(a, b, c, z) def grad(self, inputs, grads): a, b, c, z = inputs diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index a01347ef9c..1dbc93b9fa 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -72,6 +72,7 @@ from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.replace import clone_replace +from pytensor.graph.type import HasShape from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.link.c.basic import CLinker from pytensor.printing import op_debug_information @@ -2509,13 +2510,25 @@ def compute_all_gradients(known_grads): return rval var_mappings = self.get_oinp_iinp_iout_oout_mappings() - dC_dinps_t = [None for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs] + + n_mit_mot_outs = info.n_mit_mot_outs + # In the case of mit-mot there can be more inner outputs than outer ones + n_extra_mit_mot_outs = n_mit_mot_outs - info.n_mit_mot + idx_nitsot_out_start = n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot + idx_nitsot_out_end = idx_nitsot_out_start + info.n_nit_sot + + # Create dummy variables for the internal input gradients + states = ( + self.inner_mitmot(self_inputs) + + self.inner_mitsot(self_inputs) + + self.inner_sitsot(self_inputs) + ) dC_dXts = [] Xts = [] for idx, Xt in enumerate(diff_outputs): # We are looking for x[t-1] for a given x[t] - if idx >= info.n_mit_mot_outs: + if idx >= n_mit_mot_outs: Xt_placeholder = safe_new(Xt) Xts.append(Xt_placeholder) @@ -2523,9 +2536,7 @@ def compute_all_gradients(known_grads): # or not. NOTE : This cannot be done by using # "if Xt not in self.inner_nitsot_outs(self_outputs)" because # the exact same variable can be used as multiple outputs. - idx_nitsot_start = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot - idx_nitsot_end = idx_nitsot_start + info.n_nit_sot - if idx < idx_nitsot_start or idx >= idx_nitsot_end: + if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end: # What we do here is loop through dC_douts and collect all # those that are connected to the specific one and do an # upcast on all of their dtypes to get the dtype for this @@ -2533,12 +2544,6 @@ def compute_all_gradients(known_grads): # specific previous step is defined or not is done somewhere # else. dtypes = [] - states = ( - self.inner_mitmot(self_inputs) - + self.inner_mitsot(self_inputs) - + self.inner_sitsot(self_inputs) - ) - for pos, inp in enumerate(states): if inp in graph_inputs([Xt]): # Get the index of the outer output that to which @@ -2555,35 +2560,43 @@ def compute_all_gradients(known_grads): new_dtype = config.floatX dC_dXt = safe_new(Xt, dtype=new_dtype) else: - if isinstance(dC_douts[idx].type, DisconnectedType): + # nit-sot outputs + # If not disconnected assume the output gradient type is a valid type for the input gradient + if isinstance( + dC_douts[idx - n_extra_mit_mot_outs].type, DisconnectedType + ): continue - dC_dXt = safe_new(dC_douts[idx][0]) + dC_dXt = safe_new(dC_douts[idx - n_extra_mit_mot_outs][0]) dC_dXts.append(dC_dXt) + # Handle cases where the very same variable may be used as different outputs + # TODO: Couldn't we add a view Op to avoid this when building the Scan graph? known_grads = {} dc_dxts_idx = 0 for i in range(len(diff_outputs)): - if i < idx_nitsot_start or i >= idx_nitsot_end: - if diff_outputs[i] in known_grads: - known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] - else: - known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] - dc_dxts_idx += 1 + if not (i < idx_nitsot_out_start or i >= idx_nitsot_out_end) and isinstance( + dC_douts[i - n_extra_mit_mot_outs].type, DisconnectedType + ): + # Special case where we don't have a dC_dXt for disconnected nitsot outputs + continue + + # Just some trouble to avoid a +0 + if diff_outputs[i] in known_grads: + known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] else: - if isinstance(dC_douts[i].type, DisconnectedType): - continue - else: - if diff_outputs[i] in known_grads: - known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] - else: - known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] - dc_dxts_idx += 1 + known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] + dc_dxts_idx += 1 + dC_dinps_t = compute_all_gradients(known_grads) # mask inputs that get no gradients for dx in range(len(dC_dinps_t)): - if not dC_dinps_t[dx]: - dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx]) + if dC_dinps_t[dx] is None: + dC_dinps_t[dx] = dC_dinps_t[dx] = ( + pt.zeros_like(diff_inputs[dx]) + if isinstance(diff_inputs[dx].type, HasShape) + else pt.zeros(()) + ) else: disconnected_dC_dinps_t[dx] = False for Xt, Xt_placeholder in zip( @@ -2846,7 +2859,6 @@ def compute_all_gradients(known_grads): for idx in range(info.n_sit_sot): mitmot_inp_taps.append([0, 1]) mitmot_out_taps.append([1]) - through_shared = False if not isinstance(dC_douts[idx + offset].type, DisconnectedType): outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) else: @@ -2958,7 +2970,8 @@ def compute_all_gradients(known_grads): else: outer_inp_sitsot.append( pt.zeros( - [grad_steps + 1] + [x.shape[i] for i in range(x.ndim)], + [grad_steps + 1] + + (list(x.shape) if isinstance(x.type, HasShape) else []), dtype=y.dtype, ) ) @@ -3007,9 +3020,7 @@ def compute_all_gradients(known_grads): name=f"grad_of_{self.name}" if self.name else None, allow_gc=self.allow_gc, ) - outputs = local_op(*outer_inputs) - if not isinstance(outputs, list | tuple): - outputs = [outputs] + outputs = local_op(*outer_inputs, return_list=True) # Re-order the gradients correctly gradients = [DisconnectedType()()] @@ -3095,7 +3106,6 @@ def compute_all_gradients(known_grads): ) ) - start = len(gradients) gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)] begin = end @@ -3155,7 +3165,12 @@ def R_op(self, inputs, eval_points): rop_self_outputs = self_outputs if info.n_shared_outs > 0: rop_self_outputs = rop_self_outputs[: -info.n_shared_outs] - rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points) + rop_outs = Rop( + rop_self_outputs, + rop_of_inputs, + inner_eval_points, + use_op_rop_implementation=True, + ) if not isinstance(rop_outs, list | tuple): rop_outs = [rop_outs] # Step 2. Figure out what corresponds to what in the scan diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index c590bc804a..7f200b2a7c 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -3610,7 +3610,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3647,11 +3647,11 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp nnz = PyArray_DIMS({_indices})[0]; npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; @@ -3744,7 +3744,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3782,11 +3782,11 @@ def c_code(self, node, name, inputs, outputs, sub): // extract number of rows npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index bf6d6f0bc6..13735d2aca 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -158,8 +158,8 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{y}* ydata = (dtype_{y}*)PyArray_DATA({y}); dtype_{z}* zdata = (dtype_{z}*)PyArray_DATA({z}); - npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_DESCR({y})->elsize; - npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; + npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_ITEMSIZE({y}); + npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); npy_intp pos; if ({format} == 0){{ @@ -186,7 +186,7 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[3]] def c_code_cache_version(self): - return (2,) + return (3,) @node_rewriter([sparse.AddSD]) @@ -361,13 +361,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -436,7 +436,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3,) + return (4,) sd_csc = StructuredDotCSC() @@ -555,13 +555,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -614,7 +614,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ def c_code_cache_version(self): - return (2,) + return (3,) sd_csr = StructuredDotCSR() @@ -845,12 +845,12 @@ def c_code(self, node, name, inputs, outputs, sub): const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA({x_ptr}); const dtype_{alpha} alpha = ((dtype_{alpha}*)PyArray_DATA({alpha}))[0]; - npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_DESCR({zn})->elsize; - npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_DESCR({x_val})->elsize; - npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_DESCR({x_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_DESCR({x_ptr})->elsize; - npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_DESCR({y})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_ITEMSIZE({zn}); + npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_ITEMSIZE({x_val}); + npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_ITEMSIZE({x_ind}); + npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_ITEMSIZE({x_ptr}); + npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_ITEMSIZE({y}); // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction if ((N > 0x7fffffffL)||(Sy > 0x7fffffffL)||(Szn > 0x7fffffffL)||(Sy < -0x7fffffffL)||(Szn < -0x7fffffffL)) @@ -896,7 +896,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3, blas.blas_header_version()) + return (4, blas.blas_header_version()) usmm_csc_dense = UsmmCscDense(inplace=False) @@ -1035,13 +1035,13 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; - npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_DESCR({b_val})->elsize; - npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_DESCR({b_ind})->elsize; - npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_DESCR({b_ptr})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); + npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_ITEMSIZE({b_val}); + npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_ITEMSIZE({b_ind}); + npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_ITEMSIZE({b_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -1086,7 +1086,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ def c_code_cache_version(self): - return (3,) + return (4,) csm_grad_c = CSMGradC() @@ -1482,7 +1482,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (2,) + return (3,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1544,7 +1544,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over rows for (npy_intp j = 0; j < N; ++j) @@ -1655,7 +1655,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (3,) + return (4,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1723,7 +1723,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over columns for (npy_intp j = 0; j < N; ++j) @@ -1868,7 +1868,7 @@ def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols): ) def c_code_cache_version(self): - return (4, blas.blas_header_version()) + return (5, blas.blas_header_version()) def c_support_code(self, **kwargs): return blas.blas_header_text() @@ -1995,14 +1995,14 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{z_ind}* __restrict__ Dzi = (dtype_{z_ind}*)PyArray_DATA({z_ind}); dtype_{z_ptr}* __restrict__ Dzp = (dtype_{z_ptr}*)PyArray_DATA({z_ptr}); - const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_DESCR({x})->elsize; - const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; - const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_DESCR({p_data})->elsize; - const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_DESCR({p_ind})->elsize; - const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_DESCR({p_ptr})->elsize; - const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_DESCR({z_data})->elsize; - const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_DESCR({z_ind})->elsize; - const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_DESCR({z_ptr})->elsize; + const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_ITEMSIZE({x}); + const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); + const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_ITEMSIZE({p_data}); + const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_ITEMSIZE({p_ind}); + const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_ITEMSIZE({p_ptr}); + const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_ITEMSIZE({z_data}); + const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_ITEMSIZE({z_ind}); + const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_ITEMSIZE({z_ptr}); memcpy(Dzi, Dpi, PyArray_DIMS({p_ind})[0]*sizeof(dtype_{p_ind})); memcpy(Dzp, Dpp, PyArray_DIMS({p_ptr})[0]*sizeof(dtype_{p_ptr})); diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 67b6ab071e..88d3f33199 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -123,7 +123,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # isort: on # Allow accessing numpy constants from pytensor.tensor -from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi +from numpy import e, euler_gamma, inf, nan, newaxis, pi from pytensor.tensor.basic import * from pytensor.tensor.blas import batched_dot, batched_tensordot diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 8ee9894c9d..061a159fc2 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -14,8 +14,7 @@ from typing import cast as type_cast import numpy as np -from numpy.core.multiarray import normalize_axis_index -from numpy.core.numeric import normalize_axis_tuple +from numpy.exceptions import AxisError import pytensor import pytensor.scalar.sharedvar @@ -32,6 +31,7 @@ from pytensor.graph.type import HasShape, Type from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.raise_op import CheckAndRaise from pytensor.scalar import int32 @@ -228,7 +228,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: elif x_.ndim > ndim: try: x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim))) - except np.AxisError: + except AxisError: raise ValueError( f"ndarray could not be cast to constant with {int(ndim)} dimensions" ) @@ -613,7 +613,6 @@ def get_scalar_constant_value( """ if isinstance(v, TensorVariable | np.ndarray): if v.ndim != 0: - print(v, v.ndim) raise NotScalarConstantError("Input ndim != 0") return get_underlying_scalar_constant_value( v, @@ -4406,7 +4405,7 @@ def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVa axis = (axis,) out_ndim = len(axis) + a.ndim - axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim) + axis = normalize_axis_tuple(axis, out_ndim) if not axis: return a diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index d0f524e413..592a4ba27c 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -498,7 +498,7 @@ def c_header_dirs(self, **kwargs): int unit = 0; int type_num = PyArray_DESCR(%(_x)s)->type_num; - int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes npy_intp* Nx = PyArray_DIMS(%(_x)s); npy_intp* Ny = PyArray_DIMS(%(_y)s); @@ -789,7 +789,7 @@ def build_gemm_call(self): ) def build_gemm_version(self): - return (13, blas_header_version()) + return (14, blas_header_version()) class Gemm(GemmRelated): @@ -1030,7 +1030,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(x_new, %(_x)s) == -1) + if(PyArray_CopyInto(x_new, %(_x)s) == -1) { %(fail)s } @@ -1056,7 +1056,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(y_new, %(_y)s) == -1) + if(PyArray_CopyInto(y_new, %(_y)s) == -1) { %(fail)s } @@ -1102,7 +1102,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): gv = self.build_gemm_version() if gv: - return (7, *gv) + return (8, *gv) else: return gv @@ -1538,7 +1538,7 @@ def contiguous(var, ndim): return f""" int type_num = PyArray_DESCR({_x})->type_num; - int type_size = PyArray_DESCR({_x})->elsize; // in bytes + int type_size = PyArray_ITEMSIZE({_x}); // in bytes if (PyArray_NDIM({_x}) != 3) {{ PyErr_Format(PyExc_NotImplementedError, @@ -1598,7 +1598,7 @@ def contiguous(var, ndim): def c_code_cache_version(self): from pytensor.tensor.blas_headers import blas_header_version - return (5, blas_header_version()) + return (6, blas_header_version()) def grad(self, inp, grads): x, y = inp diff --git a/pytensor/tensor/blas_headers.py b/pytensor/tensor/blas_headers.py index 645f04bfb3..5d49b70ec4 100644 --- a/pytensor/tensor/blas_headers.py +++ b/pytensor/tensor/blas_headers.py @@ -1053,7 +1053,7 @@ def openblas_threads_text(): def blas_header_version(): # Version for the base header - version = (9,) + version = (10,) if detect_macos_sdot_bug(): if detect_macos_sdot_bug.fix_works: # Version with fix @@ -1071,7 +1071,7 @@ def ____gemm_code(check_ab, a_init, b_init): const char * error_string = NULL; int type_num = PyArray_DESCR(_x)->type_num; - int type_size = PyArray_DESCR(_x)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(_x); // in bytes npy_intp* Nx = PyArray_DIMS(_x); npy_intp* Ny = PyArray_DIMS(_y); diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index d1dfe44b90..fc937bf404 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -8,6 +8,7 @@ from math import gcd import numpy as np +from numpy.exceptions import ComplexWarning try: @@ -2338,7 +2339,7 @@ def conv( bval = _bvalfromboundary("fill") with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) for b in range(img.shape[0]): for g in range(self.num_groups): for n in range(output_channel_offset): diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index cba40ec6f8..660c16d387 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -6,13 +6,14 @@ from typing import cast import numpy as np -from numpy.core.einsumfunc import _find_contraction, _parse_einsum_input # type: ignore -from numpy.core.numeric import ( # type: ignore + +from pytensor.compile.builders import OpFromGraph +from pytensor.npy_2_compat import ( + _find_contraction, + _parse_einsum_input, normalize_axis_index, normalize_axis_tuple, ) - -from pytensor.compile.builders import OpFromGraph from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( arange, @@ -255,7 +256,7 @@ def _general_dot( .. testoutput:: - (3, 4, 2) + (np.int64(3), np.int64(4), np.int64(2)) """ # Shortcut for non batched case if not batch_axes[0] and not batch_axes[1]: diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index cb60427ba0..37acfc8e86 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -4,7 +4,6 @@ from typing import Literal import numpy as np -from numpy.core.numeric import normalize_axis_tuple import pytensor.tensor.basic from pytensor.configdefaults import config @@ -17,6 +16,7 @@ from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp from pytensor.link.c.params_type import ParamsType from pytensor.misc.frozendict import frozendict +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type from pytensor.scalar.basic import bool as scalar_bool @@ -41,9 +41,6 @@ from pytensor.utils import uniq -_numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] - - class DimShuffle(ExternalCOp): """ Allows to reorder the dimensions of a tensor or insert or remove @@ -166,15 +163,20 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): self.transposition = self.shuffle + drop # List of dimensions of the output that are broadcastable and were not # in the original input - self.augment = sorted(i for i, x in enumerate(new_order) if x == "x") + self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.drop = drop - self.is_left_expand_dims = self.augment and ( + dims_are_shuffled = sorted(self.shuffle) != self.shuffle + + self.is_transpose = dims_are_shuffled and not augment and not drop + self.is_squeeze = drop and not dims_are_shuffled and not augment + self.is_expand_dims = augment and not dims_are_shuffled and not drop + self.is_left_expand_dims = self.is_expand_dims and ( input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) ) - self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list( - range(input_ndim) - ) + self.is_right_expand_dims = self.is_expand_dims and new_order[ + :input_ndim + ] == list(range(input_ndim)) if self.inplace: self.view_map = {0: [0]} @@ -215,16 +217,15 @@ def make_node(self, inp): return Apply(self, [input], [output]) def __str__(self): - shuffle = sorted(self.shuffle) != self.shuffle - if self.augment and not (shuffle or self.drop): + if self.is_expand_dims: if len(self.augment) == 1: return f"ExpandDims{{axis={self.augment[0]}}}" return f"ExpandDims{{axes={self.augment}}}" - if self.drop and not (self.augment or shuffle): + if self.is_squeeze: if len(self.drop) == 1: - return f"DropDims{{axis={self.drop[0]}}}" - return f"DropDims{{axes={self.drop}}}" - if shuffle and not (self.augment or self.drop): + return f"Squeeze{{axis={self.drop[0]}}}" + return f"Squeeze{{axes={self.drop}}}" + if self.is_transpose: return f"Transpose{{axes={self.shuffle}}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" @@ -667,7 +668,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): and isinstance(self.nfunc, np.ufunc) and node.inputs[0].dtype in discrete_dtypes ): - char = np.sctype2char(out_dtype) + char = np.dtype(out_dtype).char sig = char * node.nin + "->" + char * node.nout node.tag.sig = sig node.tag.fake_node = Apply( diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index fedcd32ab9..7a1bc75b0b 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -2,7 +2,6 @@ from collections.abc import Collection, Iterable import numpy as np -from numpy.core.multiarray import normalize_axis_index import pytensor import pytensor.scalar.basic as ps @@ -17,8 +16,14 @@ from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import EnumList, Generic +from pytensor.npy_2_compat import ( + normalize_axis_index, + npy_2_compat_header, + numpy_axis_is_none_flag, + old_np_unique, +) from pytensor.raise_op import Assert -from pytensor.scalar import int32 as int_t +from pytensor.scalar import int64 as int_t from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb @@ -43,6 +48,7 @@ from pytensor.tensor.shape import Shape_i from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector +from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.variable import TensorVariable from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH @@ -298,7 +304,11 @@ def __init__(self, axis: int | None = None, mode="add"): self.axis = axis self.mode = mode - c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) + @property + def c_axis(self) -> int: + if self.axis is None: + return numpy_axis_is_none_flag + return self.axis def make_node(self, x): x = ptb.as_tensor_variable(x) @@ -355,24 +365,37 @@ def infer_shape(self, fgraph, node, shapes): return shapes + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inames, onames, sub): (x,) = inames (z,) = onames fail = sub["fail"] params = sub["params"] - code = f""" - int axis = {params}->c_axis; + if self.axis is None: + axis_code = "int axis = NPY_RAVEL_AXIS;\n" + else: + axis_code = f"int axis = {params}->c_axis;\n" + + code = ( + axis_code + + f""" + #undef NPY_UF_DBG_TRACING + #define NPY_UF_DBG_TRACING 1 + if (axis == 0 && PyArray_NDIM({x}) == 1) - axis = NPY_MAXDIMS; + axis = NPY_RAVEL_AXIS; npy_intp shape[1] = {{ PyArray_SIZE({x}) }}; - if(axis == NPY_MAXDIMS && !({z} && PyArray_DIMS({z})[0] == shape[0])) + if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0])) {{ Py_XDECREF({z}); - {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x})); + {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x})); }} - else if(axis != NPY_MAXDIMS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) + else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) {{ Py_XDECREF({z}); {z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x})); @@ -399,11 +422,12 @@ def c_code(self, node, name, inames, onames, sub): Py_XDECREF(t); }} """ + ) return code def c_code_cache_version(self): - return (8,) + return (9,) def __str__(self): return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}" @@ -594,11 +618,7 @@ def squeeze(x, axis=None): elif not isinstance(axis, Collection): axis = (axis,) - # scalar inputs are treated as 1D regarding axis in this `Op` - try: - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=_x.ndim) + axis = normalize_reduce_axis(axis, ndim=_x.ndim) if not axis: # Nothing to do @@ -646,12 +666,17 @@ class Repeat(Op): __props__ = ("axis",) - def __init__(self, axis=None): + def __init__(self, axis: int | None = None): + if axis is not None: + if not isinstance(axis, int) or axis < 0: + raise ValueError( + f"Repeat only accepts positive integer axis or None, got {axis}" + ) self.axis = axis def make_node(self, x, repeats): x = ptb.as_tensor_variable(x) - repeats = ptb.as_tensor_variable(repeats) + repeats = ptb.as_tensor_variable(repeats, dtype="int64") if repeats.dtype not in integer_dtypes: raise TypeError("repeats.dtype must be an integer.") @@ -687,17 +712,12 @@ def make_node(self, x, repeats): out_shape = list(x.type.shape) out_shape[self.axis] = None - out_type = TensorType( - x.dtype, shape=tuple(1 if s == 1 else None for s in out_shape) - ) - + out_type = TensorType(x.dtype, shape=out_shape) return Apply(self, [x, repeats], [out_type()]) def perform(self, node, inputs, output_storage): - x = inputs[0] - repeats = inputs[1] - z = output_storage[0] - z[0] = np.repeat(x, repeats=repeats, axis=self.axis) + [x, repeats] = inputs + output_storage[0][0] = np.repeat(x, repeats=repeats, axis=self.axis) def connection_pattern(self, node): return [[True], [False]] @@ -705,40 +725,51 @@ def connection_pattern(self, node): def grad(self, inputs, gout): (x, repeats) = inputs (gz,) = gout + axis = self.axis if repeats.ndim == 0: - if self.axis is None: - axis = x.ndim - else: - if self.axis >= 0: - axis = self.axis + 1 - else: - axis = self.axis + x.ndim + 1 + # When axis is a scalar (same number of reps for all elements), + # We can split the repetitions into their own axis with reshape and sum them back + # to the original element location + sum_axis = x.ndim if axis is None else axis + 1 + shape = list(x.shape) + shape.insert(sum_axis, repeats) + gx = gz.reshape(shape).sum(axis=sum_axis) - shape = [x.shape[k] for k in range(x.ndim)] - shape.insert(axis, repeats) - - return [ - gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis), - DisconnectedType()(), - ] elif repeats.ndim == 1: - # For this implementation, we would need to specify the length - # of repeats in order to split gz in the right way to sum - # the good part. - raise NotImplementedError() + # To sum the gradients that belong to the same repeated x, + # We create a repeated eye and dot product it with the gradient. + axis_size = x.size if axis is None else x.shape[axis] + repeated_eye = repeat( + ptb.eye(axis_size), repeats, axis=0 + ) # A sparse repeat would be neat + + if axis is None: + gx = gz @ repeated_eye + # Undo the ravelling when axis=None + gx = gx.reshape(x.shape) + else: + # Place gradient axis at end for dot product + gx = ptb.moveaxis(gz, axis, -1) + gx = gx @ repeated_eye + # Place gradient back into the correct axis + gx = ptb.moveaxis(gx, -1, axis) + else: raise ValueError() + return [gx, DisconnectedType()()] + def infer_shape(self, fgraph, node, ins_shapes): i0_shapes = ins_shapes[0] repeats = node.inputs[1] out_shape = list(i0_shapes) + axis = self.axis # uint64 shape are not supported. dtype = None if repeats.dtype in ("uint8", "uint16", "uint32"): dtype = "int64" - if self.axis is None: + if axis is None: if repeats.ndim == 0: if len(i0_shapes) == 0: out_shape = [repeats] @@ -751,82 +782,115 @@ def infer_shape(self, fgraph, node, ins_shapes): out_shape = [pt_sum(repeats, dtype=dtype)] else: if repeats.ndim == 0: - out_shape[self.axis] = out_shape[self.axis] * repeats + out_shape[axis] = out_shape[axis] * repeats else: - out_shape[self.axis] = pt_sum(repeats, dtype=dtype) + out_shape[axis] = pt_sum(repeats, dtype=dtype) return [out_shape] -def repeat(x, repeats, axis=None): - """Repeat elements of an array. +def repeat( + a: TensorLike, repeats: TensorLike, axis: int or None = None +) -> TensorVariable: + """Repeat elements of a tensor. - It returns an array which has the same shape as `x`, except along the given - `axis`. The `axis` parameter is used to specify the axis along which values - are repeated. By default, a flattened version of `x` is used. + See :func:`numpy.repeat` for more information. - The number of repetitions for each element is `repeats`. `repeats` is - broadcasted to fit the length of the given `axis`. Parameters ---------- - x - Input data, tensor variable. - repeats - int, scalar or tensor variable + a: tensor_like + Input tensor + repeats: tensor_like + The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. axis : int, optional + The axis along which to repeat values. By default, use the flattened input array, and return a flat output array. - See Also + Returns + ------- + repeated_tensor: TensorVariable + Output tensor which as the same shape as a, except along the given axis + + Examples -------- - tensor.tile + + .. testcode:: + + import pytensor.tensor as pt + + a = pt.arange(4).reshape((2, 2)) + out = pt.repeat(a, repeats=[2, 3], axis=0) + print(out.eval()) + + .. testoutput:: + + [[0 1] + [0 1] + [2 3] + [2 3] + [2 3]] + + When axis is None, the array is first flattened and then repeated + + .. testcode:: + + import pytensor.tensor as pt + + a = pt.arange(4).reshape((2, 2)) + out = pt.repeat(a, repeats=[2, 3, 0, 1], axis=None) + print(out.eval()) + + .. testoutput:: + + [0 0 1 1 1 3] + .. versionadded:: 0.6 """ + a = ptb.as_tensor_variable(a) + + if axis is not None: + axis = normalize_axis_index(axis, a.ndim) + repeats = ptb.as_tensor_variable(repeats, dtype=np.int64) if repeats.ndim > 1: raise ValueError("The dimension of repeats should not exceed 1.") if repeats.ndim == 1 and not repeats.broadcastable[0]: - return Repeat(axis=axis)(x, repeats) + # We only use the Repeat Op for vector repeats + return Repeat(axis=axis)(a, repeats) else: if repeats.ndim == 1: repeats = repeats[0] - if x.dtype == "uint64": + if a.dtype == "uint64": + # Multiplying int64 (shape) by uint64 (repeats) yields a float64 + # Which is not valid for the `reshape` operation at the end raise TypeError("repeat doesn't support dtype uint64") if axis is None: axis = 0 - x = x.flatten() - else: - if axis >= x.ndim: - raise ValueError("Axis should not exceed x.ndim-1.") - if axis < 0: - axis = x.ndim + axis + a = a.flatten() - shape = [x.shape[i] for i in range(x.ndim)] + repeat_shape = list(a.shape) - # shape_ is the shape of the intermediate tensor which has + # alloc_shape is the shape of the intermediate tensor which has # an additional dimension comparing to x. We use alloc to # allocate space for this intermediate tensor to replicate x # along that additional dimension. - shape_ = shape[:] - shape_.insert(axis + 1, repeats) + alloc_shape = repeat_shape[:] + alloc_shape.insert(axis + 1, repeats) - # shape is now the shape of output, where shape[axis] becomes + # repeat_shape is now the shape of output, where shape[axis] becomes # shape[axis]*repeats. - shape[axis] = shape[axis] * repeats - - # dims_ is the dimension of that intermediate tensor. - dims_ = list(np.arange(x.ndim)) - dims_.insert(axis + 1, "x") + repeat_shape[axis] = repeat_shape[axis] * repeats # After the original tensor is duplicated along the additional - # dimension, we reshape it to the expected output shape, and - # return the output z. - z = ptb.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape) - return z + # dimension, we reshape it to the expected output shape + return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape( + repeat_shape + ) class Bartlett(Op): @@ -1163,6 +1227,9 @@ class Unique(Op): """ Wraps `numpy.unique`. + The indices returned when `return_inverse` is True are ravelled + to match the behavior of `numpy.unique` from before numpy version 2.0. + Examples -------- >>> import numpy as np @@ -1208,17 +1275,21 @@ def make_node(self, x): outputs = [TensorType(dtype=x.dtype, shape=out_shape)()] typ = TensorType(dtype="int64", shape=(None,)) + if self.return_index: outputs.append(typ()) + if self.return_inverse: outputs.append(typ()) + if self.return_counts: outputs.append(typ()) + return Apply(self, [x], outputs) def perform(self, node, inputs, output_storage): [x] = inputs - outs = np.unique( + outs = old_np_unique( x, return_index=self.return_index, return_inverse=self.return_inverse, @@ -1243,9 +1314,14 @@ def infer_shape(self, fgraph, node, i0_shapes): out_shapes[0] = tuple(shape) if self.return_inverse: - shape = prod(x_shape) if self.axis is None else x_shape[axis] return_index_out_idx = 2 if self.return_index else 1 - out_shapes[return_index_out_idx] = (shape,) + + if self.axis is not None: + shape = (x_shape[axis],) + else: + shape = (prod(x_shape),) + + out_shapes[return_index_out_idx] = shape return out_shapes diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index 76738fdb63..cb4476ede0 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -258,11 +258,6 @@ def tri_gamma_inplace(a): """second derivative of the log gamma function""" -@scalar_elemwise -def chi2sf_inplace(x, k): - """chi squared survival function""" - - @scalar_elemwise def gammainc_inplace(k, x): """regularized lower gamma function (P)""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index f11e33b41d..a88d678392 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional import numpy as np -from numpy.core.numeric import normalize_axis_tuple from pytensor import config, printing from pytensor import scalar as ps @@ -14,6 +13,11 @@ from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import ( + normalize_axis_tuple, + npy_2_compat_header, + numpy_axis_is_none_flag, +) from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.scalar.basic import BinaryScalarOp @@ -160,7 +164,7 @@ def get_params(self, node): c_axis = np.int64(self.axis[0]) else: # The value here doesn't matter, it won't be used - c_axis = np.int64(-1) + c_axis = numpy_axis_is_none_flag return self.params_type.get_params(c_axis=c_axis) def make_node(self, x): @@ -203,13 +207,17 @@ def perform(self, node, inp, outs): max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (argmax,) = out fail = sub["fail"] params = sub["params"] if self.axis is None: - axis_code = "axis = NPY_MAXDIMS;" + axis_code = "axis = NPY_RAVEL_AXIS;" else: if len(self.axis) != 1: raise NotImplementedError() @@ -431,20 +439,25 @@ def L_op(self, inputs, outputs, grads): return (g_x,) def R_op(self, inputs, eval_points): + [x] = inputs if eval_points[0] is None: - return [None, None] - if len(self.axis) != 1: - raise ValueError("R_op supported for max only for one axis!") - if self.axis[0] > 1: - raise ValueError("R_op supported for max only when axis is 0 or 1") + return [None] + axis = tuple(range(x.ndim) if self.axis is None else self.axis) + if isinstance(axis, int): + axis = [axis] + if len(axis) != 1: + raise NotImplementedError("R_op supported for max only for one axis!") + if axis[0] > 1: + raise NotImplementedError("R_op supported for max only when axis is 0 or 1") if inputs[0].ndim != 2: - raise ValueError("R_op supported for max only when input is a matrix") - max_pos = Argmax(self.axis).make_node(*inputs).outputs - # print(eval_points[0].eval()) + raise NotImplementedError( + "R_op supported for max only when input is a matrix" + ) + max_pos = Argmax(self.axis)(*inputs) if self.axis[0] == 0: - return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] + return [eval_points[0][max_pos, arange(eval_points[0].shape[1])]] else: - return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None] + return [eval_points[0][arange(eval_points[0].shape[0]), max_pos]] class Min(NonZeroDimsCAReduce): @@ -1154,9 +1167,10 @@ def polygamma(n, x): """Polygamma function of order n evaluated at x""" -@scalar_elemwise def chi2sf(x, k): """chi squared survival function""" + warnings.warn("chi2sf is deprecated. Use `gammaincc(k / 2, x / 2)` instead") + return gammaincc(k / 2, x / 2) @scalar_elemwise @@ -2152,13 +2166,11 @@ def tensordot( a = as_tensor_variable(a) b = as_tensor_variable(b) runtime_shape_a = a.shape - bcast_a = a.broadcastable static_shape_a = a.type.shape - ndim_a = a.ndim + ndim_a = a.type.ndim runtime_shape_b = b.shape - bcast_b = b.broadcastable static_shape_b = b.type.shape - ndim_b = b.ndim + ndim_b = b.type.ndim if na != nb: raise ValueError( "The number of axes supplied for tensordot must be equal for each tensor. " @@ -2166,48 +2178,67 @@ def tensordot( ) axes_a = list(normalize_axis_tuple(axes_a, ndim_a)) axes_b = list(normalize_axis_tuple(axes_b, ndim_b)) + + # The operation is only valid if the original dimensions match in length + # The ravelling of the dimensions to coerce the operation into a single dot + # could mask such errors, so we add an Assert if needed. must_assert_runtime = False - for k in range(na): - ax_a = axes_a[k] - ax_b = axes_b[k] - if (bcast_a[ax_a] != bcast_b[ax_b]) or ( + for ax_a, ax_b in zip(axes_a, axes_b, strict=True): + if ( static_shape_a[ax_a] is not None and static_shape_b[ax_b] is not None and static_shape_a[ax_a] != static_shape_b[ax_b] ): raise ValueError( - "Input arrays have inconsistent broadcastable pattern or type shape along the axes " + "Input arrays have inconsistent type shape along the axes " "that are to be reduced with tensordot." ) elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None: if must_assert_runtime: a = Assert( "Input array shape along reduced axes of tensordot are not equal" - )(a, eq(a.shape[ax_a], b.shape[ax_b])) + )(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b])) must_assert_runtime = True - # Move the axes to sum over to the end of "a" - # and to the front of "b" - notin = [k for k in range(ndim_a) if k not in axes_a] - newaxes_a = notin + axes_a - N2 = 1 - for axis in axes_a: - N2 *= runtime_shape_a[axis] - newshape_a = (-1, N2) - olda = [runtime_shape_a[axis] for axis in notin] - - notin = [k for k in range(ndim_b) if k not in axes_b] - newaxes_b = axes_b + notin - N2 = 1 - for axis in axes_b: - N2 *= runtime_shape_b[axis] - newshape_b = (N2, -1) - oldb = [runtime_shape_b[axis] for axis in notin] - - at = a.transpose(newaxes_a).reshape(newshape_a) - bt = b.transpose(newaxes_b).reshape(newshape_b) - res = _dot(at, bt) - return res.reshape(olda + oldb) + # Convert tensordot into a stacked dot product. + # We stack the summed axes and the non-summed axes of each tensor separately, + # and place the summed axes at the end of a and the beginning of b + non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a] + non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a] + transpose_axes_a = non_summed_axes_a + axes_a + # We only need a reshape when we need to combine summed or non-summed dims + # or introduce a new dimension (expand_dims) when doing a non-scalar outer product (len(axes) = 0) + a_needs_reshape = (ndim_a != 0) and ( + (len(non_summed_axes_a) > 1) or (len(axes_a) != 1) + ) + + non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b] + non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b] + transpose_axes_b = axes_b + non_summed_axes_b + b_needs_reshape = (ndim_b != 0) and ( + (len(non_summed_axes_b) > 1) or (len(axes_b) != 1) + ) + + # summed_size_a and summed_size_b must be the same, + # but to facilitate reasoning about useless reshapes we compute both from their shapes + at = a.transpose(transpose_axes_a) + if a_needs_reshape: + non_summed_size_a = variadic_mul(*non_summed_dims_a) + summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a]) + at = at.reshape((non_summed_size_a, summed_size_a)) + + bt = b.transpose(transpose_axes_b) + if b_needs_reshape: + non_summed_size_b = variadic_mul(*non_summed_dims_b) + summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b]) + bt = bt.reshape((summed_size_b, non_summed_size_b)) + + res = dot(at, bt) + + if a_needs_reshape or b_needs_reshape: + res = res.reshape(non_summed_dims_a + non_summed_dims_b) + + return res def outer(x, y): diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 1f589e1789..ee33f6533c 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -4,13 +4,13 @@ from typing import Literal, cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm @@ -216,9 +216,8 @@ def perform(self, node, inputs, outputs): (z,) = outputs try: z[0] = np.asarray(np.linalg.det(x), dtype=x.dtype) - except Exception: - print("Failed to compute determinant", x) - raise + except Exception as e: + raise ValueError("Failed to compute determinant", x) from e def grad(self, inputs, g_outputs): (gz,) = g_outputs @@ -256,9 +255,8 @@ def perform(self, node, inputs, outputs): (sign, det) = outputs try: sign[0], det[0] = (np.array(z, dtype=x.dtype) for z in np.linalg.slogdet(x)) - except Exception: - print("Failed to compute determinant", x) - raise + except Exception as e: + raise ValueError("Failed to compute determinant", x) from e def infer_shape(self, fgraph, node, shapes): return [(), ()] diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index bebcad55be..6d6a4ee270 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1,10 +1,16 @@ import abc import warnings +from typing import Literal import numpy as np import scipy.stats as stats +from numpy import broadcast_shapes as np_broadcast_shapes +from numpy import einsum as np_einsum +from numpy import sqrt as np_sqrt +from numpy.linalg import cholesky as np_cholesky +from numpy.linalg import eigh as np_eigh +from numpy.linalg import svd as np_svd -import pytensor from pytensor.tensor import get_vector_length, specify_shape from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.math import sqrt @@ -831,27 +837,6 @@ def __call__(self, mu, kappa, size=None, **kwargs): vonmises = VonMisesRV() -def safe_multivariate_normal(mean, cov, size=None, rng=None): - """A shape consistent multivariate normal sampler. - - What we mean by "shape consistent": SciPy will return scalars when the - arguments are vectors with dimension of size 1. We require that the output - be at least 1D, so that it's consistent with the underlying random - variable. - - """ - res = np.atleast_1d( - stats.multivariate_normal(mean=mean, cov=cov, allow_singular=True).rvs( - size=size, random_state=rng - ) - ) - - if size is not None: - res = res.reshape([*size, -1]) - - return res - - class MvNormalRV(RandomVariable): r"""A multivariate normal random variable. @@ -870,8 +855,17 @@ class MvNormalRV(RandomVariable): signature = "(n),(n,n)->(n)" dtype = "floatX" _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") + __props__ = ("name", "signature", "dtype", "inplace", "method") - def __call__(self, mean=None, cov=None, size=None, **kwargs): + def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs): + super().__init__(*args, **kwargs) + if method not in ("cholesky", "svd", "eigh"): + raise ValueError( + f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'." + ) + self.method = method + + def __call__(self, mean, cov, size=None, **kwargs): r""" "Draw samples from a multivariate normal distribution. Signature @@ -894,38 +888,34 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs): is specified, a single `N`-dimensional sample is returned. """ - dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype - - if mean is None: - mean = np.array([0.0], dtype=dtype) - if cov is None: - cov = np.array([[1.0]], dtype=dtype) return super().__call__(mean, cov, size=size, **kwargs) - @classmethod - def rng_fn(cls, rng, mean, cov, size): - if mean.ndim > 1 or cov.ndim > 2: - # Neither SciPy nor NumPy implement parameter broadcasting for - # multivariate normals (or any other multivariate distributions), - # so we need to implement that here + def rng_fn(self, rng, mean, cov, size): + if size is None: + size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) - if size is None: - mean, cov = broadcast_params([mean, cov], [1, 2]) - else: - mean = np.broadcast_to(mean, size + mean.shape[-1:]) - cov = np.broadcast_to(cov, size + cov.shape[-2:]) - - res = np.empty(mean.shape) - for idx in np.ndindex(mean.shape[:-1]): - m = mean[idx] - c = cov[idx] - res[idx] = safe_multivariate_normal(m, c, rng=rng) - return res + if self.method == "cholesky": + A = np_cholesky(cov) + elif self.method == "svd": + A, s, _ = np_svd(cov) + A *= np_sqrt(s, out=s)[..., None, :] else: - return safe_multivariate_normal(mean, cov, size=size, rng=rng) + w, A = np_eigh(cov) + A *= np_sqrt(w, out=w)[..., None, :] + + out = rng.normal(size=(*size, mean.shape[-1])) + np_einsum( + "...ij,...j->...i", # numpy doesn't have a batch matrix-vector product + A, + out, + optimize=False, # Nothing to optimize with two operands, skip costly setup + out=out, + ) + out += mean + return out -multivariate_normal = MvNormalRV() +multivariate_normal = MvNormalRV(method="cholesky") class DirichletRV(RandomVariable): diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index c76d250c9e..a8b67dee4f 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Sequence -from copy import copy +from copy import deepcopy from typing import Any, cast import numpy as np @@ -395,7 +395,7 @@ def perform(self, node, inputs, outputs): # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. if not self.inplace: - rng = copy(rng) + rng = deepcopy(rng) outputs[0][0] = rng outputs[1][0] = np.asarray( diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index 88d5e6197f..df8e3b691d 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -87,8 +87,8 @@ def filter(self, data, strict=False, allow_downcast=None): @staticmethod def values_eq(a, b): - sa = a if isinstance(a, dict) else a.__getstate__() - sb = b if isinstance(b, dict) else b.__getstate__() + sa = a if isinstance(a, dict) else a.bit_generator.state + sb = b if isinstance(b, dict) else b.bit_generator.state def _eq(sa, sb): for key in sa: diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index d3fc0398c4..31264f74d4 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -573,7 +573,7 @@ def print_profile(cls, stream, prof, level=0): print(blanc, " callbacks_time", file=stream) for i in sorted(prof[12].items(), key=lambda a: a[1]): if i[1] > 0: - print(i) + print(i) # noqa: T201 @node_rewriter([Dot]) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 3226f9b5a7..eaba64c275 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -314,14 +314,14 @@ def apply(self, fgraph): except (ValueError, InconsistencyError) as e: prof["nb_inconsistent"] += 1 if check_each_change != 1 and not raised_warning: - print( + print( # noqa: T201 ( "Some inplace rewriting was not " "performed due to an unexpected error:" ), file=sys.stderr, ) - print(e, file=sys.stderr) + print(e, file=sys.stderr) # noqa: T201 raised_warning = True fgraph.revert(chk) continue @@ -335,7 +335,7 @@ def apply(self, fgraph): fgraph.validate() except Exception: if not raised_warning: - print( + print( # noqa: T201 ( "Some inplace rewriting was not " "performed due to an unexpected error" @@ -1080,7 +1080,7 @@ def print_profile(stream, prof, level=0): print(blanc, " callbacks_time", file=stream) for i in sorted(prof[6].items(), key=lambda a: a[1])[::-1]: if i[1] > 0: - print(blanc, " ", i) + print(blanc, " ", i) # noqa: T201 print(blanc, " time_toposort", prof[7], file=stream) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 065ecfc0b1..0af1d40bf6 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -3434,14 +3434,14 @@ def perform_sigm_times_exp( sigm_minus_x = [] if full_tree is None: full_tree = tree - if False: # Debug code. - print("") - print(f" full_tree = {full_tree}") - print(f" tree = {tree}") - print(f" exp_x = {exp_x}") - print(f" exp_minus_x = {exp_minus_x}") - print(f" sigm_x = {sigm_x}") - print(f" sigm_minus_x= {sigm_minus_x}") + # if False: # Debug code. + # print("") + # print(f" full_tree = {full_tree}") + # print(f" tree = {tree}") + # print(f" exp_x = {exp_x}") + # print(f" exp_minus_x = {exp_minus_x}") + # print(f" sigm_x = {sigm_x}") + # print(f" sigm_minus_x= {sigm_minus_x}") neg, inputs = tree if isinstance(inputs, list): # Recurse through inputs of the multiplication. diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index e277772ad4..e86411dd9c 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -12,16 +12,17 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( GraphRewriter, - check_chain, copy_stack_trace, node_rewriter, ) from pytensor.graph.utils import InconsistencyError, get_variable_trace_string +from pytensor.scalar import ScalarType from pytensor.tensor.basic import ( MakeVector, as_tensor_variable, cast, constant, + expand_dims, get_scalar_constant_value, register_infer_shape, stack, @@ -35,6 +36,7 @@ register_useless, topo_constant_folding, ) +from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import ( Reshape, Shape, @@ -47,6 +49,7 @@ from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes from pytensor.tensor.type_other import NoneConst, NoneTypeT +from pytensor.tensor.variable import TensorVariable class ShapeFeature(Feature): @@ -755,6 +758,38 @@ def apply(self, fgraph): pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) +@register_useless +@register_canonicalize +@node_rewriter([Reshape]) +def local_useless_expand_dims_in_reshape(fgraph, node): + """ + Removes useless expand_dims `DimShuffle` operations inside Reshape: + reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp) + reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp) + + Implicit (and useless) squeezes are kept in the graph, as they are + part of the canonical form of the graph. + """ + expanded_x, new_shape = node.inputs + + if not ( + expanded_x.owner is not None + and isinstance(expanded_x.owner.op, DimShuffle) + and expanded_x.owner.op.augment + ): + return False + + [x] = expanded_x.owner.inputs + + new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x") + if new_order != tuple(range(x.type.ndim)): + x = x.dimshuffle(new_order) + + new_reshaped_x = x.reshape(new_shape) + copy_stack_trace(node.outputs[0], new_reshaped_x) + return [new_reshaped_x] + + @register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Reshape]) @@ -763,30 +798,89 @@ def local_reshape_chain(fgraph, node): Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2) """ - if not check_chain(node, Reshape, Reshape): + inner_reshape, final_shape = node.inputs + + if not (inner_reshape.owner and isinstance(inner_reshape.owner.op, Reshape)): + return None + + x, _ = inner_reshape.owner.inputs + new_reshape = node.op(x, final_shape) + + copy_stack_trace(node.outputs, new_reshape) + return [new_reshape] + + +def _is_shape_i_of_x( + var: TensorVariable, + x: TensorVariable, + i: int, + shape_feature: ShapeFeature | None = None, +) -> bool: + if var.type.ndim != 0: return False - rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) - - # Copy over stacktrace from previous output node, as any error - # in new computational graph would have been caused by last op - # in the old computational graph. - copy_stack_trace(node.outputs, rval) - - # It might happen that the desired output of this node has a - # broadcastable pattern that does not match that of 'rval'. This is - # when originally, we were able to figure out that one of the - # dimensions of the reshape is one, but some other transformation - # replaced the shape by one for which this cannot be guessed. - # We should try to figure out why we lost the information about this - # constant value... but in the meantime, better not apply this - # rewrite. - if rval.type.ndim == node.outputs[0].type.ndim and all( - s1 == s2 - for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape, strict=True) - if s1 == 1 or s2 == 1 - ): - return [rval] + constant_var = get_scalar_constant_value( + var, + only_process_constants=False, + # Don't go through Elemwise to keep things fast + elemwise=False, + raise_not_constant=False, + ) + + # Check var is a constant expression with the same value as x.type.shape[i] + if constant_var == x.type.shape[i]: + return True + + # Match shape_of[x][i] or its constant equivalent + if shape_feature is not None: + i_shape_of_x = shape_feature.get_shape(x, i) + if i_shape_of_x == var or ( + isinstance(i_shape_of_x, Constant) and (i_shape_of_x.data == constant_var) + ): + return True + + if var.owner is None: + # No more constant possibilities + return False + + # Match Shape_i{i}(x) + if isinstance(var.owner.op, Shape_i): + return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore + + # Match Subtensor((ScalarType,))(Shape(input), i) + if isinstance(var.owner.op, Subtensor): + return ( + # Check we have integer indexing operation + # (and not slice or multiple indexing) + len(var.owner.op.idx_list) == 1 + and isinstance(var.owner.op.idx_list[0], ScalarType) + # Check we are indexing on the shape of x + and var.owner.inputs[0].owner is not None + and isinstance(var.owner.inputs[0].owner.op, Shape) + and var.owner.inputs[0].owner.inputs[0] == x + # Check that index == i + and ( + get_scalar_constant_value(var.owner.inputs[1], raise_not_constant=False) + == i + ) + ) + + return False + + +def _unpack_shape_vector(shape: TensorVariable) -> tuple[TensorVariable, ...]: + """Return the elements of a symbolic vector representing a shape. + + Handles the most common constant vector or make_vector cases. + + Returns tuple(shape) as fallback. + """ + if isinstance(shape, Constant): + return tuple(as_tensor_variable(dim, ndim=0) for dim in shape.data) + elif shape.owner and isinstance(shape.owner.op, MakeVector): + return tuple(shape.owner.inputs) + else: + return tuple(shape) @register_useless("shape_unsafe") @@ -821,132 +915,151 @@ def local_useless_reshape(fgraph, node): if shape_input == inp: return [inp] - # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for - # broadcastable and constant dimensions - if isinstance(output_shape, Constant) or ( - output_shape.owner and isinstance(output_shape.owner.op, MakeVector) - ): - if isinstance(output_shape, Constant): - output_shape_is = [ - as_tensor_variable(dim, ndim=0) for dim in output_shape.data - ] - else: - output_shape_is = output_shape.owner.inputs - - shape_feature = getattr(fgraph, "shape_feature", None) - - nb_m1 = 0 - shape_match = [False] * inp.type.ndim - for dim in range(inp.type.ndim): - outshp_i = output_shape_is[dim] - # Match Shape_i{dim}(input) - if ( - outshp_i.owner - and isinstance(outshp_i.owner.op, Shape_i) - and outshp_i.owner.op.i == dim - and outshp_i.owner.inputs[0] == inp - ): - shape_match[dim] = True - continue + shape_feature = getattr(fgraph, "shape_feature", None) - # Match Shape(input)[dim] - if ( - outshp_i.owner - and isinstance(outshp_i.owner.op, Subtensor) - and len(outshp_i.owner.inputs) == 2 - and get_scalar_constant_value( - outshp_i.owner.inputs[1], raise_not_constant=False - ) - == dim - ): - subtensor_inp = outshp_i.owner.inputs[0] - if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): - shape_input_i = subtensor_inp.owner.inputs[0] - if shape_input_i == inp: - shape_match[dim] = True - continue - - # Match constant if input.type.shape[dim] == constant - cst_outshp_i = get_scalar_constant_value( - outshp_i, only_process_constants=True, raise_not_constant=False - ) - if inp.type.shape[dim] == cst_outshp_i: - shape_match[dim] = True - continue + # Match case where at least (n-1) entries correspond to the original shape: + # Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]]) + # Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape. + output_shape_is = _unpack_shape_vector(output_shape) + nb_m1 = 0 + shape_match = [False] * inp.type.ndim + for dim in range(inp.type.ndim): + outshp_i = output_shape_is[dim] + if _is_shape_i_of_x(outshp_i, inp, dim, shape_feature=shape_feature): + shape_match[dim] = True + elif isinstance(outshp_i, Constant) and outshp_i.data == -1: + shape_match[dim] = True + nb_m1 += 1 - # Match -1 - if cst_outshp_i == -1: - shape_match[dim] = True - nb_m1 += 1 - continue + if nb_m1 <= 1 and all(shape_match): + return [inp] # This is provably correct - # Match shape_of[input][dim] or its constant equivalent - if shape_feature: - inpshp_i = shape_feature.get_shape(inp, dim) - if inpshp_i == outshp_i or ( - get_scalar_constant_value( - inpshp_i, only_process_constants=True, raise_not_constant=False - ) - == get_scalar_constant_value( - outshp_i, only_process_constants=True, raise_not_constant=False - ) - ): - shape_match[dim] = True - continue + # There is one missing match, but all other dimensions match + # Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y) + if (nb_m1 == 0) and (shape_match.count(False) == 1): + return [inp] # This could mask a shape error - if nb_m1 <= 1 and all(shape_match): - return [inp] + return False - if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1): - return [inp] - return False - - -@register_canonicalize +@register_canonicalize("shape_unsafe") @node_rewriter([Reshape]) def local_reshape_to_dimshuffle(fgraph, node): - r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s. + r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions. - The goal is to avoid using `Reshape` to add or remove broadcastable - dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can - cancel out and/or be removed later on. + It's always valid to squeeze an input before doing the same reshape operation. + Equivalently, it's always valid to remove `1` entries from the reshape shape + and replace them by an expand_dims after the rewritten reshape operation. + + We chose to canonicalize the graph in this way as it allows isolating + operations that are unique to the reshaping operation (mixing dimensions) + from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims). + This can allow further simplifications by other rewrites that target + DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations. For example: - - reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,)) - - reshape(x, (1, m, 1, n, 1, 1)) - -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) + - reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n)) + - reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0) + - reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5)) + """ - op = node.op inp, output_shape = node.inputs [output] = node.outputs - dimshuffle_new_order = [] - new_output_shape = [] - index = 0 # index over the output of the new reshape - for i in range(output.ndim): - # Since output_shape is a symbolic vector, we trust get_scalar_constant_value - # to go through however it is formed to see if its i-th element is 1. - # We need only_process_constants=False for that. - dim = get_scalar_constant_value( - output_shape[i], - only_process_constants=False, - elemwise=False, - raise_not_constant=False, - ) - if dim == 1: - dimshuffle_new_order.append("x") - else: - dimshuffle_new_order.append(index) - new_output_shape.append(dim) - index = index + 1 + # Remove any broadcastable dimensions from the input + squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast] + + # Trivial case, all dimensions of input/output are known to be broadcastable: + # there's nothing to reshape + if all(inp.type.broadcastable) or all(output.type.broadcastable): + new_output_shape = [] + expand_axes = tuple(range(output.type.ndim)) + + else: + unpacked_shape = _unpack_shape_vector(output_shape) + new_output_shape = [] + expand_axes = [] + for i, dim_length in enumerate(unpacked_shape): + if isinstance(dim_length, Constant) and ( + dim_length.data == 1 + # -1 can be an implicit expand_dims, but it's tricky to prove + # as we would need to check whether all other dimensions + # already explain the full size of the array. + # Example: np.zeros((2, 2, 2)).reshape((8, -1)) + # We rely on the output static shape which will already have figured + # it out for some (but not all) cases + or (dim_length.data == -1 and output.type.shape[i] == 1) + ): + expand_axes.append(i) + else: + new_output_shape.append(dim_length) + + if squeeze_axes or expand_axes: + new_out = inp.squeeze(squeeze_axes) + + if new_output_shape: + new_out = new_out.reshape(new_output_shape) + copy_stack_trace(output, new_out) + + new_out = expand_dims(new_out, expand_axes) + + if not new_output_shape: + # Eagerly merge consecutive squeeze and expand_dims + new_out = apply_local_dimshuffle_lift(fgraph, new_out) + + copy_stack_trace(output, new_out) + return [new_out] + + +@register_specialize +@node_rewriter([Reshape]) +def local_fuse_squeeze_reshape(fgraph, node): + r"""If there is a squeeze right before a reshape, merge them. + + This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization. + """ + x, new_shape = node.inputs + + if ( + x.owner is not None + and isinstance(x.owner.op, DimShuffle) + and x.owner.op.is_squeeze + ): + # A reshape can always subsume a squeeze. + x = x.owner.inputs[0] + return [x.reshape(new_shape)] + - if index != output.type.ndim: - inner = op.__class__(len(new_output_shape))(inp, new_output_shape) - copy_stack_trace(output, inner) - new_node = [inner.dimshuffle(dimshuffle_new_order)] - copy_stack_trace(output, new_node) - return new_node +@register_specialize +@node_rewriter([DimShuffle]) +def local_fuse_expand_dims_reshape(fgraph, node): + r"""If there is an expand_dims right after a reshape, merge them. + + This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization. + """ + if not node.op.is_expand_dims: + return None + + reshaped_x = node.inputs[0] + + if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)): + return None + + if len(fgraph.clients[reshaped_x]) > 1: + # The reshape is used elsewhere, don't fuse as it can sometimes require a copy. + # Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]` + return None + + x, new_shape = reshaped_x.owner.inputs + + # Add expand_dims to shape + new_shape = list(_unpack_shape_vector(new_shape)) + for i in node.op.augment: + new_shape.insert(i, 1) + + new_reshaped_x = x.reshape(new_shape) + copy_stack_trace(node.outputs[0], new_reshaped_x) + return [new_reshaped_x] @register_canonicalize @@ -1186,44 +1299,6 @@ def local_track_shape_i(fgraph, node): return [shape_feature.shape_of[replacement][node.op.i]] -@register_canonicalize -@node_rewriter([Reshape]) -def local_useless_dimshuffle_in_reshape(fgraph, node): - """ - Removes useless DimShuffle operation inside Reshape: - - reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) - reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) - reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) - reshape(col.dimshuffle(0), shp) => reshape(col, shp) - - """ - op = node.op - if not isinstance(op, Reshape): - return False - if not ( - node.inputs[0].owner is not None - and isinstance(node.inputs[0].owner.op, DimShuffle) - ): - return False - - new_order = node.inputs[0].owner.op.new_order - inp = node.inputs[0].owner.inputs[0] - new_order_of_nonbroadcast = [] - for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): - if s != 1: - new_order_of_nonbroadcast.append(i) - no_change_in_order = all( - new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] - for i in range(len(new_order_of_nonbroadcast) - 1) - ) - if no_change_in_order: - shape = node.inputs[1] - ret = op.__class__(node.outputs[0].ndim)(inp, shape) - copy_stack_trace(node.outputs[0], ret) - return [ret] - - @register_useless @register_canonicalize @register_specialize diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 8913d6fb4d..e839ac1f08 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -1,10 +1,11 @@ import warnings +from collections.abc import Sequence from numbers import Number from textwrap import dedent -from typing import cast +from typing import TYPE_CHECKING, Union, cast +from typing import cast as typing_cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.gradient import DisconnectedType @@ -14,6 +15,7 @@ from pytensor.graph.type import HasShape from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.scalar import int32 from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import basic as ptb @@ -24,6 +26,9 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable +if TYPE_CHECKING: + from pytensor.tensor import TensorLike + ShapeValueType = None | np.integer | int | Variable @@ -639,6 +644,8 @@ def make_node(self, x, shp): x = ptb.as_tensor_variable(x) shp_orig = shp shp = ptb.as_tensor_variable(shp, ndim=1) + if shp.type.shape == (None,): + shp = specify_shape(shp, self.ndim) if not ( shp.dtype in int_dtypes or (isinstance(shp, TensorConstant) and shp.data.size == 0) @@ -842,9 +849,14 @@ def _vectorize_reshape(op, node, x, shape): return reshape(x, new_shape, ndim=len(new_shape)).owner -def reshape(x, newshape, ndim=None): +def reshape( + x: "TensorLike", + newshape: Union["TensorLike", Sequence["TensorLike"]], + *, + ndim: int | None = None, +) -> TensorVariable: if ndim is None: - newshape = ptb.as_tensor_variable(newshape) + newshape = ptb.as_tensor_variable(newshape) # type: ignore if newshape.type.ndim != 1: raise TypeError( "New shape in reshape must be a vector or a list/tuple of" @@ -862,7 +874,7 @@ def reshape(x, newshape, ndim=None): ) op = Reshape(ndim) rval = op(x, newshape) - return rval + return typing_cast(TensorVariable, rval) def shape_padleft(t, n_ones=1): diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 325567918a..94973810fd 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,11 +1,12 @@ import logging -import typing import warnings +from collections.abc import Sequence from functools import reduce from typing import Literal, cast import numpy as np -import scipy.linalg +import scipy.linalg as scipy_linalg +from numpy.exceptions import ComplexWarning import pytensor import pytensor.tensor as pt @@ -58,7 +59,7 @@ def make_node(self, x): f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" ) # Call scipy to find output dtype - dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype + dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) def perform(self, node, inputs, outputs): @@ -68,21 +69,21 @@ def perform(self, node, inputs, outputs): # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS # If we have a `C_CONTIGUOUS` array we transpose to benefit from it if self.overwrite_a and x.flags["C_CONTIGUOUS"]: - out[0] = scipy.linalg.cholesky( + out[0] = scipy_linalg.cholesky( x.T, lower=not self.lower, check_finite=self.check_finite, overwrite_a=True, ).T else: - out[0] = scipy.linalg.cholesky( + out[0] = scipy_linalg.cholesky( x, lower=self.lower, check_finite=self.check_finite, overwrite_a=self.overwrite_a, ) - except scipy.linalg.LinAlgError: + except scipy_linalg.LinAlgError: if self.on_error == "raise": raise else: @@ -334,7 +335,7 @@ def __init__(self, **kwargs): def perform(self, node, inputs, output_storage): C, b = inputs - rval = scipy.linalg.cho_solve( + rval = scipy_linalg.cho_solve( (C, self.lower), b, check_finite=self.check_finite, @@ -369,7 +370,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. - b_ndim : int + b_ndim : int Whether the core case of b is a vector (1) or matrix (2). This will influence how batched dimensions are interpreted. """ @@ -401,7 +402,7 @@ def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): def perform(self, node, inputs, outputs): A, b = inputs - outputs[0][0] = scipy.linalg.solve_triangular( + outputs[0][0] = scipy_linalg.solve_triangular( A, b, lower=self.lower, @@ -502,7 +503,7 @@ def __init__(self, *, assume_a="gen", **kwargs): def perform(self, node, inputs, outputs): a, b = inputs - outputs[0][0] = scipy.linalg.solve( + outputs[0][0] = scipy_linalg.solve( a=a, b=b, lower=self.lower, @@ -619,9 +620,9 @@ def make_node(self, a, b): def perform(self, node, inputs, outputs): (w,) = outputs if len(inputs) == 2: - w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) + w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) else: - w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) + w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) def grad(self, inputs, g_outputs): a, b = inputs @@ -675,7 +676,7 @@ def make_node(self, a, b, gw): def perform(self, node, inputs, outputs): (a, b, gw) = inputs - w, v = scipy.linalg.eigh(a, b, lower=self.lower) + w, v = scipy_linalg.eigh(a, b, lower=self.lower) gA = v.dot(np.diag(gw).dot(v.T)) gB = -v.dot(np.diag(gw * w).dot(v.T)) @@ -718,7 +719,7 @@ def make_node(self, A): def perform(self, node, inputs, outputs): (A,) = inputs (expm,) = outputs - expm[0] = scipy.linalg.expm(A) + expm[0] = scipy_linalg.expm(A) def grad(self, inputs, outputs): (A,) = inputs @@ -758,8 +759,8 @@ def perform(self, node, inputs, outputs): # this expression. (A, gA) = inputs (out,) = outputs - w, V = scipy.linalg.eig(A, right=True) - U = scipy.linalg.inv(V).T + w, V = scipy_linalg.eig(A, right=True) + U = scipy_linalg.inv(V).T exp_w = np.exp(w) X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w) @@ -767,7 +768,7 @@ def perform(self, node, inputs, outputs): Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T) with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) out[0] = Y.astype(A.dtype) @@ -800,7 +801,7 @@ def perform(self, node, inputs, output_storage): X = output_storage[0] out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) + X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -870,7 +871,7 @@ def perform(self, node, inputs, output_storage): X = output_storage[0] out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( + X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( out_dtype ) @@ -918,7 +919,7 @@ def _direct_solve_discrete_lyapunov( vec_Q = Q.ravel() vec_X = solve(eye - AxA, vec_Q, b_ndim=1) - return cast(TensorVariable, reshape(vec_X, A.shape)) + return reshape(vec_X, A.shape) def solve_discrete_lyapunov( @@ -992,7 +993,7 @@ def perform(self, node, inputs, output_storage): Q = 0.5 * (Q + Q.T) out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) + X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -1064,7 +1065,7 @@ def solve_discrete_are( ) -def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: +def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype: return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) @@ -1118,7 +1119,7 @@ def make_node(self, *matrices): def perform(self, node, inputs, output_storage, params=None): dtype = node.outputs[0].type.dtype - output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) + output_storage[0][0] = scipy_linalg.block_diag(*inputs).astype(dtype) def block_diag(*matrices: TensorVariable): @@ -1175,4 +1176,5 @@ def block_diag(*matrices: TensorVariable): "solve_discrete_are", "solve_triangular", "block_diag", + "cho_solve", ] diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index a2f02fabd8..5b05ad03f4 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -6,6 +6,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp +from pytensor.npy_2_compat import npy_2_compat_header from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.math import gamma, gammaln, log, neg, sum @@ -60,12 +61,16 @@ def infer_shape(self, fgraph, node, shape): return [shape[1]] def c_code_cache_version(self): - return (4,) + return (5,) + + def c_support_code_apply(self, node: Apply, name: str) -> str: + # return super().c_support_code_apply(node, name) + return npy_2_compat_header() def c_code(self, node, name, inp, out, sub): dy, sm = inp (dx,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] return dedent( @@ -79,7 +84,7 @@ def c_code(self, node, name, inp, out, sub): int sm_ndim = PyArray_NDIM({sm}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || sm_ndim == 1); // Validate inputs if ((PyArray_TYPE({dy}) != NPY_DOUBLE) && @@ -95,13 +100,15 @@ def c_code(self, node, name, inp, out, sub): {fail}; }} - if (axis < 0) axis = sm_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > sm_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad"); - {fail}; + if (axis < 0) axis = sm_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > sm_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad"); + {fail}; + }} }} - if (({dx} == NULL) || !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim))) {{ @@ -289,10 +296,14 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return ["", ""] + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] # dtype = node.inputs[0].type.dtype_specs()[1] # TODO: put this into a templated function, in the support code @@ -309,7 +320,7 @@ def c_code(self, node, name, inp, out, sub): int x_ndim = PyArray_NDIM({x}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && @@ -319,11 +330,14 @@ def c_code(self, node, name, inp, out, sub): {fail} }} - if (axis < 0) axis = x_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax"); - {fail} + if (axis < 0) axis = x_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax"); + {fail} + }} }} // Allocate Output Array @@ -481,7 +495,7 @@ def c_code(self, node, name, inp, out, sub): @staticmethod def c_code_cache_version(): - return (4,) + return (5,) def softmax(c, axis=None): @@ -541,10 +555,14 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return [""] + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] return dedent( @@ -558,7 +576,7 @@ def c_code(self, node, name, inp, out, sub): int x_ndim = PyArray_NDIM({x}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && @@ -568,13 +586,15 @@ def c_code(self, node, name, inp, out, sub): {fail} }} - if (axis < 0) axis = x_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); - {fail} + if (axis < 0) axis = x_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); + {fail} + }} }} - // Allocate Output Array if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim))) {{ @@ -730,7 +750,7 @@ def c_code(self, node, name, inp, out, sub): @staticmethod def c_code_cache_version(): - return (1,) + return (2,) def log_softmax(c, axis=None): diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index a3a81f63bd..3a2304eb7b 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -18,6 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2 from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -756,13 +757,15 @@ def get_constant_idx( Example usage where `v` and `a` are appropriately typed PyTensor variables : >>> from pytensor.scalar import int64 >>> from pytensor.tensor import matrix + >>> import numpy as np + >>> >>> v = int64("v") >>> a = matrix("a") >>> b = a[v, 1:3] >>> b.owner.op.idx_list (ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None)) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True) - [v, slice(1, 3, None)] + [v, slice(np.int64(1), np.int64(3), None)] >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs) Traceback (most recent call last): pytensor.tensor.exceptions.NotScalarConstantError @@ -2148,7 +2151,7 @@ def infer_shape(self, fgraph, node, ishapes): def c_support_code(self, **kwargs): # In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG, # which is not defined. It should be NPY_MIN_LONG instead in that case. - return dedent( + return npy_2_compat_header() + dedent( """\ #ifndef MIN_LONG #define MIN_LONG NPY_MIN_LONG @@ -2173,7 +2176,7 @@ def c_code(self, node, name, input_names, output_names, sub): if (!PyArray_CanCastSafely(i_type, NPY_INTP) && PyArray_SIZE({i_name}) > 0) {{ npy_int64 min_val, max_val; - PyObject* py_min_val = PyArray_Min({i_name}, NPY_MAXDIMS, + PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS, NULL); if (py_min_val == NULL) {{ {fail}; @@ -2183,7 +2186,7 @@ def c_code(self, node, name, input_names, output_names, sub): if (min_val == -1 && PyErr_Occurred()) {{ {fail}; }} - PyObject* py_max_val = PyArray_Max({i_name}, NPY_MAXDIMS, + PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS, NULL); if (py_max_val == NULL) {{ {fail}; @@ -2242,7 +2245,7 @@ def c_code(self, node, name, input_names, output_names, sub): """ def c_code_cache_version(self): - return (0, 1, 2) + return (0, 1, 2, 3) advanced_subtensor1 = AdvancedSubtensor1() @@ -2519,9 +2522,9 @@ def gen_num(typen): return code def c_code(self, node, name, input_names, output_names, sub): - numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] - if bool(numpy_ver < [1, 8]): + if numpy_version < "1.8.0" or using_numpy_2: raise NotImplementedError + x, y, idx = input_names out = output_names[0] copy_of_x = self.copy_of_x(x) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 0f99fa48aa..b96113c8e3 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal, Optional import numpy as np +import numpy.typing as npt import pytensor from pytensor import scalar as ps @@ -69,7 +70,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): def __init__( self, - dtype: str | np.dtype, + dtype: str | npt.DTypeLike, shape: Iterable[bool | int | None] | None = None, name: str | None = None, broadcastable: Iterable[bool] | None = None, @@ -101,11 +102,11 @@ def __init__( if str(dtype) == "floatX": self.dtype = config.floatX else: - if np.obj2sctype(dtype) is None: + try: + self.dtype = str(np.dtype(dtype)) + except TypeError: raise TypeError(f"Invalid dtype: {dtype}") - self.dtype = np.dtype(dtype).name - def parse_bcast_and_shape(s): if isinstance(s, bool | np.bool_): return 1 if s else None @@ -177,7 +178,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: else: if allow_downcast: # Convert to self.dtype, regardless of the type of data - data = np.asarray(data, dtype=self.dtype) + data = np.asarray(data).astype(self.dtype) # TODO: consider to pad shape with ones to make it consistent # with self.broadcastable... like vector->row type thing else: @@ -789,14 +790,16 @@ def tensor( **kwargs, ) -> "TensorVariable": if name is not None: - # Help catching errors with the new tensor API - # Many single letter strings are valid sctypes - if str(name) == "floatX" or (len(str(name)) > 1 and np.obj2sctype(name)): - np.obj2sctype(name) - raise ValueError( - f"The first and only positional argument of tensor is now `name`. Got {name}.\n" - "This name looks like a dtype, which you should pass as a keyword argument only." - ) + try: + # Help catching errors with the new tensor API + # Many single letter strings are valid sctypes + if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): + raise ValueError( + f"The first and only positional argument of tensor is now `name`. Got {name}.\n" + "This name looks like a dtype, which you should pass as a keyword argument only." + ) + except TypeError: + pass if dtype is None: dtype = config.floatX diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index e6451c9236..9ce12296cd 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -3,10 +3,10 @@ from typing import cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.graph import FunctionGraph, Variable +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.utils import hash_from_code @@ -236,8 +236,8 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: if axis is not None: try: axis = normalize_axis_tuple(axis, ndim=max(1, ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=ndim) + except np.exceptions.AxisError: + raise np.exceptions.AxisError(axis, ndim=ndim) # TODO: If axis tuple is equivalent to None, return None for more canonicalization? return cast(tuple, axis) diff --git a/pytensor/tensor/xlogx.py b/pytensor/tensor/xlogx.py index 8cc27de9fb..3709688e54 100644 --- a/pytensor/tensor/xlogx.py +++ b/pytensor/tensor/xlogx.py @@ -10,15 +10,11 @@ class XlogX(ps.UnaryScalarOp): """ - @staticmethod - def st_impl(x): + def impl(self, x): if x == 0.0: return 0.0 return x * np.log(x) - def impl(self, x): - return XlogX.st_impl(x) - def grad(self, inputs, grads): (x,) = inputs (gz,) = grads @@ -45,15 +41,11 @@ class XlogY0(ps.BinaryScalarOp): """ - @staticmethod - def st_impl(x, y): + def impl(self, x, y): if x == 0.0: return 0.0 return x * np.log(y) - def impl(self, x, y): - return XlogY0.st_impl(x, y) - def grad(self, inputs, grads): x, y = inputs (gz,) = grads diff --git a/scripts/slowest_tests/extract-slow-tests.py b/scripts/slowest_tests/extract-slow-tests.py index 3a06e4a68b..14df837a7b 100644 --- a/scripts/slowest_tests/extract-slow-tests.py +++ b/scripts/slowest_tests/extract-slow-tests.py @@ -72,7 +72,7 @@ def main(read_lines): lines = read_lines() times = extract_lines(lines) parsed_times = format_times(times) - print("\n".join(parsed_times)) + print("\n".join(parsed_times)) # noqa: T201 if __name__ == "__main__": diff --git a/tests/compile/function/test_function.py b/tests/compile/function/test_function.py index f835953b19..9f75ef15d8 100644 --- a/tests/compile/function/test_function.py +++ b/tests/compile/function/test_function.py @@ -11,6 +11,7 @@ from pytensor.compile.function import function, function_dump from pytensor.compile.io import In from pytensor.configdefaults import config +from pytensor.npy_2_compat import UintOverflowError from pytensor.tensor.type import ( bscalar, bvector, @@ -166,12 +167,12 @@ def test_in_allow_downcast_int(self): # Value too big for a, silently ignored assert np.array_equal(f([2**20], np.ones(1, dtype="int8"), 1), [2]) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError (in numpy >= 2.0... TypeError in numpy < 2.0) + with pytest.raises(UintOverflowError): f([3], [312], 1) - # Value too big for c, raises TypeError - with pytest.raises(TypeError): + # Value too big for c, raises OverflowError + with pytest.raises(UintOverflowError): f([3], [6], 806) def test_in_allow_downcast_floatX(self): diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index 0a9bda9846..3e23b12f74 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -9,6 +9,7 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.graph.utils import MissingInputError +from pytensor.npy_2_compat import UintOverflowError from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.type import ( bscalar, @@ -237,12 +238,12 @@ def test_param_allow_downcast_int(self): # Value too big for a, silently ignored assert np.all(f([2**20], np.ones(1, dtype="int8"), 1) == 2) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): f([3], [312], 1) - # Value too big for c, raises TypeError - with pytest.raises(TypeError): + # Value too big for c, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): f([3], [6], 806) def test_param_allow_downcast_floatX(self): @@ -327,16 +328,19 @@ def test_allow_input_downcast_int(self): with pytest.raises(TypeError): g([3], np.array([6], dtype="int16"), 0) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): g([3], [312], 0) h = pfunc([a, b, c], (a + b + c)) # Default: allow_input_downcast=None # Everything here should behave like with False assert np.all(h([3], [6], 0) == 9) + with pytest.raises(TypeError): h([3], np.array([6], dtype="int16"), 0) - with pytest.raises(TypeError): + + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): h([3], [312], 0) def test_allow_downcast_floatX(self): diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 8fc2a529df..ba0257cdda 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -306,7 +306,8 @@ def lop_ov(inps, outs, grads): @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) - def test_rop(self, cls_ofg): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_rop(self, cls_ofg, use_op_rop_implementation): a = vector() M = matrix() b = dot(a, M) @@ -315,7 +316,7 @@ def test_rop(self, cls_ofg): W = matrix() y = op_matmul(x, W) du = vector() - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) xval = np.random.random((16,)).astype(config.floatX) Wval = np.random.random((16, 16)).astype(config.floatX) @@ -324,7 +325,8 @@ def test_rop(self, cls_ofg): dvval2 = fn(xval, Wval, duval) np.testing.assert_array_almost_equal(dvval2, dvval, 4) - def test_rop_multiple_outputs(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_rop_multiple_outputs(self, use_op_rop_implementation): a = vector() M = matrix() b = dot(a, M) @@ -339,21 +341,21 @@ def test_rop_multiple_outputs(self): duval = np.random.random((16,)).astype(config.floatX) y = op_matmul(x, W)[0] - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = np.dot(duval, Wval) np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) y = op_matmul(x, W)[1] - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = -np.dot(duval, Wval) np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) y = pt.add(*op_matmul(x, W)) - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = np.zeros_like(np.dot(duval, Wval)) @@ -362,7 +364,16 @@ def test_rop_multiple_outputs(self): @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) - def test_rop_override(self, cls_ofg): + @pytest.mark.parametrize( + "use_op_rop_implementation", + [ + True, + pytest.param( + False, marks=pytest.mark.xfail(reason="Custom ROp is ignored") + ), + ], + ) + def test_rop_override(self, cls_ofg, use_op_rop_implementation): x, y = vectors("xy") def ro(inps, epts): @@ -380,7 +391,12 @@ def ro(inps, epts): du, dv = vector("du"), vector("dv") for op in [op_mul, op_mul2]: zz = op_mul(xx, yy) - dw = Rop(zz, [xx, yy], [du, dv]) + dw = Rop( + zz, + [xx, yy], + [du, dv], + use_op_rop_implementation=use_op_rop_implementation, + ) fn = function([xx, yy, du, dv], dw) vals = np.random.random((4, 32)).astype(config.floatX) dwval = fn(*vals) diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index 95e52d6b53..fae76fab0d 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -146,7 +146,7 @@ def dontuse_perform(self, node, inp, out_): raise ValueError(self.behaviour) def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inp, out, sub): (a,) = inp @@ -165,8 +165,8 @@ def c_code(self, node, name, inp, out, sub): prep_vars = f""" //the output array has size M x N npy_intp M = PyArray_DIMS({a})[0]; - npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_DESCR({a})->elsize; - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; + npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_ITEMSIZE({a}); + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); npy_double * Da = (npy_double*)PyArray_BYTES({a}); npy_double * Dz = (npy_double*)PyArray_BYTES({z}); diff --git a/tests/d3viz/test_d3viz.py b/tests/d3viz/test_d3viz.py index 7e4b0426a0..38809a5faa 100644 --- a/tests/d3viz/test_d3viz.py +++ b/tests/d3viz/test_d3viz.py @@ -28,7 +28,7 @@ def check(self, f, reference=None, verbose=False): tmp_dir = Path(tempfile.mkdtemp()) html_file = tmp_dir / "index.html" if verbose: - print(html_file) + print(html_file) # noqa: T201 d3v.d3viz(f, html_file) assert html_file.stat().st_size > 0 if reference: diff --git a/tests/link/c/test_cmodule.py b/tests/link/c/test_cmodule.py index 2242bc12e9..46533fef35 100644 --- a/tests/link/c/test_cmodule.py +++ b/tests/link/c/test_cmodule.py @@ -258,7 +258,6 @@ def test_default_blas_ldflags( def patched_compile_tmp(*args, **kwargs): def wrapped(test_code, tmp_prefix, flags, try_run, output): if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]: - print(enabled_accelerate_framework) if enabled_accelerate_framework: return (True, True) else: diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index a01f5e3f46..fa25f3aac0 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -18,6 +18,7 @@ batched_permutation_tester, batched_unweighted_choice_without_replacement_tester, batched_weighted_choice_without_replacement_tester, + create_mvnormal_cov_decomposition_method_test, ) @@ -62,7 +63,9 @@ def test_random_updates(rng_ctor): assert all( a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) for a, b in zip( - rng.get_value().__getstate__(), original_value.__getstate__(), strict=True + rng.get_value().bit_generator.state, + original_value.bit_generator.state, + strict=True, ) ) @@ -547,6 +550,11 @@ def test_random_mvnormal(): np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) +test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test( + "JAX" +) + + @pytest.mark.parametrize( "parameter, size", [ diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 1b0fa8fd52..f0f73ca74d 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -836,7 +836,6 @@ def test_config_options_fastmath(): with config.change_flags(numba__fastmath=True): pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) - print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__)) numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] assert numba_mul_fn.targetoptions["fastmath"] == { "afn", diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 6fbb6e6c58..3dc427cd9c 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -7,58 +7,13 @@ from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.tensor import nlinalg, slinalg +from pytensor.tensor import nlinalg from tests.link.numba.test_basic import compare_numba_and_py, set_test_value rng = np.random.default_rng(42849) -@pytest.mark.parametrize( - "A, x, lower, exc", - [ - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), - "gen", - None, - ), - ( - set_test_value( - pt.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), - "gen", - None, - ), - ], -) -def test_Solve(A, x, lower, exc): - g = slinalg.Solve(lower=lower, b_ndim=1)(A, x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - ) - - @pytest.mark.parametrize( "x, exc", [ diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index b966ed2870..1569ea8ae8 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -22,6 +22,7 @@ batched_permutation_tester, batched_unweighted_choice_without_replacement_tester, batched_weighted_choice_without_replacement_tester, + create_mvnormal_cov_decomposition_method_test, ) @@ -147,6 +148,11 @@ def test_multivariate_normal(): ) +test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test( + "NUMBA" +) + + @pytest.mark.parametrize( "rv_op, dist_args, size", [ diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 8b1f3ececb..8e49627361 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,19 +1,23 @@ import re +from functools import partial +from typing import Literal import numpy as np import pytest +from numpy.testing import assert_allclose +from scipy import linalg as scipy_linalg import pytensor import pytensor.tensor as pt -from pytensor import config from pytensor.graph import FunctionGraph +from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py numba = pytest.importorskip("numba") -ATOL = 0 if config.floatX.endswith("64") else 1e-6 -RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 +floatX = pytensor.config.floatX + rng = np.random.default_rng(42849) @@ -27,8 +31,8 @@ def transpose_func(x, trans): @pytest.mark.parametrize( - "b_func, b_size", - [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + "b_shape", + [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"], ) @pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) @@ -36,50 +40,88 @@ def transpose_func(x, trans): @pytest.mark.parametrize( "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] ) -@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"]) +@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) @pytest.mark.filterwarnings( 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' ) -def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex): - if complex: +def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex): + if is_complex: # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, # why? pytest.skip("Complex inputs currently not supported to solve_triangular") - complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128" - dtype = complex_dtype if complex else config.floatX + complex_dtype = "complex64" if floatX.endswith("32") else "complex128" + dtype = complex_dtype if is_complex else floatX A = pt.matrix("A", dtype=dtype) - b = b_func("b", dtype=dtype) + b = pt.tensor("b", shape=b_shape, dtype=dtype) + + def A_func(x): + x = x @ x.conj().T + x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype) - X = pt.linalg.solve_triangular( - A, b, lower=lower, trans=trans, unit_diagonal=unit_diag + if unit_diag: + x_tri[np.diag_indices_from(x_tri)] = 1.0 + + return x_tri.astype(dtype) + + solve_op = partial( + pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag ) + + X = solve_op(A, b) f = pytensor.function([A, b], X, mode="NUMBA") A_val = np.random.normal(size=(5, 5)) - b = np.random.normal(size=b_size) + b_val = np.random.normal(size=b_shape) - if complex: + if is_complex: A_val = A_val + np.random.normal(size=(5, 5)) * 1j - b = b + np.random.normal(size=b_size) * 1j - A_sym = A_val @ A_val.conj().T + b_val = b_val + np.random.normal(size=b_shape) * 1j - A_tri = np.linalg.cholesky(A_sym).astype(dtype) - if unit_diag: - adj_mat = np.ones((5, 5)) - adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri) - A_tri = A_tri * adj_mat + X_np = f(A_func(A_val.copy()), b_val.copy()) - A_tri = A_tri.astype(dtype) - b = b.astype(dtype) + test_input = transpose_func(A_func(A_val.copy()), trans) - if not lower: - A_tri = A_tri.T + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 - X_np = f(A_tri, b) - np.testing.assert_allclose( - transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL + np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) + + compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()]) + + +@pytest.mark.parametrize( + "lower, unit_diag, trans", + [(True, True, True), (False, False, False)], + ids=["lower_unit_trans", "defaults"], +) +def test_solve_triangular_grad(lower, unit_diag, trans): + A_val = np.random.normal(size=(5, 5)).astype(floatX) + b_val = np.random.normal(size=(5, 5)).astype(floatX) + + # utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When + # a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be + # wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices + # to the space of triangular matrices, and test the gradient of that entire graph. + def A_func_pt(x): + x = x @ x.conj().T + x_tri = pt.linalg.cholesky(x, lower=lower).astype(floatX) + + if unit_diag: + n = A_val.shape[0] + x_tri = x_tri[np.diag_indices(n)].set(1.0) + + return transpose_func(x_tri.astype(floatX), trans) + + solve_op = partial( + pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag + ) + + utt.verify_grad( + lambda A, b: solve_op(A_func_pt(A), b), + [A_val.copy(), b_val.copy()], + mode="NUMBA", ) @@ -93,11 +135,11 @@ def test_solve_triangular_raises_on_nan_inf(value): X = pt.linalg.solve_triangular(A, b, check_finite=True) f = pytensor.function([A, b], X, mode="NUMBA") - A_val = np.random.normal(size=(5, 5)) + A_val = np.random.normal(size=(5, 5)).astype(floatX) A_sym = A_val @ A_val.conj().T - A_tri = np.linalg.cholesky(A_sym).astype(config.floatX) - b = np.full((5, 1), value) + A_tri = np.linalg.cholesky(A_sym).astype(floatX) + b = np.full((5, 1), value).astype(floatX) with pytest.raises( np.linalg.LinAlgError, @@ -119,19 +161,19 @@ def test_numba_Cholesky(lower, trans): fg = FunctionGraph(outputs=[chol]) - x = np.array([0.1, 0.2, 0.3]) - val = np.eye(3) + x[None, :] * x[:, None] + x = np.array([0.1, 0.2, 0.3]).astype(floatX) + val = np.eye(3).astype(floatX) + x[None, :] * x[:, None] compare_numba_and_py(fg, [val]) def test_numba_Cholesky_raises_on_nan_input(): - test_value = rng.random(size=(3, 3)).astype(config.floatX) + test_value = rng.random(size=(3, 3)).astype(floatX) test_value[0, 0] = np.nan - x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + x = pt.tensor(dtype=floatX, shape=(3, 3)) x = x.T.dot(x) - g = pt.linalg.cholesky(x, check_finite=True) + g = pt.linalg.cholesky(x) f = pytensor.function([x], g, mode="NUMBA") with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): @@ -140,9 +182,9 @@ def test_numba_Cholesky_raises_on_nan_input(): @pytest.mark.parametrize("on_error", ["nan", "raise"]) def test_numba_Cholesky_raise_on(on_error): - test_value = rng.random(size=(3, 3)).astype(config.floatX) + test_value = rng.random(size=(3, 3)).astype(floatX) - x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + x = pt.tensor(dtype=floatX, shape=(3, 3)) g = pt.linalg.cholesky(x, on_error=on_error) f = pytensor.function([x], g, mode="NUMBA") @@ -155,6 +197,16 @@ def test_numba_Cholesky_raise_on(on_error): assert np.all(np.isnan(f(test_value))) +@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) +def test_numba_Cholesky_grad(lower): + rng = np.random.default_rng(utt.fetch_seed()) + L = rng.normal(size=(5, 5)).astype(floatX) + X = L @ L.T + + chol_op = partial(pt.linalg.cholesky, lower=lower) + utt.verify_grad(chol_op, [X], mode="NUMBA") + + def test_block_diag(): A = pt.matrix("A") B = pt.matrix("B") @@ -162,9 +214,242 @@ def test_block_diag(): D = pt.matrix("D") X = pt.linalg.block_diag(A, B, C, D) - A_val = np.random.normal(size=(5, 5)) - B_val = np.random.normal(size=(3, 3)) - C_val = np.random.normal(size=(2, 2)) - D_val = np.random.normal(size=(4, 4)) + A_val = np.random.normal(size=(5, 5)).astype(floatX) + B_val = np.random.normal(size=(3, 3)).astype(floatX) + C_val = np.random.normal(size=(2, 2)).astype(floatX) + D_val = np.random.normal(size=(4, 4)).astype(floatX) out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X]) compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val]) + + +def test_lamch(): + from scipy.linalg import get_lapack_funcs + + from pytensor.link.numba.dispatch.slinalg import _xlamch + + @numba.njit() + def xlamch(kind): + return _xlamch(kind) + + lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),)) + + np.testing.assert_allclose(xlamch("E"), lamch("E")) + np.testing.assert_allclose(xlamch("S"), lamch("S")) + np.testing.assert_allclose(xlamch("P"), lamch("P")) + np.testing.assert_allclose(xlamch("B"), lamch("B")) + np.testing.assert_allclose(xlamch("R"), lamch("R")) + np.testing.assert_allclose(xlamch("M"), lamch("M")) + + +@pytest.mark.parametrize( + "ord_numba, ord_scipy", [("F", "fro"), ("1", 1), ("I", np.inf)] +) +def test_xlange(ord_numba, ord_scipy): + # xlange is called internally only, we don't dispatch pt.linalg.norm to it + from scipy import linalg + + from pytensor.link.numba.dispatch.slinalg import _xlange + + @numba.njit() + def xlange(x, ord): + return _xlange(x, ord) + + x = np.random.normal(size=(5, 5)).astype(floatX) + np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy)) + + +@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)]) +def test_xgecon(ord_numba, ord_scipy): + # gecon is called internally only, we don't dispatch pt.linalg.norm to it + from scipy.linalg import get_lapack_funcs + + from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange + + @numba.njit() + def gecon(x, norm): + anorm = _xlange(x, norm) + cond, info = _xgecon(x, anorm, norm) + return cond, info + + x = np.random.normal(size=(5, 5)).astype(floatX) + + rcond, info = gecon(x, norm=ord_numba) + + # Test against direct call to the underlying LAPACK functions + # Solution does **not** agree with 1 / np.linalg.cond(x) ! + lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,)) + norm = lange(ord_numba, x) + rcond2, _ = gecon(x, norm, norm=ord_numba) + + assert info == 0 + np.testing.assert_allclose(rcond, rcond2) + + +@pytest.mark.parametrize("overwrite_a", [True, False]) +def test_getrf(overwrite_a): + from scipy.linalg import lu_factor + + from pytensor.link.numba.dispatch.slinalg import _getrf + + # TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor + + @numba.njit() + def getrf(x, overwrite_a): + return _getrf(x, overwrite_a=overwrite_a) + + x = np.random.normal(size=(5, 5)).astype(floatX) + x = np.asfortranarray( + x + ) # x needs to be fortran-contiguous going into getrf for the overwrite option to work + + lu, ipiv = lu_factor(x, overwrite_a=False) + LU, IPIV, info = getrf(x, overwrite_a=overwrite_a) + + assert info == 0 + assert_allclose(LU, lu) + + if overwrite_a: + assert_allclose(x, LU) + + # TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing + # this, though. + assert_allclose(IPIV - 1, ipiv) + + +@pytest.mark.parametrize("trans", [0, 1]) +@pytest.mark.parametrize("overwrite_a", [True, False]) +@pytest.mark.parametrize("overwrite_b", [True, False]) +@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"]) +def test_getrs(trans, overwrite_a, overwrite_b, b_shape): + from scipy.linalg import lu_factor + from scipy.linalg import lu_solve as sp_lu_solve + + from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs + + # TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor + + @numba.njit() + def lu_solve(a, b, trans, overwrite_a, overwrite_b): + lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a) + x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b) + return x, lu, info + + a = np.random.normal(size=(5, 5)).astype(floatX) + b = np.random.normal(size=b_shape).astype(floatX) + + # inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work + a = np.asfortranarray(a) + b = np.asfortranarray(b) + + lu_and_piv = lu_factor(a, overwrite_a=False) + x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False) + + x, lu, info = lu_solve( + a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b + ) + assert info == 0 + if overwrite_a: + assert_allclose(a, lu) + if overwrite_b: + assert_allclose(b, x) + + assert_allclose(x, x_sp) + + +@pytest.mark.parametrize( + "b_shape", + [(5, 1), (5, 5), (5,)], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) +@pytest.mark.filterwarnings( + 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' +) +def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]): + A = pt.matrix("A", dtype=floatX) + b = pt.tensor("b", shape=b_shape, dtype=floatX) + + A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX)) + b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX)) + + def A_func(x): + if assume_a == "pos": + x = x @ x.T + elif assume_a == "sym": + x = (x + x.T) / 2 + return x + + X = pt.linalg.solve( + A_func(A), + b, + assume_a=assume_a, + b_ndim=len(b_shape), + ) + f = pytensor.function( + [pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA" + ) + op = f.maker.fgraph.outputs[0].owner.op + + compare_numba_and_py(([A, b], [X]), inputs=[A_val, b_val], inplace=True) + + # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. + A_val_copy = A_val.copy() + b_val_copy = b_val.copy() + + X_np = f(A_val, b_val) + + # overwrite_b is preferred when both inputs can be destroyed + assert op.destroy_map == {0: [1]} + + # Confirm inputs were destroyed by checking against the copies + assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) + assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) + + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 + + # Confirm b_val is used to store to solution + np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL) + assert not np.allclose(b_val, b_val_copy) + + # Test that the result is numerically correct. Need to use the unmodified copy + np.testing.assert_allclose( + A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL + ) + + # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here + utt.verify_grad( + lambda A, b: pt.linalg.solve( + A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape) + ), + [A_val_copy, b_val_copy], + mode="NUMBA", + ) + + +@pytest.mark.parametrize( + "b_func, b_size", + [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}") +def test_cho_solve(b_func, b_size, lower): + A = pt.matrix("A", dtype=floatX) + b = b_func("b", dtype=floatX) + + C = pt.linalg.cholesky(A, lower=lower) + X = pt.linalg.cho_solve((C, lower), b) + f = pytensor.function([A, b], X, mode="NUMBA") + + A = np.random.normal(size=(5, 5)).astype(floatX) + A = A @ A.conj().T + + b = np.random.normal(size=b_size) + b = b.astype(floatX) + + X_np = f(A, b) + + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 + + np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 2ac8ee7c3b..d5c23c83e4 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries(): compare_pytorch_and_py( f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) ) + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "n_splits, axis, values, sizes", + [ + ( + 0, + 0, + rng.normal(size=20).astype(config.floatX), + [], + ), + ( + 5, + 0, + rng.normal(size=5).astype(config.floatX), + rng.multinomial(5, np.ones(5) / 5), + ), + ( + 5, + 0, + rng.normal(size=10).astype(config.floatX), + rng.multinomial(10, np.ones(5) / 5), + ), + ( + 5, + -1, + rng.normal(size=(11, 7)).astype(config.floatX), + rng.multinomial(7, np.ones(5) / 5), + ), + ( + 5, + -2, + rng.normal(size=(11, 7)).astype(config.floatX), + rng.multinomial(11, np.ones(5) / 5), + ), + ], +) +def test_Split(n_splits, axis, values, sizes): + i = pt.tensor("i", shape=values.shape, dtype=config.floatX) + s = pt.vector("s", dtype="int64") + g = pt.split(i, s, n_splits, axis=axis) + assert len(g) == n_splits + if n_splits == 0: + return + g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g) + + compare_pytorch_and_py(g_fg, [values, sizes]) diff --git a/tests/link/test_vm.py b/tests/link/test_vm.py index 69a922e731..dad7ed4fdd 100644 --- a/tests/link/test_vm.py +++ b/tests/link/test_vm.py @@ -1,4 +1,3 @@ -import time from collections import Counter import numpy as np @@ -108,23 +107,25 @@ def numpy_version(x, depth): return z def time_numpy(): + # TODO: Make this a benchmark test steps_a = 5 steps_b = 100 x = np.asarray([2.0, 3.0], dtype=config.floatX) numpy_version(x, steps_a) - t0 = time.perf_counter() - # print numpy_version(x, steps_a) - t1 = time.perf_counter() - t2 = time.perf_counter() - # print numpy_version(x, steps_b) - t3 = time.perf_counter() - t_a = t1 - t0 - t_b = t3 - t2 + # t0 = time.perf_counter() + numpy_version(x, steps_a) + # t1 = time.perf_counter() + # t2 = time.perf_counter() + numpy_version(x, steps_b) + # t3 = time.perf_counter() + # t_a = t1 - t0 + # t_b = t3 - t2 - print(f"numpy takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") + # print(f"numpy takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") def time_linker(name, linker): + # TODO: Make this a benchmark test steps_a = 5 steps_b = 100 x = vector() @@ -135,20 +136,20 @@ def time_linker(name, linker): f_b = function([x], b, mode=Mode(optimizer=None, linker=linker())) f_a([2.0, 3.0]) - t0 = time.perf_counter() + # t0 = time.perf_counter() f_a([2.0, 3.0]) - t1 = time.perf_counter() + # t1 = time.perf_counter() f_b([2.0, 3.0]) - t2 = time.perf_counter() + # t2 = time.perf_counter() f_b([2.0, 3.0]) - t3 = time.perf_counter() + # t3 = time.perf_counter() - t_a = t1 - t0 - t_b = t3 - t2 + # t_a = t1 - t0 + # t_b = t3 - t2 - print(f"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") + # print(f"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") time_linker("c|py", OpWiseCLinker) time_linker("vmLinker", VMLinker) @@ -167,7 +168,7 @@ def time_linker(name, linker): ], ) def test_speed_lazy(linker): - # TODO FIXME: This isn't a real test. + # TODO FIXME: This isn't a real test. Make this a benchmark test def build_graph(x, depth=5): z = x @@ -185,20 +186,20 @@ def build_graph(x, depth=5): f_b = function([x], b, mode=Mode(optimizer=None, linker=linker)) f_a([2.0]) - t0 = time.perf_counter() + # t0 = time.perf_counter() f_a([2.0]) - t1 = time.perf_counter() + # t1 = time.perf_counter() f_b([2.0]) - t2 = time.perf_counter() + # t2 = time.perf_counter() f_b([2.0]) - t3 = time.perf_counter() + # t3 = time.perf_counter() - t_a = t1 - t0 - t_b = t3 - t2 + # t_a = t1 - t0 + # t_b = t3 - t2 - print(f"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") + # print(f"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") @pytest.mark.parametrize( diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index b75e9ca852..351c2e703a 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -12,7 +12,6 @@ import os import pickle import shutil -import sys from pathlib import Path from tempfile import mkdtemp @@ -1923,7 +1922,8 @@ def inner_fn(): fgrad = function([], g_sh) assert fgrad() == 1 - def test_R_op(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op(self, use_op_rop_implementation): seed = utt.fetch_seed() rng = np.random.default_rng(seed) floatX = config.floatX @@ -1958,9 +1958,9 @@ def rnn_fn(_u, _y, _W): eh0 = vector("eh0") eW = matrix("eW") - nwo_u = Rop(o, _u, eu) - nwo_h0 = Rop(o, _h0, eh0) - nwo_W = Rop(o, _W, eW) + nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation) + nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation) + nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation) fn_rop = function( [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore" ) @@ -1993,12 +1993,13 @@ def rnn_fn(_u, _y, _W): vnu, vnh0, vnW = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) tnu, tnh0, tnW = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) - utt.assert_allclose(vnu, tnu, atol=1e-6) - utt.assert_allclose(vnh0, tnh0, atol=1e-6) - utt.assert_allclose(vnW, tnW, atol=1e-6) + np.testing.assert_allclose(vnu, tnu, atol=1e-6) + np.testing.assert_allclose(vnh0, tnh0, atol=1e-6) + np.testing.assert_allclose(vnW, tnW, atol=1e-6) @pytest.mark.slow - def test_R_op_2(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op_2(self, use_op_rop_implementation): seed = utt.fetch_seed() rng = np.random.default_rng(seed) floatX = config.floatX @@ -2041,9 +2042,9 @@ def rnn_fn(_u, _y, _W): eh0 = vector("eh0") eW = matrix("eW") - nwo_u = Rop(o, _u, eu) - nwo_h0 = Rop(o, _h0, eh0) - nwo_W = Rop(o, _W, eW) + nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation) + nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation) + nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation) fn_rop = function( [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W, o], on_unused_input="ignore" ) @@ -2075,11 +2076,12 @@ def rnn_fn(_u, _y, _W): ) tnu, tnh0, tnW, tno = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) - utt.assert_allclose(vnu, tnu, atol=1e-6) - utt.assert_allclose(vnh0, tnh0, atol=1e-6) - utt.assert_allclose(vnW, tnW, atol=2e-6) + np.testing.assert_allclose(vnu, tnu, atol=1e-6) + np.testing.assert_allclose(vnh0, tnh0, atol=1e-6) + np.testing.assert_allclose(vnW, tnW, atol=2e-6) - def test_R_op_mitmot(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op_mitmot(self, use_op_rop_implementation): # this test is a copy paste from the script given by Justin Bayer to # reproduce this bug # We have 2 parameter groups with the following shapes. @@ -2095,13 +2097,10 @@ def test_R_op_mitmot(self): W1 = pars[:3].reshape(W1shape) W2 = pars[3:].reshape(W2shape) - # Define recurrent model. We are using a model where each input is a - # tensor - # of shape (T, B, D) where T is the number of timesteps, B is the - # number of - # sequences iterated over in parallel and D is the dimensionality of - # each - # item at a timestep. + # Define recurrent model. We are using a model where each input + # is a tensor of shape (T, B, D) where T is the number of timesteps, + # B is the number of sequences iterated over in parallel and + # D is the dimensionality of each item at a timestep. inpt = tensor3("inpt") target = tensor3("target") @@ -2129,7 +2128,130 @@ def test_R_op_mitmot(self): d_cost_wrt_pars = grad(cost, pars) p = dvector() - Rop(d_cost_wrt_pars, pars, p) + # TODO: We should test something about the Rop! + Rop( + d_cost_wrt_pars, + pars, + p, + use_op_rop_implementation=use_op_rop_implementation, + ) + + def test_second_derivative_disconnected_cost_with_mit_mot(self): + # This test is a regression test for a bug that was revealed + # when we computed the pushforward of a Scan gradient via two applications of pullback + seq = pt.vector("seq", shape=(2,)) + z = pt.scalar("z") + x0 = pt.vector("x0", shape=(2,)) + + # When s is 1 and z is 2, xs[-1] is just a sneaky + # x ** 4 (after two nsteps) + # grad should be 4 * x ** 3 + # and grad of grad should be 12 * x ** 2 + def step(s, xtm2, xtm1, z): + return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2) + + xs, _ = scan( + step, + sequences=[seq], + outputs_info=[{"initial": x0, "taps": (-2, -1)}], + non_sequences=[z], + n_steps=2, + ) + last_x = xs[-1] + + g_wrt_x0, g_wrt_z, g_wrt_seq = pt.grad(last_x, [x0, z, seq]) + g = g_wrt_x0.sum() + g_wrt_z.sum() * 0 + g_wrt_seq.sum() * 0 + assert g.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 4 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96 + + # Leave out z + g_wrt_x0, g_wrt_seq = pt.grad(last_x, [x0, seq]) + g = g_wrt_x0.sum() + g_wrt_seq.sum() * 0 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96 + + # Leave out seq + g_wrt_x0, g_wrt_z = pt.grad(last_x, [x0, z]) + g = g_wrt_x0.sum() + g_wrt_z.sum() * 0 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2 + + # Leave out z and seq + g_wrt_x0 = pt.grad(last_x, x0) + g = g_wrt_x0.sum() + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2 + + @pytest.mark.parametrize("case", ("inside-explicit", "inside-implicit", "outside")) + def test_non_shaped_input_disconnected_gradient(self, case): + """Test that Scan gradient works when non shaped variables are disconnected from the gradient. + + Regression test for https://github.com/pymc-devs/pytensor/issues/6 + """ + + # In all cases rng is disconnected from the output gradient + # Note that when it is an input to the scan (explicit or not) it is still not updated by the scan, + # so it is equivalent to the `outside` case. A rewrite could have legally hoisted the rng out of the scan. + rng = shared(np.random.default_rng()) + + data = pt.zeros(16) + + nonlocal_random_index = pt.random.integers(16, rng=rng) + nonlocal_random_datum = data[nonlocal_random_index] + + if case == "outside": + + def step(s, random_datum): + return (random_datum + s) ** 2 + + strict = True + non_sequences = [nonlocal_random_datum] + + elif case == "inside-implicit": + + def step(s): + return (nonlocal_random_datum + s) ** 2 + + strict = False + non_sequences = [] # Scan will introduce the non_sequences for us + + elif case == "inside-explicit": + + def step(s, data, rng): + random_index = pt.random.integers( + 16, rng=rng + ) # Not updated by the scan + random_datum = data[random_index] + return (random_datum + s) ** 2 + + strict = (True,) + non_sequences = [data, rng] + + else: + raise ValueError(f"Invalid case: {case}") + + seq = vector("seq") + xs, _ = scan( + step, + sequences=[seq], + non_sequences=non_sequences, + strict=strict, + ) + x0 = xs[0] + + np.testing.assert_allclose( + x0.eval({seq: [np.pi, np.nan, np.nan]}), + np.pi**2, + ) + np.testing.assert_allclose( + grad(x0, seq)[0].eval({seq: [np.pi, np.nan, np.nan]}), + 2 * np.pi, + ) @pytest.mark.skipif( @@ -3076,7 +3198,7 @@ def loss_inner(sum_inner, W): cost = result_outer[0][-1] H = hessian(cost, W) - print(".", file=sys.stderr) + # print(".", file=sys.stderr) f = function([W, n_steps], H) benchmark(f, np.ones((8,), dtype="float32"), 1) diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 6f77625f2f..fd9c43b129 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -673,7 +673,7 @@ def test_machine_translation(self): zi = tensor3("zi") zi_value = x_value - init = pt.alloc(np.cast[config.floatX](0), batch_size, dim) + init = pt.alloc(np.asarray(0, dtype=config.floatX), batch_size, dim) def rnn_step1( # sequences diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index acc793156f..f8a6c243c0 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -778,8 +778,10 @@ def rand_bool_mask(shape, rng=None): multivariate_normal, ( np.array([200, 250], dtype=config.floatX), - # Second covariance is invalid, to test it is not chosen - np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX) + # Second covariance is very large, to test it is not chosen + np.dstack([np.eye(2), np.eye(2) * 1000, np.eye(2)]).T.astype( + config.floatX + ) * 1e-6, ), (3,), diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 7d24a49228..4192a6c473 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1,6 +1,6 @@ import pickle import re -from copy import copy +from copy import deepcopy import numpy as np import pytest @@ -19,6 +19,7 @@ from pytensor.tensor import ones, stack from pytensor.tensor.random.basic import ( ChoiceWithoutReplacement, + MvNormalRV, PermutationRV, _gamma, bernoulli, @@ -113,7 +114,9 @@ def test_fn(*args, random_state=None, **kwargs): pt_rng = shared(rng, borrow=True) - numpy_res = np.asarray(test_fn(*param_vals, random_state=copy(rng), **kwargs_vals)) + numpy_res = np.asarray( + test_fn(*param_vals, random_state=deepcopy(rng), **kwargs_vals) + ) pytensor_res = rv(*params, rng=pt_rng, **kwargs) @@ -521,13 +524,19 @@ def test_fn(shape, scale, **kwargs): def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): - if mean is None: - mean = np.array([0.0], dtype=config.floatX) - if cov is None: - cov = np.array([[1.0]], dtype=config.floatX) - if size is not None: - size = tuple(size) - return multivariate_normal.rng_fn(random_state, mean, cov, size) + rng = random_state if random_state is not None else np.random.default_rng() + + if size is None: + size = np.broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + + mean = np.broadcast_to(mean, (*size, *mean.shape[-1:])) + cov = np.broadcast_to(cov, (*size, *cov.shape[-2:])) + + @np.vectorize(signature="(n),(n,n)->(n)") + def vec_mvnormal(mean, cov): + return rng.multivariate_normal(mean, cov, method="cholesky") + + return vec_mvnormal(mean, cov) @pytest.mark.parametrize( @@ -609,18 +618,30 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): ), ], ) +@pytest.mark.skipif( + config.floatX == "float32", + reason="Draws are only strictly equal to numpy in float64", +) def test_mvnormal_samples(mu, cov, size): compare_sample_values( multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn ) -def test_mvnormal_default_args(): - compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn) +def test_mvnormal_no_default_args(): + with pytest.raises( + TypeError, match="missing 2 required positional arguments: 'mean' and 'cov'" + ): + multivariate_normal() + +def test_mvnormal_impl_catches_incompatible_size(): with pytest.raises(ValueError, match="operands could not be broadcast together "): multivariate_normal.rng_fn( - None, np.zeros((3, 2)), np.ones((3, 2, 2)), size=(4,) + np.random.default_rng(), + np.zeros((3, 2)), + np.broadcast_to(np.eye(2), (3, 2, 2)), + size=(4,), ) @@ -668,6 +689,49 @@ def test_mvnormal_ShapeFeature(): assert s4.get_test_value() == 3 +def create_mvnormal_cov_decomposition_method_test(mode): + @pytest.mark.parametrize("psd", (True, False)) + @pytest.mark.parametrize("method", ("cholesky", "svd", "eigh")) + def test_mvnormal_cov_decomposition_method(method, psd): + mean = 2 ** np.arange(3) + if psd: + cov = [ + [1, 0.5, -1], + [0.5, 2, 0], + [-1, 0, 3], + ] + else: + cov = [ + [1, 0.5, 0], + [0.5, 2, 0], + [0, 0, 0], + ] + rng = shared(np.random.default_rng(675)) + draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,)) + assert draws.owner.op.method == method + + # JAX doesn't raise errors at runtime + if not psd and method == "cholesky": + if mode == "JAX": + # JAX doesn't raise errors at runtime, instead it returns nan + np.isnan(draws.eval(mode=mode)).all() + else: + with pytest.raises(np.linalg.LinAlgError): + draws.eval(mode=mode) + + else: + draws_eval = draws.eval(mode=mode) + np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, rtol=0.02) + np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.1) + + return test_mvnormal_cov_decomposition_method + + +test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test( + None +) + + @pytest.mark.parametrize( "alphas, size", [ diff --git a/tests/tensor/random/test_type.py b/tests/tensor/random/test_type.py index d289862347..d358f2a93a 100644 --- a/tests/tensor/random/test_type.py +++ b/tests/tensor/random/test_type.py @@ -52,7 +52,7 @@ def test_filter(self): with pytest.raises(TypeError): rng_type.filter(1) - rng_dict = rng.__getstate__() + rng_dict = rng.bit_generator.state assert rng_type.is_valid_value(rng_dict) is False assert rng_type.is_valid_value(rng_dict, strict=False) @@ -88,13 +88,13 @@ def test_values_eq(self): assert rng_type.values_eq(bitgen_g, bitgen_h) assert rng_type.is_valid_value(bitgen_a, strict=True) - assert rng_type.is_valid_value(bitgen_b.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_b.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_c, strict=True) - assert rng_type.is_valid_value(bitgen_d.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_d.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_e, strict=True) - assert rng_type.is_valid_value(bitgen_f.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_f.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_g, strict=True) - assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_h.bit_generator.state, strict=False) def test_may_share_memory(self): bg_a = np.random.PCG64() diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 70e8a710e9..f7d8731c1b 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -165,14 +165,20 @@ def test_seed(self, rng_ctor): state_rng = random.state_updates[0][0].get_value(borrow=True) if hasattr(state_rng, "get_state"): - ref_state = ref_rng.get_state() random_state = state_rng.get_state() + + # hack to try to get something reasonable for ref_rng + try: + ref_state = ref_rng.get_state() + except AttributeError: + ref_state = list(ref_rng.bit_generator.state.values()) + assert np.array_equal(random_state[1], ref_state[1]) assert random_state[0] == ref_state[0] assert random_state[2:] == ref_state[2:] else: - ref_state = ref_rng.__getstate__() - random_state = state_rng.__getstate__() + ref_state = ref_rng.bit_generator.state + random_state = state_rng.bit_generator.state assert random_state["bit_generator"] == ref_state["bit_generator"] assert random_state["state"] == ref_state["state"] diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 8911f56630..ac8576a8a1 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -332,7 +332,6 @@ def test_basic_tile(self): mode = rewrite_mode.including( "local_dimshuffle_lift", - "local_useless_dimshuffle_in_reshape", "local_alloc_sink_dimshuffle", ) f = function([x], [y], mode=mode) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index f1b71949d1..6fb0594ed5 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -56,7 +56,10 @@ from pytensor.tensor.math import round as pt_round from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift -from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape +from pytensor.tensor.rewriting.shape import ( + local_fuse_squeeze_reshape, + local_useless_expand_dims_in_reshape, +) from pytensor.tensor.shape import reshape from pytensor.tensor.type import ( TensorType, @@ -182,7 +185,7 @@ def test_dimshuffle_lift_multi_out_elemwise(self): assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner) -def test_local_useless_dimshuffle_in_reshape(): +def test_local_useless_expand_dims_in_reshape(): vec = TensorType(dtype="float64", shape=(None,))("vector") mat = TensorType(dtype="float64", shape=(None, None))("mat") row = TensorType(dtype="float64", shape=(1, None))("row") @@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape(): clone=False, ) assert len(g.apply_nodes) == 4 * 3 - useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) + useless_dimshuffle_in_reshape = out2in( + local_useless_expand_dims_in_reshape, + # Useless squeeze in reshape is not a canonicalization anymore + local_fuse_squeeze_reshape, + ) useless_dimshuffle_in_reshape.rewrite(g) assert equal_computations( g.outputs, @@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape(): # Check stacktrace was copied over correctly after rewrite was applied assert check_stack_trace(g, ops_to_check="all") - # Check that the rewrite does not get applied when the order - # of dimensions has changed. + # Check that the rewrite does not mess meaningful transpositions before the reshape reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False) assert len(h.apply_nodes) == 3 useless_dimshuffle_in_reshape.rewrite(h) - assert equal_computations( - h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)] - ) + assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)]) class TestFusion: diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c9b9afff19..50e48ce95d 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -14,7 +14,7 @@ from pytensor.tensor import swapaxes from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import _allclose, dot, matmul +from pytensor.tensor.math import dot, matmul from pytensor.tensor.nlinalg import ( SVD, Det, @@ -42,18 +42,19 @@ from tests.test_rop import break_op -ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8 - - -def test_rop_lop(): +def test_matrix_inverse_rop_lop(): + rtol = 1e-7 if config.floatX == "float64" else 1e-5 mx = matrix("mx") mv = matrix("mv") v = vector("v") y = MatrixInverse()(mx).sum(axis=0) - yv = pytensor.gradient.Rop(y, mx, mv) + yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True) rop_f = function([mx, mv], yv) + yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False) + rop_via_lop_f = function([mx, mv], yv_via_lop) + sy, _ = pytensor.scan( lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(), sequences=pt.arange(y.shape[0]), @@ -65,22 +66,16 @@ def test_rop_lop(): vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) + np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol) - assert _allclose(v1, v2), f"ROP mismatch: {v1} {v2}" - - raised = False - try: + with pytest.raises(ValueError): pytensor.gradient.Rop( - pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv - ) - except ValueError: - raised = True - if not raised: - raise Exception( - "Op did not raised an error even though the function" - " is not differentiable" + pytensor.clone_replace(y, replace={mx: break_op(mx)}), + mx, + mv, + use_op_rop_implementation=True, ) vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) @@ -90,9 +85,9 @@ def test_rop_lop(): sy = pytensor.gradient.grad((v * y).sum(), mx) scan_f = function([mx, v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + v = lop_f(vx, vv) + np.testing.assert_allclose(v, v_ref, rtol=rtol) def test_transinv_to_invtrans(): @@ -630,11 +625,12 @@ def test_inv_diag_from_eye_mul(shape, inv_op): inverse_matrix = np.linalg.inv(x_test_matrix) rewritten_inverse = f_rewritten(x_test) + atol = rtol = 1e-3 if config.floatX == "float32" else 1e-8 assert_allclose( inverse_matrix, rewritten_inverse, - atol=ATOL, - rtol=RTOL, + atol=atol, + rtol=rtol, ) @@ -657,11 +653,12 @@ def test_inv_diag_from_diag(inv_op): inverse_matrix = np.linalg.inv(x_test_matrix) rewritten_inverse = f_rewritten(x_test) + atol = rtol = 1e-3 if config.floatX == "float32" else 1e-8 assert_allclose( inverse_matrix, rewritten_inverse, - atol=ATOL, - rtol=RTOL, + atol=atol, + rtol=rtol, ) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index a1759ef81b..9a092663a9 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1628,6 +1628,7 @@ def test_local_mul_specialize(): def speed_local_pow_specialize_range(): + # TODO: This should be a benchmark test val = np.random.random(1e7) v = vector() mode = get_default_mode() @@ -1641,9 +1642,9 @@ def speed_local_pow_specialize_range(): t2 = time.perf_counter() f2(val) t3 = time.perf_counter() - print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2) + # print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2) if not t2 - t1 < t3 - t2: - print("WARNING WE ARE SLOWER") + raise ValueError("WARNING WE ARE SLOWER") for i in range(-3, -1500, -1): f1 = function([v], v**i, mode=mode) f2 = function([v], v**i, mode=mode_without_pow_rewrite) @@ -1653,9 +1654,9 @@ def speed_local_pow_specialize_range(): t2 = time.perf_counter() f2(val) t3 = time.perf_counter() - print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2) + # print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2) if not t2 - t1 < t3 - t2: - print("WARNING WE ARE SLOWER") + raise ValueError("WARNING WE ARE SLOWER") def test_local_pow_specialize(): @@ -2483,19 +2484,20 @@ def test_local_grad_log_erfc_neg(self): assert f.maker.fgraph.outputs[0].dtype == config.floatX def speed_local_log_erfc(self): + # TODO: Make this a benchmark test! val = np.random.random(1e6) x = vector() mode = get_mode("FAST_RUN") f1 = function([x], log(erfc(x)), mode=mode.excluding("local_log_erfc")) f2 = function([x], log(erfc(x)), mode=mode) - print(f1.maker.fgraph.toposort()) - print(f2.maker.fgraph.toposort()) - t0 = time.perf_counter() + # print(f1.maker.fgraph.toposort()) + # print(f2.maker.fgraph.toposort()) + # t0 = time.perf_counter() f1(val) - t1 = time.perf_counter() + # t1 = time.perf_counter() f2(val) - t2 = time.perf_counter() - print(t1 - t0, t2 - t1) + # t2 = time.perf_counter() + # print(t1 - t0, t2 - t1) class TestLocalMergeSwitchSameCond: @@ -4144,13 +4146,13 @@ def check(expr1, expr2): perform_sigm_times_exp(trees[0]) trees[0] = simplify_mul(trees[0]) good = is_same_graph(compute_mul(trees[0]), compute_mul(trees[1])) - if not good: - print(trees[0]) - print(trees[1]) - print("***") - pytensor.printing.debugprint(compute_mul(trees[0])) - print("***") - pytensor.printing.debugprint(compute_mul(trees[1])) + # if not good: + # print(trees[0]) + # print(trees[1]) + # print("***") + # pytensor.printing.debugprint(compute_mul(trees[0])) + # print("***") + # pytensor.printing.debugprint(compute_mul(trees[1])) assert good check(sigmoid(x) * exp_op(-x), sigmoid(-x)) diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index bbfd829070..27678bd630 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -6,7 +6,7 @@ import pytensor.tensor as pt from pytensor import shared from pytensor.compile.function import function -from pytensor.compile.mode import get_default_mode, get_mode +from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import deep_copy_op from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations @@ -383,6 +383,13 @@ def test_all_but_one_match(self): new_out = rewrite_graph(out) assert new_out is out + # Or if more than one dimension cannot be matched + x = tensor(shape=(None, None, None)) + shape = [x.shape[0], 3, 3] + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is out + class TestLocalReshapeToDimshuffle: def setup_method(self): @@ -419,6 +426,60 @@ def test_basic(self): assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape)) + def test_expand_dims(self): + x = pt.scalar() + # This reshape does an implicit expand_dims + out = x.reshape((1, -1)) + assert isinstance(out.owner.op, Reshape) + new_out = rewrite_graph(out, include=("canonicalize",)) + assert equal_computations([new_out], [pt.expand_dims(x, (0, 1))]) + + def test_squeeze_of_alloc(self): + # This shows up in the graph of repeat + x = pt.vector("x", shape=(9,)) + bcast_x = pt.alloc(x, 1, 12, x.shape[0]) + + # This reshape does an implicit squeeze + out = bcast_x.reshape((12, x.shape[0])) + + new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt")) + assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False) + + +def test_expand_dims_squeeze_reshape_fusion(): + x = pt.tensor("x", shape=(1, 9)) + reshape_x = x.squeeze(0).reshape((3, 3))[..., None] + + assert isinstance(reshape_x.owner.op, DimShuffle) + assert isinstance(reshape_x.owner.inputs[0].owner.op, Reshape) + assert isinstance(reshape_x.owner.inputs[0].owner.inputs[0].owner.op, DimShuffle) + + out = rewrite_graph(reshape_x, include=("specialize",)) + + # In this case we cannot get rid of the reshape, squeeze or expand_dims, + # so we fuse them all in one reshape + assert equal_computations([out], [x.reshape((3, 3, 1))]) + + +def test_implicit_broadcasting_via_repeat(): + x = pt.vector("x", shape=(3,), dtype=int) + y = pt.vector("y", shape=(9,), dtype=int) + out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1) + # There are two Reshapes in the graph + assert isinstance(out.owner.inputs[0].owner.op, Reshape) + assert isinstance(out.owner.inputs[1].owner.op, Reshape) + + new_out = rewrite_graph(out, include=("canonicalize", "specialize")) + assert equal_computations([new_out], [x[None] <= y[:, None]]) + + no_rewrite_mode = Mode(linker="py", optimizer=None) + x_test = np.arange(3) + 1 + y_test = np.arange(9) + np.testing.assert_allclose( + new_out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode), + out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode), + ) + def test_local_reshape_lift(): x = tensor4() diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 6b5ec48112..467dc66407 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3198,7 +3198,6 @@ def test_autocast_custom(): assert (dvector() + 1.1).dtype == "float64" assert (fvector() + np.float32(1.1)).dtype == "float32" assert (fvector() + np.float64(1.1)).dtype == "float64" - assert (fvector() + 1.1).dtype == config.floatX assert (lvector() + np.int64(1)).dtype == "int64" assert (lvector() + np.int32(1)).dtype == "int64" assert (lvector() + np.int16(1)).dtype == "int64" diff --git a/tests/tensor/test_complex.py b/tests/tensor/test_complex.py index f0f7333f9c..a1b99751ed 100644 --- a/tests/tensor/test_complex.py +++ b/tests/tensor/test_complex.py @@ -73,9 +73,7 @@ def f(a): try: utt.verify_grad(f, [aval]) except GradientError as e: - print(e.num_grad.gf) - print(e.analytic_grad) - raise + raise ValueError(f"Failed: {e.num_grad.gf=} {e.analytic_grad=}") from e @pytest.mark.skip(reason="Complex grads not enabled, see #178") def test_mul_mixed1(self): @@ -88,9 +86,7 @@ def f(a): try: utt.verify_grad(f, [aval]) except GradientError as e: - print(e.num_grad.gf) - print(e.analytic_grad) - raise + raise ValueError(f"Failed: {e.num_grad.gf=} {e.analytic_grad=}") from e @pytest.mark.skip(reason="Complex grads not enabled, see #178") def test_mul_mixed(self): @@ -104,9 +100,7 @@ def f(a, b): try: utt.verify_grad(f, [aval, bval]) except GradientError as e: - print(e.num_grad.gf) - print(e.analytic_grad) - raise + raise ValueError(f"Failed: {e.num_grad.gf=} {e.analytic_grad=}") from e @pytest.mark.skip(reason="Complex grads not enabled, see #178") def test_polar_grads(self): diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index bd208c5848..5ce533d3a3 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -18,6 +18,7 @@ from pytensor.graph.replace import vectorize_node from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker +from pytensor.npy_2_compat import numpy_maxdims from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -39,7 +40,27 @@ ) from tests import unittest_tools from tests.link.test_link import make_function -from tests.tensor.test_math import reduce_bitwise_and + + +def reduce_bitwise_and(x, axis=-1, dtype="int8"): + """Helper function for TestCAReduce""" + if dtype == "uint8": + # in numpy version >= 2.0, out of bounds uint8 values are not converted + identity = np.array((255,), dtype=dtype)[0] + else: + identity = np.array((-1,), dtype=dtype)[0] + + shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis) + if 0 in shape_without_axis: + return np.empty(shape=shape_without_axis, dtype=x.dtype) + + def custom_reduce(a): + out = identity + for i in range(a.size): + out = np.bitwise_and(a[i], out) + return out + + return np.apply_along_axis(custom_reduce, axis, x) class TestDimShuffle(unittest_tools.InferShapeTester): @@ -121,7 +142,8 @@ def test_infer_shape(self): def test_too_big_rank(self): x = self.type(self.dtype, shape=())() - y = x.dimshuffle(("x",) * (np.MAXDIMS + 1)) + y = x.dimshuffle(("x",) * (numpy_maxdims + 1)) + with pytest.raises(ValueError): y.eval({x: 0}) @@ -672,7 +694,7 @@ def test_scalar_input(self): assert self.op(ps.add, axis=(-1,))(x).eval({x: 5}) == 5 with pytest.raises( - np.AxisError, + np.exceptions.AxisError, match=re.escape("axis (-2,) is out of bounds for array of dimension 0"), ): self.op(ps.add, axis=(-2,))(x) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index c45e6b1e48..6a93f3c7fd 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -9,6 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph.basic import Constant, applys_between, equal_computations +from pytensor.npy_2_compat import old_np_unique from pytensor.raise_op import Assert from pytensor.tensor import alloc from pytensor.tensor.elemwise import DimShuffle @@ -469,7 +470,7 @@ def test_scalar_input(self): assert squeeze(x, axis=(0,)).eval({x: 5}) == 5 with pytest.raises( - np.AxisError, + np.exceptions.AxisError, match=re.escape("axis (1,) is out of bounds for array of dimension 0"), ): squeeze(x, axis=1) @@ -595,7 +596,6 @@ def test_basic(self, ndim, dtype): isinstance(n.op, Repeat) for n in f.maker.fgraph.toposort() ) - @pytest.mark.slow @pytest.mark.parametrize("ndim", [1, 3]) @pytest.mark.parametrize("dtype", ["int8", "uint8", "uint64"]) def test_infer_shape(self, ndim, dtype): @@ -606,6 +606,10 @@ def test_infer_shape(self, ndim, dtype): a = rng.random(shp).astype(config.floatX) for axis in self._possible_axis(ndim): + if axis is not None and axis < 0: + # Operator does not support negative axis + continue + r_var = scalar(dtype=dtype) r = np.asarray(3, dtype=dtype) if dtype in self.numpy_unsupported_dtypes: @@ -635,12 +639,23 @@ def test_infer_shape(self, ndim, dtype): self.op_class, ) - @pytest.mark.parametrize("ndim", range(3)) - def test_grad(self, ndim): - a = np.random.random((10,) * ndim).astype(config.floatX) - - for axis in self._possible_axis(ndim): - utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a]) + @pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}") + @pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}") + @pytest.mark.parametrize("axis", [None, 0, 1], ids=lambda a: f"axis={a}") + def test_grad(self, x_ndim, repeats_ndim, axis): + rng = np.random.default_rng( + [653, x_ndim, 2 if axis is None else axis, repeats_ndim] + ) + x_test = rng.normal(size=np.arange(3, 3 + x_ndim)) + if repeats_ndim == 0: + repeats_size = () + else: + repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,) + repeats = rng.integers(1, 6, size=repeats_size) + utt.verify_grad( + lambda x: Repeat(axis=axis)(x, repeats), + [x_test], + ) def test_broadcastable(self): x = TensorType(config.floatX, shape=(None, 1, None))() @@ -694,7 +709,7 @@ def test_perform(self, shp): y = scalar() f = function([x, y], fill_diagonal(x, y)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out), val) @@ -706,7 +721,7 @@ def test_perform_3d(self): x = tensor3() y = scalar() f = function([x, y], fill_diagonal(x, y)) - val = np.cast[config.floatX](rng.random() + 10) + val = rng.random(dtype=config.floatX) + 10 out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert out[0, 0, 0] == val @@ -768,7 +783,7 @@ def test_perform(self, test_offset, shp): f = function([x, y, z], fill_diagonal_offset(x, y, z)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val, test_offset) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out, test_offset), val) @@ -885,14 +900,14 @@ def setup_method(self): ) def test_basic_vector(self, x, inp, axis): list_outs_expected = [ - np.unique(inp, axis=axis), - np.unique(inp, True, axis=axis), - np.unique(inp, False, True, axis=axis), - np.unique(inp, True, True, axis=axis), - np.unique(inp, False, False, True, axis=axis), - np.unique(inp, True, False, True, axis=axis), - np.unique(inp, False, True, True, axis=axis), - np.unique(inp, True, True, True, axis=axis), + old_np_unique(inp, axis=axis), + old_np_unique(inp, True, axis=axis), + old_np_unique(inp, False, True, axis=axis), + old_np_unique(inp, True, True, axis=axis), + old_np_unique(inp, False, False, True, axis=axis), + old_np_unique(inp, True, False, True, axis=axis), + old_np_unique(inp, False, True, True, axis=axis), + old_np_unique(inp, True, True, True, axis=axis), ] for params, outs_expected in zip( self.op_params, list_outs_expected, strict=True diff --git a/tests/tensor/test_fft.py b/tests/tensor/test_fft.py index 94c49662bc..3976c67622 100644 --- a/tests/tensor/test_fft.py +++ b/tests/tensor/test_fft.py @@ -43,7 +43,6 @@ def test_1Drfft(self): utt.assert_allclose(rfft_ref, res_rfft_comp) m = rfft.type() - print(m.ndim) irfft = fft.irfft(m) f_irfft = pytensor.function([m], irfft) res_irfft = f_irfft(res_rfft) diff --git a/tests/tensor/test_io.py b/tests/tensor/test_io.py index cece2af277..4c5e5655fe 100644 --- a/tests/tensor/test_io.py +++ b/tests/tensor/test_io.py @@ -49,7 +49,7 @@ def test_memmap(self): path = Variable(Generic(), None) x = load(path, "int32", (None,), mmap_mode="c") fn = function([path], x) - assert isinstance(fn(self.filename), np.core.memmap) + assert isinstance(fn(self.filename), np.memmap) def teardown_method(self): (pytensor.config.compiledir / "_test.npy").unlink() diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 2d19ef0114..9ab4fd104d 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -19,10 +19,11 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, grad, numeric_grad -from pytensor.graph.basic import Variable, ancestors, applys_between +from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import vectorize_node from pytensor.link.c.basic import DualLinker +from pytensor.npy_2_compat import using_numpy_2 from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.tensor import blas, blas_c @@ -391,11 +392,20 @@ def test_maximum_minimum_grad(): grad=_grad_broadcast_unary_normal, ) + +# in numpy >= 2.0, negating a uint raises an error +neg_good = _good_broadcast_unary_normal.copy() +if using_numpy_2: + neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")} +else: + neg_bad = None + TestNegBroadcast = makeBroadcastTester( op=neg, expected=lambda x: -x, - good=_good_broadcast_unary_normal, + good=neg_good, grad=_grad_broadcast_unary_normal, + bad_compile=neg_bad, ) TestSgnBroadcast = makeBroadcastTester( @@ -1393,18 +1403,37 @@ def _grad_list(self): # check_grad_max(data, eval_outputs(grad(max_and_argmax(n, # axis=1)[0], n)),axis=1) + @pytest.mark.parametrize( + "dtype", + ( + "uint8", + "uint16", + "uint32", + pytest.param("uint64", marks=pytest.mark.xfail(reason="Fails due to #770")), + ), + ) + def test_uint(self, dtype): + itype = np.iinfo(dtype) + data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) + n = as_tensor_variable(data) + + assert min(n).dtype == dtype + i_min = eval_outputs(min(n)) + assert i_min == itype.min + + assert max(n).dtype == dtype + i_max = eval_outputs(max(n)) + assert i_max == itype.max + @pytest.mark.xfail(reason="Fails due to #770") - def test_uint(self): - for dtype in ("uint8", "uint16", "uint32", "uint64"): - itype = np.iinfo(dtype) - data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) - n = as_tensor_variable(data) - assert min(n).dtype == dtype - i = eval_outputs(min(n)) - assert i == itype.min - assert max(n).dtype == dtype - i = eval_outputs(max(n)) - assert i == itype.max + def test_uint64_special_value(self): + """Example from issue #770""" + dtype = "uint64" + data = np.array([0, 9223372036854775], dtype=dtype) + n = as_tensor_variable(data) + + i_max = eval_outputs(max(n)) + assert i_max == data.max() def test_bool(self): data = np.array([True, False], "bool") @@ -2278,7 +2307,7 @@ def test_type_shape(self): with pytest.raises( ValueError, - match="Input arrays have inconsistent broadcastable pattern or type shape", + match="Input arrays have inconsistent type shape", ): tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1) @@ -2323,6 +2352,41 @@ def test_shape_assert(self, axes, has_assert, values, expected_fail): else: assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv})) + def test_eager_simplification(self): + # Test that cases where tensordot isn't needed, it returns a simple graph + scl = tensor(shape=()) + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + + # scalar product + out = tensordot(scl, scl, axes=[[], []]) + assert equal_computations([out], [scl * scl]) + + # vector-vector product + out = tensordot(vec, vec, axes=[[-1], [-1]]) + assert equal_computations([out], [dot(vec, vec)]) + + # matrix-vector product + out = tensordot(mat, vec, axes=[[-1], [-1]]) + assert equal_computations([out], [dot(mat, vec)]) + + out = tensordot(mat, vec, axes=[[-2], [-1]]) + assert equal_computations([out], [dot(mat.T, vec)]) + + # vector-matrix product + out = tensordot(vec, mat, axes=[[-1], [-2]]) + assert equal_computations([out], [dot(vec, mat)]) + + out = tensordot(vec, mat, axes=[[-1], [-1]]) + assert equal_computations([out], [dot(vec, mat.T)]) + + # matrix-matrix product + out = tensordot(mat, mat, axes=[[-1], [-2]]) + assert equal_computations([out], [dot(mat, mat)]) + + out = tensordot(mat, mat, axes=[[-1], [-1]]) + assert equal_computations([out], [dot(mat, mat.T)]) + def test_smallest(): x = dvector() @@ -2457,11 +2521,22 @@ def pytensor_i_scalar(dtype): def numpy_i_scalar(dtype): return numpy_scalar(dtype) + pytensor_funcs = { + "scalar": pytensor_scalar, + "array": pytensor_array, + "i_scalar": pytensor_i_scalar, + } + numpy_funcs = { + "scalar": numpy_scalar, + "array": numpy_array, + "i_scalar": numpy_i_scalar, + } + with config.change_flags(cast_policy="numpy+floatX"): # We will test all meaningful combinations of # scalar and array operations. - pytensor_args = [eval(f"pytensor_{c}") for c in combo] - numpy_args = [eval(f"numpy_{c}") for c in combo] + pytensor_args = [pytensor_funcs[c] for c in combo] + numpy_args = [numpy_funcs[c] for c in combo] pytensor_arg_1 = pytensor_args[0](a_type) pytensor_arg_2 = pytensor_args[1](b_type) pytensor_dtype = op( @@ -3409,22 +3484,6 @@ def test_var_axes(self): x.var(a) -def reduce_bitwise_and(x, axis=-1, dtype="int8"): - identity = np.array((-1,), dtype=dtype)[0] - - shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis) - if 0 in shape_without_axis: - return np.empty(shape=shape_without_axis, dtype=x.dtype) - - def custom_reduce(a): - out = identity - for i in range(a.size): - out = np.bitwise_and(a[i], out) - return out - - return np.apply_along_axis(custom_reduce, axis, x) - - def test_clip_grad(): # test the gradient of clip def func(x, y, z): diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 921aae826b..8f70950206 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -306,16 +306,6 @@ def scipy_special_gammal(k, x): name="Chi2SF", ) -TestChi2SFInplaceBroadcast = makeBroadcastTester( - op=inplace.chi2sf_inplace, - expected=expected_chi2sf, - good=_good_broadcast_unary_chi2sf, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, - name="Chi2SF", -) - rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_binary_gamma = dict( normal=( diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 2ffcb25fe5..3f0b04d45d 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -98,6 +98,7 @@ def setup_method(self): Shape_i, DimShuffle, Elemwise, + SpecifyShape, ) super().setup_method() @@ -253,9 +254,7 @@ def test_bad_shape(self): f(a_val, [7, 5]) with pytest.raises(ValueError): f(a_val, [-1, -1]) - with pytest.raises( - ValueError, match=".*Shape argument to Reshape has incorrect length.*" - ): + with pytest.raises(AssertionError): f(a_val, [3, 4, 1]) def test_0(self): @@ -603,7 +602,7 @@ def test_validation(self): class TestRopLop(RopLopChecker): def test_shape(self): - self.check_nondiff_rop(self.x.shape[0]) + self.check_nondiff_rop(self.x.shape[0], self.x, self.v) def test_specifyshape(self): self.check_rop_lop(specify_shape(self.x, self.in_shape), self.in_shape) @@ -797,7 +796,6 @@ def test_reshape(self): assert equal_computations([vect_out], [reshape(mat, new_shape)]) new_shape = stack([[-1, x], [x - 1, -1]], axis=0) - print(new_shape.type) [vect_out] = vectorize_node(node, vec, new_shape).outputs vec_test_value = np.arange(6) np.testing.assert_allclose( diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f46d771938..34f1396f4c 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -209,12 +209,12 @@ def test__repr__(self): ) -class TestSolve(utt.InferShapeTester): - def test__init__(self): - with pytest.raises(ValueError) as excinfo: - Solve(assume_a="test", b_ndim=2) - assert "is not a recognized matrix structure" in str(excinfo.value) +def test_solve_raises_on_invalid_A(): + with pytest.raises(ValueError, match="is not a recognized matrix structure"): + Solve(assume_a="test", b_ndim=2) + +class TestSolve(utt.InferShapeTester): @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) def test_infer_shape(self, b_shape): rng = np.random.default_rng(utt.fetch_seed()) @@ -232,64 +232,78 @@ def test_infer_shape(self, b_shape): warn=False, ) - def test_correctness(self): + @pytest.mark.parametrize( + "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] + ) + @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + def test_solve_correctness(self, b_size: tuple[int], assume_a: str): rng = np.random.default_rng(utt.fetch_seed()) - A = matrix() - b = matrix() - y = solve(A, b) - gen_solve_func = pytensor.function([A, b], y) + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("b", shape=b_size) - b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_size).astype(config.floatX) - A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) - A_val = np.dot(A_val.transpose(), A_val) + solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) - np.testing.assert_allclose( - scipy.linalg.solve(A_val, b_val, assume_a="gen"), - gen_solve_func(A_val, b_val), - ) + def A_func(x): + if assume_a == "pos": + return x @ x.T + elif assume_a == "sym": + return (x + x.T) / 2 + else: + return x + + solve_input_val = A_func(A_val) + + y = solve_op(A_func(A), b) + solve_func = pytensor.function([A, b], y) + X_np = solve_func(A_val.copy(), b_val.copy()) + + ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4 - A_undef = np.array( - [ - [1, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 1, 1], - [0, 0, 0, 1, 0], - ], - dtype=config.floatX, - ) np.testing.assert_allclose( - scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) + scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a), + X_np, + atol=ATOL, + rtol=RTOL, ) + np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL) + @pytest.mark.parametrize( - "m, n, assume_a, lower", - [ - (5, None, "gen", False), - (5, None, "gen", True), - (4, 2, "gen", False), - (4, 2, "gen", True), - ], + "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] ) - def test_solve_grad(self, m, n, assume_a, lower): + @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + @pytest.mark.skipif( + config.floatX == "float32", reason="Gradients not numerically stable in float32" + ) + def test_solve_gradient(self, b_size: tuple[int], assume_a: str): rng = np.random.default_rng(utt.fetch_seed()) - # Ensure diagonal elements of `A` are relatively large to avoid - # numerical precision issues - A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) + eps = 2e-8 if config.floatX == "float64" else None - if n is None: - b_val = rng.normal(size=m).astype(config.floatX) - else: - b_val = rng.normal(size=(m, n)).astype(config.floatX) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_size).astype(config.floatX) - eps = None - if config.floatX == "float64": - eps = 2e-8 + def A_func(x): + if assume_a == "pos": + return x @ x.T + elif assume_a == "sym": + return (x + x.T) / 2 + else: + return x - solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2) - utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) + solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) + + # To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices + # (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included, + # the random perturbations used by verify_grad will result in invalid input matrices, and + # LAPACK will silently do the wrong thing, making the gradients wrong + utt.verify_grad( + lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps + ) class TestSolveTriangular(utt.InferShapeTester): diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 9eb06f28a3..1a8b2455ec 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -152,7 +152,7 @@ def upcast_float16_ufunc(fn): """ def ret(*args, **kwargs): - out_dtype = np.find_common_type([a.dtype for a in args], [np.float16]) + out_dtype = np.result_type(np.float16, *args) if out_dtype == "float16": # Force everything to float32 sig = "f" * fn.nin + "->" + "f" * fn.nout @@ -339,6 +339,7 @@ def makeTester( good=None, bad_build=None, bad_runtime=None, + bad_compile=None, grad=None, mode=None, grad_rtol=None, @@ -373,6 +374,7 @@ def makeTester( _test_memmap = test_memmap _check_name = check_name _grad_eps = grad_eps + _bad_compile = bad_compile or {} class Checker: op = staticmethod(_op) @@ -382,6 +384,7 @@ class Checker: good = _good bad_build = _bad_build bad_runtime = _bad_runtime + bad_compile = _bad_compile grad = _grad mode = _mode skip = skip_ @@ -539,6 +542,24 @@ def test_bad_build(self): # instantiated on the following bad inputs: %s" # % (self.op, testname, node, inputs)) + @config.change_flags(compute_test_value="off") + @pytest.mark.skipif(skip, reason="Skipped") + def test_bad_compile(self): + for testname, inputs in self.bad_compile.items(): + inputrs = [shared(input) for input in inputs] + try: + node = safe_make_node(self.op, *inputrs) + except Exception as exc: + err_msg = ( + f"Test {self.op}::{testname}: Error occurred while trying" + f" to make a node with inputs {inputs}" + ) + exc.args += (err_msg,) + raise + + with pytest.raises(Exception): + inplace_func([], node.outputs, mode=mode, name="test_bad_runtime") + @config.change_flags(compute_test_value="off") @pytest.mark.skipif(skip, reason="Skipped") def test_bad_runtime(self): diff --git a/tests/test_config.py b/tests/test_config.py index 4370309f39..2dd3c32180 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -192,7 +192,7 @@ def test_invalid_configvar_access(): # But we can make sure that nothing crazy happens when we access it: with pytest.raises(configparser.ConfigAccessViolation, match="different instance"): - print(root.test__on_test_instance) + assert root.test__on_test_instance is not None def test_no_more_dotting(): diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 79c55caf44..24f5964c92 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -481,12 +481,12 @@ def make_grad_func(X): int_type = imatrix().dtype float_type = "float64" - X = np.cast[int_type](rng.standard_normal((m, d)) * 127.0) - W = np.cast[W.dtype](rng.standard_normal((d, n))) - b = np.cast[b.dtype](rng.standard_normal(n)) + X = np.asarray(rng.standard_normal((m, d)) * 127.0, dtype=int_type) + W = rng.standard_normal((d, n), dtype=W.dtype) + b = rng.standard_normal(n, dtype=b.dtype) int_result = int_func(X, W, b) - float_result = float_func(np.cast[float_type](X), W, b) + float_result = float_func(np.asarray(X, dtype=float_type), W, b) assert np.allclose(int_result, float_result), (int_result, float_result) @@ -508,7 +508,7 @@ def test_grad_disconnected(self): # the output f = pytensor.function([x], g) rng = np.random.default_rng([2012, 9, 5]) - x = np.cast[x.dtype](rng.standard_normal(3)) + x = rng.standard_normal(3, dtype=x.dtype) g = f(x) assert np.allclose(g, np.ones(x.shape, dtype=x.dtype)) @@ -631,7 +631,8 @@ def test_known_grads(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(10), rng.integers(10), rng.standard_normal()] values = [ - np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True) + np.asarray(value, dtype=ipt.dtype) + for ipt, value in zip(inputs, values, strict=True) ] true_grads = grad(cost, inputs, disconnected_inputs="ignore") @@ -679,7 +680,7 @@ def test_known_grads_integers(): f = pytensor.function([g_expected], g_grad) x = -3 - gv = np.cast[config.floatX](0.6) + gv = np.asarray(0.6, dtype=config.floatX) g_actual = f(gv) @@ -746,7 +747,8 @@ def test_subgraph_grad(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(2), rng.standard_normal(3)] values = [ - np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True) + np.asarray(value, dtype=ipt.dtype) + for ipt, value in zip(inputs, values, strict=True) ] wrt = [w2, w1] @@ -1031,21 +1033,21 @@ def test_jacobian_scalar(): # test when the jacobian is called with a tensor as wrt Jx = jacobian(y, x) f = pytensor.function([x], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a tuple as wrt Jx = jacobian(y, (x,)) assert isinstance(Jx, tuple) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list as wrt Jx = jacobian(y, [x]) assert isinstance(Jx, list) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list of two elements @@ -1053,8 +1055,8 @@ def test_jacobian_scalar(): y = x * z Jx = jacobian(y, [x, z]) f = pytensor.function([x, z], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) - vz = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) vJx = f(vx, vz) assert np.allclose(vJx[0], vz) diff --git a/tests/test_printing.py b/tests/test_printing.py index be5dbbc5a1..4dd4f3866d 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -138,9 +138,9 @@ def test_min_informative_str(): D. D E. E""" - if mis != reference: - print("--" + mis + "--") - print("--" + reference + "--") + # if mis != reference: + # print("--" + mis + "--") + # print("--" + reference + "--") assert mis == reference diff --git a/tests/test_rop.py b/tests/test_rop.py index 0b9fe41a1e..b592f557a5 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -16,8 +16,14 @@ import pytensor import pytensor.tensor as pt -from pytensor import function -from pytensor.gradient import Lop, Rop, grad, grad_undefined +from pytensor import config, function +from pytensor.gradient import ( + Lop, + NullTypeGradError, + Rop, + grad, + grad_undefined, +) from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.tensor.math import argmax, dot @@ -61,6 +67,10 @@ class RopLopChecker: Rop to class that inherit from it. """ + @staticmethod + def rtol(): + return 1e-7 if config.floatX == "float64" else 1e-5 + def setup_method(self): # Using vectors make things a lot simpler for generating the same # computations using scan @@ -72,13 +82,13 @@ def setup_method(self): self.mv = matrix("mv") self.mat_in_shape = (5 + self.rng.integers(3), 5 + self.rng.integers(3)) - def check_nondiff_rop(self, y): + def check_nondiff_rop(self, y, x, v): """ If your op is not differentiable(so you can't define Rop) test that an error is raised. """ with pytest.raises(ValueError): - Rop(y, self.x, self.v) + Rop(y, x, v, use_op_rop_implementation=True) def check_mat_rop_lop(self, y, out_shape): """ @@ -106,8 +116,14 @@ def check_mat_rop_lop(self, y, out_shape): vv = np.asarray( self.rng.uniform(size=self.mat_in_shape), pytensor.config.floatX ) - yv = Rop(y, self.mx, self.mv) + yv = Rop(y, self.mx, self.mv, use_op_rop_implementation=True) rop_f = function([self.mx, self.mv], yv, on_unused_input="ignore") + + yv_through_lop = Rop(y, self.mx, self.mv, use_op_rop_implementation=False) + rop_through_lop_f = function( + [self.mx, self.mv], yv_through_lop, on_unused_input="ignore" + ) + sy, _ = pytensor.scan( lambda i, y, x, v: (grad(y[i], x) * v).sum(), sequences=pt.arange(y.shape[0]), @@ -115,13 +131,14 @@ def check_mat_rop_lop(self, y, out_shape): ) scan_f = function([self.mx, self.mv], sy, on_unused_input="ignore") - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) - - assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref) + np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref) self.check_nondiff_rop( - pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}) + pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}), + self.mx, + self.mv, ) vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX) @@ -131,45 +148,47 @@ def check_mat_rop_lop(self, y, out_shape): sy = grad((self.v * y).sum(), self.mx) scan_f = function([self.mx, self.v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v = lop_f(vx, vv) + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(v, v_ref) - def check_rop_lop(self, y, out_shape): + def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True): """ As check_mat_rop_lop, except the input is self.x which is a vector. The output is still a vector. """ + rtol = self.rtol() + # TEST ROP vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) - yv = Rop(y, self.x, self.v) + yv = Rop(y, self.x, self.v, use_op_rop_implementation=True) rop_f = function([self.x, self.v], yv, on_unused_input="ignore") + + yv_through_lop = Rop(y, self.x, self.v, use_op_rop_implementation=False) + rop_through_lop_f = function( + [self.x, self.v], yv_through_lop, on_unused_input="ignore" + ) + J, _ = pytensor.scan( lambda i, y, x: grad(y[i], x), sequences=pt.arange(y.shape[0]), non_sequences=[y, self.x], ) sy = dot(J, self.v) - scan_f = function([self.x, self.v], sy, on_unused_input="ignore") - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) + np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref, rtol=rtol) - try: - Rop( + if check_nondiff_rop: + self.check_nondiff_rop( pytensor.clone_replace(y, replace={self.x: break_op(self.x)}), self.x, self.v, ) - except ValueError: - pytest.skip( - "Rop does not handle non-differentiable inputs " - "correctly. Bug exposed by fixing Add.grad method." - ) vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX) @@ -182,22 +201,20 @@ def check_rop_lop(self, y, out_shape): non_sequences=[y, self.x], ) sy = dot(self.v, J) - scan_f = function([self.x, self.v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v = lop_f(vx, vv) + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(v, v_ref, rtol=rtol) class TestRopLop(RopLopChecker): def test_max(self): - # self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ()) self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],)) self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],)) def test_argmax(self): - self.check_nondiff_rop(argmax(self.mx, axis=1)) + self.check_nondiff_rop(argmax(self.mx, axis=1), self.mx, self.mv) def test_subtensor(self): self.check_rop_lop(self.x[:4], (4,)) @@ -252,10 +269,14 @@ def test_dot(self): insh = self.in_shape[0] vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX) W = pytensor.shared(vW) - self.check_rop_lop(dot(self.x, W), self.in_shape) + # check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths + # See: test_Rop_partially_differentiable_paths + self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False) def test_elemwise0(self): - self.check_rop_lop((self.x + 1) ** 2, self.in_shape) + # check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths + # See: test_Rop_partially_differentiable_paths + self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False) def test_elemwise1(self): self.check_rop_lop(self.x + pt.cast(self.x, "int32"), self.in_shape) @@ -287,18 +308,18 @@ def test_alloc(self): self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0], ) - def test_invalid_input(self): - success = False - - try: - Rop(0.0, [matrix()], [vector()]) - success = True - except ValueError: - pass - - assert not success + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_invalid_input(self, use_op_rop_implementation): + with pytest.raises(ValueError): + Rop( + 0.0, + [matrix()], + [vector()], + use_op_rop_implementation=use_op_rop_implementation, + ) - def test_multiple_outputs(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_multiple_outputs(self, use_op_rop_implementation): m = matrix("m") v = vector("v") m_ = matrix("m_") @@ -309,10 +330,20 @@ def test_multiple_outputs(self): m_val = self.rng.uniform(size=(3, 7)).astype(pytensor.config.floatX) v_val = self.rng.uniform(size=(7,)).astype(pytensor.config.floatX) - rop_out1 = Rop([m, v, m + v], [m, v], [m_, v_]) + rop_out1 = Rop( + [m, v, m + v], + [m, v], + [m_, v_], + use_op_rop_implementation=use_op_rop_implementation, + ) assert isinstance(rop_out1, list) assert len(rop_out1) == 3 - rop_out2 = Rop((m, v, m + v), [m, v], [m_, v_]) + rop_out2 = Rop( + (m, v, m + v), + [m, v], + [m_, v_], + use_op_rop_implementation=use_op_rop_implementation, + ) assert isinstance(rop_out2, tuple) assert len(rop_out2) == 3 @@ -322,12 +353,65 @@ def test_multiple_outputs(self): f = pytensor.function([m, v, m_, v_], all_outs) f(mval, vval, m_val, v_val) - def test_Rop_dot_bug_18Oct2013_Jeremiah(self): + @pytest.mark.parametrize( + "use_op_rop_implementation", + [pytest.param(True, marks=pytest.mark.xfail()), False], + ) + def test_Rop_partially_differentiable_paths(self, use_op_rop_implementation): # This test refers to a bug reported by Jeremiah Lowin on 18th Oct # 2013. The bug consists when through a dot operation there is only # one differentiable path (i.e. there is no gradient wrt to one of # the inputs). x = pt.arange(20.0).reshape([1, 20]) - v = pytensor.shared(np.ones([20])) + v = pytensor.shared(np.ones([20]), name="v") d = dot(x, v).sum() - Rop(grad(d, v), v, v) + + Rop( + grad(d, v), + v, + v, + use_op_rop_implementation=use_op_rop_implementation, + # 2025: This is a tricky case, the gradient of the gradient does not depend on v + # although v still exists in the graph inside a `Second` operator. + # The original test was checking that Rop wouldn't raise an error, but Lop does. + # Since the correct behavior is ambiguous, I let both implementations off the hook. + disconnected_outputs="raise" if use_op_rop_implementation else "ignore", + ) + + # 2025: Here is an unambiguous test for the original commented issue: + x = pt.matrix("x") + y = pt.matrix("y") + out = dot(x, break_op(y)).sum() + # Should not raise an error + Rop( + out, + [x], + [x.type()], + use_op_rop_implementation=use_op_rop_implementation, + disconnected_outputs="raise", + ) + + # More extensive testing shows that the legacy Rop implementation FAILS to raise when + # the cost is linked through strictly non-differentiable paths. + # This is not Dot specific, we would observe the same with any operation where the gradient + # with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...) + out = dot(break_op(x), y).sum() + with pytest.raises((ValueError, NullTypeGradError)): + Rop( + out, + [x], + [x.type()], + use_op_rop_implementation=use_op_rop_implementation, + disconnected_outputs="raise", + ) + + # Only when both paths are non-differentiable is an error correctly raised again. + out = dot(break_op(x), break_op(y)).sum() + with pytest.raises((ValueError, NullTypeGradError)): + Rop( + out, + [x], + [x.type()], + use_op_rop_implementation=use_op_rop_implementation, + disconnected_outputs="raise", + ) diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 466bdc865d..19598bfb21 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -577,10 +577,10 @@ def test_correct_answer(self): x = tensor3() y = tensor3() - A = np.cast[pytensor.config.floatX](np.random.random((5, 3))) - B = np.cast[pytensor.config.floatX](np.random.random((7, 2))) - X = np.cast[pytensor.config.floatX](np.random.random((5, 6, 1))) - Y = np.cast[pytensor.config.floatX](np.random.random((1, 9, 3))) + A = np.random.random((5, 3)).astype(pytensor.config.floatX) + B = np.random.random((7, 2)).astype(pytensor.config.floatX) + X = np.random.random((5, 6, 1)).astype(pytensor.config.floatX) + Y = np.random.random((1, 9, 3)).astype(pytensor.config.floatX) make_list((3.0, 4.0)) c = make_list((a, b)) diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index a5b0a21a49..adb83fe7c0 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -1,5 +1,6 @@ import logging import sys +import warnings from copy import copy, deepcopy from functools import wraps @@ -41,12 +42,9 @@ def fetch_seed(pseed=None): else: seed = None except ValueError: - print( - ( - "Error: config.unittests__rseed contains " - "invalid seed, using None instead" - ), - file=sys.stderr, + warnings.warn( + "Error: config.unittests__rseed contains " + "invalid seed, using None instead" ) seed = None