From 17748b7d3419e78e7dfad3dc768908ac240aa6df Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 5 Feb 2025 10:39:59 +0100 Subject: [PATCH 01/43] Remove accidental print statements --- pyproject.toml | 11 ++++- pytensor/breakpoint.py | 16 +++---- pytensor/compile/compiledir.py | 46 +++++++++---------- pytensor/compile/debugmode.py | 8 ++-- pytensor/compile/mode.py | 2 +- pytensor/compile/monitormode.py | 6 +-- pytensor/compile/nanguardmode.py | 2 +- pytensor/compile/profiling.py | 4 +- pytensor/graph/features.py | 22 ++++----- pytensor/graph/fg.py | 2 +- pytensor/graph/rewriting/basic.py | 20 ++++---- pytensor/graph/utils.py | 4 +- pytensor/link/c/basic.py | 12 ++--- pytensor/link/c/op.py | 2 +- pytensor/printing.py | 7 ++- pytensor/tensor/basic.py | 1 - pytensor/tensor/nlinalg.py | 10 ++-- pytensor/tensor/rewriting/blas.py | 2 +- pytensor/tensor/rewriting/elemwise.py | 8 ++-- pytensor/tensor/rewriting/math.py | 16 +++---- scripts/slowest_tests/extract-slow-tests.py | 2 +- tests/d3viz/test_d3viz.py | 2 +- tests/link/c/test_cmodule.py | 1 - tests/link/numba/test_basic.py | 1 - tests/link/test_vm.py | 51 +++++++++++---------- tests/scan/test_basic.py | 3 +- tests/tensor/rewriting/test_math.py | 36 ++++++++------- tests/tensor/test_complex.py | 12 ++--- tests/tensor/test_fft.py | 1 - tests/tensor/test_shape.py | 1 - tests/test_config.py | 2 +- tests/test_printing.py | 6 +-- tests/unittest_tools.py | 10 ++-- 33 files changed, 161 insertions(+), 168 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4e2a1fdb05..e82c42753a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,7 @@ exclude = ["doc/", "pytensor/_version.py"] docstring-code-format = true [tool.ruff.lint] -select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"] +select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] unfixable = [ # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead @@ -144,7 +144,12 @@ lines-after-imports = 2 # TODO: Get rid of these: "**/__init__.py" = ["F401", "E402", "F403"] "pytensor/tensor/linalg.py" = ["F403"] -"pytensor/link/c/cmodule.py" = ["PTH"] +"pytensor/link/c/cmodule.py" = ["PTH", "T201"] +"pytensor/misc/elemwise_time_test.py" = ["T201"] +"pytensor/misc/elemwise_openmp_speedup.py" = ["T201"] +"pytensor/misc/check_duplicate_key.py" = ["T201"] +"pytensor/misc/check_blas.py" = ["T201"] +"pytensor/bin/pytensor_cache.py" = ["T201"] # For the tests we skip because `pytest.importorskip` is used: "tests/link/jax/test_scalar.py" = ["E402"] "tests/link/jax/test_tensor_basic.py" = ["E402"] @@ -158,6 +163,8 @@ lines-after-imports = 2 "tests/sparse/test_sp2.py" = ["E402"] "tests/sparse/test_utils.py" = ["E402"] "tests/sparse/sandbox/test_sp.py" = ["E402", "F401"] +"tests/compile/test_monitormode.py" = ["T201"] +"scripts/run_mypy.py" = ["T201"] [tool.mypy] diff --git a/pytensor/breakpoint.py b/pytensor/breakpoint.py index 314f2a7325..3d59b5c24c 100644 --- a/pytensor/breakpoint.py +++ b/pytensor/breakpoint.py @@ -108,14 +108,14 @@ def perform(self, node, inputs, output_storage): f"'{self.name}' could not be casted to NumPy arrays" ) - print("\n") - print("-------------------------------------------------") - print(f"Conditional breakpoint '{self.name}' activated\n") - print("The monitored variables are stored, in order,") - print("in the list variable 'monitored' as NumPy arrays.\n") - print("Their contents can be altered and, when execution") - print("resumes, the updated values will be used.") - print("-------------------------------------------------") + print("\n") # noqa: T201 + print("-------------------------------------------------") # noqa: T201 + print(f"Conditional breakpoint '{self.name}' activated\n") # noqa: T201 + print("The monitored variables are stored, in order,") # noqa: T201 + print("in the list variable 'monitored' as NumPy arrays.\n") # noqa: T201 + print("Their contents can be altered and, when execution") # noqa: T201 + print("resumes, the updated values will be used.") # noqa: T201 + print("-------------------------------------------------") # noqa: T201 try: import pudb diff --git a/pytensor/compile/compiledir.py b/pytensor/compile/compiledir.py index 0482ed6cd8..127b971b2e 100644 --- a/pytensor/compile/compiledir.py +++ b/pytensor/compile/compiledir.py @@ -95,10 +95,10 @@ def cleanup(): def print_title(title, overline="", underline=""): len_title = len(title) if overline: - print(str(overline) * len_title) - print(title) + print(str(overline) * len_title) # noqa: T201 + print(title) # noqa: T201 if underline: - print(str(underline) * len_title) + print(str(underline) * len_title) # noqa: T201 def print_compiledir_content(): @@ -159,7 +159,7 @@ def print_compiledir_content(): _logger.error(f"Could not read key file '{filename}'.") print_title(f"PyTensor cache: {compiledir}", overline="=", underline="=") - print() + print() # noqa: T201 print_title(f"List of {len(table)} compiled individual ops", underline="+") print_title( @@ -168,9 +168,9 @@ def print_compiledir_content(): ) table = sorted(table, key=lambda t: str(t[1])) for dir, op, types, compile_time in table: - print(dir, f"{compile_time:.3f}s", op, types) + print(dir, f"{compile_time:.3f}s", op, types) # noqa: T201 - print() + print() # noqa: T201 print_title( f"List of {len(table_multiple_ops)} compiled sets of ops", underline="+" ) @@ -180,9 +180,9 @@ def print_compiledir_content(): ) table_multiple_ops = sorted(table_multiple_ops, key=lambda t: (t[1], t[2])) for dir, ops_to_str, types_to_str, compile_time in table_multiple_ops: - print(dir, f"{compile_time:.3f}s", ops_to_str, types_to_str) + print(dir, f"{compile_time:.3f}s", ops_to_str, types_to_str) # noqa: T201 - print() + print() # noqa: T201 print_title( ( f"List of {len(table_op_class)} compiled Op classes and " @@ -191,33 +191,33 @@ def print_compiledir_content(): underline="+", ) for op_class, nb in reversed(table_op_class.most_common()): - print(op_class, nb) + print(op_class, nb) # noqa: T201 if big_key_files: big_key_files = sorted(big_key_files, key=lambda t: str(t[1])) big_total_size = sum(sz for _, sz, _ in big_key_files) - print( + print( # noqa: T201 f"There are directories with key files bigger than {int(max_key_file_size)} bytes " "(they probably contain big tensor constants)" ) - print( + print( # noqa: T201 f"They use {int(big_total_size)} bytes out of {int(total_key_sizes)} (total size " "used by all key files)" ) for dir, size, ops in big_key_files: - print(dir, size, ops) + print(dir, size, ops) # noqa: T201 nb_keys = sorted(nb_keys.items()) - print() + print() # noqa: T201 print_title("Number of keys for a compiled module", underline="+") print_title( "number of keys/number of modules with that number of keys", underline="-" ) for n_k, n_m in nb_keys: - print(n_k, n_m) - print() - print( + print(n_k, n_m) # noqa: T201 + print() # noqa: T201 + print( # noqa: T201 f"Skipped {int(zeros_op)} files that contained 0 op " "(are they always pytensor.scalar ops?)" ) @@ -242,18 +242,18 @@ def basecompiledir_ls(): subdirs = sorted(subdirs) others = sorted(others) - print(f"Base compile dir is {config.base_compiledir}") - print("Sub-directories (possible compile caches):") + print(f"Base compile dir is {config.base_compiledir}") # noqa: T201 + print("Sub-directories (possible compile caches):") # noqa: T201 for d in subdirs: - print(f" {d}") + print(f" {d}") # noqa: T201 if not subdirs: - print(" (None)") + print(" (None)") # noqa: T201 if others: - print() - print("Other files in base_compiledir:") + print() # noqa: T201 + print("Other files in base_compiledir:") # noqa: T201 for f in others: - print(f" {f}") + print(f" {f}") # noqa: T201 def basecompiledir_purge(): diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index cc1a5b225a..5c51222a1b 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -1315,9 +1315,9 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def printstuff(self): for key in self.equiv: - print(key) + print(key) # noqa: T201 for e in self.equiv[key]: - print(" ", e) + print(" ", e) # noqa: T201 # List of default version of make thunk. @@ -1569,7 +1569,7 @@ def f(): ##### for r, s in storage_map.items(): if s[0] is not None: - print(r, s) + print(r, s) # noqa: T201 assert s[0] is None # try: @@ -2079,7 +2079,7 @@ def __init__( raise StochasticOrder(infolog.getvalue()) else: if self.verbose: - print( + print( # noqa: T201 "OPTCHECK: optimization", i, "of", diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ae905089b5..43a5e131cb 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -178,7 +178,7 @@ def __init__(self, header): def apply(self, fgraph): import pytensor.printing - print("PrintCurrentFunctionGraph:", self.header) + print("PrintCurrentFunctionGraph:", self.header) # noqa: T201 pytensor.printing.debugprint(fgraph.outputs) diff --git a/pytensor/compile/monitormode.py b/pytensor/compile/monitormode.py index 770d4e2f7e..8663bc8832 100644 --- a/pytensor/compile/monitormode.py +++ b/pytensor/compile/monitormode.py @@ -108,8 +108,8 @@ def detect_nan(fgraph, i, node, fn): not isinstance(output[0], np.random.RandomState | np.random.Generator) and np.isnan(output[0]).any() ): - print("*** NaN detected ***") + print("*** NaN detected ***") # noqa: T201 debugprint(node) - print(f"Inputs : {[input[0] for input in fn.inputs]}") - print(f"Outputs: {[output[0] for output in fn.outputs]}") + print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201 + print(f"Outputs: {[output[0] for output in fn.outputs]}") # noqa: T201 break diff --git a/pytensor/compile/nanguardmode.py b/pytensor/compile/nanguardmode.py index 7f90825953..e2fd44cda3 100644 --- a/pytensor/compile/nanguardmode.py +++ b/pytensor/compile/nanguardmode.py @@ -236,7 +236,7 @@ def do_check_on(value, nd, var=None): if config.NanGuardMode__action == "raise": raise AssertionError(msg) elif config.NanGuardMode__action == "pdb": - print(msg) + print(msg) # noqa: T201 import pdb pdb.set_trace() diff --git a/pytensor/compile/profiling.py b/pytensor/compile/profiling.py index 3dfe5283bb..a68365527f 100644 --- a/pytensor/compile/profiling.py +++ b/pytensor/compile/profiling.py @@ -82,7 +82,7 @@ def _atexit_print_fn(): to_sum.append(ps) else: # TODO print the name if there is one! - print("Skipping empty Profile") + print("Skipping empty Profile") # noqa: T201 if len(to_sum) > 1: # Make a global profile cum = copy.copy(to_sum[0]) @@ -125,7 +125,7 @@ def _atexit_print_fn(): assert len(merge) == len(cum.rewriter_profile[1]) cum.rewriter_profile = (cum.rewriter_profile[0], merge) except Exception as e: - print(e) + print(e) # noqa: T201 cum.rewriter_profile = None else: cum.rewriter_profile = None diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 93321fa61f..06be6d013a 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -491,7 +491,7 @@ def validate_(self, fgraph): if verbose: r = uf.f_locals.get("r", "") reason = uf_info.function - print(f"validate failed on node {r}.\n Reason: {reason}, {e}") + print(f"validate failed on node {r}.\n Reason: {reason}, {e}") # noqa: T201 raise t1 = time.perf_counter() if fgraph.profile: @@ -603,13 +603,13 @@ def replace_all_validate( except Exception as e: fgraph.revert(chk) if verbose: - print( + print( # noqa: T201 f"rewriting: validate failed on node {r}.\n Reason: {reason}, {e}" ) raise if verbose: - print( + print( # noqa: T201 f"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}" ) @@ -692,11 +692,11 @@ def on_import(self, fgraph, node, reason): except TypeError: # node.op is unhashable return except Exception as e: - print("OFFENDING node", type(node), type(node.op), file=sys.stderr) + print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201 try: - print("OFFENDING node hash", hash(node.op), file=sys.stderr) + print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201 except Exception: - print("OFFENDING node not hashable", file=sys.stderr) + print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201 raise e def on_prune(self, fgraph, node, reason): @@ -725,7 +725,7 @@ def __init__(self, active=True): def on_attach(self, fgraph): if self.active: - print("-- attaching to: ", fgraph) + print("-- attaching to: ", fgraph) # noqa: T201 def on_detach(self, fgraph): """ @@ -733,19 +733,19 @@ def on_detach(self, fgraph): that it installed into the function_graph """ if self.active: - print("-- detaching from: ", fgraph) + print("-- detaching from: ", fgraph) # noqa: T201 def on_import(self, fgraph, node, reason): if self.active: - print(f"-- importing: {node}, reason: {reason}") + print(f"-- importing: {node}, reason: {reason}") # noqa: T201 def on_prune(self, fgraph, node, reason): if self.active: - print(f"-- pruning: {node}, reason: {reason}") + print(f"-- pruning: {node}, reason: {reason}") # noqa: T201 def on_change_input(self, fgraph, node, i, r, new_r, reason=None): if self.active: - print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") + print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") # noqa: T201 class PreserveVariableAttributes(Feature): diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 1d845e2eb3..e9b676f51a 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -491,7 +491,7 @@ def replace( if verbose is None: verbose = config.optimizer_verbose if verbose: - print( + print( # noqa: T201 f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}" ) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 344d6a1940..16b5b65a0e 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -1002,7 +1002,7 @@ def transform(self, fgraph, node, *args, **kwargs): # ensure we have data for all input variables that need it if missing: if self.verbose > 0: - print( + print( # noqa: T201 f"{self.__class__.__name__} cannot meta-rewrite {node}, " f"{len(missing)} of {int(node.nin)} input shapes unknown" ) @@ -1010,7 +1010,7 @@ def transform(self, fgraph, node, *args, **kwargs): # now we can apply the different rewrites in turn, # compile the resulting subgraphs and time their execution if self.verbose > 1: - print( + print( # noqa: T201 f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):" ) timings = [] @@ -1027,20 +1027,20 @@ def transform(self, fgraph, node, *args, **kwargs): continue except Exception as e: if self.verbose > 0: - print(f"* {node_rewriter}: exception", e) + print(f"* {node_rewriter}: exception", e) # noqa: T201 continue else: if self.verbose > 1: - print(f"* {node_rewriter}: {timing:.5g} sec") + print(f"* {node_rewriter}: {timing:.5g} sec") # noqa: T201 timings.append((timing, outputs, node_rewriter)) else: if self.verbose > 0: - print(f"* {node_rewriter}: not applicable") + print(f"* {node_rewriter}: not applicable") # noqa: T201 # finally, we choose the fastest one if timings: timings.sort() if self.verbose > 1: - print(f"= {timings[0][2]}") + print(f"= {timings[0][2]}") # noqa: T201 return timings[0][1] return @@ -1305,7 +1305,7 @@ def transform(self, fgraph, node): new_vars = list(new_repl.values()) if config.optimizer_verbose: - print( + print( # noqa: T201 f"rewriting: rewrite {rewrite} replaces node {node} with {new_repl}" ) @@ -2641,21 +2641,21 @@ def print_profile(cls, stream, prof, level=0): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: - print(blanc, "merge not implemented for ", o) + print(blanc, "merge not implemented for ", o) # noqa: T201 for o, prof in zip( rewrite.final_rewriters, final_sub_profs[i], strict=True ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: - print(blanc, "merge not implemented for ", o) + print(blanc, "merge not implemented for ", o) # noqa: T201 for o, prof in zip( rewrite.cleanup_rewriters, cleanup_sub_profs[i], strict=True ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: - print(blanc, "merge not implemented for ", o) + print(blanc, "merge not implemented for ", o) # noqa: T201 @staticmethod def merge_profile(prof1, prof2): diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py index 9c2eef5049..42ebbcd216 100644 --- a/pytensor/graph/utils.py +++ b/pytensor/graph/utils.py @@ -274,9 +274,9 @@ def __repr__(self): return "scratchpad" + str(self.__dict__) def info(self): - print(f"") + print(f"") # noqa: T201 for k, v in self.__dict__.items(): - print(f" {k}: {v}") + print(f" {k}: {v}") # noqa: T201 # These two methods have been added to help Mypy def __getattribute__(self, name): diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index 0b717c74a6..d7f43e7377 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -875,10 +875,10 @@ def code_gen(self): self.c_init_code_apply = c_init_code_apply if (self.init_tasks, self.tasks) != self.get_init_tasks(): - print("init_tasks\n", self.init_tasks, file=sys.stderr) - print(self.get_init_tasks()[0], file=sys.stderr) - print("tasks\n", self.tasks, file=sys.stderr) - print(self.get_init_tasks()[1], file=sys.stderr) + print("init_tasks\n", self.init_tasks, file=sys.stderr) # noqa: T201 + print(self.get_init_tasks()[0], file=sys.stderr) # noqa: T201 + print("tasks\n", self.tasks, file=sys.stderr) # noqa: T201 + print(self.get_init_tasks()[1], file=sys.stderr) # noqa: T201 assert (self.init_tasks, self.tasks) == self.get_init_tasks() # List of indices that should be ignored when passing the arguments @@ -1756,7 +1756,7 @@ def __call__(self): exc_value = exc_type(_exc_value) exc_value.__thunk_trace__ = trace except Exception: - print( + print( # noqa: T201 ( "ERROR retrieving error_storage." "Was the error set in the c code?" @@ -1764,7 +1764,7 @@ def __call__(self): end=" ", file=sys.stderr, ) - print(self.error_storage, file=sys.stderr) + print(self.error_storage, file=sys.stderr) # noqa: T201 raise raise exc_value.with_traceback(exc_trace) diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 74905d686f..b668f242e1 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -79,7 +79,7 @@ def is_f16(t): # that don't implement c code. In those cases, we # don't want to print a warning. cl.get_dynamic_module() - print(f"Disabling C code for {self} due to unsupported float16") + warnings.warn(f"Disabling C code for {self} due to unsupported float16") raise NotImplementedError("float16") outputs = cl.make_thunk( input_storage=node_input_storage, output_storage=node_output_storage diff --git a/pytensor/printing.py b/pytensor/printing.py index 6a18f6e8e5..bc42029c11 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -726,7 +726,7 @@ def _print_fn(op, xin): pmsg = temp() else: pmsg = temp - print(op.message, attr, "=", pmsg) + print(op.message, attr, "=", pmsg) # noqa: T201 class Print(Op): @@ -1657,7 +1657,7 @@ def apply_name(node): raise if print_output_file: - print("The output file is available at", outfile) + print("The output file is available at", outfile) # noqa: T201 class _TagGenerator: @@ -1824,8 +1824,7 @@ def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> s # The __str__ method is encoding the object's id in its str name = position_independent_str(obj) if " at 0x" in name: - print(name) - raise AssertionError() + raise AssertionError(name) prefix = cur_tag + "=" diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 8ee9894c9d..26bd34692b 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -613,7 +613,6 @@ def get_scalar_constant_value( """ if isinstance(v, TensorVariable | np.ndarray): if v.ndim != 0: - print(v, v.ndim) raise NotScalarConstantError("Input ndim != 0") return get_underlying_scalar_constant_value( v, diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 1f589e1789..a9d7016099 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -216,9 +216,8 @@ def perform(self, node, inputs, outputs): (z,) = outputs try: z[0] = np.asarray(np.linalg.det(x), dtype=x.dtype) - except Exception: - print("Failed to compute determinant", x) - raise + except Exception as e: + raise ValueError("Failed to compute determinant", x) from e def grad(self, inputs, g_outputs): (gz,) = g_outputs @@ -256,9 +255,8 @@ def perform(self, node, inputs, outputs): (sign, det) = outputs try: sign[0], det[0] = (np.array(z, dtype=x.dtype) for z in np.linalg.slogdet(x)) - except Exception: - print("Failed to compute determinant", x) - raise + except Exception as e: + raise ValueError("Failed to compute determinant", x) from e def infer_shape(self, fgraph, node, shapes): return [(), ()] diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index d3fc0398c4..31264f74d4 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -573,7 +573,7 @@ def print_profile(cls, stream, prof, level=0): print(blanc, " callbacks_time", file=stream) for i in sorted(prof[12].items(), key=lambda a: a[1]): if i[1] > 0: - print(i) + print(i) # noqa: T201 @node_rewriter([Dot]) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 3226f9b5a7..eaba64c275 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -314,14 +314,14 @@ def apply(self, fgraph): except (ValueError, InconsistencyError) as e: prof["nb_inconsistent"] += 1 if check_each_change != 1 and not raised_warning: - print( + print( # noqa: T201 ( "Some inplace rewriting was not " "performed due to an unexpected error:" ), file=sys.stderr, ) - print(e, file=sys.stderr) + print(e, file=sys.stderr) # noqa: T201 raised_warning = True fgraph.revert(chk) continue @@ -335,7 +335,7 @@ def apply(self, fgraph): fgraph.validate() except Exception: if not raised_warning: - print( + print( # noqa: T201 ( "Some inplace rewriting was not " "performed due to an unexpected error" @@ -1080,7 +1080,7 @@ def print_profile(stream, prof, level=0): print(blanc, " callbacks_time", file=stream) for i in sorted(prof[6].items(), key=lambda a: a[1])[::-1]: if i[1] > 0: - print(blanc, " ", i) + print(blanc, " ", i) # noqa: T201 print(blanc, " time_toposort", prof[7], file=stream) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 065ecfc0b1..0af1d40bf6 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -3434,14 +3434,14 @@ def perform_sigm_times_exp( sigm_minus_x = [] if full_tree is None: full_tree = tree - if False: # Debug code. - print("") - print(f" full_tree = {full_tree}") - print(f" tree = {tree}") - print(f" exp_x = {exp_x}") - print(f" exp_minus_x = {exp_minus_x}") - print(f" sigm_x = {sigm_x}") - print(f" sigm_minus_x= {sigm_minus_x}") + # if False: # Debug code. + # print("") + # print(f" full_tree = {full_tree}") + # print(f" tree = {tree}") + # print(f" exp_x = {exp_x}") + # print(f" exp_minus_x = {exp_minus_x}") + # print(f" sigm_x = {sigm_x}") + # print(f" sigm_minus_x= {sigm_minus_x}") neg, inputs = tree if isinstance(inputs, list): # Recurse through inputs of the multiplication. diff --git a/scripts/slowest_tests/extract-slow-tests.py b/scripts/slowest_tests/extract-slow-tests.py index 3a06e4a68b..14df837a7b 100644 --- a/scripts/slowest_tests/extract-slow-tests.py +++ b/scripts/slowest_tests/extract-slow-tests.py @@ -72,7 +72,7 @@ def main(read_lines): lines = read_lines() times = extract_lines(lines) parsed_times = format_times(times) - print("\n".join(parsed_times)) + print("\n".join(parsed_times)) # noqa: T201 if __name__ == "__main__": diff --git a/tests/d3viz/test_d3viz.py b/tests/d3viz/test_d3viz.py index 7e4b0426a0..38809a5faa 100644 --- a/tests/d3viz/test_d3viz.py +++ b/tests/d3viz/test_d3viz.py @@ -28,7 +28,7 @@ def check(self, f, reference=None, verbose=False): tmp_dir = Path(tempfile.mkdtemp()) html_file = tmp_dir / "index.html" if verbose: - print(html_file) + print(html_file) # noqa: T201 d3v.d3viz(f, html_file) assert html_file.stat().st_size > 0 if reference: diff --git a/tests/link/c/test_cmodule.py b/tests/link/c/test_cmodule.py index 2242bc12e9..46533fef35 100644 --- a/tests/link/c/test_cmodule.py +++ b/tests/link/c/test_cmodule.py @@ -258,7 +258,6 @@ def test_default_blas_ldflags( def patched_compile_tmp(*args, **kwargs): def wrapped(test_code, tmp_prefix, flags, try_run, output): if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]: - print(enabled_accelerate_framework) if enabled_accelerate_framework: return (True, True) else: diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 1b0fa8fd52..f0f73ca74d 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -836,7 +836,6 @@ def test_config_options_fastmath(): with config.change_flags(numba__fastmath=True): pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) - print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__)) numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] assert numba_mul_fn.targetoptions["fastmath"] == { "afn", diff --git a/tests/link/test_vm.py b/tests/link/test_vm.py index 69a922e731..dad7ed4fdd 100644 --- a/tests/link/test_vm.py +++ b/tests/link/test_vm.py @@ -1,4 +1,3 @@ -import time from collections import Counter import numpy as np @@ -108,23 +107,25 @@ def numpy_version(x, depth): return z def time_numpy(): + # TODO: Make this a benchmark test steps_a = 5 steps_b = 100 x = np.asarray([2.0, 3.0], dtype=config.floatX) numpy_version(x, steps_a) - t0 = time.perf_counter() - # print numpy_version(x, steps_a) - t1 = time.perf_counter() - t2 = time.perf_counter() - # print numpy_version(x, steps_b) - t3 = time.perf_counter() - t_a = t1 - t0 - t_b = t3 - t2 + # t0 = time.perf_counter() + numpy_version(x, steps_a) + # t1 = time.perf_counter() + # t2 = time.perf_counter() + numpy_version(x, steps_b) + # t3 = time.perf_counter() + # t_a = t1 - t0 + # t_b = t3 - t2 - print(f"numpy takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") + # print(f"numpy takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") def time_linker(name, linker): + # TODO: Make this a benchmark test steps_a = 5 steps_b = 100 x = vector() @@ -135,20 +136,20 @@ def time_linker(name, linker): f_b = function([x], b, mode=Mode(optimizer=None, linker=linker())) f_a([2.0, 3.0]) - t0 = time.perf_counter() + # t0 = time.perf_counter() f_a([2.0, 3.0]) - t1 = time.perf_counter() + # t1 = time.perf_counter() f_b([2.0, 3.0]) - t2 = time.perf_counter() + # t2 = time.perf_counter() f_b([2.0, 3.0]) - t3 = time.perf_counter() + # t3 = time.perf_counter() - t_a = t1 - t0 - t_b = t3 - t2 + # t_a = t1 - t0 + # t_b = t3 - t2 - print(f"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") + # print(f"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") time_linker("c|py", OpWiseCLinker) time_linker("vmLinker", VMLinker) @@ -167,7 +168,7 @@ def time_linker(name, linker): ], ) def test_speed_lazy(linker): - # TODO FIXME: This isn't a real test. + # TODO FIXME: This isn't a real test. Make this a benchmark test def build_graph(x, depth=5): z = x @@ -185,20 +186,20 @@ def build_graph(x, depth=5): f_b = function([x], b, mode=Mode(optimizer=None, linker=linker)) f_a([2.0]) - t0 = time.perf_counter() + # t0 = time.perf_counter() f_a([2.0]) - t1 = time.perf_counter() + # t1 = time.perf_counter() f_b([2.0]) - t2 = time.perf_counter() + # t2 = time.perf_counter() f_b([2.0]) - t3 = time.perf_counter() + # t3 = time.perf_counter() - t_a = t1 - t0 - t_b = t3 - t2 + # t_a = t1 - t0 + # t_b = t3 - t2 - print(f"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") + # print(f"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") @pytest.mark.parametrize( diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index b75e9ca852..9fa893ab27 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -12,7 +12,6 @@ import os import pickle import shutil -import sys from pathlib import Path from tempfile import mkdtemp @@ -3076,7 +3075,7 @@ def loss_inner(sum_inner, W): cost = result_outer[0][-1] H = hessian(cost, W) - print(".", file=sys.stderr) + # print(".", file=sys.stderr) f = function([W, n_steps], H) benchmark(f, np.ones((8,), dtype="float32"), 1) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index a1759ef81b..9a092663a9 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1628,6 +1628,7 @@ def test_local_mul_specialize(): def speed_local_pow_specialize_range(): + # TODO: This should be a benchmark test val = np.random.random(1e7) v = vector() mode = get_default_mode() @@ -1641,9 +1642,9 @@ def speed_local_pow_specialize_range(): t2 = time.perf_counter() f2(val) t3 = time.perf_counter() - print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2) + # print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2) if not t2 - t1 < t3 - t2: - print("WARNING WE ARE SLOWER") + raise ValueError("WARNING WE ARE SLOWER") for i in range(-3, -1500, -1): f1 = function([v], v**i, mode=mode) f2 = function([v], v**i, mode=mode_without_pow_rewrite) @@ -1653,9 +1654,9 @@ def speed_local_pow_specialize_range(): t2 = time.perf_counter() f2(val) t3 = time.perf_counter() - print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2) + # print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2) if not t2 - t1 < t3 - t2: - print("WARNING WE ARE SLOWER") + raise ValueError("WARNING WE ARE SLOWER") def test_local_pow_specialize(): @@ -2483,19 +2484,20 @@ def test_local_grad_log_erfc_neg(self): assert f.maker.fgraph.outputs[0].dtype == config.floatX def speed_local_log_erfc(self): + # TODO: Make this a benchmark test! val = np.random.random(1e6) x = vector() mode = get_mode("FAST_RUN") f1 = function([x], log(erfc(x)), mode=mode.excluding("local_log_erfc")) f2 = function([x], log(erfc(x)), mode=mode) - print(f1.maker.fgraph.toposort()) - print(f2.maker.fgraph.toposort()) - t0 = time.perf_counter() + # print(f1.maker.fgraph.toposort()) + # print(f2.maker.fgraph.toposort()) + # t0 = time.perf_counter() f1(val) - t1 = time.perf_counter() + # t1 = time.perf_counter() f2(val) - t2 = time.perf_counter() - print(t1 - t0, t2 - t1) + # t2 = time.perf_counter() + # print(t1 - t0, t2 - t1) class TestLocalMergeSwitchSameCond: @@ -4144,13 +4146,13 @@ def check(expr1, expr2): perform_sigm_times_exp(trees[0]) trees[0] = simplify_mul(trees[0]) good = is_same_graph(compute_mul(trees[0]), compute_mul(trees[1])) - if not good: - print(trees[0]) - print(trees[1]) - print("***") - pytensor.printing.debugprint(compute_mul(trees[0])) - print("***") - pytensor.printing.debugprint(compute_mul(trees[1])) + # if not good: + # print(trees[0]) + # print(trees[1]) + # print("***") + # pytensor.printing.debugprint(compute_mul(trees[0])) + # print("***") + # pytensor.printing.debugprint(compute_mul(trees[1])) assert good check(sigmoid(x) * exp_op(-x), sigmoid(-x)) diff --git a/tests/tensor/test_complex.py b/tests/tensor/test_complex.py index f0f7333f9c..a1b99751ed 100644 --- a/tests/tensor/test_complex.py +++ b/tests/tensor/test_complex.py @@ -73,9 +73,7 @@ def f(a): try: utt.verify_grad(f, [aval]) except GradientError as e: - print(e.num_grad.gf) - print(e.analytic_grad) - raise + raise ValueError(f"Failed: {e.num_grad.gf=} {e.analytic_grad=}") from e @pytest.mark.skip(reason="Complex grads not enabled, see #178") def test_mul_mixed1(self): @@ -88,9 +86,7 @@ def f(a): try: utt.verify_grad(f, [aval]) except GradientError as e: - print(e.num_grad.gf) - print(e.analytic_grad) - raise + raise ValueError(f"Failed: {e.num_grad.gf=} {e.analytic_grad=}") from e @pytest.mark.skip(reason="Complex grads not enabled, see #178") def test_mul_mixed(self): @@ -104,9 +100,7 @@ def f(a, b): try: utt.verify_grad(f, [aval, bval]) except GradientError as e: - print(e.num_grad.gf) - print(e.analytic_grad) - raise + raise ValueError(f"Failed: {e.num_grad.gf=} {e.analytic_grad=}") from e @pytest.mark.skip(reason="Complex grads not enabled, see #178") def test_polar_grads(self): diff --git a/tests/tensor/test_fft.py b/tests/tensor/test_fft.py index 94c49662bc..3976c67622 100644 --- a/tests/tensor/test_fft.py +++ b/tests/tensor/test_fft.py @@ -43,7 +43,6 @@ def test_1Drfft(self): utt.assert_allclose(rfft_ref, res_rfft_comp) m = rfft.type() - print(m.ndim) irfft = fft.irfft(m) f_irfft = pytensor.function([m], irfft) res_irfft = f_irfft(res_rfft) diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 2ffcb25fe5..e85b8cfd46 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -797,7 +797,6 @@ def test_reshape(self): assert equal_computations([vect_out], [reshape(mat, new_shape)]) new_shape = stack([[-1, x], [x - 1, -1]], axis=0) - print(new_shape.type) [vect_out] = vectorize_node(node, vec, new_shape).outputs vec_test_value = np.arange(6) np.testing.assert_allclose( diff --git a/tests/test_config.py b/tests/test_config.py index 4370309f39..2dd3c32180 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -192,7 +192,7 @@ def test_invalid_configvar_access(): # But we can make sure that nothing crazy happens when we access it: with pytest.raises(configparser.ConfigAccessViolation, match="different instance"): - print(root.test__on_test_instance) + assert root.test__on_test_instance is not None def test_no_more_dotting(): diff --git a/tests/test_printing.py b/tests/test_printing.py index be5dbbc5a1..4dd4f3866d 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -138,9 +138,9 @@ def test_min_informative_str(): D. D E. E""" - if mis != reference: - print("--" + mis + "--") - print("--" + reference + "--") + # if mis != reference: + # print("--" + mis + "--") + # print("--" + reference + "--") assert mis == reference diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index a5b0a21a49..adb83fe7c0 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -1,5 +1,6 @@ import logging import sys +import warnings from copy import copy, deepcopy from functools import wraps @@ -41,12 +42,9 @@ def fetch_seed(pseed=None): else: seed = None except ValueError: - print( - ( - "Error: config.unittests__rseed contains " - "invalid seed, using None instead" - ), - file=sys.stderr, + warnings.warn( + "Error: config.unittests__rseed contains " + "invalid seed, using None instead" ) seed = None From 4fa9bb878b94703063b89b434a20b9dcb72d9472 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Date: Mon, 10 Feb 2025 02:05:23 +0100 Subject: [PATCH 02/43] PyTorch inline constants in dispatch to avoid graph breaks (#1118) * Split and inverse * PyTorch inline constants in dispatch to avoid graph breaks --- pytensor/link/pytorch/dispatch/basic.py | 44 +++++++++++++++--- pytensor/link/pytorch/dispatch/scalar.py | 6 +++ pytensor/link/pytorch/dispatch/shape.py | 19 ++++++-- pytensor/link/pytorch/dispatch/subtensor.py | 15 +++++++ pytensor/link/pytorch/linker.py | 3 ++ tests/link/pytorch/test_basic.py | 50 +++++++++++++++++++++ 6 files changed, 127 insertions(+), 10 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..ef4bf10637 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -8,6 +8,7 @@ from pytensor.compile import PYTORCH from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python @@ -19,6 +20,7 @@ Eye, Join, MakeVector, + Split, TensorFromScalar, ) @@ -120,14 +122,23 @@ def arange(start, stop, step): @pytorch_funcify.register(Join) -def pytorch_funcify_Join(op, **kwargs): - def join(axis, *tensors): - # tensors could also be tuples, and in this case they don't have a ndim - tensors = [torch.tensor(tensor) for tensor in tensors] +def pytorch_funcify_Join(op, node, **kwargs): + axis = node.inputs[0] - return torch.cat(tensors, dim=axis) + if isinstance(axis, Constant): + axis = int(axis.data) - return join + def join_constant_axis(_, *tensors): + return torch.cat(tensors, dim=axis) + + return join_constant_axis + + else: + + def join(axis, *tensors): + return torch.cat(tensors, dim=axis) + + return join @pytorch_funcify.register(Eye) @@ -172,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs): @pytorch_funcify.register(OpFromGraph) def pytorch_funcify_OpFromGraph(op, node, **kwargs): kwargs.pop("storage_map", None) - # Apply inner rewrites PYTORCH.optimizer(op.fgraph) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) @@ -185,3 +195,23 @@ def tensorfromscalar(x): return torch.as_tensor(x) return tensorfromscalar + + +@pytorch_funcify.register(Split) +def pytorch_funcify_Split(op, node, **kwargs): + x, dim, split_sizes = node.inputs + if isinstance(dim, Constant) and isinstance(split_sizes, Constant): + dim = int(dim.data) + split_sizes = tuple(int(size) for size in split_sizes.data) + + def split_constant_axis_and_sizes(x, *_): + return x.split(split_sizes, dim=dim) + + return split_constant_axis_and_sizes + + else: + + def inner_fn(x, dim, split_amounts): + return x.split(split_amounts.tolist(), dim=dim.item()) + + return inner_fn diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 65170b1f53..6a1c6b235e 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -5,12 +5,18 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( Cast, + Invert, ScalarOp, ) from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.math import Softplus +@pytorch_funcify.register(Invert) +def pytorch_funcify_invert(op, node, **kwargs): + return torch.bitwise_not + + @pytorch_funcify.register(ScalarOp) def pytorch_funcify_ScalarOp(op, node, **kwargs): """Return pytorch function that implements the same computation as the Scalar Op. diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index f771ac7211..c15b3a3779 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -1,15 +1,28 @@ import torch +from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast @pytorch_funcify.register(Reshape) def pytorch_funcify_Reshape(op, node, **kwargs): - def reshape(x, shape): - return torch.reshape(x, tuple(shape)) + _, shape = node.inputs - return reshape + if isinstance(shape, Constant): + constant_shape = tuple(int(dim) for dim in shape.data) + + def reshape_constant_shape(x, *_): + return torch.reshape(x, constant_shape) + + return reshape_constant_shape + + else: + + def reshape(x, shape): + return torch.reshape(x, tuple(shape)) + + return reshape @pytorch_funcify.register(Shape) diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 75e7ec0776..34358797fb 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -1,3 +1,4 @@ +from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -23,7 +24,21 @@ def check_negative_steps(indices): @pytorch_funcify.register(Subtensor) def pytorch_funcify_Subtensor(op, node, **kwargs): idx_list = op.idx_list + x, *idxs = node.inputs + if all(isinstance(idx, Constant) for idx in idxs): + # Use constant indices to avoid graph break + constant_indices = indices_from_subtensor( + [int(idx.data) for idx in idxs], idx_list + ) + check_negative_steps(constant_indices) + + def constant_index_subtensor(x, *_): + return x[constant_indices] + + return constant_index_subtensor + + # Fallback that will introduce a graph break def subtensor(x, *flattened_indices): indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..b8475e3157 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs): def jit_compile(self, fn): import torch + # flag that tend to help our graphs + torch._dynamo.config.capture_dynamic_output_shape_ops = True + from pytensor.link.pytorch.dispatch import pytorch_typify class wrapper: diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 2ac8ee7c3b..d5c23c83e4 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries(): compare_pytorch_and_py( f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) ) + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "n_splits, axis, values, sizes", + [ + ( + 0, + 0, + rng.normal(size=20).astype(config.floatX), + [], + ), + ( + 5, + 0, + rng.normal(size=5).astype(config.floatX), + rng.multinomial(5, np.ones(5) / 5), + ), + ( + 5, + 0, + rng.normal(size=10).astype(config.floatX), + rng.multinomial(10, np.ones(5) / 5), + ), + ( + 5, + -1, + rng.normal(size=(11, 7)).astype(config.floatX), + rng.multinomial(7, np.ones(5) / 5), + ), + ( + 5, + -2, + rng.normal(size=(11, 7)).astype(config.floatX), + rng.multinomial(11, np.ones(5) / 5), + ), + ], +) +def test_Split(n_splits, axis, values, sizes): + i = pt.tensor("i", shape=values.shape, dtype=config.floatX) + s = pt.vector("s", dtype="int64") + g = pt.split(i, s, n_splits, axis=axis) + assert len(g) == n_splits + if n_splits == 0: + return + g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g) + + compare_pytorch_and_py(g_fg, [values, sizes]) From da4960b809050a60215700fe6c2b9e07f366b013 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Feb 2025 23:49:31 +0100 Subject: [PATCH 03/43] Remove unnecessary type ignore in new version of mypy --- pytensor/link/vm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index af44af3254..c6e1283806 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -118,7 +118,7 @@ def calculate_reallocate_info( # where gc for i in range(idx + 1, len(order)): if reuse_out is not None: - break # type: ignore + break for out in order[i].outputs: if ( getattr(out.type, "ndim", None) == 0 From ffdde1cd3408ce6f4166a0e73f8b23bfa37acfdb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 5 Feb 2025 10:24:47 +0100 Subject: [PATCH 04/43] Implement gradient for vector repetitions Also cleans up implementation and documentation --- pytensor/tensor/extra_ops.py | 176 ++++++++++++++++++++------------- tests/tensor/test_extra_ops.py | 28 ++++-- 2 files changed, 131 insertions(+), 73 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index fedcd32ab9..27eabc5ba4 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -646,12 +646,17 @@ class Repeat(Op): __props__ = ("axis",) - def __init__(self, axis=None): + def __init__(self, axis: int | None = None): + if axis is not None: + if not isinstance(axis, int) or axis < 0: + raise ValueError( + f"Repeat only accepts positive integer axis or None, got {axis}" + ) self.axis = axis def make_node(self, x, repeats): x = ptb.as_tensor_variable(x) - repeats = ptb.as_tensor_variable(repeats) + repeats = ptb.as_tensor_variable(repeats, dtype="int64") if repeats.dtype not in integer_dtypes: raise TypeError("repeats.dtype must be an integer.") @@ -687,17 +692,12 @@ def make_node(self, x, repeats): out_shape = list(x.type.shape) out_shape[self.axis] = None - out_type = TensorType( - x.dtype, shape=tuple(1 if s == 1 else None for s in out_shape) - ) - + out_type = TensorType(x.dtype, shape=out_shape) return Apply(self, [x, repeats], [out_type()]) def perform(self, node, inputs, output_storage): - x = inputs[0] - repeats = inputs[1] - z = output_storage[0] - z[0] = np.repeat(x, repeats=repeats, axis=self.axis) + [x, repeats] = inputs + output_storage[0][0] = np.repeat(x, repeats=repeats, axis=self.axis) def connection_pattern(self, node): return [[True], [False]] @@ -705,40 +705,51 @@ def connection_pattern(self, node): def grad(self, inputs, gout): (x, repeats) = inputs (gz,) = gout + axis = self.axis if repeats.ndim == 0: - if self.axis is None: - axis = x.ndim - else: - if self.axis >= 0: - axis = self.axis + 1 - else: - axis = self.axis + x.ndim + 1 - - shape = [x.shape[k] for k in range(x.ndim)] - shape.insert(axis, repeats) + # When axis is a scalar (same number of reps for all elements), + # We can split the repetitions into their own axis with reshape and sum them back + # to the original element location + sum_axis = x.ndim if axis is None else axis + 1 + shape = list(x.shape) + shape.insert(sum_axis, repeats) + gx = gz.reshape(shape).sum(axis=sum_axis) - return [ - gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis), - DisconnectedType()(), - ] elif repeats.ndim == 1: - # For this implementation, we would need to specify the length - # of repeats in order to split gz in the right way to sum - # the good part. - raise NotImplementedError() + # To sum the gradients that belong to the same repeated x, + # We create a repeated eye and dot product it with the gradient. + axis_size = x.size if axis is None else x.shape[axis] + repeated_eye = repeat( + ptb.eye(axis_size), repeats, axis=0 + ) # A sparse repeat would be neat + + if axis is None: + gx = gz @ repeated_eye + # Undo the ravelling when axis=None + gx = gx.reshape(x.shape) + else: + # Place gradient axis at end for dot product + gx = ptb.moveaxis(gz, axis, -1) + gx = gx @ repeated_eye + # Place gradient back into the correct axis + gx = ptb.moveaxis(gx, -1, axis) + else: raise ValueError() + return [gx, DisconnectedType()()] + def infer_shape(self, fgraph, node, ins_shapes): i0_shapes = ins_shapes[0] repeats = node.inputs[1] out_shape = list(i0_shapes) + axis = self.axis # uint64 shape are not supported. dtype = None if repeats.dtype in ("uint8", "uint16", "uint32"): dtype = "int64" - if self.axis is None: + if axis is None: if repeats.ndim == 0: if len(i0_shapes) == 0: out_shape = [repeats] @@ -751,82 +762,115 @@ def infer_shape(self, fgraph, node, ins_shapes): out_shape = [pt_sum(repeats, dtype=dtype)] else: if repeats.ndim == 0: - out_shape[self.axis] = out_shape[self.axis] * repeats + out_shape[axis] = out_shape[axis] * repeats else: - out_shape[self.axis] = pt_sum(repeats, dtype=dtype) + out_shape[axis] = pt_sum(repeats, dtype=dtype) return [out_shape] -def repeat(x, repeats, axis=None): - """Repeat elements of an array. +def repeat( + a: TensorLike, repeats: TensorLike, axis: int or None = None +) -> TensorVariable: + """Repeat elements of a tensor. - It returns an array which has the same shape as `x`, except along the given - `axis`. The `axis` parameter is used to specify the axis along which values - are repeated. By default, a flattened version of `x` is used. + See :func:`numpy.repeat` for more information. - The number of repetitions for each element is `repeats`. `repeats` is - broadcasted to fit the length of the given `axis`. Parameters ---------- - x - Input data, tensor variable. - repeats - int, scalar or tensor variable + a: tensor_like + Input tensor + repeats: tensor_like + The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. axis : int, optional + The axis along which to repeat values. By default, use the flattened input array, and return a flat output array. - See Also + Returns + ------- + repeated_tensor: TensorVariable + Output tensor which as the same shape as a, except along the given axis + + Examples -------- - tensor.tile + + .. testcode:: + + import pytensor.tensor as pt + + a = pt.arange(4).reshape((2, 2)) + out = pt.repeat(a, repeats=[2, 3], axis=0) + print(out.eval()) + + .. testoutput:: + + [[0 1] + [0 1] + [2 3] + [2 3] + [2 3]] + + When axis is None, the array is first flattened and then repeated + + .. testcode:: + + import pytensor.tensor as pt + + a = pt.arange(4).reshape((2, 2)) + out = pt.repeat(a, repeats=[2, 3, 0, 1], axis=None) + print(out.eval()) + + .. testoutput:: + + [0 0 1 1 1 3] + .. versionadded:: 0.6 """ + a = ptb.as_tensor_variable(a) + + if axis is not None: + axis = normalize_axis_index(axis, a.ndim) + repeats = ptb.as_tensor_variable(repeats, dtype=np.int64) if repeats.ndim > 1: raise ValueError("The dimension of repeats should not exceed 1.") if repeats.ndim == 1 and not repeats.broadcastable[0]: - return Repeat(axis=axis)(x, repeats) + # We only use the Repeat Op for vector repeats + return Repeat(axis=axis)(a, repeats) else: if repeats.ndim == 1: repeats = repeats[0] - if x.dtype == "uint64": + if a.dtype == "uint64": + # Multiplying int64 (shape) by uint64 (repeats) yields a float64 + # Which is not valid for the `reshape` operation at the end raise TypeError("repeat doesn't support dtype uint64") if axis is None: axis = 0 - x = x.flatten() - else: - if axis >= x.ndim: - raise ValueError("Axis should not exceed x.ndim-1.") - if axis < 0: - axis = x.ndim + axis + a = a.flatten() - shape = [x.shape[i] for i in range(x.ndim)] + repeat_shape = list(a.shape) - # shape_ is the shape of the intermediate tensor which has + # alloc_shape is the shape of the intermediate tensor which has # an additional dimension comparing to x. We use alloc to # allocate space for this intermediate tensor to replicate x # along that additional dimension. - shape_ = shape[:] - shape_.insert(axis + 1, repeats) + alloc_shape = repeat_shape[:] + alloc_shape.insert(axis + 1, repeats) - # shape is now the shape of output, where shape[axis] becomes + # repeat_shape is now the shape of output, where shape[axis] becomes # shape[axis]*repeats. - shape[axis] = shape[axis] * repeats - - # dims_ is the dimension of that intermediate tensor. - dims_ = list(np.arange(x.ndim)) - dims_.insert(axis + 1, "x") + repeat_shape[axis] = repeat_shape[axis] * repeats # After the original tensor is duplicated along the additional - # dimension, we reshape it to the expected output shape, and - # return the output z. - z = ptb.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape) - return z + # dimension, we reshape it to the expected output shape + return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape( + repeat_shape + ) class Bartlett(Op): diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index c45e6b1e48..e4f4945393 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -595,7 +595,6 @@ def test_basic(self, ndim, dtype): isinstance(n.op, Repeat) for n in f.maker.fgraph.toposort() ) - @pytest.mark.slow @pytest.mark.parametrize("ndim", [1, 3]) @pytest.mark.parametrize("dtype", ["int8", "uint8", "uint64"]) def test_infer_shape(self, ndim, dtype): @@ -606,6 +605,10 @@ def test_infer_shape(self, ndim, dtype): a = rng.random(shp).astype(config.floatX) for axis in self._possible_axis(ndim): + if axis is not None and axis < 0: + # Operator does not support negative axis + continue + r_var = scalar(dtype=dtype) r = np.asarray(3, dtype=dtype) if dtype in self.numpy_unsupported_dtypes: @@ -635,12 +638,23 @@ def test_infer_shape(self, ndim, dtype): self.op_class, ) - @pytest.mark.parametrize("ndim", range(3)) - def test_grad(self, ndim): - a = np.random.random((10,) * ndim).astype(config.floatX) - - for axis in self._possible_axis(ndim): - utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a]) + @pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}") + @pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}") + @pytest.mark.parametrize("axis", [None, 0, 1], ids=lambda a: f"axis={a}") + def test_grad(self, x_ndim, repeats_ndim, axis): + rng = np.random.default_rng( + [653, x_ndim, 2 if axis is None else axis, repeats_ndim] + ) + x_test = rng.normal(size=np.arange(3, 3 + x_ndim)) + if repeats_ndim == 0: + repeats_size = () + else: + repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,) + repeats = rng.integers(1, 6, size=repeats_size) + utt.verify_grad( + lambda x: Repeat(axis=axis)(x, repeats), + [x_test], + ) def test_broadcastable(self): x = TensorType(config.floatX, shape=(None, 1, None))() From 60c2d925c35e54129bb13a02d845549f8b3a0362 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 12 Feb 2025 14:15:50 +0100 Subject: [PATCH 05/43] Deprecate Chi2SF ScalarOp --- pytensor/scalar/math.py | 45 --------------------------------- pytensor/tensor/inplace.py | 5 ---- pytensor/tensor/math.py | 3 ++- tests/tensor/test_math_scipy.py | 10 -------- 4 files changed, 2 insertions(+), 61 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index a5512c6564..33c6b1c932 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -40,7 +40,6 @@ true_div, upcast, upgrade_to_float, - upgrade_to_float64, upgrade_to_float_no_complex, ) from pytensor.scalar.basic import abs as scalar_abs @@ -592,50 +591,6 @@ def c_code(self, *args, **kwargs): polygamma = PolyGamma(name="polygamma") -class Chi2SF(BinaryScalarOp): - """ - Compute (1 - chi2_cdf(x)) - ie. chi2 pvalue (chi2 'survival function') - """ - - nfunc_spec = ("scipy.stats.chi2.sf", 2, 1) - - @staticmethod - def st_impl(x, k): - return scipy.stats.chi2.sf(x, k) - - def impl(self, x, k): - return Chi2SF.st_impl(x, k) - - def c_support_code(self, **kwargs): - return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") - - def c_code(self, node, name, inp, out, sub): - x, k = inp - (z,) = out - if node.inputs[0].type in float_types: - dtype = "npy_" + node.outputs[0].dtype - return f"""{z} = - ({dtype}) 1 - GammaP({k}/2., {x}/2.);""" - raise NotImplementedError("only floatingpoint is implemented") - - def __eq__(self, other): - return type(self) is type(other) - - def __hash__(self): - return hash(type(self)) - - def c_code_cache_version(self): - v = super().c_code_cache_version() - if v: - return (2, *v) - else: - return v - - -chi2sf = Chi2SF(upgrade_to_float64, name="chi2sf") - - class GammaInc(BinaryScalarOp): """ Compute the regularized lower gamma function (P). diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index 76738fdb63..cb4476ede0 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -258,11 +258,6 @@ def tri_gamma_inplace(a): """second derivative of the log gamma function""" -@scalar_elemwise -def chi2sf_inplace(x, k): - """chi squared survival function""" - - @scalar_elemwise def gammainc_inplace(k, x): """regularized lower gamma function (P)""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index f11e33b41d..b185f686bc 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1154,9 +1154,10 @@ def polygamma(n, x): """Polygamma function of order n evaluated at x""" -@scalar_elemwise def chi2sf(x, k): """chi squared survival function""" + warnings.warn("chi2sf is deprecated. Use `gammaincc(k / 2, x / 2)` instead") + return gammaincc(k / 2, x / 2) @scalar_elemwise diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 921aae826b..8f70950206 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -306,16 +306,6 @@ def scipy_special_gammal(k, x): name="Chi2SF", ) -TestChi2SFInplaceBroadcast = makeBroadcastTester( - op=inplace.chi2sf_inplace, - expected=expected_chi2sf, - good=_good_broadcast_unary_chi2sf, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, - name="Chi2SF", -) - rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_binary_gamma = dict( normal=( From 0b07727b61a46801f34f588c5c088783c14b7afa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 23 Jan 2025 18:29:18 +0100 Subject: [PATCH 06/43] Remove unused ScalarOp.st_impl --- pytensor/scalar/math.py | 139 +++++++-------------------------------- pytensor/tensor/xlogx.py | 12 +--- 2 files changed, 25 insertions(+), 126 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 33c6b1c932..f8bc4a5df0 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -10,7 +10,6 @@ import numpy as np import scipy.special -import scipy.stats from pytensor.configdefaults import config from pytensor.gradient import grad_not_implemented, grad_undefined @@ -261,12 +260,8 @@ def c_code(self, node, name, inp, out, sub): class Owens_t(BinaryScalarOp): nfunc_spec = ("scipy.special.owens_t", 2, 1) - @staticmethod - def st_impl(h, a): - return scipy.special.owens_t(h, a) - def impl(self, h, a): - return Owens_t.st_impl(h, a) + return scipy.special.owens_t(h, a) def grad(self, inputs, grads): (h, a) = inputs @@ -290,12 +285,8 @@ def c_code(self, *args, **kwargs): class Gamma(UnaryScalarOp): nfunc_spec = ("scipy.special.gamma", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.gamma(x) - def impl(self, x): - return Gamma.st_impl(x) + return scipy.special.gamma(x) def L_op(self, inputs, outputs, gout): (x,) = inputs @@ -329,12 +320,8 @@ class GammaLn(UnaryScalarOp): nfunc_spec = ("scipy.special.gammaln", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.gammaln(x) - def impl(self, x): - return GammaLn.st_impl(x) + return scipy.special.gammaln(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -373,12 +360,8 @@ class Psi(UnaryScalarOp): nfunc_spec = ("scipy.special.psi", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.psi(x) - def impl(self, x): - return Psi.st_impl(x) + return scipy.special.psi(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -464,12 +447,8 @@ class TriGamma(UnaryScalarOp): """ - @staticmethod - def st_impl(x): - return scipy.special.polygamma(1, x) - def impl(self, x): - return TriGamma.st_impl(x) + return scipy.special.polygamma(1, x) def L_op(self, inputs, outputs, outputs_gradients): (x,) = inputs @@ -567,12 +546,8 @@ def output_types_preference(n_type, x_type): # Scipy doesn't support it return upgrade_to_float_no_complex(x_type) - @staticmethod - def st_impl(n, x): - return scipy.special.polygamma(n, x) - def impl(self, n, x): - return PolyGamma.st_impl(n, x) + return scipy.special.polygamma(n, x) def L_op(self, inputs, outputs, output_gradients): (n, x) = inputs @@ -598,12 +573,8 @@ class GammaInc(BinaryScalarOp): nfunc_spec = ("scipy.special.gammainc", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammainc(k, x) - def impl(self, k, x): - return GammaInc.st_impl(k, x) + return scipy.special.gammainc(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -649,12 +620,8 @@ class GammaIncC(BinaryScalarOp): nfunc_spec = ("scipy.special.gammaincc", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincc(k, x) - def impl(self, k, x): - return GammaIncC.st_impl(k, x) + return scipy.special.gammaincc(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -700,12 +667,8 @@ class GammaIncInv(BinaryScalarOp): nfunc_spec = ("scipy.special.gammaincinv", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincinv(k, x) - def impl(self, k, x): - return GammaIncInv.st_impl(k, x) + return scipy.special.gammaincinv(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -729,12 +692,8 @@ class GammaIncCInv(BinaryScalarOp): nfunc_spec = ("scipy.special.gammainccinv", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammainccinv(k, x) - def impl(self, k, x): - return GammaIncCInv.st_impl(k, x) + return scipy.special.gammainccinv(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -968,12 +927,8 @@ class GammaU(BinaryScalarOp): # Note there is no basic SciPy version so no nfunc_spec. - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincc(k, x) * scipy.special.gamma(k) - def impl(self, k, x): - return GammaU.st_impl(k, x) + return scipy.special.gammaincc(k, x) * scipy.special.gamma(k) def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") @@ -1004,12 +959,8 @@ class GammaL(BinaryScalarOp): # Note there is no basic SciPy version so no nfunc_spec. - @staticmethod - def st_impl(k, x): - return scipy.special.gammainc(k, x) * scipy.special.gamma(k) - def impl(self, k, x): - return GammaL.st_impl(k, x) + return scipy.special.gammainc(k, x) * scipy.special.gamma(k) def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") @@ -1040,12 +991,8 @@ class Jv(BinaryScalarOp): nfunc_spec = ("scipy.special.jv", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.jv(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return scipy.special.jv(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1069,12 +1016,8 @@ class J1(UnaryScalarOp): nfunc_spec = ("scipy.special.j1", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.j1(x) - def impl(self, x): - return self.st_impl(x) + return scipy.special.j1(x) def grad(self, inputs, grads): (x,) = inputs @@ -1100,12 +1043,8 @@ class J0(UnaryScalarOp): nfunc_spec = ("scipy.special.j0", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.j0(x) - def impl(self, x): - return self.st_impl(x) + return scipy.special.j0(x) def grad(self, inp, grads): (x,) = inp @@ -1131,12 +1070,8 @@ class Iv(BinaryScalarOp): nfunc_spec = ("scipy.special.iv", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.iv(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return scipy.special.iv(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1160,12 +1095,8 @@ class I1(UnaryScalarOp): nfunc_spec = ("scipy.special.i1", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.i1(x) - def impl(self, x): - return self.st_impl(x) + return scipy.special.i1(x) def grad(self, inputs, grads): (x,) = inputs @@ -1186,12 +1117,8 @@ class I0(UnaryScalarOp): nfunc_spec = ("scipy.special.i0", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.i0(x) - def impl(self, x): - return self.st_impl(x) + return scipy.special.i0(x) def grad(self, inp, grads): (x,) = inp @@ -1212,12 +1139,8 @@ class Ive(BinaryScalarOp): nfunc_spec = ("scipy.special.ive", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.ive(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return scipy.special.ive(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1241,12 +1164,8 @@ class Kve(BinaryScalarOp): nfunc_spec = ("scipy.special.kve", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.kve(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return scipy.special.kve(v, x) def L_op(self, inputs, outputs, output_grads): v, x = inputs @@ -1327,8 +1246,7 @@ class Softplus(UnaryScalarOp): "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" """ - @staticmethod - def static_impl(x): + def impl(self, x): # If x is an int8 or uint8, numpy.exp will compute the result in # half-precision (float16), where we want float32. not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8") @@ -1343,9 +1261,6 @@ def static_impl(x): else: return x - def impl(self, x): - return Softplus.static_impl(x) - def grad(self, inp, grads): (x,) = inp (gz,) = grads @@ -1408,16 +1323,12 @@ class Log1mexp(UnaryScalarOp): "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" """ - @staticmethod - def static_impl(x): + def impl(self, x): if x < np.log(0.5): return np.log1p(-np.exp(x)) else: return np.log(-np.expm1(x)) - def impl(self, x): - return Log1mexp.static_impl(x) - def grad(self, inp, grads): (x,) = inp (gz,) = grads @@ -1749,12 +1660,8 @@ class Hyp2F1(ScalarOp): nin = 4 nfunc_spec = ("scipy.special.hyp2f1", 4, 1) - @staticmethod - def st_impl(a, b, c, z): - return scipy.special.hyp2f1(a, b, c, z) - def impl(self, a, b, c, z): - return Hyp2F1.st_impl(a, b, c, z) + return scipy.special.hyp2f1(a, b, c, z) def grad(self, inputs, grads): a, b, c, z = inputs diff --git a/pytensor/tensor/xlogx.py b/pytensor/tensor/xlogx.py index 8cc27de9fb..3709688e54 100644 --- a/pytensor/tensor/xlogx.py +++ b/pytensor/tensor/xlogx.py @@ -10,15 +10,11 @@ class XlogX(ps.UnaryScalarOp): """ - @staticmethod - def st_impl(x): + def impl(self, x): if x == 0.0: return 0.0 return x * np.log(x) - def impl(self, x): - return XlogX.st_impl(x) - def grad(self, inputs, grads): (x,) = inputs (gz,) = grads @@ -45,15 +41,11 @@ class XlogY0(ps.BinaryScalarOp): """ - @staticmethod - def st_impl(x, y): + def impl(self, x, y): if x == 0.0: return 0.0 return x * np.log(y) - def impl(self, x, y): - return XlogY0.st_impl(x, y) - def grad(self, inputs, grads): x, y = inputs (gz,) = grads From 0b94be01cd3e1c3f388687ddfb95261f0a55df40 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 23 Jan 2025 18:29:40 +0100 Subject: [PATCH 07/43] Reduce overhead of Scalar python implementation --- pytensor/scalar/basic.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 3c33434e56..c13afbd6fa 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -36,7 +36,6 @@ from pytensor.utils import ( apply_across_args, difference, - from_return_values, to_return_values, ) @@ -1081,6 +1080,16 @@ def real_out(type): return (type,) +def _cast_to_promised_scalar_dtype(x, dtype): + try: + return x.astype(dtype) + except AttributeError: + if dtype == "bool": + return np.bool_(x) + else: + return getattr(np, dtype)(x) + + class ScalarOp(COp): nin = -1 nout = 1 @@ -1134,28 +1143,18 @@ def output_types(self, types): else: raise NotImplementedError(f"Cannot calculate the output types for {self}") - @staticmethod - def _cast_scalar(x, dtype): - if hasattr(x, "astype"): - return x.astype(dtype) - elif dtype == "bool": - return np.bool_(x) - else: - return getattr(np, dtype)(x) - def perform(self, node, inputs, output_storage): if self.nout == 1: - dtype = node.outputs[0].dtype - output_storage[0][0] = self._cast_scalar(self.impl(*inputs), dtype) + output_storage[0][0] = _cast_to_promised_scalar_dtype( + self.impl(*inputs), + node.outputs[0].dtype, + ) else: - variables = from_return_values(self.impl(*inputs)) - assert len(variables) == len(output_storage) # strict=False because we are in a hot loop for out, storage, variable in zip( - node.outputs, output_storage, variables, strict=False + node.outputs, output_storage, self.impl(*inputs), strict=False ): - dtype = out.dtype - storage[0] = self._cast_scalar(variable, dtype) + storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype) def impl(self, *inputs): raise MethodNotDefined("impl", type(self), self.__class__.__name__) From 7411a0824d13251c0b88bfa646668f39903437ec Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 23 Jan 2025 18:35:07 +0100 Subject: [PATCH 08/43] More direct access to special functions --- pytensor/scalar/math.py | 62 ++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index f8bc4a5df0..ec7eca76b9 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -9,7 +9,7 @@ from textwrap import dedent import numpy as np -import scipy.special +from scipy import special from pytensor.configdefaults import config from pytensor.gradient import grad_not_implemented, grad_undefined @@ -52,7 +52,7 @@ class Erf(UnaryScalarOp): nfunc_spec = ("scipy.special.erf", 1, 1) def impl(self, x): - return scipy.special.erf(x) + return special.erf(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -86,7 +86,7 @@ class Erfc(UnaryScalarOp): nfunc_spec = ("scipy.special.erfc", 1, 1) def impl(self, x): - return scipy.special.erfc(x) + return special.erfc(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -113,7 +113,7 @@ def c_code(self, node, name, inp, out, sub): return f"{z} = erfc(({cast}){x});" -# scipy.special.erfc don't support complex. Why? +# special.erfc don't support complex. Why? erfc = Erfc(upgrade_to_float_no_complex, name="erfc") @@ -135,7 +135,7 @@ class Erfcx(UnaryScalarOp): nfunc_spec = ("scipy.special.erfcx", 1, 1) def impl(self, x): - return scipy.special.erfcx(x) + return special.erfcx(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -191,7 +191,7 @@ class Erfinv(UnaryScalarOp): nfunc_spec = ("scipy.special.erfinv", 1, 1) def impl(self, x): - return scipy.special.erfinv(x) + return special.erfinv(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -226,7 +226,7 @@ class Erfcinv(UnaryScalarOp): nfunc_spec = ("scipy.special.erfcinv", 1, 1) def impl(self, x): - return scipy.special.erfcinv(x) + return special.erfcinv(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -261,7 +261,7 @@ class Owens_t(BinaryScalarOp): nfunc_spec = ("scipy.special.owens_t", 2, 1) def impl(self, h, a): - return scipy.special.owens_t(h, a) + return special.owens_t(h, a) def grad(self, inputs, grads): (h, a) = inputs @@ -286,7 +286,7 @@ class Gamma(UnaryScalarOp): nfunc_spec = ("scipy.special.gamma", 1, 1) def impl(self, x): - return scipy.special.gamma(x) + return special.gamma(x) def L_op(self, inputs, outputs, gout): (x,) = inputs @@ -321,7 +321,7 @@ class GammaLn(UnaryScalarOp): nfunc_spec = ("scipy.special.gammaln", 1, 1) def impl(self, x): - return scipy.special.gammaln(x) + return special.gammaln(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -361,7 +361,7 @@ class Psi(UnaryScalarOp): nfunc_spec = ("scipy.special.psi", 1, 1) def impl(self, x): - return scipy.special.psi(x) + return special.psi(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -448,7 +448,7 @@ class TriGamma(UnaryScalarOp): """ def impl(self, x): - return scipy.special.polygamma(1, x) + return special.polygamma(1, x) def L_op(self, inputs, outputs, outputs_gradients): (x,) = inputs @@ -547,7 +547,7 @@ def output_types_preference(n_type, x_type): return upgrade_to_float_no_complex(x_type) def impl(self, n, x): - return scipy.special.polygamma(n, x) + return special.polygamma(n, x) def L_op(self, inputs, outputs, output_gradients): (n, x) = inputs @@ -574,7 +574,7 @@ class GammaInc(BinaryScalarOp): nfunc_spec = ("scipy.special.gammainc", 2, 1) def impl(self, k, x): - return scipy.special.gammainc(k, x) + return special.gammainc(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -621,7 +621,7 @@ class GammaIncC(BinaryScalarOp): nfunc_spec = ("scipy.special.gammaincc", 2, 1) def impl(self, k, x): - return scipy.special.gammaincc(k, x) + return special.gammaincc(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -668,7 +668,7 @@ class GammaIncInv(BinaryScalarOp): nfunc_spec = ("scipy.special.gammaincinv", 2, 1) def impl(self, k, x): - return scipy.special.gammaincinv(k, x) + return special.gammaincinv(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -693,7 +693,7 @@ class GammaIncCInv(BinaryScalarOp): nfunc_spec = ("scipy.special.gammainccinv", 2, 1) def impl(self, k, x): - return scipy.special.gammainccinv(k, x) + return special.gammainccinv(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -928,7 +928,7 @@ class GammaU(BinaryScalarOp): # Note there is no basic SciPy version so no nfunc_spec. def impl(self, k, x): - return scipy.special.gammaincc(k, x) * scipy.special.gamma(k) + return special.gammaincc(k, x) * special.gamma(k) def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") @@ -960,7 +960,7 @@ class GammaL(BinaryScalarOp): # Note there is no basic SciPy version so no nfunc_spec. def impl(self, k, x): - return scipy.special.gammainc(k, x) * scipy.special.gamma(k) + return special.gammainc(k, x) * special.gamma(k) def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") @@ -992,7 +992,7 @@ class Jv(BinaryScalarOp): nfunc_spec = ("scipy.special.jv", 2, 1) def impl(self, v, x): - return scipy.special.jv(v, x) + return special.jv(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1017,7 +1017,7 @@ class J1(UnaryScalarOp): nfunc_spec = ("scipy.special.j1", 1, 1) def impl(self, x): - return scipy.special.j1(x) + return special.j1(x) def grad(self, inputs, grads): (x,) = inputs @@ -1044,7 +1044,7 @@ class J0(UnaryScalarOp): nfunc_spec = ("scipy.special.j0", 1, 1) def impl(self, x): - return scipy.special.j0(x) + return special.j0(x) def grad(self, inp, grads): (x,) = inp @@ -1071,7 +1071,7 @@ class Iv(BinaryScalarOp): nfunc_spec = ("scipy.special.iv", 2, 1) def impl(self, v, x): - return scipy.special.iv(v, x) + return special.iv(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1096,7 +1096,7 @@ class I1(UnaryScalarOp): nfunc_spec = ("scipy.special.i1", 1, 1) def impl(self, x): - return scipy.special.i1(x) + return special.i1(x) def grad(self, inputs, grads): (x,) = inputs @@ -1118,7 +1118,7 @@ class I0(UnaryScalarOp): nfunc_spec = ("scipy.special.i0", 1, 1) def impl(self, x): - return scipy.special.i0(x) + return special.i0(x) def grad(self, inp, grads): (x,) = inp @@ -1140,7 +1140,7 @@ class Ive(BinaryScalarOp): nfunc_spec = ("scipy.special.ive", 2, 1) def impl(self, v, x): - return scipy.special.ive(v, x) + return special.ive(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1165,7 +1165,7 @@ class Kve(BinaryScalarOp): nfunc_spec = ("scipy.special.kve", 2, 1) def impl(self, v, x): - return scipy.special.kve(v, x) + return special.kve(v, x) def L_op(self, inputs, outputs, output_grads): v, x = inputs @@ -1195,7 +1195,7 @@ class Sigmoid(UnaryScalarOp): nfunc_spec = ("scipy.special.expit", 1, 1) def impl(self, x): - return scipy.special.expit(x) + return special.expit(x) def grad(self, inp, grads): (x,) = inp @@ -1362,7 +1362,7 @@ class BetaInc(ScalarOp): nfunc_spec = ("scipy.special.betainc", 3, 1) def impl(self, a, b, x): - return scipy.special.betainc(a, b, x) + return special.betainc(a, b, x) def grad(self, inp, grads): a, b, x = inp @@ -1622,7 +1622,7 @@ class BetaIncInv(ScalarOp): nfunc_spec = ("scipy.special.betaincinv", 3, 1) def impl(self, a, b, x): - return scipy.special.betaincinv(a, b, x) + return special.betaincinv(a, b, x) def grad(self, inputs, grads): (a, b, x) = inputs @@ -1661,7 +1661,7 @@ class Hyp2F1(ScalarOp): nfunc_spec = ("scipy.special.hyp2f1", 4, 1) def impl(self, a, b, c, z): - return scipy.special.hyp2f1(a, b, c, z) + return special.hyp2f1(a, b, c, z) def grad(self, inputs, grads): a, b, c, z = inputs From 1ed36119b75294ffe67020c29df015f5d987f236 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 04:18:44 +0000 Subject: [PATCH 09/43] Bump pypa/gh-action-pypi-publish from 1.12.2 to 1.12.4 Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.12.2 to 1.12.4. - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.12.2...v1.12.4) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/pypi.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 3462dd00ff..e588a5eaeb 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -189,5 +189,5 @@ jobs: name: universal_wheel path: dist - - uses: pypa/gh-action-pypi-publish@v1.12.2 + - uses: pypa/gh-action-pypi-publish@v1.12.4 # Implicitly attests that the packages were uploaded in the context of this workflow. From 2823dfcacd819a6c279520fbd9b6364acb731c11 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 12 Feb 2025 12:01:30 +0100 Subject: [PATCH 10/43] Faster python implementation of MvNormal Also remove bad default values --- pytensor/tensor/random/basic.py | 57 ++++++--------------- tests/tensor/random/rewriting/test_basic.py | 6 ++- tests/tensor/random/test_basic.py | 38 ++++++++++---- 3 files changed, 49 insertions(+), 52 deletions(-) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index bebcad55be..4732bfcb15 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -3,6 +3,9 @@ import numpy as np import scipy.stats as stats +from numpy import broadcast_shapes as np_broadcast_shapes +from numpy import einsum as np_einsum +from numpy.linalg import cholesky as np_cholesky import pytensor from pytensor.tensor import get_vector_length, specify_shape @@ -831,27 +834,6 @@ def __call__(self, mu, kappa, size=None, **kwargs): vonmises = VonMisesRV() -def safe_multivariate_normal(mean, cov, size=None, rng=None): - """A shape consistent multivariate normal sampler. - - What we mean by "shape consistent": SciPy will return scalars when the - arguments are vectors with dimension of size 1. We require that the output - be at least 1D, so that it's consistent with the underlying random - variable. - - """ - res = np.atleast_1d( - stats.multivariate_normal(mean=mean, cov=cov, allow_singular=True).rvs( - size=size, random_state=rng - ) - ) - - if size is not None: - res = res.reshape([*size, -1]) - - return res - - class MvNormalRV(RandomVariable): r"""A multivariate normal random variable. @@ -904,25 +886,20 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs): @classmethod def rng_fn(cls, rng, mean, cov, size): - if mean.ndim > 1 or cov.ndim > 2: - # Neither SciPy nor NumPy implement parameter broadcasting for - # multivariate normals (or any other multivariate distributions), - # so we need to implement that here - - if size is None: - mean, cov = broadcast_params([mean, cov], [1, 2]) - else: - mean = np.broadcast_to(mean, size + mean.shape[-1:]) - cov = np.broadcast_to(cov, size + cov.shape[-2:]) - - res = np.empty(mean.shape) - for idx in np.ndindex(mean.shape[:-1]): - m = mean[idx] - c = cov[idx] - res[idx] = safe_multivariate_normal(m, c, rng=rng) - return res - else: - return safe_multivariate_normal(mean, cov, size=size, rng=rng) + if size is None: + size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + + chol = np_cholesky(cov) + out = rng.normal(size=(*size, mean.shape[-1])) + np_einsum( + "...ij,...j->...i", # numpy doesn't have a batch matrix-vector product + chol, + out, + out=out, + optimize=False, # Nothing to optimize with two operands, skip costly setup + ) + out += mean + return out multivariate_normal = MvNormalRV() diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index acc793156f..f8a6c243c0 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -778,8 +778,10 @@ def rand_bool_mask(shape, rng=None): multivariate_normal, ( np.array([200, 250], dtype=config.floatX), - # Second covariance is invalid, to test it is not chosen - np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX) + # Second covariance is very large, to test it is not chosen + np.dstack([np.eye(2), np.eye(2) * 1000, np.eye(2)]).T.astype( + config.floatX + ) * 1e-6, ), (3,), diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 7d24a49228..7fc6b9e1b9 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -521,13 +521,19 @@ def test_fn(shape, scale, **kwargs): def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): - if mean is None: - mean = np.array([0.0], dtype=config.floatX) - if cov is None: - cov = np.array([[1.0]], dtype=config.floatX) - if size is not None: - size = tuple(size) - return multivariate_normal.rng_fn(random_state, mean, cov, size) + rng = random_state if random_state is not None else np.random.default_rng() + + if size is None: + size = np.broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + + mean = np.broadcast_to(mean, (*size, *mean.shape[-1:])) + cov = np.broadcast_to(cov, (*size, *cov.shape[-2:])) + + @np.vectorize(signature="(n),(n,n)->(n)") + def vec_mvnormal(mean, cov): + return rng.multivariate_normal(mean, cov, method="cholesky") + + return vec_mvnormal(mean, cov) @pytest.mark.parametrize( @@ -609,18 +615,30 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): ), ], ) +@pytest.mark.skipif( + config.floatX == "float32", + reason="Draws are only strictly equal to numpy in float64", +) def test_mvnormal_samples(mu, cov, size): compare_sample_values( multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn ) -def test_mvnormal_default_args(): - compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn) +def test_mvnormal_no_default_args(): + with pytest.raises( + TypeError, match="missing 2 required positional arguments: 'mean' and 'cov'" + ): + multivariate_normal() + +def test_mvnormal_impl_catches_incompatible_size(): with pytest.raises(ValueError, match="operands could not be broadcast together "): multivariate_normal.rng_fn( - None, np.zeros((3, 2)), np.ones((3, 2, 2)), size=(4,) + np.random.default_rng(), + np.zeros((3, 2)), + np.broadcast_to(np.eye(2), (3, 2, 2)), + size=(4,), ) From 2aecb956b35ab54e594ba7bb80be5b86470f9b9e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 12 Feb 2025 13:14:31 +0100 Subject: [PATCH 11/43] Allow decomposition methods in MvNormal --- pytensor/link/jax/dispatch/random.py | 15 ++++++++- pytensor/link/numba/dispatch/random.py | 19 +++++++++-- pytensor/tensor/random/basic.py | 41 ++++++++++++++++-------- tests/link/jax/test_random.py | 6 ++++ tests/link/numba/test_random.py | 6 ++++ tests/tensor/random/test_basic.py | 44 ++++++++++++++++++++++++++ 6 files changed, 113 insertions(+), 18 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 9a89bf1406..d66ddc049d 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -128,7 +128,6 @@ def jax_sample_fn(op, node): @jax_sample_fn.register(ptr.BetaRV) @jax_sample_fn.register(ptr.DirichletRV) @jax_sample_fn.register(ptr.PoissonRV) -@jax_sample_fn.register(ptr.MvNormalRV) def jax_sample_fn_generic(op, node): """Generic JAX implementation of random variables.""" name = op.name @@ -173,6 +172,20 @@ def sample_fn(rng, size, dtype, *parameters): return sample_fn +@jax_sample_fn.register(ptr.MvNormalRV) +def jax_sample_mvnormal(op, node): + def sample_fn(rng, size, dtype, mean, cov): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + sample = jax.random.multivariate_normal( + sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method + ) + rng["jax_state"] = rng_key + return (rng, sample) + + return sample_fn + + @jax_sample_fn.register(ptr.BernoulliRV) def jax_sample_fn_bernoulli(op, node): """JAX implementation of `BernoulliRV`.""" diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 04181e8335..e80a033c82 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -144,11 +144,24 @@ def random_fn(rng, p): @numba_core_rv_funcify.register(ptr.MvNormalRV) def core_MvNormalRV(op, node): + method = op.method + @numba_basic.numba_njit def random_fn(rng, mean, cov): - chol = np.linalg.cholesky(cov) - stdnorm = rng.normal(size=cov.shape[-1]) - return np.dot(chol, stdnorm) + mean + if method == "cholesky": + A = np.linalg.cholesky(cov) + elif method == "svd": + A, s, _ = np.linalg.svd(cov) + A *= np.sqrt(s)[None, :] + else: + w, A = np.linalg.eigh(cov) + A *= np.sqrt(w)[None, :] + + out = rng.normal(size=cov.shape[-1]) + # out argument not working correctly: https://github.com/numba/numba/issues/9924 + out[:] = np.dot(A, out) + out += mean + return out random_fn.handles_out = True return random_fn diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 4732bfcb15..6d6a4ee270 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1,13 +1,16 @@ import abc import warnings +from typing import Literal import numpy as np import scipy.stats as stats from numpy import broadcast_shapes as np_broadcast_shapes from numpy import einsum as np_einsum +from numpy import sqrt as np_sqrt from numpy.linalg import cholesky as np_cholesky +from numpy.linalg import eigh as np_eigh +from numpy.linalg import svd as np_svd -import pytensor from pytensor.tensor import get_vector_length, specify_shape from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.math import sqrt @@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable): signature = "(n),(n,n)->(n)" dtype = "floatX" _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") + __props__ = ("name", "signature", "dtype", "inplace", "method") - def __call__(self, mean=None, cov=None, size=None, **kwargs): + def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs): + super().__init__(*args, **kwargs) + if method not in ("cholesky", "svd", "eigh"): + raise ValueError( + f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'." + ) + self.method = method + + def __call__(self, mean, cov, size=None, **kwargs): r""" "Draw samples from a multivariate normal distribution. Signature @@ -876,33 +888,34 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs): is specified, a single `N`-dimensional sample is returned. """ - dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype - - if mean is None: - mean = np.array([0.0], dtype=dtype) - if cov is None: - cov = np.array([[1.0]], dtype=dtype) return super().__call__(mean, cov, size=size, **kwargs) - @classmethod - def rng_fn(cls, rng, mean, cov, size): + def rng_fn(self, rng, mean, cov, size): if size is None: size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) - chol = np_cholesky(cov) + if self.method == "cholesky": + A = np_cholesky(cov) + elif self.method == "svd": + A, s, _ = np_svd(cov) + A *= np_sqrt(s, out=s)[..., None, :] + else: + w, A = np_eigh(cov) + A *= np_sqrt(w, out=w)[..., None, :] + out = rng.normal(size=(*size, mean.shape[-1])) np_einsum( "...ij,...j->...i", # numpy doesn't have a batch matrix-vector product - chol, + A, out, - out=out, optimize=False, # Nothing to optimize with two operands, skip costly setup + out=out, ) out += mean return out -multivariate_normal = MvNormalRV() +multivariate_normal = MvNormalRV(method="cholesky") class DirichletRV(RandomVariable): diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index a01f5e3f46..2c0e4231c8 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -18,6 +18,7 @@ batched_permutation_tester, batched_unweighted_choice_without_replacement_tester, batched_weighted_choice_without_replacement_tester, + create_mvnormal_cov_decomposition_method_test, ) @@ -547,6 +548,11 @@ def test_random_mvnormal(): np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) +test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test( + "JAX" +) + + @pytest.mark.parametrize( "parameter, size", [ diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index b966ed2870..1569ea8ae8 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -22,6 +22,7 @@ batched_permutation_tester, batched_unweighted_choice_without_replacement_tester, batched_weighted_choice_without_replacement_tester, + create_mvnormal_cov_decomposition_method_test, ) @@ -147,6 +148,11 @@ def test_multivariate_normal(): ) +test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test( + "NUMBA" +) + + @pytest.mark.parametrize( "rv_op, dist_args, size", [ diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 7fc6b9e1b9..23d1b87020 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -19,6 +19,7 @@ from pytensor.tensor import ones, stack from pytensor.tensor.random.basic import ( ChoiceWithoutReplacement, + MvNormalRV, PermutationRV, _gamma, bernoulli, @@ -686,6 +687,49 @@ def test_mvnormal_ShapeFeature(): assert s4.get_test_value() == 3 +def create_mvnormal_cov_decomposition_method_test(mode): + @pytest.mark.parametrize("psd", (True, False)) + @pytest.mark.parametrize("method", ("cholesky", "svd", "eigh")) + def test_mvnormal_cov_decomposition_method(method, psd): + mean = 2 ** np.arange(3) + if psd: + cov = [ + [1, 0.5, -1], + [0.5, 2, 0], + [-1, 0, 3], + ] + else: + cov = [ + [1, 0.5, 0], + [0.5, 2, 0], + [0, 0, 0], + ] + rng = shared(np.random.default_rng(675)) + draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,)) + assert draws.owner.op.method == method + + # JAX doesn't raise errors at runtime + if not psd and method == "cholesky": + if mode == "JAX": + # JAX doesn't raise errors at runtime, instead it returns nan + np.isnan(draws.eval(mode=mode)).all() + else: + with pytest.raises(np.linalg.LinAlgError): + draws.eval(mode=mode) + + else: + draws_eval = draws.eval(mode=mode) + np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, rtol=0.02) + np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.1) + + return test_mvnormal_cov_decomposition_method + + +test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test( + None +) + + @pytest.mark.parametrize( "alphas, size", [ From 298bb13351182f0986c0d506750399d467dea0f9 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 14 Feb 2025 17:16:47 +0100 Subject: [PATCH 12/43] Remove global RTOl and ATOL in test file --- tests/tensor/rewriting/test_linalg.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c9b9afff19..bbb20251bf 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -42,9 +42,6 @@ from tests.test_rop import break_op -ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8 - - def test_rop_lop(): mx = matrix("mx") mv = matrix("mv") @@ -630,11 +627,12 @@ def test_inv_diag_from_eye_mul(shape, inv_op): inverse_matrix = np.linalg.inv(x_test_matrix) rewritten_inverse = f_rewritten(x_test) + atol = rtol = 1e-3 if config.floatX == "float32" else 1e-8 assert_allclose( inverse_matrix, rewritten_inverse, - atol=ATOL, - rtol=RTOL, + atol=atol, + rtol=rtol, ) @@ -657,11 +655,12 @@ def test_inv_diag_from_diag(inv_op): inverse_matrix = np.linalg.inv(x_test_matrix) rewritten_inverse = f_rewritten(x_test) + atol = rtol = 1e-3 if config.floatX == "float32" else 1e-8 assert_allclose( inverse_matrix, rewritten_inverse, - atol=ATOL, - rtol=RTOL, + atol=atol, + rtol=rtol, ) From 49cf9d2282a1f641a9e97a262a8cf2322d758564 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 14 Feb 2025 17:43:53 +0100 Subject: [PATCH 13/43] Cleanup Rop tests and fix Max Rop implementation --- pytensor/tensor/math.py | 25 +++-- tests/scan/test_basic.py | 24 +++-- tests/tensor/rewriting/test_linalg.py | 27 ++---- tests/tensor/test_shape.py | 2 +- tests/test_rop.py | 128 +++++++++++++++++--------- 5 files changed, 121 insertions(+), 85 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index b185f686bc..9fa823feb8 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -431,20 +431,25 @@ def L_op(self, inputs, outputs, grads): return (g_x,) def R_op(self, inputs, eval_points): + [x] = inputs if eval_points[0] is None: - return [None, None] - if len(self.axis) != 1: - raise ValueError("R_op supported for max only for one axis!") - if self.axis[0] > 1: - raise ValueError("R_op supported for max only when axis is 0 or 1") + return [None] + axis = tuple(range(x.ndim) if self.axis is None else self.axis) + if isinstance(axis, int): + axis = [axis] + if len(axis) != 1: + raise NotImplementedError("R_op supported for max only for one axis!") + if axis[0] > 1: + raise NotImplementedError("R_op supported for max only when axis is 0 or 1") if inputs[0].ndim != 2: - raise ValueError("R_op supported for max only when input is a matrix") - max_pos = Argmax(self.axis).make_node(*inputs).outputs - # print(eval_points[0].eval()) + raise NotImplementedError( + "R_op supported for max only when input is a matrix" + ) + max_pos = Argmax(self.axis)(*inputs) if self.axis[0] == 0: - return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] + return [eval_points[0][max_pos, arange(eval_points[0].shape[1])]] else: - return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None] + return [eval_points[0][arange(eval_points[0].shape[0]), max_pos]] class Min(NonZeroDimsCAReduce): diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 9fa893ab27..d61c90d904 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -1992,9 +1992,9 @@ def rnn_fn(_u, _y, _W): vnu, vnh0, vnW = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) tnu, tnh0, tnW = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) - utt.assert_allclose(vnu, tnu, atol=1e-6) - utt.assert_allclose(vnh0, tnh0, atol=1e-6) - utt.assert_allclose(vnW, tnW, atol=1e-6) + np.testing.assert_allclose(vnu, tnu, atol=1e-6) + np.testing.assert_allclose(vnh0, tnh0, atol=1e-6) + np.testing.assert_allclose(vnW, tnW, atol=1e-6) @pytest.mark.slow def test_R_op_2(self): @@ -2074,9 +2074,9 @@ def rnn_fn(_u, _y, _W): ) tnu, tnh0, tnW, tno = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) - utt.assert_allclose(vnu, tnu, atol=1e-6) - utt.assert_allclose(vnh0, tnh0, atol=1e-6) - utt.assert_allclose(vnW, tnW, atol=2e-6) + np.testing.assert_allclose(vnu, tnu, atol=1e-6) + np.testing.assert_allclose(vnh0, tnh0, atol=1e-6) + np.testing.assert_allclose(vnW, tnW, atol=2e-6) def test_R_op_mitmot(self): # this test is a copy paste from the script given by Justin Bayer to @@ -2094,13 +2094,10 @@ def test_R_op_mitmot(self): W1 = pars[:3].reshape(W1shape) W2 = pars[3:].reshape(W2shape) - # Define recurrent model. We are using a model where each input is a - # tensor - # of shape (T, B, D) where T is the number of timesteps, B is the - # number of - # sequences iterated over in parallel and D is the dimensionality of - # each - # item at a timestep. + # Define recurrent model. We are using a model where each input + # is a tensor of shape (T, B, D) where T is the number of timesteps, + # B is the number of sequences iterated over in parallel and + # D is the dimensionality of each item at a timestep. inpt = tensor3("inpt") target = tensor3("target") @@ -2128,6 +2125,7 @@ def test_R_op_mitmot(self): d_cost_wrt_pars = grad(cost, pars) p = dvector() + # TODO: We should test something about the Rop! Rop(d_cost_wrt_pars, pars, p) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index bbb20251bf..4cc2ce1e12 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -14,7 +14,7 @@ from pytensor.tensor import swapaxes from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import _allclose, dot, matmul +from pytensor.tensor.math import dot, matmul from pytensor.tensor.nlinalg import ( SVD, Det, @@ -42,7 +42,8 @@ from tests.test_rop import break_op -def test_rop_lop(): +def test_matrix_inverse_rop_lop(): + rtol = 1e-7 if config.floatX == "float64" else 1e-5 mx = matrix("mx") mv = matrix("mv") v = vector("v") @@ -62,23 +63,13 @@ def test_rop_lop(): vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) - assert _allclose(v1, v2), f"ROP mismatch: {v1} {v2}" - - raised = False - try: + with pytest.raises(ValueError): pytensor.gradient.Rop( pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv ) - except ValueError: - raised = True - if not raised: - raise Exception( - "Op did not raised an error even though the function" - " is not differentiable" - ) vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) yv = pytensor.gradient.Lop(y, mx, v) @@ -87,9 +78,9 @@ def test_rop_lop(): sy = pytensor.gradient.grad((v * y).sum(), mx) scan_f = function([mx, v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + v = lop_f(vx, vv) + np.testing.assert_allclose(v, v_ref, rtol=rtol) def test_transinv_to_invtrans(): diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index e85b8cfd46..7700d2b14b 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -603,7 +603,7 @@ def test_validation(self): class TestRopLop(RopLopChecker): def test_shape(self): - self.check_nondiff_rop(self.x.shape[0]) + self.check_nondiff_rop(self.x.shape[0], self.x, self.v) def test_specifyshape(self): self.check_rop_lop(specify_shape(self.x, self.in_shape), self.in_shape) diff --git a/tests/test_rop.py b/tests/test_rop.py index 0b9fe41a1e..769bf247be 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -16,8 +16,14 @@ import pytensor import pytensor.tensor as pt -from pytensor import function -from pytensor.gradient import Lop, Rop, grad, grad_undefined +from pytensor import config, function +from pytensor.gradient import ( + Lop, + NullTypeGradError, + Rop, + grad, + grad_undefined, +) from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.tensor.math import argmax, dot @@ -61,6 +67,10 @@ class RopLopChecker: Rop to class that inherit from it. """ + @staticmethod + def rtol(): + return 1e-7 if config.floatX == "float64" else 1e-5 + def setup_method(self): # Using vectors make things a lot simpler for generating the same # computations using scan @@ -72,13 +82,13 @@ def setup_method(self): self.mv = matrix("mv") self.mat_in_shape = (5 + self.rng.integers(3), 5 + self.rng.integers(3)) - def check_nondiff_rop(self, y): + def check_nondiff_rop(self, y, x, v): """ If your op is not differentiable(so you can't define Rop) test that an error is raised. """ with pytest.raises(ValueError): - Rop(y, self.x, self.v) + Rop(y, x, v) def check_mat_rop_lop(self, y, out_shape): """ @@ -115,13 +125,13 @@ def check_mat_rop_lop(self, y, out_shape): ) scan_f = function([self.mx, self.mv], sy, on_unused_input="ignore") - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) - - assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref) self.check_nondiff_rop( - pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}) + pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}), + self.mx, + self.mv, ) vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX) @@ -131,15 +141,17 @@ def check_mat_rop_lop(self, y, out_shape): sy = grad((self.v * y).sum(), self.mx) scan_f = function([self.mx, self.v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v = lop_f(vx, vv) + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(v, v_ref) - def check_rop_lop(self, y, out_shape): + def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True): """ As check_mat_rop_lop, except the input is self.x which is a vector. The output is still a vector. """ + rtol = self.rtol() + # TEST ROP vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) @@ -152,24 +164,17 @@ def check_rop_lop(self, y, out_shape): non_sequences=[y, self.x], ) sy = dot(J, self.v) - scan_f = function([self.x, self.v], sy, on_unused_input="ignore") - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) - try: - Rop( + if check_nondiff_rop: + self.check_nondiff_rop( pytensor.clone_replace(y, replace={self.x: break_op(self.x)}), self.x, self.v, ) - except ValueError: - pytest.skip( - "Rop does not handle non-differentiable inputs " - "correctly. Bug exposed by fixing Add.grad method." - ) vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX) @@ -182,22 +187,20 @@ def check_rop_lop(self, y, out_shape): non_sequences=[y, self.x], ) sy = dot(self.v, J) - scan_f = function([self.x, self.v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v = lop_f(vx, vv) + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(v, v_ref, rtol=rtol) class TestRopLop(RopLopChecker): def test_max(self): - # self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ()) self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],)) self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],)) def test_argmax(self): - self.check_nondiff_rop(argmax(self.mx, axis=1)) + self.check_nondiff_rop(argmax(self.mx, axis=1), self.mx, self.mv) def test_subtensor(self): self.check_rop_lop(self.x[:4], (4,)) @@ -252,10 +255,14 @@ def test_dot(self): insh = self.in_shape[0] vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX) W = pytensor.shared(vW) - self.check_rop_lop(dot(self.x, W), self.in_shape) + # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths + # See: test_Rop_partially_differentiable_paths + self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False) def test_elemwise0(self): - self.check_rop_lop((self.x + 1) ** 2, self.in_shape) + # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths + # See: test_Rop_partially_differentiable_paths + self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False) def test_elemwise1(self): self.check_rop_lop(self.x + pt.cast(self.x, "int32"), self.in_shape) @@ -288,15 +295,8 @@ def test_alloc(self): ) def test_invalid_input(self): - success = False - - try: + with pytest.raises(ValueError): Rop(0.0, [matrix()], [vector()]) - success = True - except ValueError: - pass - - assert not success def test_multiple_outputs(self): m = matrix("m") @@ -322,12 +322,54 @@ def test_multiple_outputs(self): f = pytensor.function([m, v, m_, v_], all_outs) f(mval, vval, m_val, v_val) - def test_Rop_dot_bug_18Oct2013_Jeremiah(self): + @pytest.mark.xfail() + def test_Rop_partially_differentiable_paths(self): # This test refers to a bug reported by Jeremiah Lowin on 18th Oct # 2013. The bug consists when through a dot operation there is only # one differentiable path (i.e. there is no gradient wrt to one of # the inputs). x = pt.arange(20.0).reshape([1, 20]) - v = pytensor.shared(np.ones([20])) + v = pytensor.shared(np.ones([20]), name="v") d = dot(x, v).sum() - Rop(grad(d, v), v, v) + + Rop( + grad(d, v), + v, + v, + disconnected_outputs="raise", + ) + + # 2025: Here is an unambiguous test for the original commented issue: + x = pt.matrix("x") + y = pt.matrix("y") + out = dot(x, break_op(y)).sum() + # Should not raise an error + Rop( + out, + [x], + [x.type()], + disconnected_outputs="raise", + ) + + # More extensive testing shows that the Rop implementation FAILS to raise when + # the cost is linked through strictly non-differentiable paths. + # This is not Dot specific, we would observe the same with any operation where the gradient + # with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...) + out = dot(break_op(x), y).sum() + with pytest.raises((ValueError, NullTypeGradError)): + Rop( + out, + [x], + [x.type()], + disconnected_outputs="raise", + ) + + # Only when both paths are non-differentiable is an error correctly raised again. + out = dot(break_op(x), break_op(y)).sum() + with pytest.raises((ValueError, NullTypeGradError)): + Rop( + out, + [x], + [x.type()], + disconnected_outputs="raise", + ) From 4aea87c25486639b0483c66f81f86a28eb524131 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 14 Feb 2025 17:23:00 +0100 Subject: [PATCH 14/43] Fix bug when taking the L_op of a Scan with mit-mot and disconnected output gradients --- pytensor/scan/op.py | 72 +++++++++++++++++++++------------------- tests/scan/test_basic.py | 51 ++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 34 deletions(-) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index a01347ef9c..a588531d9c 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2509,13 +2509,25 @@ def compute_all_gradients(known_grads): return rval var_mappings = self.get_oinp_iinp_iout_oout_mappings() - dC_dinps_t = [None for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs] + + n_mit_mot_outs = info.n_mit_mot_outs + # In the case of mit-mot there can be more inner outputs than outer ones + n_extra_mit_mot_outs = n_mit_mot_outs - info.n_mit_mot + idx_nitsot_out_start = n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot + idx_nitsot_out_end = idx_nitsot_out_start + info.n_nit_sot + + # Create dummy variables for the internal input gradients + states = ( + self.inner_mitmot(self_inputs) + + self.inner_mitsot(self_inputs) + + self.inner_sitsot(self_inputs) + ) dC_dXts = [] Xts = [] for idx, Xt in enumerate(diff_outputs): # We are looking for x[t-1] for a given x[t] - if idx >= info.n_mit_mot_outs: + if idx >= n_mit_mot_outs: Xt_placeholder = safe_new(Xt) Xts.append(Xt_placeholder) @@ -2523,9 +2535,7 @@ def compute_all_gradients(known_grads): # or not. NOTE : This cannot be done by using # "if Xt not in self.inner_nitsot_outs(self_outputs)" because # the exact same variable can be used as multiple outputs. - idx_nitsot_start = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot - idx_nitsot_end = idx_nitsot_start + info.n_nit_sot - if idx < idx_nitsot_start or idx >= idx_nitsot_end: + if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end: # What we do here is loop through dC_douts and collect all # those that are connected to the specific one and do an # upcast on all of their dtypes to get the dtype for this @@ -2533,12 +2543,6 @@ def compute_all_gradients(known_grads): # specific previous step is defined or not is done somewhere # else. dtypes = [] - states = ( - self.inner_mitmot(self_inputs) - + self.inner_mitsot(self_inputs) - + self.inner_sitsot(self_inputs) - ) - for pos, inp in enumerate(states): if inp in graph_inputs([Xt]): # Get the index of the outer output that to which @@ -2555,35 +2559,39 @@ def compute_all_gradients(known_grads): new_dtype = config.floatX dC_dXt = safe_new(Xt, dtype=new_dtype) else: - if isinstance(dC_douts[idx].type, DisconnectedType): + # nit-sot outputs + # If not disconnected assume the output gradient type is a valid type for the input gradient + if isinstance( + dC_douts[idx - n_extra_mit_mot_outs].type, DisconnectedType + ): continue - dC_dXt = safe_new(dC_douts[idx][0]) + dC_dXt = safe_new(dC_douts[idx - n_extra_mit_mot_outs][0]) dC_dXts.append(dC_dXt) + # Handle cases where the very same variable may be used as different outputs + # TODO: Couldn't we add a view Op to avoid this when building the Scan graph? known_grads = {} dc_dxts_idx = 0 for i in range(len(diff_outputs)): - if i < idx_nitsot_start or i >= idx_nitsot_end: - if diff_outputs[i] in known_grads: - known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] - else: - known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] - dc_dxts_idx += 1 + if not (i < idx_nitsot_out_start or i >= idx_nitsot_out_end) and isinstance( + dC_douts[i - n_extra_mit_mot_outs].type, DisconnectedType + ): + # Special case where we don't have a dC_dXt for disconnected nitsot outputs + continue + + # Just some trouble to avoid a +0 + if diff_outputs[i] in known_grads: + known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] else: - if isinstance(dC_douts[i].type, DisconnectedType): - continue - else: - if diff_outputs[i] in known_grads: - known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] - else: - known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] - dc_dxts_idx += 1 + known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] + dc_dxts_idx += 1 + dC_dinps_t = compute_all_gradients(known_grads) # mask inputs that get no gradients for dx in range(len(dC_dinps_t)): - if not dC_dinps_t[dx]: - dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx]) + if dC_dinps_t[dx] is None: + dC_dinps_t[dx] = dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx]) else: disconnected_dC_dinps_t[dx] = False for Xt, Xt_placeholder in zip( @@ -2846,7 +2854,6 @@ def compute_all_gradients(known_grads): for idx in range(info.n_sit_sot): mitmot_inp_taps.append([0, 1]) mitmot_out_taps.append([1]) - through_shared = False if not isinstance(dC_douts[idx + offset].type, DisconnectedType): outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) else: @@ -3007,9 +3014,7 @@ def compute_all_gradients(known_grads): name=f"grad_of_{self.name}" if self.name else None, allow_gc=self.allow_gc, ) - outputs = local_op(*outer_inputs) - if not isinstance(outputs, list | tuple): - outputs = [outputs] + outputs = local_op(*outer_inputs, return_list=True) # Re-order the gradients correctly gradients = [DisconnectedType()()] @@ -3095,7 +3100,6 @@ def compute_all_gradients(known_grads): ) ) - start = len(gradients) gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)] begin = end diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index d61c90d904..f3fe8c167f 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -2128,6 +2128,57 @@ def test_R_op_mitmot(self): # TODO: We should test something about the Rop! Rop(d_cost_wrt_pars, pars, p) + def test_second_derivative_disconnected_cost_with_mit_mot(self): + # This test is a regression test for a bug that was revealed + # when we computed the pushforward of a Scan gradient via two applications of pullback + seq = pt.vector("seq", shape=(2,)) + z = pt.scalar("z") + x0 = pt.vector("x0", shape=(2,)) + + # When s is 1 and z is 2, xs[-1] is just a sneaky + # x ** 4 (after two nsteps) + # grad should be 4 * x ** 3 + # and grad of grad should be 12 * x ** 2 + def step(s, xtm2, xtm1, z): + return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2) + + xs, _ = scan( + step, + sequences=[seq], + outputs_info=[{"initial": x0, "taps": (-2, -1)}], + non_sequences=[z], + n_steps=2, + ) + last_x = xs[-1] + + g_wrt_x0, g_wrt_z, g_wrt_seq = pt.grad(last_x, [x0, z, seq]) + g = g_wrt_x0.sum() + g_wrt_z.sum() * 0 + g_wrt_seq.sum() * 0 + assert g.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 4 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96 + + # Leave out z + g_wrt_x0, g_wrt_seq = pt.grad(last_x, [x0, seq]) + g = g_wrt_x0.sum() + g_wrt_seq.sum() * 0 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96 + + # Leave out seq + g_wrt_x0, g_wrt_z = pt.grad(last_x, [x0, z]) + g = g_wrt_x0.sum() + g_wrt_z.sum() * 0 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2 + + # Leave out z and seq + g_wrt_x0 = pt.grad(last_x, x0) + g = g_wrt_x0.sum() + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2 + @pytest.mark.skipif( not config.cxx, reason="G++ not available, so we need to skip this test." From 84c7802702c380ee817f6853bb3d350be3ad2ed6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 14 Feb 2025 18:43:32 +0100 Subject: [PATCH 15/43] Handle Scan gradients of non shaped disconnected inputs --- pytensor/scan/op.py | 10 ++++-- tests/scan/test_basic.py | 66 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index a588531d9c..9e55739ed4 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -72,6 +72,7 @@ from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.replace import clone_replace +from pytensor.graph.type import HasShape from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.link.c.basic import CLinker from pytensor.printing import op_debug_information @@ -2591,7 +2592,11 @@ def compute_all_gradients(known_grads): # mask inputs that get no gradients for dx in range(len(dC_dinps_t)): if dC_dinps_t[dx] is None: - dC_dinps_t[dx] = dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx]) + dC_dinps_t[dx] = dC_dinps_t[dx] = ( + pt.zeros_like(diff_inputs[dx]) + if isinstance(diff_inputs[dx].type, HasShape) + else pt.zeros(()) + ) else: disconnected_dC_dinps_t[dx] = False for Xt, Xt_placeholder in zip( @@ -2965,7 +2970,8 @@ def compute_all_gradients(known_grads): else: outer_inp_sitsot.append( pt.zeros( - [grad_steps + 1] + [x.shape[i] for i in range(x.ndim)], + [grad_steps + 1] + + (list(x.shape) if isinstance(x.type, HasShape) else []), dtype=y.dtype, ) ) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index f3fe8c167f..b86423a6ff 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -2179,6 +2179,72 @@ def step(s, xtm2, xtm1, z): assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2 + @pytest.mark.parametrize("case", ("inside-explicit", "inside-implicit", "outside")) + def test_non_shaped_input_disconnected_gradient(self, case): + """Test that Scan gradient works when non shaped variables are disconnected from the gradient. + + Regression test for https://github.com/pymc-devs/pytensor/issues/6 + """ + + # In all cases rng is disconnected from the output gradient + # Note that when it is an input to the scan (explicit or not) it is still not updated by the scan, + # so it is equivalent to the `outside` case. A rewrite could have legally hoisted the rng out of the scan. + rng = shared(np.random.default_rng()) + + data = pt.zeros(16) + + nonlocal_random_index = pt.random.integers(16, rng=rng) + nonlocal_random_datum = data[nonlocal_random_index] + + if case == "outside": + + def step(s, random_datum): + return (random_datum + s) ** 2 + + strict = True + non_sequences = [nonlocal_random_datum] + + elif case == "inside-implicit": + + def step(s): + return (nonlocal_random_datum + s) ** 2 + + strict = False + non_sequences = [] # Scan will introduce the non_sequences for us + + elif case == "inside-explicit": + + def step(s, data, rng): + random_index = pt.random.integers( + 16, rng=rng + ) # Not updated by the scan + random_datum = data[random_index] + return (random_datum + s) ** 2 + + strict = (True,) + non_sequences = [data, rng] + + else: + raise ValueError(f"Invalid case: {case}") + + seq = vector("seq") + xs, _ = scan( + step, + sequences=[seq], + non_sequences=non_sequences, + strict=strict, + ) + x0 = xs[0] + + np.testing.assert_allclose( + x0.eval({seq: [np.pi, np.nan, np.nan]}), + np.pi**2, + ) + np.testing.assert_allclose( + grad(x0, seq)[0].eval({seq: [np.pi, np.nan, np.nan]}), + 2 * np.pi, + ) + @pytest.mark.skipif( not config.cxx, reason="G++ not available, so we need to skip this test." From b5a64c775eea7f3d0fed56b3a9237906701e3d28 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Feb 2025 22:22:24 +0100 Subject: [PATCH 16/43] Compute pushforward via double application of pullback Also fixes bug in Scan L_op and Max R_op Co-authored-by: Adrian Seyboldt --- doc/extending/op.rst | 1 - doc/library/gradient.rst | 76 ---------- doc/library/tensor/basic.rst | 2 - doc/tutorial/gradients.rst | 21 ++- pytensor/compile/builders.py | 13 +- pytensor/gradient.py | 203 ++++++++++++++++++++------ pytensor/scan/op.py | 7 +- tests/compile/test_builders.py | 32 +++- tests/scan/test_basic.py | 28 ++-- tests/tensor/rewriting/test_linalg.py | 11 +- tests/test_rop.py | 70 +++++++-- 11 files changed, 301 insertions(+), 163 deletions(-) delete mode 100644 doc/library/gradient.rst diff --git a/doc/extending/op.rst b/doc/extending/op.rst index ddd397dee9..b1585c4ecd 100644 --- a/doc/extending/op.rst +++ b/doc/extending/op.rst @@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`. the outputs) back to their corresponding shapes and return them as the output of the :meth:`Op.R_op` method. - :ref:`List of op with r op support `. diff --git a/doc/library/gradient.rst b/doc/library/gradient.rst deleted file mode 100644 index f823a1c381..0000000000 --- a/doc/library/gradient.rst +++ /dev/null @@ -1,76 +0,0 @@ -.. _libdoc_gradient: - -=========================================== -:mod:`gradient` -- Symbolic Differentiation -=========================================== - -.. module:: gradient - :platform: Unix, Windows - :synopsis: low-level automatic differentiation -.. moduleauthor:: LISA - -.. testsetup:: * - - from pytensor.gradient import * - -Symbolic gradient is usually computed from :func:`gradient.grad`, which offers a -more convenient syntax for the common case of wanting the gradient of some -scalar cost with respect to some input expressions. The :func:`grad_sources_inputs` -function does the underlying work, and is more flexible, but is also more -awkward to use when :func:`gradient.grad` can do the job. - - -Gradient related functions -========================== - -.. automodule:: pytensor.gradient - :members: - -.. _R_op_list: - - -List of Implemented R op -======================== - - -See the :ref:`gradient tutorial ` for the R op documentation. - -list of ops that support R-op: - * with test - * SpecifyShape - * MaxAndArgmax - * Subtensor - * IncSubtensor set_subtensor too - * Alloc - * Dot - * Elemwise - * Sum - * Softmax - * Shape - * Join - * Rebroadcast - * Reshape - * DimShuffle - * Scan [In tests/scan/test_basic.test_rop] - - * without test - * Split - * ARange - * ScalarFromTensor - * AdvancedSubtensor1 - * AdvancedIncSubtensor1 - * AdvancedIncSubtensor - -Partial list of ops without support for R-op: - - * All sparse ops - * All linear algebra ops. - * PermuteRowElements - * AdvancedSubtensor - * TensorDot - * Outer - * Prod - * MulwithoutZeros - * ProdWithoutZeros - * CAReduce(for max,... done for MaxAndArgmax op) - * MaxAndArgmax(only for matrix on axis 0 or 1) diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst index 8d22c1e577..4f087b6788 100644 --- a/doc/library/tensor/basic.rst +++ b/doc/library/tensor/basic.rst @@ -1791,5 +1791,3 @@ Gradient / Differentiation :members: grad :noindex: -See the :ref:`gradient ` page for complete documentation -of the gradient module. diff --git a/doc/tutorial/gradients.rst b/doc/tutorial/gradients.rst index edb38bb018..f8b7f7ff98 100644 --- a/doc/tutorial/gradients.rst +++ b/doc/tutorial/gradients.rst @@ -86,9 +86,7 @@ of symbolic differentiation). ``i`` of the output list is the gradient of the first argument of `pt.grad` with respect to the ``i``-th element of the list given as second argument. The first argument of `pt.grad` has to be a scalar (a tensor - of size 1). For more information on the semantics of the arguments of - `pt.grad` and details about the implementation, see - :ref:`this` section of the library. + of size 1). Additional information on the inner workings of differentiation may also be found in the more advanced tutorial :ref:`Extending PyTensor`. @@ -204,7 +202,21 @@ you need to do something similar to this: >>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1]) array([ 2., 2.]) -:ref:`List ` of Op that implement Rop. +By default, the R-operator is implemented as a double application of the L_operator +(see `reference `_). +In most cases this should be as performant as a specialized implementation of the R-operator. +However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators, +such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator. + +When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing +`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method. + + +>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True) +>>> f = pytensor.function([W, V, x], JV) +>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1]) +array([ 2., 2.]) + L-operator ---------- @@ -234,7 +246,6 @@ array([[ 0., 0.], as the input parameter, while the result of the R-operator has a shape similar to that of the output. - :ref:`List of op with r op support `. Hessian times a Vector ====================== diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 49baa3bb26..a4a3d1840a 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -340,6 +340,12 @@ def __init__( ``None``, this will be used as the connection_pattern for this :class:`Op`. + .. warning:: + + rop overrides is ignored when `pytensor.gradient.Rop` is called with + `use_op_rop_implementation=False` (default). In this case the Lop + is used twice to obtain a mathematically equivalent Rop. + strict: bool, default False If true, it raises when any variables needed to compute the inner graph are not provided as explici inputs. This can only happen for graphs with @@ -641,7 +647,12 @@ def _build_and_cache_rop_op(self): return rop_overrides eval_points = [inp_t() for inp_t in self.input_types] - fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points) + fn_rop = partial( + Rop, + wrt=inner_inputs, + eval_points=eval_points, + use_op_rop_implementation=True, + ) callable_args = (inner_inputs, eval_points) if rop_overrides is None: diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 13ca943383..04572b29d0 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -142,13 +142,50 @@ def __str__(self): disconnected_type = DisconnectedType() -def Rop( - f: Variable | Sequence[Variable], - wrt: Variable | Sequence[Variable], - eval_points: Variable | Sequence[Variable], +def pushforward_through_pullback( + outputs: Sequence[Variable], + inputs: Sequence[Variable], + tangents: Sequence[Variable], disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", return_disconnected: Literal["none", "zero", "disconnected"] = "zero", -) -> Variable | None | Sequence[Variable | None]: +) -> Sequence[Variable | None]: + """Compute the pushforward (Rop) through two applications of a pullback (Lop) operation. + + References + ---------- + .. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017. + Available: https://j-towns.github.io/2017/06/12/A-new-trick.html + + """ + # Cotangents are just auxiliary variables that should be pruned from the final graph, + # but that would require a graph rewrite before the user tries to compile a pytensor function. + # To avoid trouble we use .zeros_like() instead of .type(), which does not create a new root variable. + cotangents = [out.zeros_like(dtype=config.floatX) for out in outputs] # type: ignore + + input_cotangents = Lop( + f=outputs, + wrt=inputs, + eval_points=cotangents, + disconnected_inputs=disconnected_outputs, + return_disconnected="zero", + ) + + return Lop( + f=input_cotangents, # type: ignore + wrt=cotangents, + eval_points=tangents, + disconnected_inputs="ignore", + return_disconnected=return_disconnected, + ) + + +def _rop_legacy( + f: Sequence[Variable], + wrt: Sequence[Variable], + eval_points: Sequence[Variable], + disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", +) -> Sequence[Variable | None]: """Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`. Mathematically this stands for the Jacobian of `f` right multiplied by the @@ -190,38 +227,6 @@ def Rop( If `f` is a list/tuple, then return a list/tuple with the results. """ - if not isinstance(wrt, list | tuple): - _wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)] - else: - _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] - - if not isinstance(eval_points, list | tuple): - _eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)] - else: - _eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points] - - if not isinstance(f, list | tuple): - _f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)] - else: - _f = [pytensor.tensor.as_tensor_variable(x) for x in f] - - if len(_wrt) != len(_eval_points): - raise ValueError("`wrt` must be the same length as `eval_points`.") - - # Check that each element of wrt corresponds to an element - # of eval_points with the same dimensionality. - for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)): - try: - if wrt_elem.type.ndim != eval_point.type.ndim: - raise ValueError( - f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: " - f"{wrt_elem.type.ndim} and {eval_point.type.ndim}" - ) - except AttributeError: - # wrt_elem and eval_point don't always have ndim like random type - # Tensor, Sparse have the ndim attribute - pass - seen_nodes: dict[Apply, Sequence[Variable]] = {} def _traverse(node): @@ -237,8 +242,8 @@ def _traverse(node): # inputs of the node local_eval_points = [] for inp in inputs: - if inp in _wrt: - local_eval_points.append(_eval_points[_wrt.index(inp)]) + if inp in wrt: + local_eval_points.append(eval_points[wrt.index(inp)]) elif inp.owner is None: try: local_eval_points.append(inp.zeros_like()) @@ -292,13 +297,13 @@ def _traverse(node): # end _traverse # Populate the dictionary - for out in _f: + for out in f: _traverse(out.owner) rval: list[Variable | None] = [] - for out in _f: - if out in _wrt: - rval.append(_eval_points[_wrt.index(out)]) + for out in f: + if out in wrt: + rval.append(eval_points[wrt.index(out)]) elif ( seen_nodes.get(out.owner, None) is None or seen_nodes[out.owner][out.owner.outputs.index(out)] is None @@ -337,6 +342,116 @@ def _traverse(node): else: rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)]) + return rval + + +def Rop( + f: Variable | Sequence[Variable], + wrt: Variable | Sequence[Variable], + eval_points: Variable | Sequence[Variable], + disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", + use_op_rop_implementation: bool = False, +) -> Variable | None | Sequence[Variable | None]: + """Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`. + + Mathematically this stands for the Jacobian of `f` right multiplied by the + `eval_points`. + + By default, the R-operator is implemented as a double application of the L_operator [1]_. + In most cases this should be as performant as a specialized implementation of the R-operator. + However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators, + such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator. + + When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing + `use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method. + + Parameters + ---------- + f + The outputs of the computational graph to which the R-operator is + applied. + wrt + Variables for which the R-operator of `f` is computed. + eval_points + Points at which to evaluate each of the variables in `wrt`. + disconnected_outputs + Defines the behaviour if some of the variables in `f` + have no dependency on any of the variable in `wrt` (or if + all links are non-differentiable). The possible values are: + + - ``'ignore'``: considers that the gradient on these parameters is zero. + - ``'warn'``: consider the gradient zero, and print a warning. + - ``'raise'``: raise `DisconnectedInputError`. + + return_disconnected + - ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``wrt[i].zeros_like()``. + - ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``None`` + - ``'disconnected'`` : returns variables of type `DisconnectedType` + use_op_lop_implementation: bool, default=True + If `True`, we obtain Rop via double application of Lop. + If `False`, the legacy Rop implementation is used. The number of graphs that support this form + is much more restricted, and the generated graphs may be less optimized. + + Returns + ------- + :class:`~pytensor.graph.basic.Variable` or list/tuple of Variables + A symbolic expression such obeying + ``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``, + where the indices in that expression are magic multidimensional + indices that specify both the position within a list and all + coordinates of the tensor elements. + If `f` is a list/tuple, then return a list/tuple with the results. + + References + ---------- + .. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017. + Available: https://j-towns.github.io/2017/06/12/A-new-trick.html + """ + + if not isinstance(wrt, list | tuple): + _wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)] + else: + _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] + + if not isinstance(eval_points, list | tuple): + _eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)] + else: + _eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points] + + if not isinstance(f, list | tuple): + _f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)] + else: + _f = [pytensor.tensor.as_tensor_variable(x) for x in f] + + if len(_wrt) != len(_eval_points): + raise ValueError("`wrt` must be the same length as `eval_points`.") + + # Check that each element of wrt corresponds to an element + # of eval_points with the same dimensionality. + for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)): + try: + if wrt_elem.type.ndim != eval_point.type.ndim: + raise ValueError( + f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: " + f"{wrt_elem.type.ndim} and {eval_point.type.ndim}" + ) + except AttributeError: + # wrt_elem and eval_point don't always have ndim like random type + # Tensor, Sparse have the ndim attribute + pass + + if use_op_rop_implementation: + rval = _rop_legacy( + _f, _wrt, _eval_points, disconnected_outputs, return_disconnected + ) + else: + rval = pushforward_through_pullback( + _f, _wrt, _eval_points, disconnected_outputs, return_disconnected + ) + using_list = isinstance(f, list) using_tuple = isinstance(f, tuple) return as_list_or_tuple(using_list, using_tuple, rval) @@ -348,6 +463,7 @@ def Lop( eval_points: Variable | Sequence[Variable], consider_constant: Sequence[Variable] | None = None, disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", ) -> Variable | None | Sequence[Variable | None]: """Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`. @@ -404,6 +520,7 @@ def Lop( consider_constant=consider_constant, wrt=_wrt, disconnected_inputs=disconnected_inputs, + return_disconnected=return_disconnected, ) using_list = isinstance(wrt, list) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 9e55739ed4..1dbc93b9fa 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -3165,7 +3165,12 @@ def R_op(self, inputs, eval_points): rop_self_outputs = self_outputs if info.n_shared_outs > 0: rop_self_outputs = rop_self_outputs[: -info.n_shared_outs] - rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points) + rop_outs = Rop( + rop_self_outputs, + rop_of_inputs, + inner_eval_points, + use_op_rop_implementation=True, + ) if not isinstance(rop_outs, list | tuple): rop_outs = [rop_outs] # Step 2. Figure out what corresponds to what in the scan diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 8fc2a529df..ba0257cdda 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -306,7 +306,8 @@ def lop_ov(inps, outs, grads): @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) - def test_rop(self, cls_ofg): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_rop(self, cls_ofg, use_op_rop_implementation): a = vector() M = matrix() b = dot(a, M) @@ -315,7 +316,7 @@ def test_rop(self, cls_ofg): W = matrix() y = op_matmul(x, W) du = vector() - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) xval = np.random.random((16,)).astype(config.floatX) Wval = np.random.random((16, 16)).astype(config.floatX) @@ -324,7 +325,8 @@ def test_rop(self, cls_ofg): dvval2 = fn(xval, Wval, duval) np.testing.assert_array_almost_equal(dvval2, dvval, 4) - def test_rop_multiple_outputs(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_rop_multiple_outputs(self, use_op_rop_implementation): a = vector() M = matrix() b = dot(a, M) @@ -339,21 +341,21 @@ def test_rop_multiple_outputs(self): duval = np.random.random((16,)).astype(config.floatX) y = op_matmul(x, W)[0] - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = np.dot(duval, Wval) np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) y = op_matmul(x, W)[1] - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = -np.dot(duval, Wval) np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) y = pt.add(*op_matmul(x, W)) - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = np.zeros_like(np.dot(duval, Wval)) @@ -362,7 +364,16 @@ def test_rop_multiple_outputs(self): @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) - def test_rop_override(self, cls_ofg): + @pytest.mark.parametrize( + "use_op_rop_implementation", + [ + True, + pytest.param( + False, marks=pytest.mark.xfail(reason="Custom ROp is ignored") + ), + ], + ) + def test_rop_override(self, cls_ofg, use_op_rop_implementation): x, y = vectors("xy") def ro(inps, epts): @@ -380,7 +391,12 @@ def ro(inps, epts): du, dv = vector("du"), vector("dv") for op in [op_mul, op_mul2]: zz = op_mul(xx, yy) - dw = Rop(zz, [xx, yy], [du, dv]) + dw = Rop( + zz, + [xx, yy], + [du, dv], + use_op_rop_implementation=use_op_rop_implementation, + ) fn = function([xx, yy, du, dv], dw) vals = np.random.random((4, 32)).astype(config.floatX) dwval = fn(*vals) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index b86423a6ff..351c2e703a 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -1922,7 +1922,8 @@ def inner_fn(): fgrad = function([], g_sh) assert fgrad() == 1 - def test_R_op(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op(self, use_op_rop_implementation): seed = utt.fetch_seed() rng = np.random.default_rng(seed) floatX = config.floatX @@ -1957,9 +1958,9 @@ def rnn_fn(_u, _y, _W): eh0 = vector("eh0") eW = matrix("eW") - nwo_u = Rop(o, _u, eu) - nwo_h0 = Rop(o, _h0, eh0) - nwo_W = Rop(o, _W, eW) + nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation) + nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation) + nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation) fn_rop = function( [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore" ) @@ -1997,7 +1998,8 @@ def rnn_fn(_u, _y, _W): np.testing.assert_allclose(vnW, tnW, atol=1e-6) @pytest.mark.slow - def test_R_op_2(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op_2(self, use_op_rop_implementation): seed = utt.fetch_seed() rng = np.random.default_rng(seed) floatX = config.floatX @@ -2040,9 +2042,9 @@ def rnn_fn(_u, _y, _W): eh0 = vector("eh0") eW = matrix("eW") - nwo_u = Rop(o, _u, eu) - nwo_h0 = Rop(o, _h0, eh0) - nwo_W = Rop(o, _W, eW) + nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation) + nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation) + nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation) fn_rop = function( [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W, o], on_unused_input="ignore" ) @@ -2078,7 +2080,8 @@ def rnn_fn(_u, _y, _W): np.testing.assert_allclose(vnh0, tnh0, atol=1e-6) np.testing.assert_allclose(vnW, tnW, atol=2e-6) - def test_R_op_mitmot(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op_mitmot(self, use_op_rop_implementation): # this test is a copy paste from the script given by Justin Bayer to # reproduce this bug # We have 2 parameter groups with the following shapes. @@ -2126,7 +2129,12 @@ def test_R_op_mitmot(self): p = dvector() # TODO: We should test something about the Rop! - Rop(d_cost_wrt_pars, pars, p) + Rop( + d_cost_wrt_pars, + pars, + p, + use_op_rop_implementation=use_op_rop_implementation, + ) def test_second_derivative_disconnected_cost_with_mit_mot(self): # This test is a regression test for a bug that was revealed diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 4cc2ce1e12..50e48ce95d 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -49,9 +49,12 @@ def test_matrix_inverse_rop_lop(): v = vector("v") y = MatrixInverse()(mx).sum(axis=0) - yv = pytensor.gradient.Rop(y, mx, mv) + yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True) rop_f = function([mx, mv], yv) + yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False) + rop_via_lop_f = function([mx, mv], yv_via_lop) + sy, _ = pytensor.scan( lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(), sequences=pt.arange(y.shape[0]), @@ -65,10 +68,14 @@ def test_matrix_inverse_rop_lop(): v_ref = scan_f(vx, vv) np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) + np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol) with pytest.raises(ValueError): pytensor.gradient.Rop( - pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv + pytensor.clone_replace(y, replace={mx: break_op(mx)}), + mx, + mv, + use_op_rop_implementation=True, ) vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) diff --git a/tests/test_rop.py b/tests/test_rop.py index 769bf247be..b592f557a5 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -88,7 +88,7 @@ def check_nondiff_rop(self, y, x, v): test that an error is raised. """ with pytest.raises(ValueError): - Rop(y, x, v) + Rop(y, x, v, use_op_rop_implementation=True) def check_mat_rop_lop(self, y, out_shape): """ @@ -116,8 +116,14 @@ def check_mat_rop_lop(self, y, out_shape): vv = np.asarray( self.rng.uniform(size=self.mat_in_shape), pytensor.config.floatX ) - yv = Rop(y, self.mx, self.mv) + yv = Rop(y, self.mx, self.mv, use_op_rop_implementation=True) rop_f = function([self.mx, self.mv], yv, on_unused_input="ignore") + + yv_through_lop = Rop(y, self.mx, self.mv, use_op_rop_implementation=False) + rop_through_lop_f = function( + [self.mx, self.mv], yv_through_lop, on_unused_input="ignore" + ) + sy, _ = pytensor.scan( lambda i, y, x, v: (grad(y[i], x) * v).sum(), sequences=pt.arange(y.shape[0]), @@ -127,6 +133,7 @@ def check_mat_rop_lop(self, y, out_shape): v_ref = scan_f(vx, vv) np.testing.assert_allclose(rop_f(vx, vv), v_ref) + np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref) self.check_nondiff_rop( pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}), @@ -156,8 +163,14 @@ def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True): vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) - yv = Rop(y, self.x, self.v) + yv = Rop(y, self.x, self.v, use_op_rop_implementation=True) rop_f = function([self.x, self.v], yv, on_unused_input="ignore") + + yv_through_lop = Rop(y, self.x, self.v, use_op_rop_implementation=False) + rop_through_lop_f = function( + [self.x, self.v], yv_through_lop, on_unused_input="ignore" + ) + J, _ = pytensor.scan( lambda i, y, x: grad(y[i], x), sequences=pt.arange(y.shape[0]), @@ -168,6 +181,7 @@ def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True): v_ref = scan_f(vx, vv) np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) + np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref, rtol=rtol) if check_nondiff_rop: self.check_nondiff_rop( @@ -255,12 +269,12 @@ def test_dot(self): insh = self.in_shape[0] vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX) W = pytensor.shared(vW) - # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths + # check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths # See: test_Rop_partially_differentiable_paths self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False) def test_elemwise0(self): - # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths + # check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths # See: test_Rop_partially_differentiable_paths self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False) @@ -294,11 +308,18 @@ def test_alloc(self): self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0], ) - def test_invalid_input(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_invalid_input(self, use_op_rop_implementation): with pytest.raises(ValueError): - Rop(0.0, [matrix()], [vector()]) + Rop( + 0.0, + [matrix()], + [vector()], + use_op_rop_implementation=use_op_rop_implementation, + ) - def test_multiple_outputs(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_multiple_outputs(self, use_op_rop_implementation): m = matrix("m") v = vector("v") m_ = matrix("m_") @@ -309,10 +330,20 @@ def test_multiple_outputs(self): m_val = self.rng.uniform(size=(3, 7)).astype(pytensor.config.floatX) v_val = self.rng.uniform(size=(7,)).astype(pytensor.config.floatX) - rop_out1 = Rop([m, v, m + v], [m, v], [m_, v_]) + rop_out1 = Rop( + [m, v, m + v], + [m, v], + [m_, v_], + use_op_rop_implementation=use_op_rop_implementation, + ) assert isinstance(rop_out1, list) assert len(rop_out1) == 3 - rop_out2 = Rop((m, v, m + v), [m, v], [m_, v_]) + rop_out2 = Rop( + (m, v, m + v), + [m, v], + [m_, v_], + use_op_rop_implementation=use_op_rop_implementation, + ) assert isinstance(rop_out2, tuple) assert len(rop_out2) == 3 @@ -322,8 +353,11 @@ def test_multiple_outputs(self): f = pytensor.function([m, v, m_, v_], all_outs) f(mval, vval, m_val, v_val) - @pytest.mark.xfail() - def test_Rop_partially_differentiable_paths(self): + @pytest.mark.parametrize( + "use_op_rop_implementation", + [pytest.param(True, marks=pytest.mark.xfail()), False], + ) + def test_Rop_partially_differentiable_paths(self, use_op_rop_implementation): # This test refers to a bug reported by Jeremiah Lowin on 18th Oct # 2013. The bug consists when through a dot operation there is only # one differentiable path (i.e. there is no gradient wrt to one of @@ -336,7 +370,12 @@ def test_Rop_partially_differentiable_paths(self): grad(d, v), v, v, - disconnected_outputs="raise", + use_op_rop_implementation=use_op_rop_implementation, + # 2025: This is a tricky case, the gradient of the gradient does not depend on v + # although v still exists in the graph inside a `Second` operator. + # The original test was checking that Rop wouldn't raise an error, but Lop does. + # Since the correct behavior is ambiguous, I let both implementations off the hook. + disconnected_outputs="raise" if use_op_rop_implementation else "ignore", ) # 2025: Here is an unambiguous test for the original commented issue: @@ -348,10 +387,11 @@ def test_Rop_partially_differentiable_paths(self): out, [x], [x.type()], + use_op_rop_implementation=use_op_rop_implementation, disconnected_outputs="raise", ) - # More extensive testing shows that the Rop implementation FAILS to raise when + # More extensive testing shows that the legacy Rop implementation FAILS to raise when # the cost is linked through strictly non-differentiable paths. # This is not Dot specific, we would observe the same with any operation where the gradient # with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...) @@ -361,6 +401,7 @@ def test_Rop_partially_differentiable_paths(self): out, [x], [x.type()], + use_op_rop_implementation=use_op_rop_implementation, disconnected_outputs="raise", ) @@ -371,5 +412,6 @@ def test_Rop_partially_differentiable_paths(self): out, [x], [x.type()], + use_op_rop_implementation=use_op_rop_implementation, disconnected_outputs="raise", ) From fe8804fab9cc36077ad98071b7e7bf33c0c23bb8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 11:49:12 +0100 Subject: [PATCH 17/43] Cache sub-type of DimShuffle --- pytensor/tensor/elemwise.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index cb60427ba0..c37597906a 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -166,15 +166,20 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): self.transposition = self.shuffle + drop # List of dimensions of the output that are broadcastable and were not # in the original input - self.augment = sorted(i for i, x in enumerate(new_order) if x == "x") + self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.drop = drop - self.is_left_expand_dims = self.augment and ( + dims_are_shuffled = sorted(self.shuffle) != self.shuffle + + self.is_transpose = dims_are_shuffled and not augment and not drop + self.is_squeeze = drop and not dims_are_shuffled and not augment + self.is_expand_dims = augment and not dims_are_shuffled and not drop + self.is_left_expand_dims = self.is_expand_dims and ( input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) ) - self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list( - range(input_ndim) - ) + self.is_right_expand_dims = self.is_expand_dims and new_order[ + :input_ndim + ] == list(range(input_ndim)) if self.inplace: self.view_map = {0: [0]} @@ -215,16 +220,15 @@ def make_node(self, inp): return Apply(self, [input], [output]) def __str__(self): - shuffle = sorted(self.shuffle) != self.shuffle - if self.augment and not (shuffle or self.drop): + if self.is_expand_dims: if len(self.augment) == 1: return f"ExpandDims{{axis={self.augment[0]}}}" return f"ExpandDims{{axes={self.augment}}}" - if self.drop and not (self.augment or shuffle): + if self.is_squeeze: if len(self.drop) == 1: - return f"DropDims{{axis={self.drop[0]}}}" - return f"DropDims{{axes={self.drop}}}" - if shuffle and not (self.augment or self.drop): + return f"Squeeze{{axis={self.drop[0]}}}" + return f"Squeeze{{axes={self.drop}}}" + if self.is_transpose: return f"Transpose{{axes={self.shuffle}}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" From 947b9409ac0129aa4c1ad4bbf763d8c9b887d097 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Feb 2025 16:11:07 +0100 Subject: [PATCH 18/43] Make reshape ndim keyword only --- pytensor/tensor/shape.py | 18 ++++++++++++++---- pytensor/tensor/slinalg.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 8913d6fb4d..45a22b8714 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -1,7 +1,9 @@ import warnings +from collections.abc import Sequence from numbers import Number from textwrap import dedent -from typing import cast +from typing import TYPE_CHECKING, Union, cast +from typing import cast as typing_cast import numpy as np from numpy.core.numeric import normalize_axis_tuple # type: ignore @@ -24,6 +26,9 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable +if TYPE_CHECKING: + from pytensor.tensor import TensorLike + ShapeValueType = None | np.integer | int | Variable @@ -842,9 +847,14 @@ def _vectorize_reshape(op, node, x, shape): return reshape(x, new_shape, ndim=len(new_shape)).owner -def reshape(x, newshape, ndim=None): +def reshape( + x: "TensorLike", + newshape: Union["TensorLike", Sequence["TensorLike"]], + *, + ndim: int | None = None, +) -> TensorVariable: if ndim is None: - newshape = ptb.as_tensor_variable(newshape) + newshape = ptb.as_tensor_variable(newshape) # type: ignore if newshape.type.ndim != 1: raise TypeError( "New shape in reshape must be a vector or a list/tuple of" @@ -862,7 +872,7 @@ def reshape(x, newshape, ndim=None): ) op = Reshape(ndim) rval = op(x, newshape) - return rval + return typing_cast(TensorVariable, rval) def shape_padleft(t, n_ones=1): diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 325567918a..7f0be47656 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -918,7 +918,7 @@ def _direct_solve_discrete_lyapunov( vec_Q = Q.ravel() vec_X = solve(eye - AxA, vec_Q, b_ndim=1) - return cast(TensorVariable, reshape(vec_X, A.shape)) + return reshape(vec_X, A.shape) def solve_discrete_lyapunov( From 141307f0490db6281941d4c527bd294b5b188137 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Feb 2025 23:39:37 +0100 Subject: [PATCH 19/43] Fix bug in local_useless_reshape --- pytensor/tensor/rewriting/shape.py | 3 ++- tests/tensor/rewriting/test_shape.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index e277772ad4..cc70338559 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -897,7 +897,8 @@ def local_useless_reshape(fgraph, node): if nb_m1 <= 1 and all(shape_match): return [inp] - if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1): + # There is one missing match, but all other dimensions match + if (nb_m1 == 0) and (shape_match.count(False) == 1): return [inp] return False diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index bbfd829070..f3120a5001 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -383,6 +383,13 @@ def test_all_but_one_match(self): new_out = rewrite_graph(out) assert new_out is out + # Or if more than one dimension cannot be matched + x = tensor(shape=(None, None, None)) + shape = [x.shape[0], 3, 3] + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is out + class TestLocalReshapeToDimshuffle: def setup_method(self): From 02545ed54833d74f9363acccf998d768bd2c1673 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 11 Feb 2025 00:21:22 +0100 Subject: [PATCH 20/43] Specify reshape shape length if unknown --- pytensor/tensor/shape.py | 2 ++ tests/tensor/test_shape.py | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 45a22b8714..1c23a21347 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -644,6 +644,8 @@ def make_node(self, x, shp): x = ptb.as_tensor_variable(x) shp_orig = shp shp = ptb.as_tensor_variable(shp, ndim=1) + if shp.type.shape == (None,): + shp = specify_shape(shp, self.ndim) if not ( shp.dtype in int_dtypes or (isinstance(shp, TensorConstant) and shp.data.size == 0) diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 7700d2b14b..3f0b04d45d 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -98,6 +98,7 @@ def setup_method(self): Shape_i, DimShuffle, Elemwise, + SpecifyShape, ) super().setup_method() @@ -253,9 +254,7 @@ def test_bad_shape(self): f(a_val, [7, 5]) with pytest.raises(ValueError): f(a_val, [-1, -1]) - with pytest.raises( - ValueError, match=".*Shape argument to Reshape has incorrect length.*" - ): + with pytest.raises(AssertionError): f(a_val, [3, 4, 1]) def test_0(self): From dbf5f38e6cef7c2ceb654b44cef147fe16eea684 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Feb 2025 14:05:37 +0100 Subject: [PATCH 21/43] Refactor reshape + dimshuffle rewrites --- pytensor/graph/rewriting/basic.py | 10 - pytensor/tensor/rewriting/shape.py | 311 ++++++++++++++--------------- 2 files changed, 150 insertions(+), 171 deletions(-) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 16b5b65a0e..b91e743bb6 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -2800,16 +2800,6 @@ def _check_chain(r, chain): return r is not None -def check_chain(r, *chain): - """ - WRITEME - - """ - if isinstance(r, Apply): - r = r.outputs[0] - return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) - - def pre_greedy_node_rewriter( fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable ) -> Variable: diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index cc70338559..81e1749131 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -12,16 +12,17 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( GraphRewriter, - check_chain, copy_stack_trace, node_rewriter, ) from pytensor.graph.utils import InconsistencyError, get_variable_trace_string +from pytensor.scalar import ScalarType from pytensor.tensor.basic import ( MakeVector, as_tensor_variable, cast, constant, + expand_dims, get_scalar_constant_value, register_infer_shape, stack, @@ -47,6 +48,7 @@ from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes from pytensor.tensor.type_other import NoneConst, NoneTypeT +from pytensor.tensor.variable import TensorVariable class ShapeFeature(Feature): @@ -755,6 +757,42 @@ def apply(self, fgraph): pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) +@register_canonicalize +@node_rewriter([Reshape]) +def local_useless_dimshuffle_in_reshape(fgraph, node): + """ + Removes useless DimShuffle operation inside Reshape: + + reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) + reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) + reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) + reshape(col.dimshuffle(0), shp) => reshape(col, shp) + + """ + dimshuffled_x, new_shape = node.inputs + + if not ( + dimshuffled_x.owner is not None + and isinstance(dimshuffled_x.owner.op, DimShuffle) + ): + return False + + [inp] = dimshuffled_x.owner.inputs + new_order = dimshuffled_x.owner.op.new_order + new_order_of_nonbroadcast = [] + for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): + if s != 1: + new_order_of_nonbroadcast.append(i) + no_change_in_order = all( + new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] + for i in range(len(new_order_of_nonbroadcast) - 1) + ) + if no_change_in_order: + ret = inp.reshape(new_shape) + copy_stack_trace(node.outputs[0], ret) + return [ret] + + @register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Reshape]) @@ -763,30 +801,89 @@ def local_reshape_chain(fgraph, node): Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2) """ - if not check_chain(node, Reshape, Reshape): + inner_reshape, final_shape = node.inputs + + if not (inner_reshape.owner and isinstance(inner_reshape.owner.op, Reshape)): + return None + + x, _ = inner_reshape.owner.inputs + new_reshape = node.op(x, final_shape) + + copy_stack_trace(node.outputs, new_reshape) + return [new_reshape] + + +def _is_shape_i_of_x( + var: TensorVariable, + x: TensorVariable, + i: int, + shape_feature: ShapeFeature | None = None, +) -> bool: + if var.type.ndim != 0: return False - rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) - - # Copy over stacktrace from previous output node, as any error - # in new computational graph would have been caused by last op - # in the old computational graph. - copy_stack_trace(node.outputs, rval) - - # It might happen that the desired output of this node has a - # broadcastable pattern that does not match that of 'rval'. This is - # when originally, we were able to figure out that one of the - # dimensions of the reshape is one, but some other transformation - # replaced the shape by one for which this cannot be guessed. - # We should try to figure out why we lost the information about this - # constant value... but in the meantime, better not apply this - # rewrite. - if rval.type.ndim == node.outputs[0].type.ndim and all( - s1 == s2 - for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape, strict=True) - if s1 == 1 or s2 == 1 - ): - return [rval] + constant_var = get_scalar_constant_value( + var, + only_process_constants=False, + # Don't go through Elemwise to keep things fast + elemwise=False, + raise_not_constant=False, + ) + + # Check var is a constant expression with the same value as x.type.shape[i] + if constant_var == x.type.shape[i]: + return True + + # Match shape_of[x][i] or its constant equivalent + if shape_feature is not None: + i_shape_of_x = shape_feature.get_shape(x, i) + if i_shape_of_x == var or ( + isinstance(i_shape_of_x, Constant) and (i_shape_of_x.data == constant_var) + ): + return True + + if var.owner is None: + # No more constant possibilities + return False + + # Match Shape_i{i}(x) + if isinstance(var.owner.op, Shape_i): + return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore + + # Match Subtensor((ScalarType,))(Shape(input), i) + if isinstance(var.owner.op, Subtensor): + return ( + # Check we have integer indexing operation + # (and not slice or multiple indexing) + len(var.owner.op.idx_list) == 1 + and isinstance(var.owner.op.idx_list[0], ScalarType) + # Check we are indexing on the shape of x + and var.owner.inputs[0].owner is not None + and isinstance(var.owner.inputs[0].owner.op, Shape) + and var.owner.inputs[0].owner.inputs[0] == x + # Check that index == i + and ( + get_scalar_constant_value(var.owner.inputs[1], raise_not_constant=False) + == i + ) + ) + + return False + + +def _unpack_shape_vector(shape: TensorVariable) -> tuple[TensorVariable, ...]: + """Return the elements of a symbolic vector representing a shape. + + Handles the most common constant vector or make_vector cases. + + Returns tuple(shape) as fallback. + """ + if isinstance(shape, Constant): + return tuple(as_tensor_variable(dim, ndim=0) for dim in shape.data) + elif shape.owner and isinstance(shape.owner.op, MakeVector): + return tuple(shape.owner.inputs) + else: + return tuple(shape) @register_useless("shape_unsafe") @@ -821,87 +918,30 @@ def local_useless_reshape(fgraph, node): if shape_input == inp: return [inp] - # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for - # broadcastable and constant dimensions - if isinstance(output_shape, Constant) or ( - output_shape.owner and isinstance(output_shape.owner.op, MakeVector) - ): - if isinstance(output_shape, Constant): - output_shape_is = [ - as_tensor_variable(dim, ndim=0) for dim in output_shape.data - ] - else: - output_shape_is = output_shape.owner.inputs - - shape_feature = getattr(fgraph, "shape_feature", None) - - nb_m1 = 0 - shape_match = [False] * inp.type.ndim - for dim in range(inp.type.ndim): - outshp_i = output_shape_is[dim] - # Match Shape_i{dim}(input) - if ( - outshp_i.owner - and isinstance(outshp_i.owner.op, Shape_i) - and outshp_i.owner.op.i == dim - and outshp_i.owner.inputs[0] == inp - ): - shape_match[dim] = True - continue + shape_feature = getattr(fgraph, "shape_feature", None) - # Match Shape(input)[dim] - if ( - outshp_i.owner - and isinstance(outshp_i.owner.op, Subtensor) - and len(outshp_i.owner.inputs) == 2 - and get_scalar_constant_value( - outshp_i.owner.inputs[1], raise_not_constant=False - ) - == dim - ): - subtensor_inp = outshp_i.owner.inputs[0] - if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): - shape_input_i = subtensor_inp.owner.inputs[0] - if shape_input_i == inp: - shape_match[dim] = True - continue - - # Match constant if input.type.shape[dim] == constant - cst_outshp_i = get_scalar_constant_value( - outshp_i, only_process_constants=True, raise_not_constant=False - ) - if inp.type.shape[dim] == cst_outshp_i: - shape_match[dim] = True - continue - - # Match -1 - if cst_outshp_i == -1: - shape_match[dim] = True - nb_m1 += 1 - continue + # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1 + # or cases where all but one dimension are provably preserved + output_shape_is = _unpack_shape_vector(output_shape) - # Match shape_of[input][dim] or its constant equivalent - if shape_feature: - inpshp_i = shape_feature.get_shape(inp, dim) - if inpshp_i == outshp_i or ( - get_scalar_constant_value( - inpshp_i, only_process_constants=True, raise_not_constant=False - ) - == get_scalar_constant_value( - outshp_i, only_process_constants=True, raise_not_constant=False - ) - ): - shape_match[dim] = True - continue + nb_m1 = 0 + shape_match = [False] * inp.type.ndim + for dim in range(inp.type.ndim): + outshp_i = output_shape_is[dim] + if _is_shape_i_of_x(outshp_i, inp, dim, shape_feature=shape_feature): + shape_match[dim] = True + elif isinstance(outshp_i, Constant) and outshp_i.data == -1: + shape_match[dim] = True + nb_m1 += 1 - if nb_m1 <= 1 and all(shape_match): - return [inp] + if nb_m1 <= 1 and all(shape_match): + return [inp] - # There is one missing match, but all other dimensions match - if (nb_m1 == 0) and (shape_match.count(False) == 1): - return [inp] + # There is one missing match, but all other dimensions match + if (nb_m1 == 0) and (shape_match.count(False) == 1): + return [inp] - return False + return False @register_canonicalize @@ -915,39 +955,26 @@ def local_reshape_to_dimshuffle(fgraph, node): For example: - reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,)) - - reshape(x, (1, m, 1, n, 1, 1)) - -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) + - reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) """ - op = node.op inp, output_shape = node.inputs [output] = node.outputs - dimshuffle_new_order = [] + unpacked_shape = _unpack_shape_vector(output_shape) + expand_axes = [] new_output_shape = [] - index = 0 # index over the output of the new reshape - for i in range(output.ndim): - # Since output_shape is a symbolic vector, we trust get_scalar_constant_value - # to go through however it is formed to see if its i-th element is 1. - # We need only_process_constants=False for that. - dim = get_scalar_constant_value( - output_shape[i], - only_process_constants=False, - elemwise=False, - raise_not_constant=False, - ) - if dim == 1: - dimshuffle_new_order.append("x") + for i, dim in enumerate(unpacked_shape): + if isinstance(dim, Constant) and dim.data == 1: + expand_axes.append(i) else: - dimshuffle_new_order.append(index) new_output_shape.append(dim) - index = index + 1 - if index != output.type.ndim: - inner = op.__class__(len(new_output_shape))(inp, new_output_shape) + if len(new_output_shape) != output.type.ndim: + inner = inp.reshape(new_output_shape) copy_stack_trace(output, inner) - new_node = [inner.dimshuffle(dimshuffle_new_order)] - copy_stack_trace(output, new_node) - return new_node + new_out = expand_dims(inner, expand_axes) + copy_stack_trace(output, new_out) + return [new_out] @register_canonicalize @@ -1187,44 +1214,6 @@ def local_track_shape_i(fgraph, node): return [shape_feature.shape_of[replacement][node.op.i]] -@register_canonicalize -@node_rewriter([Reshape]) -def local_useless_dimshuffle_in_reshape(fgraph, node): - """ - Removes useless DimShuffle operation inside Reshape: - - reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) - reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) - reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) - reshape(col.dimshuffle(0), shp) => reshape(col, shp) - - """ - op = node.op - if not isinstance(op, Reshape): - return False - if not ( - node.inputs[0].owner is not None - and isinstance(node.inputs[0].owner.op, DimShuffle) - ): - return False - - new_order = node.inputs[0].owner.op.new_order - inp = node.inputs[0].owner.inputs[0] - new_order_of_nonbroadcast = [] - for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): - if s != 1: - new_order_of_nonbroadcast.append(i) - no_change_in_order = all( - new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] - for i in range(len(new_order_of_nonbroadcast) - 1) - ) - if no_change_in_order: - shape = node.inputs[1] - ret = op.__class__(node.outputs[0].ndim)(inp, shape) - copy_stack_trace(node.outputs[0], ret) - return [ret] - - @register_useless @register_canonicalize @register_specialize From 65b96c1c32150b86c9237025bcb2d3294ac07703 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Feb 2025 16:05:33 +0100 Subject: [PATCH 22/43] Canonicalize squeeze out of reshape and specialize back --- pytensor/tensor/rewriting/shape.py | 181 +++++++++++++++++------- tests/tensor/rewriting/test_basic.py | 1 - tests/tensor/rewriting/test_elemwise.py | 20 +-- tests/tensor/rewriting/test_shape.py | 56 +++++++- 4 files changed, 200 insertions(+), 58 deletions(-) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 81e1749131..e86411dd9c 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -36,6 +36,7 @@ register_useless, topo_constant_folding, ) +from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import ( Reshape, Shape, @@ -757,40 +758,36 @@ def apply(self, fgraph): pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) +@register_useless @register_canonicalize @node_rewriter([Reshape]) -def local_useless_dimshuffle_in_reshape(fgraph, node): +def local_useless_expand_dims_in_reshape(fgraph, node): """ - Removes useless DimShuffle operation inside Reshape: - - reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) - reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) - reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) - reshape(col.dimshuffle(0), shp) => reshape(col, shp) + Removes useless expand_dims `DimShuffle` operations inside Reshape: + reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp) + reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp) + Implicit (and useless) squeezes are kept in the graph, as they are + part of the canonical form of the graph. """ - dimshuffled_x, new_shape = node.inputs + expanded_x, new_shape = node.inputs if not ( - dimshuffled_x.owner is not None - and isinstance(dimshuffled_x.owner.op, DimShuffle) + expanded_x.owner is not None + and isinstance(expanded_x.owner.op, DimShuffle) + and expanded_x.owner.op.augment ): return False - [inp] = dimshuffled_x.owner.inputs - new_order = dimshuffled_x.owner.op.new_order - new_order_of_nonbroadcast = [] - for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): - if s != 1: - new_order_of_nonbroadcast.append(i) - no_change_in_order = all( - new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] - for i in range(len(new_order_of_nonbroadcast) - 1) - ) - if no_change_in_order: - ret = inp.reshape(new_shape) - copy_stack_trace(node.outputs[0], ret) - return [ret] + [x] = expanded_x.owner.inputs + + new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x") + if new_order != tuple(range(x.type.ndim)): + x = x.dimshuffle(new_order) + + new_reshaped_x = x.reshape(new_shape) + copy_stack_trace(node.outputs[0], new_reshaped_x) + return [new_reshaped_x] @register_canonicalize("shape_unsafe") @@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node): shape_feature = getattr(fgraph, "shape_feature", None) - # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1 - # or cases where all but one dimension are provably preserved + # Match case where at least (n-1) entries correspond to the original shape: + # Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]]) + # Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape. output_shape_is = _unpack_shape_vector(output_shape) - nb_m1 = 0 shape_match = [False] * inp.type.ndim for dim in range(inp.type.ndim): @@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node): nb_m1 += 1 if nb_m1 <= 1 and all(shape_match): - return [inp] + return [inp] # This is provably correct # There is one missing match, but all other dimensions match + # Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y) if (nb_m1 == 0) and (shape_match.count(False) == 1): - return [inp] + return [inp] # This could mask a shape error return False -@register_canonicalize +@register_canonicalize("shape_unsafe") @node_rewriter([Reshape]) def local_reshape_to_dimshuffle(fgraph, node): - r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s. + r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions. - The goal is to avoid using `Reshape` to add or remove broadcastable - dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can - cancel out and/or be removed later on. + It's always valid to squeeze an input before doing the same reshape operation. + Equivalently, it's always valid to remove `1` entries from the reshape shape + and replace them by an expand_dims after the rewritten reshape operation. + + We chose to canonicalize the graph in this way as it allows isolating + operations that are unique to the reshaping operation (mixing dimensions) + from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims). + This can allow further simplifications by other rewrites that target + DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations. For example: - - reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,)) - - reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) + - reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n)) + - reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0) + - reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5)) + """ inp, output_shape = node.inputs [output] = node.outputs - unpacked_shape = _unpack_shape_vector(output_shape) - expand_axes = [] - new_output_shape = [] - for i, dim in enumerate(unpacked_shape): - if isinstance(dim, Constant) and dim.data == 1: - expand_axes.append(i) - else: - new_output_shape.append(dim) + # Remove any broadcastable dimensions from the input + squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast] + + # Trivial case, all dimensions of input/output are known to be broadcastable: + # there's nothing to reshape + if all(inp.type.broadcastable) or all(output.type.broadcastable): + new_output_shape = [] + expand_axes = tuple(range(output.type.ndim)) + + else: + unpacked_shape = _unpack_shape_vector(output_shape) + new_output_shape = [] + expand_axes = [] + for i, dim_length in enumerate(unpacked_shape): + if isinstance(dim_length, Constant) and ( + dim_length.data == 1 + # -1 can be an implicit expand_dims, but it's tricky to prove + # as we would need to check whether all other dimensions + # already explain the full size of the array. + # Example: np.zeros((2, 2, 2)).reshape((8, -1)) + # We rely on the output static shape which will already have figured + # it out for some (but not all) cases + or (dim_length.data == -1 and output.type.shape[i] == 1) + ): + expand_axes.append(i) + else: + new_output_shape.append(dim_length) + + if squeeze_axes or expand_axes: + new_out = inp.squeeze(squeeze_axes) + + if new_output_shape: + new_out = new_out.reshape(new_output_shape) + copy_stack_trace(output, new_out) + + new_out = expand_dims(new_out, expand_axes) + + if not new_output_shape: + # Eagerly merge consecutive squeeze and expand_dims + new_out = apply_local_dimshuffle_lift(fgraph, new_out) - if len(new_output_shape) != output.type.ndim: - inner = inp.reshape(new_output_shape) - copy_stack_trace(output, inner) - new_out = expand_dims(inner, expand_axes) copy_stack_trace(output, new_out) return [new_out] +@register_specialize +@node_rewriter([Reshape]) +def local_fuse_squeeze_reshape(fgraph, node): + r"""If there is a squeeze right before a reshape, merge them. + + This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization. + """ + x, new_shape = node.inputs + + if ( + x.owner is not None + and isinstance(x.owner.op, DimShuffle) + and x.owner.op.is_squeeze + ): + # A reshape can always subsume a squeeze. + x = x.owner.inputs[0] + return [x.reshape(new_shape)] + + +@register_specialize +@node_rewriter([DimShuffle]) +def local_fuse_expand_dims_reshape(fgraph, node): + r"""If there is an expand_dims right after a reshape, merge them. + + This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization. + """ + if not node.op.is_expand_dims: + return None + + reshaped_x = node.inputs[0] + + if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)): + return None + + if len(fgraph.clients[reshaped_x]) > 1: + # The reshape is used elsewhere, don't fuse as it can sometimes require a copy. + # Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]` + return None + + x, new_shape = reshaped_x.owner.inputs + + # Add expand_dims to shape + new_shape = list(_unpack_shape_vector(new_shape)) + for i in node.op.augment: + new_shape.insert(i, 1) + + new_reshaped_x = x.reshape(new_shape) + copy_stack_trace(node.outputs[0], new_reshaped_x) + return [new_reshaped_x] + + @register_canonicalize @register_specialize @node_rewriter([Reshape]) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 8911f56630..ac8576a8a1 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -332,7 +332,6 @@ def test_basic_tile(self): mode = rewrite_mode.including( "local_dimshuffle_lift", - "local_useless_dimshuffle_in_reshape", "local_alloc_sink_dimshuffle", ) f = function([x], [y], mode=mode) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index f1b71949d1..6fb0594ed5 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -56,7 +56,10 @@ from pytensor.tensor.math import round as pt_round from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift -from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape +from pytensor.tensor.rewriting.shape import ( + local_fuse_squeeze_reshape, + local_useless_expand_dims_in_reshape, +) from pytensor.tensor.shape import reshape from pytensor.tensor.type import ( TensorType, @@ -182,7 +185,7 @@ def test_dimshuffle_lift_multi_out_elemwise(self): assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner) -def test_local_useless_dimshuffle_in_reshape(): +def test_local_useless_expand_dims_in_reshape(): vec = TensorType(dtype="float64", shape=(None,))("vector") mat = TensorType(dtype="float64", shape=(None, None))("mat") row = TensorType(dtype="float64", shape=(1, None))("row") @@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape(): clone=False, ) assert len(g.apply_nodes) == 4 * 3 - useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) + useless_dimshuffle_in_reshape = out2in( + local_useless_expand_dims_in_reshape, + # Useless squeeze in reshape is not a canonicalization anymore + local_fuse_squeeze_reshape, + ) useless_dimshuffle_in_reshape.rewrite(g) assert equal_computations( g.outputs, @@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape(): # Check stacktrace was copied over correctly after rewrite was applied assert check_stack_trace(g, ops_to_check="all") - # Check that the rewrite does not get applied when the order - # of dimensions has changed. + # Check that the rewrite does not mess meaningful transpositions before the reshape reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False) assert len(h.apply_nodes) == 3 useless_dimshuffle_in_reshape.rewrite(h) - assert equal_computations( - h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)] - ) + assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)]) class TestFusion: diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index f3120a5001..27678bd630 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -6,7 +6,7 @@ import pytensor.tensor as pt from pytensor import shared from pytensor.compile.function import function -from pytensor.compile.mode import get_default_mode, get_mode +from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import deep_copy_op from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations @@ -426,6 +426,60 @@ def test_basic(self): assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape)) + def test_expand_dims(self): + x = pt.scalar() + # This reshape does an implicit expand_dims + out = x.reshape((1, -1)) + assert isinstance(out.owner.op, Reshape) + new_out = rewrite_graph(out, include=("canonicalize",)) + assert equal_computations([new_out], [pt.expand_dims(x, (0, 1))]) + + def test_squeeze_of_alloc(self): + # This shows up in the graph of repeat + x = pt.vector("x", shape=(9,)) + bcast_x = pt.alloc(x, 1, 12, x.shape[0]) + + # This reshape does an implicit squeeze + out = bcast_x.reshape((12, x.shape[0])) + + new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt")) + assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False) + + +def test_expand_dims_squeeze_reshape_fusion(): + x = pt.tensor("x", shape=(1, 9)) + reshape_x = x.squeeze(0).reshape((3, 3))[..., None] + + assert isinstance(reshape_x.owner.op, DimShuffle) + assert isinstance(reshape_x.owner.inputs[0].owner.op, Reshape) + assert isinstance(reshape_x.owner.inputs[0].owner.inputs[0].owner.op, DimShuffle) + + out = rewrite_graph(reshape_x, include=("specialize",)) + + # In this case we cannot get rid of the reshape, squeeze or expand_dims, + # so we fuse them all in one reshape + assert equal_computations([out], [x.reshape((3, 3, 1))]) + + +def test_implicit_broadcasting_via_repeat(): + x = pt.vector("x", shape=(3,), dtype=int) + y = pt.vector("y", shape=(9,), dtype=int) + out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1) + # There are two Reshapes in the graph + assert isinstance(out.owner.inputs[0].owner.op, Reshape) + assert isinstance(out.owner.inputs[1].owner.op, Reshape) + + new_out = rewrite_graph(out, include=("canonicalize", "specialize")) + assert equal_computations([new_out], [x[None] <= y[:, None]]) + + no_rewrite_mode = Mode(linker="py", optimizer=None) + x_test = np.arange(3) + 1 + y_test = np.arange(9) + np.testing.assert_allclose( + new_out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode), + out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode), + ) + def test_local_reshape_lift(): x = tensor4() From 8e5e8a401aeb1e4a597d9a0b9cbb2bc2372fa20c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 11 Feb 2025 16:10:04 +0100 Subject: [PATCH 23/43] Only do reshapes in `tensordot` when needed --- pytensor/tensor/math.py | 81 +++++++++++++++++++++++---------------- tests/tensor/test_math.py | 39 ++++++++++++++++++- 2 files changed, 86 insertions(+), 34 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 9fa823feb8..4dbf30685d 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2158,13 +2158,11 @@ def tensordot( a = as_tensor_variable(a) b = as_tensor_variable(b) runtime_shape_a = a.shape - bcast_a = a.broadcastable static_shape_a = a.type.shape - ndim_a = a.ndim + ndim_a = a.type.ndim runtime_shape_b = b.shape - bcast_b = b.broadcastable static_shape_b = b.type.shape - ndim_b = b.ndim + ndim_b = b.type.ndim if na != nb: raise ValueError( "The number of axes supplied for tensordot must be equal for each tensor. " @@ -2172,48 +2170,67 @@ def tensordot( ) axes_a = list(normalize_axis_tuple(axes_a, ndim_a)) axes_b = list(normalize_axis_tuple(axes_b, ndim_b)) + + # The operation is only valid if the original dimensions match in length + # The ravelling of the dimensions to coerce the operation into a single dot + # could mask such errors, so we add an Assert if needed. must_assert_runtime = False - for k in range(na): - ax_a = axes_a[k] - ax_b = axes_b[k] - if (bcast_a[ax_a] != bcast_b[ax_b]) or ( + for ax_a, ax_b in zip(axes_a, axes_b, strict=True): + if ( static_shape_a[ax_a] is not None and static_shape_b[ax_b] is not None and static_shape_a[ax_a] != static_shape_b[ax_b] ): raise ValueError( - "Input arrays have inconsistent broadcastable pattern or type shape along the axes " + "Input arrays have inconsistent type shape along the axes " "that are to be reduced with tensordot." ) elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None: if must_assert_runtime: a = Assert( "Input array shape along reduced axes of tensordot are not equal" - )(a, eq(a.shape[ax_a], b.shape[ax_b])) + )(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b])) must_assert_runtime = True - # Move the axes to sum over to the end of "a" - # and to the front of "b" - notin = [k for k in range(ndim_a) if k not in axes_a] - newaxes_a = notin + axes_a - N2 = 1 - for axis in axes_a: - N2 *= runtime_shape_a[axis] - newshape_a = (-1, N2) - olda = [runtime_shape_a[axis] for axis in notin] - - notin = [k for k in range(ndim_b) if k not in axes_b] - newaxes_b = axes_b + notin - N2 = 1 - for axis in axes_b: - N2 *= runtime_shape_b[axis] - newshape_b = (N2, -1) - oldb = [runtime_shape_b[axis] for axis in notin] - - at = a.transpose(newaxes_a).reshape(newshape_a) - bt = b.transpose(newaxes_b).reshape(newshape_b) - res = _dot(at, bt) - return res.reshape(olda + oldb) + # Convert tensordot into a stacked dot product. + # We stack the summed axes and the non-summed axes of each tensor separately, + # and place the summed axes at the end of a and the beginning of b + non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a] + non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a] + transpose_axes_a = non_summed_axes_a + axes_a + # We only need a reshape when we need to combine summed or non-summed dims + # or introduce a new dimension (expand_dims) when doing a non-scalar outer product (len(axes) = 0) + a_needs_reshape = (ndim_a != 0) and ( + (len(non_summed_axes_a) > 1) or (len(axes_a) != 1) + ) + + non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b] + non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b] + transpose_axes_b = axes_b + non_summed_axes_b + b_needs_reshape = (ndim_b != 0) and ( + (len(non_summed_axes_b) > 1) or (len(axes_b) != 1) + ) + + # summed_size_a and summed_size_b must be the same, + # but to facilitate reasoning about useless reshapes we compute both from their shapes + at = a.transpose(transpose_axes_a) + if a_needs_reshape: + non_summed_size_a = variadic_mul(*non_summed_dims_a) + summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a]) + at = at.reshape((non_summed_size_a, summed_size_a)) + + bt = b.transpose(transpose_axes_b) + if b_needs_reshape: + non_summed_size_b = variadic_mul(*non_summed_dims_b) + summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b]) + bt = bt.reshape((summed_size_b, non_summed_size_b)) + + res = dot(at, bt) + + if a_needs_reshape or b_needs_reshape: + res = res.reshape(non_summed_dims_a + non_summed_dims_b) + + return res def outer(x, y): diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 2d19ef0114..40c505b7b4 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -19,7 +19,7 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, grad, numeric_grad -from pytensor.graph.basic import Variable, ancestors, applys_between +from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import vectorize_node from pytensor.link.c.basic import DualLinker @@ -2278,7 +2278,7 @@ def test_type_shape(self): with pytest.raises( ValueError, - match="Input arrays have inconsistent broadcastable pattern or type shape", + match="Input arrays have inconsistent type shape", ): tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1) @@ -2323,6 +2323,41 @@ def test_shape_assert(self, axes, has_assert, values, expected_fail): else: assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv})) + def test_eager_simplification(self): + # Test that cases where tensordot isn't needed, it returns a simple graph + scl = tensor(shape=()) + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + + # scalar product + out = tensordot(scl, scl, axes=[[], []]) + assert equal_computations([out], [scl * scl]) + + # vector-vector product + out = tensordot(vec, vec, axes=[[-1], [-1]]) + assert equal_computations([out], [dot(vec, vec)]) + + # matrix-vector product + out = tensordot(mat, vec, axes=[[-1], [-1]]) + assert equal_computations([out], [dot(mat, vec)]) + + out = tensordot(mat, vec, axes=[[-2], [-1]]) + assert equal_computations([out], [dot(mat.T, vec)]) + + # vector-matrix product + out = tensordot(vec, mat, axes=[[-1], [-2]]) + assert equal_computations([out], [dot(vec, mat)]) + + out = tensordot(vec, mat, axes=[[-1], [-1]]) + assert equal_computations([out], [dot(vec, mat.T)]) + + # matrix-matrix product + out = tensordot(mat, mat, axes=[[-1], [-2]]) + assert equal_computations([out], [dot(mat, mat)]) + + out = tensordot(mat, mat, axes=[[-1], [-1]]) + assert equal_computations([out], [dot(mat, mat.T)]) + def test_smallest(): x = dvector() From bbe663d93b0145befa29ffcad4c94e2cf52ae92e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 11 Feb 2025 23:48:14 +0800 Subject: [PATCH 24/43] Implement numba dispatch for all `linalg.solve` modes --- pytensor/link/numba/dispatch/_LAPACK.py | 392 ++++++++ pytensor/link/numba/dispatch/basic.py | 2 +- pytensor/link/numba/dispatch/slinalg.py | 1148 +++++++++++++++++++---- pytensor/tensor/slinalg.py | 43 +- tests/link/numba/test_nlinalg.py | 47 +- tests/link/numba/test_slinalg.py | 371 +++++++- tests/tensor/test_slinalg.py | 110 ++- 7 files changed, 1756 insertions(+), 357 deletions(-) create mode 100644 pytensor/link/numba/dispatch/_LAPACK.py diff --git a/pytensor/link/numba/dispatch/_LAPACK.py b/pytensor/link/numba/dispatch/_LAPACK.py new file mode 100644 index 0000000000..ab5561650c --- /dev/null +++ b/pytensor/link/numba/dispatch/_LAPACK.py @@ -0,0 +1,392 @@ +import ctypes + +import numpy as np +from numba.core import cgutils, types +from numba.core.extending import get_cython_function_address, intrinsic +from numba.np.linalg import ensure_lapack, get_blas_kind + + +_PTR = ctypes.POINTER + +_dbl = ctypes.c_double +_float = ctypes.c_float +_char = ctypes.c_char +_int = ctypes.c_int + +_ptr_float = _PTR(_float) +_ptr_dbl = _PTR(_dbl) +_ptr_char = _PTR(_char) +_ptr_int = _PTR(_int) + + +def _get_lapack_ptr_and_ptr_type(dtype, name): + d = get_blas_kind(dtype) + func_name = f"{d}{name}" + float_pointer = _get_float_pointer_for_dtype(d) + lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) + + return lapack_ptr, float_pointer + + +def _get_underlying_float(dtype): + s_dtype = str(dtype) + out_type = s_dtype + if s_dtype == "complex64": + out_type = "float32" + elif s_dtype == "complex128": + out_type = "float64" + + return np.dtype(out_type) + + +def _get_float_pointer_for_dtype(blas_dtype): + if blas_dtype in ["s", "c"]: + return _ptr_float + elif blas_dtype in ["d", "z"]: + return _ptr_dbl + + +def _get_output_ctype(dtype): + s_dtype = str(dtype) + if s_dtype in ["float32", "complex64"]: + return _float + elif s_dtype in ["float64", "complex128"]: + return _dbl + + +@intrinsic +def sptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float32(types.CPointer(types.float32)) + return sig, impl + + +@intrinsic +def dptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float64(types.CPointer(types.float64)) + return sig, impl + + +@intrinsic +def int_ptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.int32(types.CPointer(types.int32)) + return sig, impl + + +@intrinsic +def val_to_int_ptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.int32)(types.int32) + return sig, impl + + +@intrinsic +def val_to_sptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float32)(types.float32) + return sig, impl + + +@intrinsic +def val_to_zptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.complex128)(types.complex128) + return sig, impl + + +@intrinsic +def val_to_dptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float64)(types.float64) + return sig, impl + + +class _LAPACK: + """ + Functions to return type signatures for wrapped LAPACK functions. + + Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 + """ + + def __init__(self): + ensure_lapack() + + @classmethod + def numba_xtrtrs(cls, dtype): + """ + Solve a triangular system of equations of the form A @ X = B or A.T @ X = B. + + Called by scipy.linalg.solve_triangular + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # TRANS + _ptr_int, # DIAG + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + + return functype(lapack_ptr) + + @classmethod + def numba_xpotrf(cls, dtype): + """ + Compute the Cholesky factorization of a real symmetric positive definite matrix. + + Called by scipy.linalg.cholesky + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO, + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xpotrs(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky + factorization computed by numba_potrf. + + Called by scipy.linalg.cho_solve + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xlange(cls, dtype): + """ + Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of + a general M-by-N matrix A. + + Called by scipy.linalg.solve + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange") + output_ctype = _get_output_ctype(dtype) + functype = ctypes.CFUNCTYPE( + output_ctype, # Output + _ptr_int, # NORM + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # WORK + ) + return functype(lapack_ptr) + + @classmethod + def numba_xlamch(cls, dtype): + """ + Determine machine precision for floating point arithmetic. + """ + + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch") + output_dtype = _get_output_ctype(dtype) + functype = ctypes.CFUNCTYPE( + output_dtype, # Output + _ptr_int, # CMACH + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgecon(cls, dtype): + """ + Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf. + + Called by scipy.linalg.solve when assume_a == "gen" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # NORM + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgetrf(cls, dtype): + """ + Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges. + + Called by scipy.linalg.lu_factor + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgetrs(cls, dtype): + """ + Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU + factorization computed by GETRF. + + Called by scipy.linalg.lu_solve + """ + ... + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # TRANS + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xsysv(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method, + factorizing A into LDL^T or UDU^T form, depending on the value of UPLO + + Called by scipy.linalg.solve when assume_a == "sym" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # B + _ptr_int, # LDB + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xsycon(cls, dtype): + """ + Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization + computed by xSYTRF. + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # IPIV + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xpocon(cls, dtype): + """ + Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization + computed by potrf. + + Called by scipy.linalg.solve when assume_a == "pos" + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # ANORM + float_pointer, # RCOND + float_pointer, # WORK + _ptr_int, # IWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xposv(cls, dtype): + """ + Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky + factorization computed by potrf. + """ + + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + return functype(lapack_ptr) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 0b2b58904a..c66a237f06 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs): def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): - """Create a Numba compatible function from an Aesara `Op`.""" + """Create a Numba compatible function from a Pytensor `Op`.""" warnings.warn( f"Numba will use object mode to run {op}'s perform method", diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 96a8da282e..a3f5ea9491 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -1,136 +1,52 @@ -import ctypes +from collections.abc import Callable import numba import numpy as np -from numba.core import cgutils, types -from numba.extending import get_cython_function_address, intrinsic, overload -from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind +from numba.core import types +from numba.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from numpy.linalg import LinAlgError from scipy import linalg from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular +from pytensor.tensor.slinalg import ( + BlockDiagonal, + Cholesky, + CholeskySolve, + Solve, + SolveTriangular, +) -_PTR = ctypes.POINTER - -_dbl = ctypes.c_double -_float = ctypes.c_float -_char = ctypes.c_char -_int = ctypes.c_int - -_ptr_float = _PTR(_float) -_ptr_dbl = _PTR(_dbl) -_ptr_char = _PTR(_char) -_ptr_int = _PTR(_int) - - -@numba.core.extending.register_jitable -def _check_finite_matrix(a, func_name): - for v in np.nditer(a): - if not np.isfinite(v.item()): - raise np.linalg.LinAlgError( - "Non-numeric values (nan or inf) in input to " + func_name +@numba_basic.numba_njit(inline="always") +def _solve_check(n, info, lamch=False, rcond=None): + """ + Check arguments during the different steps of the solution phase + Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38 + """ + if info < 0: + # TODO: figure out how to do an fstring here + msg = "LAPACK reported an illegal value in input" + raise ValueError(msg) + elif 0 < info: + raise LinAlgError("Matrix is singular.") + + if lamch: + E = _xlamch("E") + if rcond < E: + # TODO: This should be a warning, but we can't raise warnings in numba mode + print( # noqa: T201 + "Ill-conditioned matrix, rcond=", rcond, ", result may not be accurate." ) -@intrinsic -def val_to_dptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.float64)(types.float64) - return sig, impl - - -@intrinsic -def val_to_zptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.complex128)(types.complex128) - return sig, impl - - -@intrinsic -def val_to_sptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.float32)(types.float32) - return sig, impl - - -@intrinsic -def val_to_int_ptr(typingctx, data): - def impl(context, builder, signature, args): - ptr = cgutils.alloca_once_value(builder, args[0]) - return ptr - - sig = types.CPointer(types.int32)(types.int32) - return sig, impl - - -@intrinsic -def int_ptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.int32(types.CPointer(types.int32)) - return sig, impl - - -@intrinsic -def dptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.float64(types.CPointer(types.float64)) - return sig, impl - - -@intrinsic -def sptr_to_val(typingctx, data): - def impl(context, builder, signature, args): - val = builder.load(args[0]) - return val - - sig = types.float32(types.CPointer(types.float32)) - return sig, impl - - -def _get_float_pointer_for_dtype(blas_dtype): - if blas_dtype in ["s", "c"]: - return _ptr_float - elif blas_dtype in ["d", "z"]: - return _ptr_dbl - - -def _get_underlying_float(dtype): - s_dtype = str(dtype) - out_type = s_dtype - if s_dtype == "complex64": - out_type = "float32" - elif s_dtype == "complex128": - out_type = "float64" - - return np.dtype(out_type) - - -def _get_lapack_ptr_and_ptr_type(dtype, name): - d = get_blas_kind(dtype) - func_name = f"{d}{name}" - float_pointer = _get_float_pointer_for_dtype(d) - lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) - - return lapack_ptr, float_pointer - - def _check_scipy_linalg_matrix(a, func_name): """ Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 @@ -152,64 +68,50 @@ def _check_scipy_linalg_matrix(a, func_name): raise numba.TypingError(msg, highlighting=False) -class _LAPACK: +def _solve_triangular( + A, B, trans=0, lower=False, unit_diagonal=False, b_ndim=1, overwrite_b=False +): """ - Functions to return type signatures for wrapped LAPACK functions. + Thin wrapper around scipy.linalg.solve_triangular. - Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 - """ - - def __init__(self): - ensure_lapack() + This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who + import pytensor. - @classmethod - def numba_xtrtrs(cls, dtype): - """ - Called by scipy.linalg.solve_triangular - """ - lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") + The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not + used by scipy.linalg.solve_triangular. + """ + return linalg.solve_triangular( + A, + B, + trans=trans, + lower=lower, + unit_diagonal=unit_diagonal, + overwrite_b=overwrite_b, + ) - functype = ctypes.CFUNCTYPE( - None, - _ptr_int, # UPLO - _ptr_int, # TRANS - _ptr_int, # DIAG - _ptr_int, # N - _ptr_int, # NRHS - float_pointer, # A - _ptr_int, # LDA - float_pointer, # B - _ptr_int, # LDB - _ptr_int, # INFO - ) - return functype(lapack_ptr) - - @classmethod - def numba_xpotrf(cls, dtype): - """ - Called by scipy.linalg.cholesky - """ - lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") - functype = ctypes.CFUNCTYPE( - None, - _ptr_int, # UPLO, - _ptr_int, # N - float_pointer, # A - _ptr_int, # LDA - _ptr_int, # INFO - ) - return functype(lapack_ptr) +@numba_basic.numba_njit(inline="always") +def _trans_char_to_int(trans): + if trans not in [0, 1, 2]: + raise ValueError('Parameter "trans" should be one of 0, 1, 2') + if trans == 0: + return ord("N") + elif trans == 1: + return ord("T") + else: + return ord("C") -def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): - return linalg.solve_triangular( - A, B, trans=trans, lower=lower, unit_diagonal=unit_diagonal - ) +@numba_basic.numba_njit(inline="always") +def _solve_check_input_shapes(A, B): + if A.shape[0] != B.shape[0]: + raise linalg.LinAlgError("Dimensions of A and B do not conform") + if A.shape[-2] != A.shape[-1]: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") @overload(_solve_triangular) -def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): +def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): ensure_lapack() _check_scipy_linalg_matrix(A, "solve_triangular") @@ -218,37 +120,27 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): w_type = _get_underlying_float(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype) - def impl(A, B, trans=0, lower=False, unit_diagonal=False): - B_is_1d = B.ndim == 1 - + def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): _N = np.int32(A.shape[-1]) - if A.shape[-2] != _N: - raise linalg.LinAlgError("Last 2 dimensions of A must be square") + _solve_check_input_shapes(A, B) - if A.shape[0] != B.shape[0]: - raise linalg.LinAlgError("Dimensions of A and B do not conform") + B_is_1d = B.ndim == 1 - if B_is_1d: - B_copy = np.asfortranarray(np.expand_dims(B, -1)) - else: + if not overwrite_b: B_copy = _copy_to_fortran_order(B) - - if trans not in [0, 1, 2]: - raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') - if trans == 0: - transval = ord("N") - elif trans == 1: - transval = ord("T") else: - transval = ord("C") + B_copy = B - B_NDIM = 1 if B_is_1d else int(B.shape[1]) + if B_is_1d: + B_copy = np.expand_dims(B, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) - TRANS = val_to_int_ptr(transval) + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(B_NDIM) + NRHS = val_to_int_ptr(NRHS) LDA = val_to_int_ptr(_N) LDB = val_to_int_ptr(_N) INFO = val_to_int_ptr(0) @@ -266,19 +158,24 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): INFO, ) + _solve_check(int_ptr_to_val(LDA), int_ptr_to_val(INFO)) + if B_is_1d: - return B_copy[..., 0], int_ptr_to_val(INFO) - return B_copy, int_ptr_to_val(INFO) + return B_copy[..., 0] + + return B_copy return impl @numba_funcify.register(SolveTriangular) def numba_funcify_SolveTriangular(op, node, **kwargs): - trans = op.trans + trans = bool(op.trans) lower = op.lower unit_diagonal = op.unit_diagonal check_finite = op.check_finite + overwrite_b = op.overwrite_b + b_ndim = op.b_ndim dtype = node.inputs[0].dtype if str(dtype).startswith("complex"): @@ -298,11 +195,16 @@ def solve_triangular(a, b): "Non-numeric values (nan or inf) in input b to solve_triangular" ) - res, info = _solve_triangular(a, b, trans, lower, unit_diagonal) - if info != 0: - raise np.linalg.LinAlgError( - "Singular matrix in input A to solve_triangular" - ) + res = _solve_triangular( + a, + b, + trans=trans, + lower=lower, + unit_diagonal=unit_diagonal, + overwrite_b=overwrite_b, + b_ndim=b_ndim, + ) + return res return solve_triangular @@ -429,3 +331,853 @@ def block_diag(*arrs): return out return block_diag + + +def _xlamch(kind: str = "E"): + """ + Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs. + """ + pass + + +@overload(_xlamch) +def xlamch_impl(kind: str = "E") -> Callable[[str], float]: + """ + Compute the machine precision for a given floating point type. + """ + from pytensor import config + + ensure_lapack() + w_type = _get_underlying_float(config.floatX) + + if w_type == "float32": + dtype = types.float32 + elif w_type == "float64": + dtype = types.float64 + else: + raise NotImplementedError("Unsupported dtype") + + numba_lamch = _LAPACK().numba_xlamch(dtype) + + def impl(kind: str = "E") -> float: + KIND = val_to_int_ptr(ord(kind)) + return numba_lamch(KIND) # type: ignore + + return impl + + +def _xlange(A: np.ndarray, order: str | None = None) -> float: + """ + Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode. + """ + return # type: ignore + + +@overload(_xlange) +def xlange_impl( + A: np.ndarray, order: str | None = None +) -> Callable[[np.ndarray, str], float]: + """ + xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of + largest absolute value of a matrix A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(A, "norm") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_lange = _LAPACK().numba_xlange(dtype) + + def impl(A: np.ndarray, order: str | None = None): + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + A_copy = _copy_to_fortran_order(A) + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + + NORM = ( + val_to_int_ptr(ord(order)) + if order is not None + else val_to_int_ptr(ord("1")) + ) + WORK = np.empty(_M, dtype=dtype) # type: ignore + + result = numba_lange( + NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes + ) + + return result + + return impl + + +def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify + graphs. + """ + return # type: ignore + + +@overload(_xgecon) +def xgecon_impl( + A: np.ndarray, A_norm: float, norm: str +) -> Callable[[np.ndarray, float, str], tuple[np.ndarray, int]]: + """ + Compute the condition number of a matrix A. + """ + ensure_lapack() + _check_scipy_linalg_matrix(A, "gecon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_gecon = _LAPACK().numba_xgecon(dtype) + + def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + A_NORM = np.array(A_norm, dtype=dtype) + NORM = val_to_int_ptr(ord(norm)) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(4 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(1) + + numba_gecon( + NORM, + N, + A_copy.view(w_type).ctypes, + LDA, + A_NORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: + """ + Placeholder for LU factorization; used by linalg.solve. + + # TODO: Implement an LU_factor Op, then dispatch to this function in numba mode. + """ + return # type: ignore + + +@overload(_getrf) +def getrf_impl( + A: np.ndarray, overwrite_a: bool = False +) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "getrf") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_getrf = _LAPACK().numba_xgetrf(dtype) + + def impl( + A: np.ndarray, overwrite_a: bool = False + ) -> tuple[np.ndarray, np.ndarray, int]: + _M, _N = np.int32(A.shape[-2:]) # type: ignore + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + M = val_to_int_ptr(_M) # type: ignore + N = val_to_int_ptr(_N) # type: ignore + LDA = val_to_int_ptr(_M) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + INFO = val_to_int_ptr(0) + + numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO) + + return A_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _getrs( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> tuple[np.ndarray, int]: + """ + Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve. + + # TODO: Implement an LU_solve Op, then dispatch to this function in numba mode. + """ + return # type: ignore + + +@overload(_getrs) +def getrs_impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(LU, "getrs") + _check_scipy_linalg_matrix(B, "getrs") + dtype = LU.dtype + w_type = _get_underlying_float(dtype) + numba_getrs = _LAPACK().numba_xgetrs(dtype) + + def impl( + LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool + ) -> tuple[np.ndarray, int]: + _N = np.int32(LU.shape[-1]) + _solve_check_input_shapes(LU, B) + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + NRHS = 1 if B_is_1d else int(B_copy.shape[-1]) + + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + IPIV = _copy_to_fortran_order(IPIV) + INFO = val_to_int_ptr(0) + + numba_getrs( + TRANS, + N, + NRHS, + LU.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + + return B_copy, int_ptr_to_val(INFO) + + return impl + + +def _solve_gen( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects + for users who import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + assume_a="gen", + transposed=transposed, + ) + + +@overload(_solve_gen) +def solve_gen_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _N = np.int32(A.shape[-1]) + _solve_check_input_shapes(A, B) + + order = "I" if transposed else "1" + norm = _xlange(A, order=order) + + N = A.shape[1] + LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a) + _solve_check(N, INFO) + + X, INFO = _getrs( + LU=LU, B=B, IPIV=IPIV, trans=transposed, overwrite_b=overwrite_b + ) + _solve_check(N, INFO) + + RCOND, INFO = _xgecon(LU, norm, "1") + _solve_check(N, INFO, True, RCOND) + + return X + + return impl + + +def _sysv( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool +) -> tuple[np.ndarray, np.ndarray, int]: + """ + Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve. + """ + return # type: ignore + + +@overload(_sysv) +def sysv_impl( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool +) -> Callable[ + [np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int] +]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "sysv") + _check_scipy_linalg_matrix(B, "sysv") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_sysv = _LAPACK().numba_xsysv(dtype) + + def impl( + A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool + ): + _LDA, _N = np.int32(A.shape[-2:]) # type: ignore + _solve_check_input_shapes(A, B) + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + if B_is_1d: + B_copy = np.asfortranarray(np.expand_dims(B_copy, -1)) + + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) # type: ignore + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_LDA) # type: ignore + IPIV = np.empty(_N, dtype=np.int32) # type: ignore + LDB = val_to_int_ptr(_N) # type: ignore + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + INFO = val_to_int_ptr(0) + + # Workspace query + numba_sysv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + WS_SIZE = np.int32(WORK[0].real) + LWORK = val_to_int_ptr(WS_SIZE) + WORK = np.empty(WS_SIZE, dtype=dtype) + + # Actual solve + numba_sysv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + IPIV.ctypes, + B_copy.view(w_type).ctypes, + LDB, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], IPIV, int_ptr_to_val(INFO) + return B_copy, IPIV, int_ptr_to_val(INFO) + + return impl + + +def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in + python mode. + """ + return # type: ignore + + +@overload(_sycon) +def sycon_impl( + A: np.ndarray, ipiv: np.ndarray, anorm: float +) -> Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "sycon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_sycon = _LAPACK().numba_xsycon(dtype) + + def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + UPLO = val_to_int_ptr(ord("L")) + ANORM = np.array(anorm, dtype=dtype) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(2 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(0) + + numba_sycon( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + ipiv.ctypes, + ANORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _solve_symmetric( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid + unexpected side-effects when users import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + assume_a="sym", + transposed=transposed, + ) + + +@overload(_solve_symmetric) +def solve_symmetric_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _solve_check_input_shapes(A, B) + + x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) + _solve_check(A.shape[-1], info) + + rcond, info = _sycon(A, ipiv, _xlange(A, order="I")) + _solve_check(A.shape[-1], info, True, rcond) + + return x + + return impl + + +def _posv( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> tuple[np.ndarray, int]: + """ + Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. + """ + return # type: ignore + + +@overload(_posv) +def posv_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[ + [np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int] +]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_posv = _LAPACK().numba_xposv(dtype) + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> tuple[np.ndarray, int]: + _solve_check_input_shapes(A, B) + + _N = np.int32(A.shape[-1]) + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + B_is_1d = B.ndim == 1 + + if not overwrite_b: + B_copy = _copy_to_fortran_order(B) + else: + B_copy = B + + if B_is_1d: + B_copy = np.expand_dims(B_copy, -1) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_posv( + UPLO, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + return B_copy, int_ptr_to_val(INFO) + + return impl + + +def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]: + """ + Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by + linalg.solve when assume_a = "pos". + """ + return # type: ignore + + +@overload(_pocon) +def pocon_impl( + A: np.ndarray, anorm: float +) -> Callable[[np.ndarray, float], tuple[np.ndarray, int]]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "pocon") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_pocon = _LAPACK().numba_xpocon(dtype) + + def impl(A: np.ndarray, anorm: float): + _N = np.int32(A.shape[-1]) + A_copy = _copy_to_fortran_order(A) + + UPLO = val_to_int_ptr(ord("L")) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + ANORM = np.array(anorm, dtype=dtype) + RCOND = np.empty(1, dtype=dtype) + WORK = np.empty(3 * _N, dtype=dtype) + IWORK = np.empty(_N, dtype=np.int32) + INFO = val_to_int_ptr(0) + + numba_pocon( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + ANORM.view(w_type).ctypes, + RCOND.view(w_type).ctypes, + WORK.view(w_type).ctypes, + IWORK.ctypes, + INFO, + ) + + return RCOND, int_ptr_to_val(INFO) + + return impl + + +def _solve_psd( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +): + """Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to + avoid unexpected side-effects when users import pytensor.""" + return linalg.solve( + A, + B, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + check_finite=check_finite, + transposed=transposed, + assume_a="pos", + ) + + +@overload(_solve_psd) +def solve_psd_impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, +) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]: + ensure_lapack() + _check_scipy_linalg_matrix(A, "solve") + _check_scipy_linalg_matrix(B, "solve") + + def impl( + A: np.ndarray, + B: np.ndarray, + lower: bool, + overwrite_a: bool, + overwrite_b: bool, + check_finite: bool, + transposed: bool, + ) -> np.ndarray: + _solve_check_input_shapes(A, B) + + x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed) + _solve_check(A.shape[-1], info) + + rcond, info = _pocon(x, _xlange(A)) + _solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) + + return x + + return impl + + +@numba_funcify.register(Solve) +def numba_funcify_Solve(op, node, **kwargs): + assume_a = op.assume_a + lower = op.lower + check_finite = op.check_finite + overwrite_a = op.overwrite_a + overwrite_b = op.overwrite_b + transposed = False # TODO: Solve doesnt currently allow the transposed argument + + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by solve in Numba mode" + ) + + if assume_a == "gen": + solve_fn = _solve_gen + elif assume_a == "sym": + solve_fn = _solve_symmetric + elif assume_a == "her": + raise NotImplementedError( + 'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, ' + "please open an issue on github." + ) + elif assume_a == "pos": + solve_fn = _solve_psd + else: + raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode") + + @numba_basic.numba_njit(inline="always") + def solve(a, b): + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to solve" + ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input b to solve" + ) + + res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed) + return res + + return solve + + +def _cho_solve(A_and_lower, B, overwrite_a=False, overwrite_b=False, check_finite=True): + """ + Solve a positive-definite linear system using the Cholesky decomposition. + """ + A, lower = A_and_lower + return linalg.cho_solve((A, lower), B) + + +@overload(_cho_solve) +def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): + ensure_lapack() + _check_scipy_linalg_matrix(C, "cho_solve") + _check_scipy_linalg_matrix(B, "cho_solve") + dtype = C.dtype + w_type = _get_underlying_float(dtype) + numba_potrs = _LAPACK().numba_xpotrs(dtype) + + def impl(C, B, lower=False, overwrite_b=False, check_finite=True): + _solve_check_input_shapes(C, B) + + _N = np.int32(C.shape[-1]) + C_copy = _copy_to_fortran_order(C) + + B_is_1d = B.ndim == 1 + if B_is_1d: + B_copy = np.asfortranarray(np.expand_dims(B, -1)) + else: + B_copy = _copy_to_fortran_order(B) + + NRHS = 1 if B_is_1d else int(B.shape[-1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(NRHS) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_potrs( + UPLO, + N, + NRHS, + C_copy.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0], int_ptr_to_val(INFO) + return B_copy, int_ptr_to_val(INFO) + + return impl + + +@numba_funcify.register(CholeskySolve) +def numba_funcify_CholeskySolve(op, node, **kwargs): + lower = op.lower + overwrite_b = op.overwrite_b + check_finite = op.check_finite + + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by cho_solve in Numba mode" + ) + + @numba_basic.numba_njit(inline="always") + def cho_solve(c, b): + if check_finite: + if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to cho_solve" + ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input b to cho_solve" + ) + + res, info = _cho_solve( + c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite + ) + + if info < 0: + raise np.linalg.LinAlgError("Illegal values found in input to cho_solve") + elif info > 0: + raise np.linalg.LinAlgError( + "Matrix is not positive definite in input to cho_solve" + ) + return res + + return cho_solve diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 7f0be47656..f101315172 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,11 +1,11 @@ import logging -import typing import warnings +from collections.abc import Sequence from functools import reduce from typing import Literal, cast import numpy as np -import scipy.linalg +import scipy.linalg as scipy_linalg import pytensor import pytensor.tensor as pt @@ -58,7 +58,7 @@ def make_node(self, x): f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" ) # Call scipy to find output dtype - dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype + dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) def perform(self, node, inputs, outputs): @@ -68,21 +68,21 @@ def perform(self, node, inputs, outputs): # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS # If we have a `C_CONTIGUOUS` array we transpose to benefit from it if self.overwrite_a and x.flags["C_CONTIGUOUS"]: - out[0] = scipy.linalg.cholesky( + out[0] = scipy_linalg.cholesky( x.T, lower=not self.lower, check_finite=self.check_finite, overwrite_a=True, ).T else: - out[0] = scipy.linalg.cholesky( + out[0] = scipy_linalg.cholesky( x, lower=self.lower, check_finite=self.check_finite, overwrite_a=self.overwrite_a, ) - except scipy.linalg.LinAlgError: + except scipy_linalg.LinAlgError: if self.on_error == "raise": raise else: @@ -334,7 +334,7 @@ def __init__(self, **kwargs): def perform(self, node, inputs, output_storage): C, b = inputs - rval = scipy.linalg.cho_solve( + rval = scipy_linalg.cho_solve( (C, self.lower), b, check_finite=self.check_finite, @@ -369,7 +369,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. - b_ndim : int + b_ndim : int Whether the core case of b is a vector (1) or matrix (2). This will influence how batched dimensions are interpreted. """ @@ -401,7 +401,7 @@ def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): def perform(self, node, inputs, outputs): A, b = inputs - outputs[0][0] = scipy.linalg.solve_triangular( + outputs[0][0] = scipy_linalg.solve_triangular( A, b, lower=self.lower, @@ -502,7 +502,7 @@ def __init__(self, *, assume_a="gen", **kwargs): def perform(self, node, inputs, outputs): a, b = inputs - outputs[0][0] = scipy.linalg.solve( + outputs[0][0] = scipy_linalg.solve( a=a, b=b, lower=self.lower, @@ -619,9 +619,9 @@ def make_node(self, a, b): def perform(self, node, inputs, outputs): (w,) = outputs if len(inputs) == 2: - w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) + w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower) else: - w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) + w[0] = scipy_linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower) def grad(self, inputs, g_outputs): a, b = inputs @@ -675,7 +675,7 @@ def make_node(self, a, b, gw): def perform(self, node, inputs, outputs): (a, b, gw) = inputs - w, v = scipy.linalg.eigh(a, b, lower=self.lower) + w, v = scipy_linalg.eigh(a, b, lower=self.lower) gA = v.dot(np.diag(gw).dot(v.T)) gB = -v.dot(np.diag(gw * w).dot(v.T)) @@ -718,7 +718,7 @@ def make_node(self, A): def perform(self, node, inputs, outputs): (A,) = inputs (expm,) = outputs - expm[0] = scipy.linalg.expm(A) + expm[0] = scipy_linalg.expm(A) def grad(self, inputs, outputs): (A,) = inputs @@ -758,8 +758,8 @@ def perform(self, node, inputs, outputs): # this expression. (A, gA) = inputs (out,) = outputs - w, V = scipy.linalg.eig(A, right=True) - U = scipy.linalg.inv(V).T + w, V = scipy_linalg.eig(A, right=True) + U = scipy_linalg.inv(V).T exp_w = np.exp(w) X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w) @@ -800,7 +800,7 @@ def perform(self, node, inputs, output_storage): X = output_storage[0] out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) + X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -870,7 +870,7 @@ def perform(self, node, inputs, output_storage): X = output_storage[0] out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( + X[0] = scipy_linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( out_dtype ) @@ -992,7 +992,7 @@ def perform(self, node, inputs, output_storage): Q = 0.5 * (Q + Q.T) out_dtype = node.outputs[0].type.dtype - X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) + X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -1064,7 +1064,7 @@ def solve_discrete_are( ) -def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: +def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype: return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) @@ -1118,7 +1118,7 @@ def make_node(self, *matrices): def perform(self, node, inputs, output_storage, params=None): dtype = node.outputs[0].type.dtype - output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) + output_storage[0][0] = scipy_linalg.block_diag(*inputs).astype(dtype) def block_diag(*matrices: TensorVariable): @@ -1175,4 +1175,5 @@ def block_diag(*matrices: TensorVariable): "solve_discrete_are", "solve_triangular", "block_diag", + "cho_solve", ] diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 6fbb6e6c58..3dc427cd9c 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -7,58 +7,13 @@ from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.tensor import nlinalg, slinalg +from pytensor.tensor import nlinalg from tests.link.numba.test_basic import compare_numba_and_py, set_test_value rng = np.random.default_rng(42849) -@pytest.mark.parametrize( - "A, x, lower, exc", - [ - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), - "gen", - None, - ), - ( - set_test_value( - pt.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), - "gen", - None, - ), - ], -) -def test_Solve(A, x, lower, exc): - g = slinalg.Solve(lower=lower, b_ndim=1)(A, x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - ) - - @pytest.mark.parametrize( "x, exc", [ diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 8b1f3ececb..8e49627361 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,19 +1,23 @@ import re +from functools import partial +from typing import Literal import numpy as np import pytest +from numpy.testing import assert_allclose +from scipy import linalg as scipy_linalg import pytensor import pytensor.tensor as pt -from pytensor import config from pytensor.graph import FunctionGraph +from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py numba = pytest.importorskip("numba") -ATOL = 0 if config.floatX.endswith("64") else 1e-6 -RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 +floatX = pytensor.config.floatX + rng = np.random.default_rng(42849) @@ -27,8 +31,8 @@ def transpose_func(x, trans): @pytest.mark.parametrize( - "b_func, b_size", - [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + "b_shape", + [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"], ) @pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) @@ -36,50 +40,88 @@ def transpose_func(x, trans): @pytest.mark.parametrize( "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] ) -@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"]) +@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) @pytest.mark.filterwarnings( 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' ) -def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex): - if complex: +def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_complex): + if is_complex: # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, # why? pytest.skip("Complex inputs currently not supported to solve_triangular") - complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128" - dtype = complex_dtype if complex else config.floatX + complex_dtype = "complex64" if floatX.endswith("32") else "complex128" + dtype = complex_dtype if is_complex else floatX A = pt.matrix("A", dtype=dtype) - b = b_func("b", dtype=dtype) + b = pt.tensor("b", shape=b_shape, dtype=dtype) + + def A_func(x): + x = x @ x.conj().T + x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype) - X = pt.linalg.solve_triangular( - A, b, lower=lower, trans=trans, unit_diagonal=unit_diag + if unit_diag: + x_tri[np.diag_indices_from(x_tri)] = 1.0 + + return x_tri.astype(dtype) + + solve_op = partial( + pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag ) + + X = solve_op(A, b) f = pytensor.function([A, b], X, mode="NUMBA") A_val = np.random.normal(size=(5, 5)) - b = np.random.normal(size=b_size) + b_val = np.random.normal(size=b_shape) - if complex: + if is_complex: A_val = A_val + np.random.normal(size=(5, 5)) * 1j - b = b + np.random.normal(size=b_size) * 1j - A_sym = A_val @ A_val.conj().T + b_val = b_val + np.random.normal(size=b_shape) * 1j - A_tri = np.linalg.cholesky(A_sym).astype(dtype) - if unit_diag: - adj_mat = np.ones((5, 5)) - adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri) - A_tri = A_tri * adj_mat + X_np = f(A_func(A_val.copy()), b_val.copy()) - A_tri = A_tri.astype(dtype) - b = b.astype(dtype) + test_input = transpose_func(A_func(A_val.copy()), trans) - if not lower: - A_tri = A_tri.T + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 - X_np = f(A_tri, b) - np.testing.assert_allclose( - transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL + np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) + + compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()]) + + +@pytest.mark.parametrize( + "lower, unit_diag, trans", + [(True, True, True), (False, False, False)], + ids=["lower_unit_trans", "defaults"], +) +def test_solve_triangular_grad(lower, unit_diag, trans): + A_val = np.random.normal(size=(5, 5)).astype(floatX) + b_val = np.random.normal(size=(5, 5)).astype(floatX) + + # utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When + # a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be + # wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices + # to the space of triangular matrices, and test the gradient of that entire graph. + def A_func_pt(x): + x = x @ x.conj().T + x_tri = pt.linalg.cholesky(x, lower=lower).astype(floatX) + + if unit_diag: + n = A_val.shape[0] + x_tri = x_tri[np.diag_indices(n)].set(1.0) + + return transpose_func(x_tri.astype(floatX), trans) + + solve_op = partial( + pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag + ) + + utt.verify_grad( + lambda A, b: solve_op(A_func_pt(A), b), + [A_val.copy(), b_val.copy()], + mode="NUMBA", ) @@ -93,11 +135,11 @@ def test_solve_triangular_raises_on_nan_inf(value): X = pt.linalg.solve_triangular(A, b, check_finite=True) f = pytensor.function([A, b], X, mode="NUMBA") - A_val = np.random.normal(size=(5, 5)) + A_val = np.random.normal(size=(5, 5)).astype(floatX) A_sym = A_val @ A_val.conj().T - A_tri = np.linalg.cholesky(A_sym).astype(config.floatX) - b = np.full((5, 1), value) + A_tri = np.linalg.cholesky(A_sym).astype(floatX) + b = np.full((5, 1), value).astype(floatX) with pytest.raises( np.linalg.LinAlgError, @@ -119,19 +161,19 @@ def test_numba_Cholesky(lower, trans): fg = FunctionGraph(outputs=[chol]) - x = np.array([0.1, 0.2, 0.3]) - val = np.eye(3) + x[None, :] * x[:, None] + x = np.array([0.1, 0.2, 0.3]).astype(floatX) + val = np.eye(3).astype(floatX) + x[None, :] * x[:, None] compare_numba_and_py(fg, [val]) def test_numba_Cholesky_raises_on_nan_input(): - test_value = rng.random(size=(3, 3)).astype(config.floatX) + test_value = rng.random(size=(3, 3)).astype(floatX) test_value[0, 0] = np.nan - x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + x = pt.tensor(dtype=floatX, shape=(3, 3)) x = x.T.dot(x) - g = pt.linalg.cholesky(x, check_finite=True) + g = pt.linalg.cholesky(x) f = pytensor.function([x], g, mode="NUMBA") with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): @@ -140,9 +182,9 @@ def test_numba_Cholesky_raises_on_nan_input(): @pytest.mark.parametrize("on_error", ["nan", "raise"]) def test_numba_Cholesky_raise_on(on_error): - test_value = rng.random(size=(3, 3)).astype(config.floatX) + test_value = rng.random(size=(3, 3)).astype(floatX) - x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + x = pt.tensor(dtype=floatX, shape=(3, 3)) g = pt.linalg.cholesky(x, on_error=on_error) f = pytensor.function([x], g, mode="NUMBA") @@ -155,6 +197,16 @@ def test_numba_Cholesky_raise_on(on_error): assert np.all(np.isnan(f(test_value))) +@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) +def test_numba_Cholesky_grad(lower): + rng = np.random.default_rng(utt.fetch_seed()) + L = rng.normal(size=(5, 5)).astype(floatX) + X = L @ L.T + + chol_op = partial(pt.linalg.cholesky, lower=lower) + utt.verify_grad(chol_op, [X], mode="NUMBA") + + def test_block_diag(): A = pt.matrix("A") B = pt.matrix("B") @@ -162,9 +214,242 @@ def test_block_diag(): D = pt.matrix("D") X = pt.linalg.block_diag(A, B, C, D) - A_val = np.random.normal(size=(5, 5)) - B_val = np.random.normal(size=(3, 3)) - C_val = np.random.normal(size=(2, 2)) - D_val = np.random.normal(size=(4, 4)) + A_val = np.random.normal(size=(5, 5)).astype(floatX) + B_val = np.random.normal(size=(3, 3)).astype(floatX) + C_val = np.random.normal(size=(2, 2)).astype(floatX) + D_val = np.random.normal(size=(4, 4)).astype(floatX) out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X]) compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val]) + + +def test_lamch(): + from scipy.linalg import get_lapack_funcs + + from pytensor.link.numba.dispatch.slinalg import _xlamch + + @numba.njit() + def xlamch(kind): + return _xlamch(kind) + + lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),)) + + np.testing.assert_allclose(xlamch("E"), lamch("E")) + np.testing.assert_allclose(xlamch("S"), lamch("S")) + np.testing.assert_allclose(xlamch("P"), lamch("P")) + np.testing.assert_allclose(xlamch("B"), lamch("B")) + np.testing.assert_allclose(xlamch("R"), lamch("R")) + np.testing.assert_allclose(xlamch("M"), lamch("M")) + + +@pytest.mark.parametrize( + "ord_numba, ord_scipy", [("F", "fro"), ("1", 1), ("I", np.inf)] +) +def test_xlange(ord_numba, ord_scipy): + # xlange is called internally only, we don't dispatch pt.linalg.norm to it + from scipy import linalg + + from pytensor.link.numba.dispatch.slinalg import _xlange + + @numba.njit() + def xlange(x, ord): + return _xlange(x, ord) + + x = np.random.normal(size=(5, 5)).astype(floatX) + np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy)) + + +@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)]) +def test_xgecon(ord_numba, ord_scipy): + # gecon is called internally only, we don't dispatch pt.linalg.norm to it + from scipy.linalg import get_lapack_funcs + + from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange + + @numba.njit() + def gecon(x, norm): + anorm = _xlange(x, norm) + cond, info = _xgecon(x, anorm, norm) + return cond, info + + x = np.random.normal(size=(5, 5)).astype(floatX) + + rcond, info = gecon(x, norm=ord_numba) + + # Test against direct call to the underlying LAPACK functions + # Solution does **not** agree with 1 / np.linalg.cond(x) ! + lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,)) + norm = lange(ord_numba, x) + rcond2, _ = gecon(x, norm, norm=ord_numba) + + assert info == 0 + np.testing.assert_allclose(rcond, rcond2) + + +@pytest.mark.parametrize("overwrite_a", [True, False]) +def test_getrf(overwrite_a): + from scipy.linalg import lu_factor + + from pytensor.link.numba.dispatch.slinalg import _getrf + + # TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor + + @numba.njit() + def getrf(x, overwrite_a): + return _getrf(x, overwrite_a=overwrite_a) + + x = np.random.normal(size=(5, 5)).astype(floatX) + x = np.asfortranarray( + x + ) # x needs to be fortran-contiguous going into getrf for the overwrite option to work + + lu, ipiv = lu_factor(x, overwrite_a=False) + LU, IPIV, info = getrf(x, overwrite_a=overwrite_a) + + assert info == 0 + assert_allclose(LU, lu) + + if overwrite_a: + assert_allclose(x, LU) + + # TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing + # this, though. + assert_allclose(IPIV - 1, ipiv) + + +@pytest.mark.parametrize("trans", [0, 1]) +@pytest.mark.parametrize("overwrite_a", [True, False]) +@pytest.mark.parametrize("overwrite_b", [True, False]) +@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"]) +def test_getrs(trans, overwrite_a, overwrite_b, b_shape): + from scipy.linalg import lu_factor + from scipy.linalg import lu_solve as sp_lu_solve + + from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs + + # TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor + + @numba.njit() + def lu_solve(a, b, trans, overwrite_a, overwrite_b): + lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a) + x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b) + return x, lu, info + + a = np.random.normal(size=(5, 5)).astype(floatX) + b = np.random.normal(size=b_shape).astype(floatX) + + # inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work + a = np.asfortranarray(a) + b = np.asfortranarray(b) + + lu_and_piv = lu_factor(a, overwrite_a=False) + x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False) + + x, lu, info = lu_solve( + a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b + ) + assert info == 0 + if overwrite_a: + assert_allclose(a, lu) + if overwrite_b: + assert_allclose(b, x) + + assert_allclose(x, x_sp) + + +@pytest.mark.parametrize( + "b_shape", + [(5, 1), (5, 5), (5,)], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) +@pytest.mark.filterwarnings( + 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' +) +def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]): + A = pt.matrix("A", dtype=floatX) + b = pt.tensor("b", shape=b_shape, dtype=floatX) + + A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX)) + b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX)) + + def A_func(x): + if assume_a == "pos": + x = x @ x.T + elif assume_a == "sym": + x = (x + x.T) / 2 + return x + + X = pt.linalg.solve( + A_func(A), + b, + assume_a=assume_a, + b_ndim=len(b_shape), + ) + f = pytensor.function( + [pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA" + ) + op = f.maker.fgraph.outputs[0].owner.op + + compare_numba_and_py(([A, b], [X]), inputs=[A_val, b_val], inplace=True) + + # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. + A_val_copy = A_val.copy() + b_val_copy = b_val.copy() + + X_np = f(A_val, b_val) + + # overwrite_b is preferred when both inputs can be destroyed + assert op.destroy_map == {0: [1]} + + # Confirm inputs were destroyed by checking against the copies + assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) + assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) + + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 + + # Confirm b_val is used to store to solution + np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL) + assert not np.allclose(b_val, b_val_copy) + + # Test that the result is numerically correct. Need to use the unmodified copy + np.testing.assert_allclose( + A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL + ) + + # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here + utt.verify_grad( + lambda A, b: pt.linalg.solve( + A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape) + ), + [A_val_copy, b_val_copy], + mode="NUMBA", + ) + + +@pytest.mark.parametrize( + "b_func, b_size", + [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower = {x}") +def test_cho_solve(b_func, b_size, lower): + A = pt.matrix("A", dtype=floatX) + b = b_func("b", dtype=floatX) + + C = pt.linalg.cholesky(A, lower=lower) + X = pt.linalg.cho_solve((C, lower), b) + f = pytensor.function([A, b], X, mode="NUMBA") + + A = np.random.normal(size=(5, 5)).astype(floatX) + A = A @ A.conj().T + + b = np.random.normal(size=b_size) + b = b.astype(floatX) + + X_np = f(A, b) + + ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if floatX.endswith("64") else 1e-4 + + np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f46d771938..34f1396f4c 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -209,12 +209,12 @@ def test__repr__(self): ) -class TestSolve(utt.InferShapeTester): - def test__init__(self): - with pytest.raises(ValueError) as excinfo: - Solve(assume_a="test", b_ndim=2) - assert "is not a recognized matrix structure" in str(excinfo.value) +def test_solve_raises_on_invalid_A(): + with pytest.raises(ValueError, match="is not a recognized matrix structure"): + Solve(assume_a="test", b_ndim=2) + +class TestSolve(utt.InferShapeTester): @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) def test_infer_shape(self, b_shape): rng = np.random.default_rng(utt.fetch_seed()) @@ -232,64 +232,78 @@ def test_infer_shape(self, b_shape): warn=False, ) - def test_correctness(self): + @pytest.mark.parametrize( + "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] + ) + @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + def test_solve_correctness(self, b_size: tuple[int], assume_a: str): rng = np.random.default_rng(utt.fetch_seed()) - A = matrix() - b = matrix() - y = solve(A, b) - gen_solve_func = pytensor.function([A, b], y) + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("b", shape=b_size) - b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_size).astype(config.floatX) - A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) - A_val = np.dot(A_val.transpose(), A_val) + solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) - np.testing.assert_allclose( - scipy.linalg.solve(A_val, b_val, assume_a="gen"), - gen_solve_func(A_val, b_val), - ) + def A_func(x): + if assume_a == "pos": + return x @ x.T + elif assume_a == "sym": + return (x + x.T) / 2 + else: + return x + + solve_input_val = A_func(A_val) + + y = solve_op(A_func(A), b) + solve_func = pytensor.function([A, b], y) + X_np = solve_func(A_val.copy(), b_val.copy()) + + ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4 + RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4 - A_undef = np.array( - [ - [1, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 1, 1], - [0, 0, 0, 1, 0], - ], - dtype=config.floatX, - ) np.testing.assert_allclose( - scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) + scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a), + X_np, + atol=ATOL, + rtol=RTOL, ) + np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL) + @pytest.mark.parametrize( - "m, n, assume_a, lower", - [ - (5, None, "gen", False), - (5, None, "gen", True), - (4, 2, "gen", False), - (4, 2, "gen", True), - ], + "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] ) - def test_solve_grad(self, m, n, assume_a, lower): + @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) + @pytest.mark.skipif( + config.floatX == "float32", reason="Gradients not numerically stable in float32" + ) + def test_solve_gradient(self, b_size: tuple[int], assume_a: str): rng = np.random.default_rng(utt.fetch_seed()) - # Ensure diagonal elements of `A` are relatively large to avoid - # numerical precision issues - A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) + eps = 2e-8 if config.floatX == "float64" else None - if n is None: - b_val = rng.normal(size=m).astype(config.floatX) - else: - b_val = rng.normal(size=(m, n)).astype(config.floatX) + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + b_val = rng.normal(size=b_size).astype(config.floatX) - eps = None - if config.floatX == "float64": - eps = 2e-8 + def A_func(x): + if assume_a == "pos": + return x @ x.T + elif assume_a == "sym": + return (x + x.T) / 2 + else: + return x - solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2) - utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) + solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) + + # To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices + # (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included, + # the random perturbations used by verify_grad will result in invalid input matrices, and + # LAPACK will silently do the wrong thing, making the gradients wrong + utt.verify_grad( + lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps + ) class TestSolveTriangular(utt.InferShapeTester): From 361280c871a622c22edbbf41b69e8052da29bc2b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 11:16:08 +0200 Subject: [PATCH 25/43] Update numpy deprecated imports - replaced np.AxisError with np.exceptions.AxisError - the `numpy.core` submodule has been renamed to `numpy._core` - some parts of `numpy.core` have been moved to `numpy.lib.array_utils` Except for `AxisError`, the updated imports are conditional on the version of numpy, so the imports should work for numpy >= 1.26. The conditional imports have been added to `npy_2_compat.py`, so the imports elsewhere are unconditonal. --- pytensor/link/c/basic.py | 7 +- pytensor/link/numba/dispatch/elemwise.py | 2 +- pytensor/npy_2_compat.py | 275 +++++++++++++++++++++++ pytensor/tensor/__init__.py | 2 +- pytensor/tensor/basic.py | 8 +- pytensor/tensor/conv/abstract_conv.py | 3 +- pytensor/tensor/einsum.py | 9 +- pytensor/tensor/elemwise.py | 5 +- pytensor/tensor/extra_ops.py | 12 +- pytensor/tensor/math.py | 2 +- pytensor/tensor/nlinalg.py | 2 +- pytensor/tensor/shape.py | 2 +- pytensor/tensor/slinalg.py | 3 +- pytensor/tensor/subtensor.py | 2 + pytensor/tensor/utils.py | 6 +- tests/tensor/test_elemwise.py | 2 +- tests/tensor/test_extra_ops.py | 2 +- tests/tensor/test_io.py | 2 +- 18 files changed, 311 insertions(+), 35 deletions(-) create mode 100644 pytensor/npy_2_compat.py diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index d7f43e7377..d509bd1d76 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -10,8 +10,6 @@ from io import StringIO from typing import TYPE_CHECKING, Any, Optional -import numpy as np - from pytensor.compile.compilelock import lock_ctx from pytensor.configdefaults import config from pytensor.graph.basic import ( @@ -33,6 +31,7 @@ from pytensor.link.c.cmodule import get_module_cache as _get_module_cache from pytensor.link.c.interface import CLinkerObject, CLinkerOp, CLinkerType from pytensor.link.utils import gc_helper, map_storage, raise_with_op, streamline +from pytensor.npy_2_compat import ndarray_c_version from pytensor.utils import difference, uniq @@ -1367,10 +1366,6 @@ def cmodule_key_( # We must always add the numpy ABI version here as # DynamicModule always add the include - if np.lib.NumpyVersion(np.__version__) < "1.16.0a": - ndarray_c_version = np.core.multiarray._get_ndarray_c_version() - else: - ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}") if c_compiler: sig.append("c_compiler_str=" + c_compiler.version_str()) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2a98985efe..03c7084a8f 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -4,7 +4,6 @@ import numba import numpy as np from numba.core.extending import overload -from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic @@ -19,6 +18,7 @@ store_core_outputs, ) from pytensor.link.utils import compile_function_src +from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.scalar.basic import ( AND, OR, diff --git a/pytensor/npy_2_compat.py b/pytensor/npy_2_compat.py new file mode 100644 index 0000000000..30214154a2 --- /dev/null +++ b/pytensor/npy_2_compat.py @@ -0,0 +1,275 @@ +from textwrap import dedent + +import numpy as np + + +# Conditional numpy imports for numpy 1.26 and 2.x compatibility +try: + from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple +except ModuleNotFoundError: + # numpy < 2.0 + from numpy.core.multiarray import normalize_axis_index # type: ignore[no-redef] + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] + + +try: + from numpy._core.einsumfunc import ( # type: ignore[attr-defined] + _find_contraction, + _parse_einsum_input, + ) +except ModuleNotFoundError: + from numpy.core.einsumfunc import ( # type: ignore[no-redef] + _find_contraction, + _parse_einsum_input, + ) + + +# suppress linting warning by "using" the imports here: +__all__ = [ + "_find_contraction", + "_parse_einsum_input", + "normalize_axis_index", + "normalize_axis_tuple", +] + + +numpy_version_tuple = tuple(int(n) for n in np.__version__.split(".")[:2]) +numpy_version = np.lib.NumpyVersion( + np.__version__ +) # used to compare with version strings, e.g. numpy_version < "1.16.0" +using_numpy_2 = numpy_version >= "2.0.0rc1" + + +if using_numpy_2: + ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() +else: + ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined] + + +if using_numpy_2: + UintOverflowError = OverflowError +else: + UintOverflowError = TypeError + + +def npy_2_compat_header() -> str: + """Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x""" + return dedent(""" + #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ + #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ + + + /* + * This header is meant to be included by downstream directly for 1.x compat. + * In that case we need to ensure that users first included the full headers + * and not just `ndarraytypes.h`. + */ + + #ifndef NPY_FEATURE_VERSION + #error "The NumPy 2 compat header requires `import_array()` for which " \\ + "the `ndarraytypes.h` header include is not sufficient. Please " \\ + "include it after `numpy/ndarrayobject.h` or similar." \\ + "" \\ + "To simplify inclusion, you may use `PyArray_ImportNumPy()` " \\ + "which is defined in the compat header and is lightweight (can be)." + #endif + + #if NPY_ABI_VERSION < 0x02000000 + /* + * Define 2.0 feature version as it is needed below to decide whether we + * compile for both 1.x and 2.x (defining it gaurantees 1.x only). + */ + #define NPY_2_0_API_VERSION 0x00000012 + /* + * If we are compiling with NumPy 1.x, PyArray_RUNTIME_VERSION so we + * pretend the `PyArray_RUNTIME_VERSION` is `NPY_FEATURE_VERSION`. + * This allows downstream to use `PyArray_RUNTIME_VERSION` if they need to. + */ + #define PyArray_RUNTIME_VERSION NPY_FEATURE_VERSION + /* Compiling on NumPy 1.x where these are the same: */ + #define PyArray_DescrProto PyArray_Descr + #endif + + + /* + * Define a better way to call `_import_array()` to simplify backporting as + * we now require imports more often (necessary to make ABI flexible). + */ + #ifdef import_array1 + + static inline int + PyArray_ImportNumPyAPI() + { + if (NPY_UNLIKELY(PyArray_API == NULL)) { + import_array1(-1); + } + return 0; + } + + #endif /* import_array1 */ + + + /* + * NPY_DEFAULT_INT + * + * The default integer has changed, `NPY_DEFAULT_INT` is available at runtime + * for use as type number, e.g. `PyArray_DescrFromType(NPY_DEFAULT_INT)`. + * + * NPY_RAVEL_AXIS + * + * This was introduced in NumPy 2.0 to allow indicating that an axis should be + * raveled in an operation. Before NumPy 2.0, NPY_MAXDIMS was used for this purpose. + * + * NPY_MAXDIMS + * + * A constant indicating the maximum number dimensions allowed when creating + * an ndarray. + * + * NPY_NTYPES_LEGACY + * + * The number of built-in NumPy dtypes. + */ + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + #define NPY_DEFAULT_INT NPY_INTP + #define NPY_RAVEL_AXIS NPY_MIN_INT + #define NPY_MAXARGS 64 + + #elif NPY_ABI_VERSION < 0x02000000 + #define NPY_DEFAULT_INT NPY_LONG + #define NPY_RAVEL_AXIS 32 + #define NPY_MAXARGS 32 + + /* Aliases of 2.x names to 1.x only equivalent names */ + #define NPY_NTYPES NPY_NTYPES_LEGACY + #define PyArray_DescrProto PyArray_Descr + #define _PyArray_LegacyDescr PyArray_Descr + /* NumPy 2 definition always works, but add it for 1.x only */ + #define PyDataType_ISLEGACY(dtype) (1) + #else + #define NPY_DEFAULT_INT \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? NPY_INTP : NPY_LONG) + #define NPY_RAVEL_AXIS \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? -1 : 32) + #define NPY_MAXARGS \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? 64 : 32) + #endif + + + /* + * Access inline functions for descriptor fields. Except for the first + * few fields, these needed to be moved (elsize, alignment) for + * additional space. Or they are descriptor specific and are not generally + * available anymore (metadata, c_metadata, subarray, names, fields). + * + * Most of these are defined via the `DESCR_ACCESSOR` macro helper. + */ + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION || NPY_ABI_VERSION < 0x02000000 + /* Compiling for 1.x or 2.x only, direct field access is OK: */ + + static inline void + PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size) + { + dtype->elsize = size; + } + + static inline npy_uint64 + PyDataType_FLAGS(const PyArray_Descr *dtype) + { + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + return dtype->flags; + #else + return (unsigned char)dtype->flags; /* Need unsigned cast on 1.x */ + #endif + } + + #define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\ + static inline type \\ + PyDataType_##FIELD(const PyArray_Descr *dtype) { \\ + if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\ + return (type)0; \\ + } \\ + return ((_PyArray_LegacyDescr *)dtype)->field; \\ + } + #else /* compiling for both 1.x and 2.x */ + + static inline void + PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + ((_PyArray_DescrNumPy2 *)dtype)->elsize = size; + } + else { + ((PyArray_DescrProto *)dtype)->elsize = (int)size; + } + } + + static inline npy_uint64 + PyDataType_FLAGS(const PyArray_Descr *dtype) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + return ((_PyArray_DescrNumPy2 *)dtype)->flags; + } + else { + return (unsigned char)((PyArray_DescrProto *)dtype)->flags; + } + } + + /* Cast to LegacyDescr always fine but needed when `legacy_only` */ + #define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\ + static inline type \\ + PyDataType_##FIELD(const PyArray_Descr *dtype) { \\ + if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\ + return (type)0; \\ + } \\ + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { \\ + return ((_PyArray_LegacyDescr *)dtype)->field; \\ + } \\ + else { \\ + return ((PyArray_DescrProto *)dtype)->field; \\ + } \\ + } + #endif + + DESCR_ACCESSOR(ELSIZE, elsize, npy_intp, 0) + DESCR_ACCESSOR(ALIGNMENT, alignment, npy_intp, 0) + DESCR_ACCESSOR(METADATA, metadata, PyObject *, 1) + DESCR_ACCESSOR(SUBARRAY, subarray, PyArray_ArrayDescr *, 1) + DESCR_ACCESSOR(NAMES, names, PyObject *, 1) + DESCR_ACCESSOR(FIELDS, fields, PyObject *, 1) + DESCR_ACCESSOR(C_METADATA, c_metadata, NpyAuxData *, 1) + + #undef DESCR_ACCESSOR + + + #if !(defined(NPY_INTERNAL_BUILD) && NPY_INTERNAL_BUILD) + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + return _PyDataType_GetArrFuncs(descr); + } + #elif NPY_ABI_VERSION < 0x02000000 + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + return descr->f; + } + #else + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + return _PyDataType_GetArrFuncs(descr); + } + else { + return ((PyArray_DescrProto *)descr)->f; + } + } + #endif + + + #endif /* not internal build */ + + #endif /* NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ */ + + """) diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 67b6ab071e..88d3f33199 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -123,7 +123,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # isort: on # Allow accessing numpy constants from pytensor.tensor -from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi +from numpy import e, euler_gamma, inf, nan, newaxis, pi from pytensor.tensor.basic import * from pytensor.tensor.blas import batched_dot, batched_tensordot diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 26bd34692b..061a159fc2 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -14,8 +14,7 @@ from typing import cast as type_cast import numpy as np -from numpy.core.multiarray import normalize_axis_index -from numpy.core.numeric import normalize_axis_tuple +from numpy.exceptions import AxisError import pytensor import pytensor.scalar.sharedvar @@ -32,6 +31,7 @@ from pytensor.graph.type import HasShape, Type from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.raise_op import CheckAndRaise from pytensor.scalar import int32 @@ -228,7 +228,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: elif x_.ndim > ndim: try: x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim))) - except np.AxisError: + except AxisError: raise ValueError( f"ndarray could not be cast to constant with {int(ndim)} dimensions" ) @@ -4405,7 +4405,7 @@ def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVa axis = (axis,) out_ndim = len(axis) + a.ndim - axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim) + axis = normalize_axis_tuple(axis, out_ndim) if not axis: return a diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index d1dfe44b90..fc937bf404 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -8,6 +8,7 @@ from math import gcd import numpy as np +from numpy.exceptions import ComplexWarning try: @@ -2338,7 +2339,7 @@ def conv( bval = _bvalfromboundary("fill") with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) for b in range(img.shape[0]): for g in range(self.num_groups): for n in range(output_channel_offset): diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index cba40ec6f8..88a6257c9c 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -6,13 +6,14 @@ from typing import cast import numpy as np -from numpy.core.einsumfunc import _find_contraction, _parse_einsum_input # type: ignore -from numpy.core.numeric import ( # type: ignore + +from pytensor.compile.builders import OpFromGraph +from pytensor.npy_2_compat import ( + _find_contraction, + _parse_einsum_input, normalize_axis_index, normalize_axis_tuple, ) - -from pytensor.compile.builders import OpFromGraph from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( arange, diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index c37597906a..a07ec0d9dd 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -4,7 +4,6 @@ from typing import Literal import numpy as np -from numpy.core.numeric import normalize_axis_tuple import pytensor.tensor.basic from pytensor.configdefaults import config @@ -17,6 +16,7 @@ from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp from pytensor.link.c.params_type import ParamsType from pytensor.misc.frozendict import frozendict +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type from pytensor.scalar.basic import bool as scalar_bool @@ -41,9 +41,6 @@ from pytensor.utils import uniq -_numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] - - class DimShuffle(ExternalCOp): """ Allows to reorder the dimensions of a tensor or insert or remove diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 27eabc5ba4..e9d06ae9c2 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -2,7 +2,7 @@ from collections.abc import Collection, Iterable import numpy as np -from numpy.core.multiarray import normalize_axis_index +from numpy.exceptions import AxisError import pytensor import pytensor.scalar.basic as ps @@ -17,6 +17,10 @@ from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import EnumList, Generic +from pytensor.npy_2_compat import ( + normalize_axis_index, + normalize_axis_tuple, +) from pytensor.raise_op import Assert from pytensor.scalar import int32 as int_t from pytensor.scalar import upcast @@ -596,9 +600,9 @@ def squeeze(x, axis=None): # scalar inputs are treated as 1D regarding axis in this `Op` try: - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=_x.ndim) + axis = normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) + except AxisError: + raise AxisError(axis, ndim=_x.ndim) if not axis: # Nothing to do diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 4dbf30685d..c4f3dc50a5 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional import numpy as np -from numpy.core.numeric import normalize_axis_tuple from pytensor import config, printing from pytensor import scalar as ps @@ -14,6 +13,7 @@ from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.scalar.basic import BinaryScalarOp diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index a9d7016099..ee33f6533c 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -4,13 +4,13 @@ from typing import Literal, cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 1c23a21347..e839ac1f08 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -6,7 +6,6 @@ from typing import cast as typing_cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.gradient import DisconnectedType @@ -16,6 +15,7 @@ from pytensor.graph.type import HasShape from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.scalar import int32 from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import basic as ptb diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index f101315172..94973810fd 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -6,6 +6,7 @@ import numpy as np import scipy.linalg as scipy_linalg +from numpy.exceptions import ComplexWarning import pytensor import pytensor.tensor as pt @@ -767,7 +768,7 @@ def perform(self, node, inputs, outputs): Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T) with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) out[0] = Y.astype(A.dtype) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index a3a81f63bd..46b9cc06fd 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -18,6 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import numpy_version, using_numpy_2 from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -2522,6 +2523,7 @@ def c_code(self, node, name, input_names, output_names, sub): numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] if bool(numpy_ver < [1, 8]): raise NotImplementedError + x, y, idx = input_names out = output_names[0] copy_of_x = self.copy_of_x(x) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index e6451c9236..9ce12296cd 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -3,10 +3,10 @@ from typing import cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.graph import FunctionGraph, Variable +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.utils import hash_from_code @@ -236,8 +236,8 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: if axis is not None: try: axis = normalize_axis_tuple(axis, ndim=max(1, ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=ndim) + except np.exceptions.AxisError: + raise np.exceptions.AxisError(axis, ndim=ndim) # TODO: If axis tuple is equivalent to None, return None for more canonicalization? return cast(tuple, axis) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index bd208c5848..8555a1d29f 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -672,7 +672,7 @@ def test_scalar_input(self): assert self.op(ps.add, axis=(-1,))(x).eval({x: 5}) == 5 with pytest.raises( - np.AxisError, + np.exceptions.AxisError, match=re.escape("axis (-2,) is out of bounds for array of dimension 0"), ): self.op(ps.add, axis=(-2,))(x) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index e4f4945393..8bf689bc15 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -469,7 +469,7 @@ def test_scalar_input(self): assert squeeze(x, axis=(0,)).eval({x: 5}) == 5 with pytest.raises( - np.AxisError, + np.exceptions.AxisError, match=re.escape("axis (1,) is out of bounds for array of dimension 0"), ): squeeze(x, axis=1) diff --git a/tests/tensor/test_io.py b/tests/tensor/test_io.py index cece2af277..4c5e5655fe 100644 --- a/tests/tensor/test_io.py +++ b/tests/tensor/test_io.py @@ -49,7 +49,7 @@ def test_memmap(self): path = Variable(Generic(), None) x = load(path, "int32", (None,), mmap_mode="c") fn = function([path], x) - assert isinstance(fn(self.filename), np.core.memmap) + assert isinstance(fn(self.filename), np.memmap) def teardown_method(self): (pytensor.config.compiledir / "_test.npy").unlink() From e6c26b23f38265ebeac78d03abee7e0c4753f34f Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Apr 2024 10:53:07 -0400 Subject: [PATCH 26/43] Changes for numpy 2.0 deprecations - Replace np.cast with np.asarray: in numpy 2.0, `np.cast[new_dtype](arr)` is deprecated. The literal replacement is `np.asarray(arr, dtype=new_dtype)`. - Replace np.sctype2char and np.obj2sctype. Added try/except to handle change in behavior of `np.dtype` - Replace np.find_common_type with np.result_type Further changes to `TensorType`: TensorType.dtype must be a string, so the code has been changed from `self.dtype = np.dtype(dtype).type`, where the right-hand side is of type `np.generic`, to `self.dtype = str(np.dtype(dtype))`, where the right-hand side is a string that satisfies: `self.dtype == str(np.dtype(self.dtype))` This doesn't change the behavior of `np.array(..., dtype=self.dtype)` etc. --- pytensor/scalar/basic.py | 22 +++++++++++----------- pytensor/tensor/elemwise.py | 2 +- pytensor/tensor/type.py | 27 +++++++++++++++------------ tests/scan/test_rewriting.py | 2 +- tests/tensor/test_extra_ops.py | 6 +++--- tests/tensor/utils.py | 2 +- tests/test_gradient.py | 28 +++++++++++++++------------- tests/typed_list/test_basic.py | 8 ++++---- 8 files changed, 51 insertions(+), 46 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index c13afbd6fa..94039f8091 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -2966,7 +2966,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (x * np.asarray(math.log(2.0)).astype(x.dtype)),) + return (gz / (x * np.array(math.log(2.0), dtype=x.dtype)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3009,7 +3009,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (x * np.asarray(math.log(10.0)).astype(x.dtype)),) + return (gz / (x * np.array(math.log(10.0), dtype=x.dtype)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3124,7 +3124,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * exp2(x) * log(np.cast[x.type](2)),) + return (gz * exp2(x) * log(np.array(2, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3263,7 +3263,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * np.asarray(np.pi / 180, gz.type),) + return (gz * np.array(np.pi / 180, dtype=gz.type),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3298,7 +3298,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * np.asarray(180.0 / np.pi, gz.type),) + return (gz * np.array(180.0 / np.pi, dtype=gz.type),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3371,7 +3371,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (-gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (-gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3445,7 +3445,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3517,7 +3517,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) + sqr(x)),) + return (gz / (np.array(1, dtype=x.type) + sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3640,7 +3640,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) - np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) - np.array(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3717,7 +3717,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) + np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) + np.array(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3795,7 +3795,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) - sqr(x)),) + return (gz / (np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a07ec0d9dd..37acfc8e86 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -668,7 +668,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): and isinstance(self.nfunc, np.ufunc) and node.inputs[0].dtype in discrete_dtypes ): - char = np.sctype2char(out_dtype) + char = np.dtype(out_dtype).char sig = char * node.nin + "->" + char * node.nout node.tag.sig = sig node.tag.fake_node = Apply( diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 0f99fa48aa..d48a7a6f08 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal, Optional import numpy as np +import numpy.typing as npt import pytensor from pytensor import scalar as ps @@ -69,7 +70,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): def __init__( self, - dtype: str | np.dtype, + dtype: str | npt.DTypeLike, shape: Iterable[bool | int | None] | None = None, name: str | None = None, broadcastable: Iterable[bool] | None = None, @@ -101,11 +102,11 @@ def __init__( if str(dtype) == "floatX": self.dtype = config.floatX else: - if np.obj2sctype(dtype) is None: + try: + self.dtype = str(np.dtype(dtype)) + except TypeError: raise TypeError(f"Invalid dtype: {dtype}") - self.dtype = np.dtype(dtype).name - def parse_bcast_and_shape(s): if isinstance(s, bool | np.bool_): return 1 if s else None @@ -789,14 +790,16 @@ def tensor( **kwargs, ) -> "TensorVariable": if name is not None: - # Help catching errors with the new tensor API - # Many single letter strings are valid sctypes - if str(name) == "floatX" or (len(str(name)) > 1 and np.obj2sctype(name)): - np.obj2sctype(name) - raise ValueError( - f"The first and only positional argument of tensor is now `name`. Got {name}.\n" - "This name looks like a dtype, which you should pass as a keyword argument only." - ) + try: + # Help catching errors with the new tensor API + # Many single letter strings are valid sctypes + if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): + raise ValueError( + f"The first and only positional argument of tensor is now `name`. Got {name}.\n" + "This name looks like a dtype, which you should pass as a keyword argument only." + ) + except TypeError: + pass if dtype is None: dtype = config.floatX diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 6f77625f2f..fd9c43b129 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -673,7 +673,7 @@ def test_machine_translation(self): zi = tensor3("zi") zi_value = x_value - init = pt.alloc(np.cast[config.floatX](0), batch_size, dim) + init = pt.alloc(np.asarray(0, dtype=config.floatX), batch_size, dim) def rnn_step1( # sequences diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 8bf689bc15..54bb7f4333 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -708,7 +708,7 @@ def test_perform(self, shp): y = scalar() f = function([x, y], fill_diagonal(x, y)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out), val) @@ -720,7 +720,7 @@ def test_perform_3d(self): x = tensor3() y = scalar() f = function([x, y], fill_diagonal(x, y)) - val = np.cast[config.floatX](rng.random() + 10) + val = rng.random(dtype=config.floatX) + 10 out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert out[0, 0, 0] == val @@ -782,7 +782,7 @@ def test_perform(self, test_offset, shp): f = function([x, y, z], fill_diagonal_offset(x, y, z)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val, test_offset) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out, test_offset), val) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 9eb06f28a3..b94750ffe2 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -152,7 +152,7 @@ def upcast_float16_ufunc(fn): """ def ret(*args, **kwargs): - out_dtype = np.find_common_type([a.dtype for a in args], [np.float16]) + out_dtype = np.result_type(np.float16, *args) if out_dtype == "float16": # Force everything to float32 sig = "f" * fn.nin + "->" + "f" * fn.nout diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 79c55caf44..24f5964c92 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -481,12 +481,12 @@ def make_grad_func(X): int_type = imatrix().dtype float_type = "float64" - X = np.cast[int_type](rng.standard_normal((m, d)) * 127.0) - W = np.cast[W.dtype](rng.standard_normal((d, n))) - b = np.cast[b.dtype](rng.standard_normal(n)) + X = np.asarray(rng.standard_normal((m, d)) * 127.0, dtype=int_type) + W = rng.standard_normal((d, n), dtype=W.dtype) + b = rng.standard_normal(n, dtype=b.dtype) int_result = int_func(X, W, b) - float_result = float_func(np.cast[float_type](X), W, b) + float_result = float_func(np.asarray(X, dtype=float_type), W, b) assert np.allclose(int_result, float_result), (int_result, float_result) @@ -508,7 +508,7 @@ def test_grad_disconnected(self): # the output f = pytensor.function([x], g) rng = np.random.default_rng([2012, 9, 5]) - x = np.cast[x.dtype](rng.standard_normal(3)) + x = rng.standard_normal(3, dtype=x.dtype) g = f(x) assert np.allclose(g, np.ones(x.shape, dtype=x.dtype)) @@ -631,7 +631,8 @@ def test_known_grads(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(10), rng.integers(10), rng.standard_normal()] values = [ - np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True) + np.asarray(value, dtype=ipt.dtype) + for ipt, value in zip(inputs, values, strict=True) ] true_grads = grad(cost, inputs, disconnected_inputs="ignore") @@ -679,7 +680,7 @@ def test_known_grads_integers(): f = pytensor.function([g_expected], g_grad) x = -3 - gv = np.cast[config.floatX](0.6) + gv = np.asarray(0.6, dtype=config.floatX) g_actual = f(gv) @@ -746,7 +747,8 @@ def test_subgraph_grad(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(2), rng.standard_normal(3)] values = [ - np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True) + np.asarray(value, dtype=ipt.dtype) + for ipt, value in zip(inputs, values, strict=True) ] wrt = [w2, w1] @@ -1031,21 +1033,21 @@ def test_jacobian_scalar(): # test when the jacobian is called with a tensor as wrt Jx = jacobian(y, x) f = pytensor.function([x], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a tuple as wrt Jx = jacobian(y, (x,)) assert isinstance(Jx, tuple) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list as wrt Jx = jacobian(y, [x]) assert isinstance(Jx, list) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list of two elements @@ -1053,8 +1055,8 @@ def test_jacobian_scalar(): y = x * z Jx = jacobian(y, [x, z]) f = pytensor.function([x, z], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) - vz = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) vJx = f(vx, vz) assert np.allclose(vJx[0], vz) diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 466bdc865d..19598bfb21 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -577,10 +577,10 @@ def test_correct_answer(self): x = tensor3() y = tensor3() - A = np.cast[pytensor.config.floatX](np.random.random((5, 3))) - B = np.cast[pytensor.config.floatX](np.random.random((7, 2))) - X = np.cast[pytensor.config.floatX](np.random.random((5, 6, 1))) - Y = np.cast[pytensor.config.floatX](np.random.random((1, 9, 3))) + A = np.random.random((5, 3)).astype(pytensor.config.floatX) + B = np.random.random((7, 2)).astype(pytensor.config.floatX) + X = np.random.random((5, 6, 1)).astype(pytensor.config.floatX) + Y = np.random.random((1, 9, 3)).astype(pytensor.config.floatX) make_list((3.0, 4.0)) c = make_list((a, b)) From 910b27c00ba93ead70a9994d8b651fa179c19380 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Wed, 5 Feb 2025 10:19:16 +0000 Subject: [PATCH 27/43] Updated lazylinker C code Some macros were removed from npy_3k_compat.h. Following numpy, I updated the affected functions to the Python 3 names, and removed support for Python 2. Also updated lazylinker_c version to indicate substantial changes to the C code. --- pytensor/link/c/c_code/lazylinker_c.c | 53 +++++++------------- pytensor/link/c/c_code/pytensor_mod_helper.h | 8 +-- pytensor/link/c/lazylinker_c.py | 2 +- 3 files changed, 21 insertions(+), 42 deletions(-) diff --git a/pytensor/link/c/c_code/lazylinker_c.c b/pytensor/link/c/c_code/lazylinker_c.c index a64614a908..08f3e4d0fb 100644 --- a/pytensor/link/c/c_code/lazylinker_c.c +++ b/pytensor/link/c/c_code/lazylinker_c.c @@ -5,9 +5,6 @@ #if PY_VERSION_HEX >= 0x03000000 #include "numpy/npy_3kcompat.h" -#define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr -#define PyCObject_GetDesc NpyCapsule_GetDesc -#define PyCObject_Check NpyCapsule_Check #endif #ifndef Py_TYPE @@ -323,9 +320,9 @@ static int CLazyLinker_init(CLazyLinker *self, PyObject *args, PyObject *kwds) { if (PyObject_HasAttrString(thunk, "cthunk")) { PyObject *cthunk = PyObject_GetAttrString(thunk, "cthunk"); // new reference - assert(cthunk && PyCObject_Check(cthunk)); - self->thunk_cptr_fn[i] = PyCObject_AsVoidPtr(cthunk); - self->thunk_cptr_data[i] = PyCObject_GetDesc(cthunk); + assert(cthunk && NpyCapsule_Check(cthunk)); + self->thunk_cptr_fn[i] = NpyCapsule_AsVoidPtr(cthunk); + self->thunk_cptr_data[i] = NpyCapsule_GetDesc(cthunk); Py_DECREF(cthunk); // cthunk is kept alive by membership in self->thunks } @@ -487,8 +484,8 @@ static PyObject *pycall(CLazyLinker *self, Py_ssize_t node_idx, int verbose) { PyList_SetItem(self->call_times, node_idx, PyFloat_FromDouble(t1 - t0 + ti)); PyObject *count = PyList_GetItem(self->call_counts, node_idx); - long icount = PyInt_AsLong(count); - PyList_SetItem(self->call_counts, node_idx, PyInt_FromLong(icount + 1)); + long icount = PyLong_AsLong(count); + PyList_SetItem(self->call_counts, node_idx, PyLong_FromLong(icount + 1)); } } else { if (verbose) { @@ -512,8 +509,8 @@ static int c_call(CLazyLinker *self, Py_ssize_t node_idx, int verbose) { PyList_SetItem(self->call_times, node_idx, PyFloat_FromDouble(t1 - t0 + ti)); PyObject *count = PyList_GetItem(self->call_counts, node_idx); - long icount = PyInt_AsLong(count); - PyList_SetItem(self->call_counts, node_idx, PyInt_FromLong(icount + 1)); + long icount = PyLong_AsLong(count); + PyList_SetItem(self->call_counts, node_idx, PyLong_FromLong(icount + 1)); } else { err = fn(self->thunk_cptr_data[node_idx]); } @@ -774,20 +771,20 @@ static PyObject *CLazyLinker_call(PyObject *_self, PyObject *args, output_subset = (char *)calloc(self->n_output_vars, sizeof(char)); for (int it = 0; it < output_subset_size; ++it) { PyObject *elem = PyList_GetItem(output_subset_ptr, it); - if (!PyInt_Check(elem)) { + if (!PyLong_Check(elem)) { err = 1; PyErr_SetString(PyExc_RuntimeError, "Some elements of output_subset list are not int"); } - output_subset[PyInt_AsLong(elem)] = 1; + output_subset[PyLong_AsLong(elem)] = 1; } } } self->position_of_error = -1; // create constants used to fill the var_compute_cells - PyObject *one = PyInt_FromLong(1); - PyObject *zero = PyInt_FromLong(0); + PyObject *one = PyLong_FromLong(1); + PyObject *zero = PyLong_FromLong(0); // pre-allocate our return value Py_INCREF(Py_None); @@ -942,11 +939,8 @@ static PyMemberDef CLazyLinker_members[] = { }; static PyTypeObject lazylinker_ext_CLazyLinkerType = { -#if defined(NPY_PY3K) PyVarObject_HEAD_INIT(NULL, 0) -#else - PyObject_HEAD_INIT(NULL) 0, /*ob_size*/ -#endif + "lazylinker_ext.CLazyLinker", /*tp_name*/ sizeof(CLazyLinker), /*tp_basicsize*/ 0, /*tp_itemsize*/ @@ -987,7 +981,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { }; static PyObject *get_version(PyObject *dummy, PyObject *args) { - PyObject *result = PyFloat_FromDouble(0.212); + PyObject *result = PyFloat_FromDouble(0.3); return result; } @@ -996,7 +990,7 @@ static PyMethodDef lazylinker_ext_methods[] = { {NULL, NULL, 0, NULL} /* Sentinel */ }; -#if defined(NPY_PY3K) + static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT, "lazylinker_ext", NULL, @@ -1006,28 +1000,19 @@ static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT, NULL, NULL, NULL}; -#endif -#if defined(NPY_PY3K) -#define RETVAL m + PyMODINIT_FUNC PyInit_lazylinker_ext(void) { -#else -#define RETVAL -PyMODINIT_FUNC initlazylinker_ext(void) { -#endif + PyObject *m; lazylinker_ext_CLazyLinkerType.tp_new = PyType_GenericNew; if (PyType_Ready(&lazylinker_ext_CLazyLinkerType) < 0) - return RETVAL; -#if defined(NPY_PY3K) + return NULL; + m = PyModule_Create(&moduledef); -#else - m = Py_InitModule3("lazylinker_ext", lazylinker_ext_methods, - "Example module that creates an extension type."); -#endif Py_INCREF(&lazylinker_ext_CLazyLinkerType); PyModule_AddObject(m, "CLazyLinker", (PyObject *)&lazylinker_ext_CLazyLinkerType); - return RETVAL; + return m; } diff --git a/pytensor/link/c/c_code/pytensor_mod_helper.h b/pytensor/link/c/c_code/pytensor_mod_helper.h index d3e4b29a2b..2f857e6775 100644 --- a/pytensor/link/c/c_code/pytensor_mod_helper.h +++ b/pytensor/link/c/c_code/pytensor_mod_helper.h @@ -18,14 +18,8 @@ #define PYTENSOR_EXTERN #endif -#if PY_MAJOR_VERSION < 3 -#define PYTENSOR_RTYPE void -#else -#define PYTENSOR_RTYPE PyObject * -#endif - /* We need to redefine PyMODINIT_FUNC to add MOD_PUBLIC in the middle */ #undef PyMODINIT_FUNC -#define PyMODINIT_FUNC PYTENSOR_EXTERN MOD_PUBLIC PYTENSOR_RTYPE +#define PyMODINIT_FUNC PYTENSOR_EXTERN MOD_PUBLIC PyObject * #endif diff --git a/pytensor/link/c/lazylinker_c.py b/pytensor/link/c/lazylinker_c.py index 679cb4e290..ce67190342 100644 --- a/pytensor/link/c/lazylinker_c.py +++ b/pytensor/link/c/lazylinker_c.py @@ -14,7 +14,7 @@ _logger = logging.getLogger(__file__) force_compile = False -version = 0.212 # must match constant returned in function get_version() +version = 0.3 # must match constant returned in function get_version() lazylinker_ext: ModuleType | None = None From 92d96ff372933608f7e175d5c29deed60be41d91 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 11:42:14 +0200 Subject: [PATCH 28/43] Changes for deprecations in numpy 2.0 C-API - replace `->elsize` by `PyArray_ITEMSIZE` - don't use deprecated PyArray_MoveInto --- pytensor/sparse/basic.py | 20 +++---- pytensor/sparse/rewriting.py | 94 ++++++++++++++++----------------- pytensor/tensor/blas.py | 14 ++--- pytensor/tensor/blas_headers.py | 4 +- tests/compile/test_debugmode.py | 6 +-- 5 files changed, 69 insertions(+), 69 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index c590bc804a..7f200b2a7c 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -3610,7 +3610,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3647,11 +3647,11 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp nnz = PyArray_DIMS({_indices})[0]; npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; @@ -3744,7 +3744,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3782,11 +3782,11 @@ def c_code(self, node, name, inputs, outputs, sub): // extract number of rows npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index bf6d6f0bc6..13735d2aca 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -158,8 +158,8 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{y}* ydata = (dtype_{y}*)PyArray_DATA({y}); dtype_{z}* zdata = (dtype_{z}*)PyArray_DATA({z}); - npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_DESCR({y})->elsize; - npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; + npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_ITEMSIZE({y}); + npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); npy_intp pos; if ({format} == 0){{ @@ -186,7 +186,7 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[3]] def c_code_cache_version(self): - return (2,) + return (3,) @node_rewriter([sparse.AddSD]) @@ -361,13 +361,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -436,7 +436,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3,) + return (4,) sd_csc = StructuredDotCSC() @@ -555,13 +555,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -614,7 +614,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ def c_code_cache_version(self): - return (2,) + return (3,) sd_csr = StructuredDotCSR() @@ -845,12 +845,12 @@ def c_code(self, node, name, inputs, outputs, sub): const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA({x_ptr}); const dtype_{alpha} alpha = ((dtype_{alpha}*)PyArray_DATA({alpha}))[0]; - npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_DESCR({zn})->elsize; - npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_DESCR({x_val})->elsize; - npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_DESCR({x_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_DESCR({x_ptr})->elsize; - npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_DESCR({y})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_ITEMSIZE({zn}); + npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_ITEMSIZE({x_val}); + npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_ITEMSIZE({x_ind}); + npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_ITEMSIZE({x_ptr}); + npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_ITEMSIZE({y}); // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction if ((N > 0x7fffffffL)||(Sy > 0x7fffffffL)||(Szn > 0x7fffffffL)||(Sy < -0x7fffffffL)||(Szn < -0x7fffffffL)) @@ -896,7 +896,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3, blas.blas_header_version()) + return (4, blas.blas_header_version()) usmm_csc_dense = UsmmCscDense(inplace=False) @@ -1035,13 +1035,13 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; - npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_DESCR({b_val})->elsize; - npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_DESCR({b_ind})->elsize; - npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_DESCR({b_ptr})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); + npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_ITEMSIZE({b_val}); + npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_ITEMSIZE({b_ind}); + npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_ITEMSIZE({b_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -1086,7 +1086,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ def c_code_cache_version(self): - return (3,) + return (4,) csm_grad_c = CSMGradC() @@ -1482,7 +1482,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (2,) + return (3,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1544,7 +1544,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over rows for (npy_intp j = 0; j < N; ++j) @@ -1655,7 +1655,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (3,) + return (4,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1723,7 +1723,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over columns for (npy_intp j = 0; j < N; ++j) @@ -1868,7 +1868,7 @@ def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols): ) def c_code_cache_version(self): - return (4, blas.blas_header_version()) + return (5, blas.blas_header_version()) def c_support_code(self, **kwargs): return blas.blas_header_text() @@ -1995,14 +1995,14 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{z_ind}* __restrict__ Dzi = (dtype_{z_ind}*)PyArray_DATA({z_ind}); dtype_{z_ptr}* __restrict__ Dzp = (dtype_{z_ptr}*)PyArray_DATA({z_ptr}); - const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_DESCR({x})->elsize; - const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; - const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_DESCR({p_data})->elsize; - const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_DESCR({p_ind})->elsize; - const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_DESCR({p_ptr})->elsize; - const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_DESCR({z_data})->elsize; - const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_DESCR({z_ind})->elsize; - const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_DESCR({z_ptr})->elsize; + const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_ITEMSIZE({x}); + const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); + const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_ITEMSIZE({p_data}); + const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_ITEMSIZE({p_ind}); + const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_ITEMSIZE({p_ptr}); + const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_ITEMSIZE({z_data}); + const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_ITEMSIZE({z_ind}); + const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_ITEMSIZE({z_ptr}); memcpy(Dzi, Dpi, PyArray_DIMS({p_ind})[0]*sizeof(dtype_{p_ind})); memcpy(Dzp, Dpp, PyArray_DIMS({p_ptr})[0]*sizeof(dtype_{p_ptr})); diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index d0f524e413..592a4ba27c 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -498,7 +498,7 @@ def c_header_dirs(self, **kwargs): int unit = 0; int type_num = PyArray_DESCR(%(_x)s)->type_num; - int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes npy_intp* Nx = PyArray_DIMS(%(_x)s); npy_intp* Ny = PyArray_DIMS(%(_y)s); @@ -789,7 +789,7 @@ def build_gemm_call(self): ) def build_gemm_version(self): - return (13, blas_header_version()) + return (14, blas_header_version()) class Gemm(GemmRelated): @@ -1030,7 +1030,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(x_new, %(_x)s) == -1) + if(PyArray_CopyInto(x_new, %(_x)s) == -1) { %(fail)s } @@ -1056,7 +1056,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(y_new, %(_y)s) == -1) + if(PyArray_CopyInto(y_new, %(_y)s) == -1) { %(fail)s } @@ -1102,7 +1102,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): gv = self.build_gemm_version() if gv: - return (7, *gv) + return (8, *gv) else: return gv @@ -1538,7 +1538,7 @@ def contiguous(var, ndim): return f""" int type_num = PyArray_DESCR({_x})->type_num; - int type_size = PyArray_DESCR({_x})->elsize; // in bytes + int type_size = PyArray_ITEMSIZE({_x}); // in bytes if (PyArray_NDIM({_x}) != 3) {{ PyErr_Format(PyExc_NotImplementedError, @@ -1598,7 +1598,7 @@ def contiguous(var, ndim): def c_code_cache_version(self): from pytensor.tensor.blas_headers import blas_header_version - return (5, blas_header_version()) + return (6, blas_header_version()) def grad(self, inp, grads): x, y = inp diff --git a/pytensor/tensor/blas_headers.py b/pytensor/tensor/blas_headers.py index 645f04bfb3..5d49b70ec4 100644 --- a/pytensor/tensor/blas_headers.py +++ b/pytensor/tensor/blas_headers.py @@ -1053,7 +1053,7 @@ def openblas_threads_text(): def blas_header_version(): # Version for the base header - version = (9,) + version = (10,) if detect_macos_sdot_bug(): if detect_macos_sdot_bug.fix_works: # Version with fix @@ -1071,7 +1071,7 @@ def ____gemm_code(check_ab, a_init, b_init): const char * error_string = NULL; int type_num = PyArray_DESCR(_x)->type_num; - int type_size = PyArray_DESCR(_x)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(_x); // in bytes npy_intp* Nx = PyArray_DIMS(_x); npy_intp* Ny = PyArray_DIMS(_y); diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index 95e52d6b53..fae76fab0d 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -146,7 +146,7 @@ def dontuse_perform(self, node, inp, out_): raise ValueError(self.behaviour) def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inp, out, sub): (a,) = inp @@ -165,8 +165,8 @@ def c_code(self, node, name, inp, out, sub): prep_vars = f""" //the output array has size M x N npy_intp M = PyArray_DIMS({a})[0]; - npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_DESCR({a})->elsize; - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; + npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_ITEMSIZE({a}); + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); npy_double * Da = (npy_double*)PyArray_BYTES({a}); npy_double * Dz = (npy_double*)PyArray_BYTES({z}); From b20f4015943ff46c160675d6ab79f05e3bffe581 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Thu, 13 Feb 2025 10:59:55 +0000 Subject: [PATCH 29/43] Update type hint for c_code_cache_version Anything `Hashable` should work, but I've made the return type `tuple[Hashable]` to keep with the current style. This means, e.g., we can use strings in the cache version. --- pytensor/link/c/interface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/link/c/interface.py b/pytensor/link/c/interface.py index 7e281af947..e9375d2511 100644 --- a/pytensor/link/c/interface.py +++ b/pytensor/link/c/interface.py @@ -1,7 +1,7 @@ import typing import warnings from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Hashable from typing import Optional from pytensor.graph.basic import Apply, Constant @@ -155,7 +155,7 @@ def c_init_code(self, **kwargs) -> list[str]: """Return a list of code snippets to be inserted in module initialization.""" return [] - def c_code_cache_version(self) -> tuple[int, ...]: + def c_code_cache_version(self) -> tuple[Hashable, ...]: """Return a tuple of integers indicating the version of this `Op`. An empty tuple indicates an "unversioned" `Op` that will not be cached @@ -223,7 +223,7 @@ def c_code( """ raise NotImplementedError() - def c_code_cache_version_apply(self, node: Apply) -> tuple[int, ...]: + def c_code_cache_version_apply(self, node: Apply) -> tuple[Hashable, ...]: """Return a tuple of integers indicating the version of this `Op`. An empty tuple indicates an "unversioned" `Op` that will not be From 69713deef2a7d07a86bb0521478238e2057b493f Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Mon, 29 Jul 2024 09:42:41 +0100 Subject: [PATCH 30/43] Make complex scalars work with numpy 2.0 This is done using C++ generic functions to get/set the real/imag parts of complex numbers. This gives us an easy way to support Numpy v < 2.0, and allows the type underlying the bit width types, like pytensor_complex128, to be correctly inferred from the numpy complex types they inherit from. Updated pytensor_complex struct to use get/set real/imag aliases defined above. Also updated operators such as `Abs` to use get_real, get_imag. Macros have been added to ensure compatibility with numpy < 2.0 Note: redefining the complex arithmetic here means that we aren't treating NaNs and infinities as carefully as the C99 standard suggets (see Appendix G of the standard). The code has been like this since it was added to Theano, so we're keeping the existing behavior. --- pytensor/scalar/basic.py | 225 ++++++++++++++++++++++++++++----------- 1 file changed, 161 insertions(+), 64 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 94039f8091..d7d719e2f4 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -349,6 +349,8 @@ def c_headers(self, c_compiler=None, **kwargs): # we declare them here and they will be re-used by TensorType l.append("") l.append("") + l.append("") + if config.lib__amdlibm and c_compiler.supports_amdlibm: l += [""] return l @@ -517,73 +519,167 @@ def c_support_code(self, **kwargs): # In that case we add the 'int' type to the real types. real_types.append("int") + # Macros for backwards compatibility with numpy < 2.0 + # + # In numpy 2.0+, these are defined in npy_math.h, but + # for early versions, they must be vendored by users (e.g. PyTensor) + backwards_compat_macros = """ + #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_ + #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_ + + #include + + #ifndef NPY_CSETREALF + #define NPY_CSETREALF(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAGF + #define NPY_CSETIMAGF(c, i) (c)->imag = (i) + #endif + #ifndef NPY_CSETREAL + #define NPY_CSETREAL(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAG + #define NPY_CSETIMAG(c, i) (c)->imag = (i) + #endif + #ifndef NPY_CSETREALL + #define NPY_CSETREALL(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAGL + #define NPY_CSETIMAGL(c, i) (c)->imag = (i) + #endif + + #endif + """ + + def _make_get_set_real_imag(scalar_type: str) -> str: + """Make overloaded getter/setter functions for real/imag parts of numpy complex types. + + The functions called by these getter/setter functions are defining in npy_math.h, or + in the `backward_compat_macros` defined above. + + Args: + scalar_type: float, double, or longdouble + + Returns: + C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the + given type. + """ + complex_type = "npy_c" + scalar_type + suffix = "" if scalar_type == "double" else scalar_type[0] + + if scalar_type == "longdouble": + scalar_type = "npy_" + scalar_type + + return_type = scalar_type + + template = f""" + static inline {return_type} get_real(const {complex_type} z) + {{ + return npy_creal{suffix}(z); + }} + + static inline void set_real({complex_type} *z, const {scalar_type} r) + {{ + NPY_CSETREAL{suffix.upper()}(z, r); + }} + + static inline {return_type} get_imag(const {complex_type} z) + {{ + return npy_cimag{suffix}(z); + }} + + static inline void set_imag({complex_type} *z, const {scalar_type} i) + {{ + NPY_CSETIMAG{suffix.upper()}(z, i); + }} + """ + return template + + get_set_aliases = "\n".join( + _make_get_set_real_imag(stype) + for stype in ["float", "double", "longdouble"] + ) + + get_set_aliases = backwards_compat_macros + "\n" + get_set_aliases + + # Template for defining pytensor_complex64 and pytensor_complex128 structs/classes + # + # The npy_complex64, npy_complex128 types are aliases defined at run time based on + # the size of floats and doubles on the machine. This means that both types are + # not necessarily defined on every machine, but a machine with 32-bit floats and + # 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128 + # as an alias of npy_complex128. + # + # In any case, the get/set real/imag functions defined above will always work for + # npy_complex64 and npy_complex128. template = """ - struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s - { - typedef pytensor_complex%(nbits)s complex_type; - typedef npy_float%(half_nbits)s scalar_type; - - complex_type operator +(const complex_type &y) const { - complex_type ret; - ret.real = this->real + y.real; - ret.imag = this->imag + y.imag; - return ret; - } - - complex_type operator -() const { - complex_type ret; - ret.real = -this->real; - ret.imag = -this->imag; - return ret; - } - bool operator ==(const complex_type &y) const { - return (this->real == y.real) && (this->imag == y.imag); - } - bool operator ==(const scalar_type &y) const { - return (this->real == y) && (this->imag == 0); - } - complex_type operator -(const complex_type &y) const { - complex_type ret; - ret.real = this->real - y.real; - ret.imag = this->imag - y.imag; - return ret; - } - complex_type operator *(const complex_type &y) const { - complex_type ret; - ret.real = this->real * y.real - this->imag * y.imag; - ret.imag = this->real * y.imag + this->imag * y.real; - return ret; - } - complex_type operator /(const complex_type &y) const { - complex_type ret; - scalar_type y_norm_square = y.real * y.real + y.imag * y.imag; - ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square; - ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; - return ret; - } - template - complex_type& operator =(const T& y); - - pytensor_complex%(nbits)s() {} - - template - pytensor_complex%(nbits)s(const T& y) { *this = y; } - - template - pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; } + struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s { + typedef pytensor_complex%(nbits)s complex_type; + typedef npy_float%(half_nbits)s scalar_type; + + complex_type operator+(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) + get_real(y)); + set_imag(&ret, get_imag(*this) + get_imag(y)); + return ret; + } + + complex_type operator-() const { + complex_type ret; + set_real(&ret, -get_real(*this)); + set_imag(&ret, -get_imag(*this)); + return ret; + } + bool operator==(const complex_type &y) const { + return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y)); + } + bool operator==(const scalar_type &y) const { + return (get_real(*this) == y) && (get_real(*this) == 0); + } + complex_type operator-(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) - get_real(y)); + set_imag(&ret, get_imag(*this) - get_imag(y)); + return ret; + } + complex_type operator*(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y)); + set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y)); + return ret; + } + complex_type operator/(const complex_type &y) const { + complex_type ret; + scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y); + set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square); + set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square); + return ret; + } + template complex_type &operator=(const T &y); + + + pytensor_complex%(nbits)s() {} + + template pytensor_complex%(nbits)s(const T &y) { *this = y; } + + template + pytensor_complex%(nbits)s(const TR &r, const TI &i) { + set_real(this, r); + set_imag(this, i); + } }; """ def operator_eq_real(mytype, othertype): return f""" template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y; this->imag=0; return *this; }} + {{ set_real(this, y); set_imag(this, 0); return *this; }} """ def operator_eq_cplx(mytype, othertype): return f""" template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y.real; this->imag=y.imag; return *this; }} + {{ set_real(this, get_real(y)); set_imag(this, get_imag(y)); return *this; }} """ operator_eq = "".join( @@ -605,10 +701,10 @@ def operator_eq_cplx(mytype, othertype): def operator_plus_real(mytype, othertype): return f""" const {mytype} operator+(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real+y, x.imag); }} + {{ return {mytype}(get_real(x) + y, get_imag(x)); }} const {mytype} operator+(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real+y, x.imag); }} + {{ return {mytype}(get_real(x) + y, get_imag(x)); }} """ operator_plus = "".join( @@ -620,10 +716,10 @@ def operator_plus_real(mytype, othertype): def operator_minus_real(mytype, othertype): return f""" const {mytype} operator-(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real-y, x.imag); }} + {{ return {mytype}(get_real(x) - y, get_imag(x)); }} const {mytype} operator-(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(y-x.real, -x.imag); }} + {{ return {mytype}(y - get_real(x), -get_imag(x)); }} """ operator_minus = "".join( @@ -635,10 +731,10 @@ def operator_minus_real(mytype, othertype): def operator_mul_real(mytype, othertype): return f""" const {mytype} operator*(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real*y, x.imag*y); }} + {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }} const {mytype} operator*(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real*y, x.imag*y); }} + {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }} """ operator_mul = "".join( @@ -648,7 +744,8 @@ def operator_mul_real(mytype, othertype): ) return ( - template % dict(nbits=64, half_nbits=32) + get_set_aliases + + template % dict(nbits=64, half_nbits=32) + template % dict(nbits=128, half_nbits=64) + operator_eq + operator_plus @@ -663,7 +760,7 @@ def c_init_code(self, **kwargs): return ["import_array();"] def c_code_cache_version(self): - return (13, np.__version__) + return (14, np.__version__) def get_shape_info(self, obj): return obj.itemsize @@ -2567,7 +2664,7 @@ def c_code(self, node, name, inputs, outputs, sub): if type in float_types: return f"{z} = fabs({x});" if type in complex_types: - return f"{z} = sqrt({x}.real*{x}.real + {x}.imag*{x}.imag);" + return f"{z} = sqrt(get_real({x}) * get_real({x}) + get_imag({x}) * get_imag({x}));" if node.outputs[0].type == bool: return f"{z} = ({x}) ? 1 : 0;" if type in uint_types: From 9416df28376888596ac3bc8719064b2df11bd381 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Mon, 29 Jul 2024 09:37:56 +0100 Subject: [PATCH 31/43] Use Python implementation for AdvancedInSubtensor1 MapIter was removed from the public numpy C-API in version 2.0, so we raise a not implemented error to default to the python code for the AdvancedInSubtensor1. The python version, defined in `AdvancedInSubtensor1.perform` calls `np.add.at`, which uses `MapIter` behind the scenes. There is active development on Numpy to improve the efficiency of `np.add.at`. To skip the C implementation and use the Python implementation, we raise a NotImplementedError for this op's c code if numpy>=2.0. --- pytensor/tensor/subtensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 46b9cc06fd..51e6dba0d8 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2520,8 +2520,7 @@ def gen_num(typen): return code def c_code(self, node, name, input_names, output_names, sub): - numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] - if bool(numpy_ver < [1, 8]): + if numpy_version < "1.8.0" or using_numpy_2: raise NotImplementedError x, y, idx = input_names From f4f58a4c42f200a4137946c4a4e70508f6017d29 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 6 Aug 2024 11:59:24 +0100 Subject: [PATCH 32/43] Changed copy to deepcopy for rng This was done for the python linker and numba linker. deepcopy seems to be the recommended method for copying a numpy Generator. After this numpy PR: https://github.com/numpy/numpy/pull/26293/commits/44ba7ca07984557f2006f9a6916adb8e3ecfca61 `copy` didn't seem to actually make an independent copy of the `np.random.Generator` objects spawned by `RandomStream`. This was causing the "test values" computed by e.g. `RandomStream.uniform` to increment the RNG state, which was causing tests that rely on `RandomStream` to fail. Here is some related discussion: https://github.com/numpy/numpy/issues/24086 I didn't see any official documentation about a change in numpy that would make copy stop working. --- pytensor/link/numba/dispatch/random.py | 4 ++-- pytensor/tensor/random/op.py | 4 ++-- tests/tensor/random/test_basic.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index e80a033c82..e20d99c605 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from copy import copy +from copy import copy, deepcopy from functools import singledispatch from textwrap import dedent @@ -34,7 +34,7 @@ def copy_NumPyRandomGenerator(rng): def impl(rng): # TODO: Open issue on Numba? with numba.objmode(new_rng=types.npy_rng): - new_rng = copy(rng) + new_rng = deepcopy(rng) return new_rng diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index c76d250c9e..a8b67dee4f 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Sequence -from copy import copy +from copy import deepcopy from typing import Any, cast import numpy as np @@ -395,7 +395,7 @@ def perform(self, node, inputs, outputs): # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. if not self.inplace: - rng = copy(rng) + rng = deepcopy(rng) outputs[0][0] = rng outputs[1][0] = np.asarray( diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 23d1b87020..4192a6c473 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1,6 +1,6 @@ import pickle import re -from copy import copy +from copy import deepcopy import numpy as np import pytest @@ -114,7 +114,9 @@ def test_fn(*args, random_state=None, **kwargs): pt_rng = shared(rng, borrow=True) - numpy_res = np.asarray(test_fn(*param_vals, random_state=copy(rng), **kwargs_vals)) + numpy_res = np.asarray( + test_fn(*param_vals, random_state=deepcopy(rng), **kwargs_vals) + ) pytensor_res = rv(*params, rng=pt_rng, **kwargs) From 2944552d6216954fac4ec923132d05c199e86474 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Wed, 29 Jan 2025 11:11:52 +0000 Subject: [PATCH 33/43] Change rng.__getstate__ to rng.bit_generator.state numpy.random.Generator.__getstate__() now returns none; to see the state of the bit generator, you need to use Generator.bit_generator.state. This change affects `RandomGeneratorType`, and several of the random tests (including some for Jax.) --- pytensor/link/jax/dispatch/random.py | 2 +- pytensor/tensor/random/type.py | 4 ++-- tests/link/jax/test_random.py | 4 +++- tests/tensor/random/test_type.py | 10 +++++----- tests/tensor/random/test_utils.py | 12 +++++++++--- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index d66ddc049d..8a33dfac13 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -56,7 +56,7 @@ def assert_size_argument_jax_compatible(node): @jax_typify.register(Generator) def jax_typify_Generator(rng, **kwargs): - state = rng.__getstate__() + state = rng.bit_generator.state state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] # XXX: Is this a reasonable approach? diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index 88d5e6197f..df8e3b691d 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -87,8 +87,8 @@ def filter(self, data, strict=False, allow_downcast=None): @staticmethod def values_eq(a, b): - sa = a if isinstance(a, dict) else a.__getstate__() - sb = b if isinstance(b, dict) else b.__getstate__() + sa = a if isinstance(a, dict) else a.bit_generator.state + sb = b if isinstance(b, dict) else b.bit_generator.state def _eq(sa, sb): for key in sa: diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 2c0e4231c8..fa25f3aac0 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -63,7 +63,9 @@ def test_random_updates(rng_ctor): assert all( a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) for a, b in zip( - rng.get_value().__getstate__(), original_value.__getstate__(), strict=True + rng.get_value().bit_generator.state, + original_value.bit_generator.state, + strict=True, ) ) diff --git a/tests/tensor/random/test_type.py b/tests/tensor/random/test_type.py index d289862347..d358f2a93a 100644 --- a/tests/tensor/random/test_type.py +++ b/tests/tensor/random/test_type.py @@ -52,7 +52,7 @@ def test_filter(self): with pytest.raises(TypeError): rng_type.filter(1) - rng_dict = rng.__getstate__() + rng_dict = rng.bit_generator.state assert rng_type.is_valid_value(rng_dict) is False assert rng_type.is_valid_value(rng_dict, strict=False) @@ -88,13 +88,13 @@ def test_values_eq(self): assert rng_type.values_eq(bitgen_g, bitgen_h) assert rng_type.is_valid_value(bitgen_a, strict=True) - assert rng_type.is_valid_value(bitgen_b.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_b.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_c, strict=True) - assert rng_type.is_valid_value(bitgen_d.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_d.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_e, strict=True) - assert rng_type.is_valid_value(bitgen_f.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_f.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_g, strict=True) - assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_h.bit_generator.state, strict=False) def test_may_share_memory(self): bg_a = np.random.PCG64() diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 70e8a710e9..f7d8731c1b 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -165,14 +165,20 @@ def test_seed(self, rng_ctor): state_rng = random.state_updates[0][0].get_value(borrow=True) if hasattr(state_rng, "get_state"): - ref_state = ref_rng.get_state() random_state = state_rng.get_state() + + # hack to try to get something reasonable for ref_rng + try: + ref_state = ref_rng.get_state() + except AttributeError: + ref_state = list(ref_rng.bit_generator.state.values()) + assert np.array_equal(random_state[1], ref_state[1]) assert random_state[0] == ref_state[0] assert random_state[2:] == ref_state[2:] else: - ref_state = ref_rng.__getstate__() - random_state = state_rng.__getstate__() + ref_state = ref_rng.bit_generator.state + random_state = state_rng.bit_generator.state assert random_state["bit_generator"] == ref_state["bit_generator"] assert random_state["state"] == ref_state["state"] From 0aa10c0bccf2b573e906073962cb6ef15eea4eb8 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Wed, 7 Aug 2024 10:22:12 +0100 Subject: [PATCH 34/43] Replace use of `np.MAXDIMS` `np.MAXDIMS` was removed from the public API and no replacement is given in the migration docs. In numpy <= 1.26, the value of `np.MAXDIMS` was 32. This was often used as a flag to mean `axis=None`. In numpy >= 2.0, the maximum number of dims of an array has been increased to 64; simultaneously, a constant `NPY_RAVEL_AXIS` was added to the C-API to indicate that `axis=None`. In most cases, the use of `np.MAXDIMS` to check for `axis=None` can be replaced by the new constant `NPY_RAVEL_AXIS`. To make this constant accessible when using numpy <= 1.26, I added a function to insert `npy_2_compat.h` into the support code for the affected ops. --- pytensor/npy_2_compat.py | 15 ++++++-- pytensor/tensor/extra_ops.py | 47 ++++++++++++++++--------- pytensor/tensor/math.py | 14 ++++++-- pytensor/tensor/special.py | 66 +++++++++++++++++++++++------------ pytensor/tensor/subtensor.py | 10 +++--- tests/tensor/test_elemwise.py | 4 ++- 6 files changed, 106 insertions(+), 50 deletions(-) diff --git a/pytensor/npy_2_compat.py b/pytensor/npy_2_compat.py index 30214154a2..facc3b8865 100644 --- a/pytensor/npy_2_compat.py +++ b/pytensor/npy_2_compat.py @@ -46,10 +46,21 @@ ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined] +# used in tests: the type of error thrown if a value is too large for the specified +# numpy data type is different in numpy 2.x +UintOverflowError = OverflowError if using_numpy_2 else TypeError + + +# to patch up some of the C code, we need to use these special values... if using_numpy_2: - UintOverflowError = OverflowError + numpy_axis_is_none_flag = np.iinfo(np.int32).min # the value of "NPY_RAVEL_AXIS" else: - UintOverflowError = TypeError + # 32 is the value used to mark axis = None in Numpy C-API prior to version 2.0 + numpy_axis_is_none_flag = 32 + + +# max number of dims is 64 in numpy 2.x; 32 in older versions +numpy_maxdims = 64 if using_numpy_2 else 32 def npy_2_compat_header() -> str: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index e9d06ae9c2..7c6dfb9876 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -2,7 +2,6 @@ from collections.abc import Collection, Iterable import numpy as np -from numpy.exceptions import AxisError import pytensor import pytensor.scalar.basic as ps @@ -19,10 +18,11 @@ from pytensor.link.c.type import EnumList, Generic from pytensor.npy_2_compat import ( normalize_axis_index, - normalize_axis_tuple, + npy_2_compat_header, + numpy_axis_is_none_flag, ) from pytensor.raise_op import Assert -from pytensor.scalar import int32 as int_t +from pytensor.scalar import int64 as int_t from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb @@ -47,6 +47,7 @@ from pytensor.tensor.shape import Shape_i from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector +from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.variable import TensorVariable from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH @@ -302,7 +303,11 @@ def __init__(self, axis: int | None = None, mode="add"): self.axis = axis self.mode = mode - c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) + @property + def c_axis(self) -> int: + if self.axis is None: + return numpy_axis_is_none_flag + return self.axis def make_node(self, x): x = ptb.as_tensor_variable(x) @@ -359,24 +364,37 @@ def infer_shape(self, fgraph, node, shapes): return shapes + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inames, onames, sub): (x,) = inames (z,) = onames fail = sub["fail"] params = sub["params"] - code = f""" - int axis = {params}->c_axis; + if self.axis is None: + axis_code = "int axis = NPY_RAVEL_AXIS;\n" + else: + axis_code = f"int axis = {params}->c_axis;\n" + + code = ( + axis_code + + f""" + #undef NPY_UF_DBG_TRACING + #define NPY_UF_DBG_TRACING 1 + if (axis == 0 && PyArray_NDIM({x}) == 1) - axis = NPY_MAXDIMS; + axis = NPY_RAVEL_AXIS; npy_intp shape[1] = {{ PyArray_SIZE({x}) }}; - if(axis == NPY_MAXDIMS && !({z} && PyArray_DIMS({z})[0] == shape[0])) + if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0])) {{ Py_XDECREF({z}); - {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x})); + {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x})); }} - else if(axis != NPY_MAXDIMS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) + else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) {{ Py_XDECREF({z}); {z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x})); @@ -403,11 +421,12 @@ def c_code(self, node, name, inames, onames, sub): Py_XDECREF(t); }} """ + ) return code def c_code_cache_version(self): - return (8,) + return (9,) def __str__(self): return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}" @@ -598,11 +617,7 @@ def squeeze(x, axis=None): elif not isinstance(axis, Collection): axis = (axis,) - # scalar inputs are treated as 1D regarding axis in this `Op` - try: - axis = normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) - except AxisError: - raise AxisError(axis, ndim=_x.ndim) + axis = normalize_reduce_axis(axis, ndim=_x.ndim) if not axis: # Nothing to do diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index c4f3dc50a5..a88d678392 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -13,7 +13,11 @@ from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.npy_2_compat import normalize_axis_tuple +from pytensor.npy_2_compat import ( + normalize_axis_tuple, + npy_2_compat_header, + numpy_axis_is_none_flag, +) from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.scalar.basic import BinaryScalarOp @@ -160,7 +164,7 @@ def get_params(self, node): c_axis = np.int64(self.axis[0]) else: # The value here doesn't matter, it won't be used - c_axis = np.int64(-1) + c_axis = numpy_axis_is_none_flag return self.params_type.get_params(c_axis=c_axis) def make_node(self, x): @@ -203,13 +207,17 @@ def perform(self, node, inp, outs): max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (argmax,) = out fail = sub["fail"] params = sub["params"] if self.axis is None: - axis_code = "axis = NPY_MAXDIMS;" + axis_code = "axis = NPY_RAVEL_AXIS;" else: if len(self.axis) != 1: raise NotImplementedError() diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index a2f02fabd8..5b05ad03f4 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -6,6 +6,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp +from pytensor.npy_2_compat import npy_2_compat_header from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.math import gamma, gammaln, log, neg, sum @@ -60,12 +61,16 @@ def infer_shape(self, fgraph, node, shape): return [shape[1]] def c_code_cache_version(self): - return (4,) + return (5,) + + def c_support_code_apply(self, node: Apply, name: str) -> str: + # return super().c_support_code_apply(node, name) + return npy_2_compat_header() def c_code(self, node, name, inp, out, sub): dy, sm = inp (dx,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] return dedent( @@ -79,7 +84,7 @@ def c_code(self, node, name, inp, out, sub): int sm_ndim = PyArray_NDIM({sm}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || sm_ndim == 1); // Validate inputs if ((PyArray_TYPE({dy}) != NPY_DOUBLE) && @@ -95,13 +100,15 @@ def c_code(self, node, name, inp, out, sub): {fail}; }} - if (axis < 0) axis = sm_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > sm_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad"); - {fail}; + if (axis < 0) axis = sm_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > sm_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad"); + {fail}; + }} }} - if (({dx} == NULL) || !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim))) {{ @@ -289,10 +296,14 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return ["", ""] + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] # dtype = node.inputs[0].type.dtype_specs()[1] # TODO: put this into a templated function, in the support code @@ -309,7 +320,7 @@ def c_code(self, node, name, inp, out, sub): int x_ndim = PyArray_NDIM({x}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && @@ -319,11 +330,14 @@ def c_code(self, node, name, inp, out, sub): {fail} }} - if (axis < 0) axis = x_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax"); - {fail} + if (axis < 0) axis = x_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax"); + {fail} + }} }} // Allocate Output Array @@ -481,7 +495,7 @@ def c_code(self, node, name, inp, out, sub): @staticmethod def c_code_cache_version(): - return (4,) + return (5,) def softmax(c, axis=None): @@ -541,10 +555,14 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return [""] + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] return dedent( @@ -558,7 +576,7 @@ def c_code(self, node, name, inp, out, sub): int x_ndim = PyArray_NDIM({x}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && @@ -568,13 +586,15 @@ def c_code(self, node, name, inp, out, sub): {fail} }} - if (axis < 0) axis = x_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); - {fail} + if (axis < 0) axis = x_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); + {fail} + }} }} - // Allocate Output Array if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim))) {{ @@ -730,7 +750,7 @@ def c_code(self, node, name, inp, out, sub): @staticmethod def c_code_cache_version(): - return (1,) + return (2,) def log_softmax(c, axis=None): diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 51e6dba0d8..c1fdb463b6 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -18,7 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.npy_2_compat import numpy_version, using_numpy_2 +from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2 from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -2149,7 +2149,7 @@ def infer_shape(self, fgraph, node, ishapes): def c_support_code(self, **kwargs): # In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG, # which is not defined. It should be NPY_MIN_LONG instead in that case. - return dedent( + return npy_2_compat_header() + dedent( """\ #ifndef MIN_LONG #define MIN_LONG NPY_MIN_LONG @@ -2174,7 +2174,7 @@ def c_code(self, node, name, input_names, output_names, sub): if (!PyArray_CanCastSafely(i_type, NPY_INTP) && PyArray_SIZE({i_name}) > 0) {{ npy_int64 min_val, max_val; - PyObject* py_min_val = PyArray_Min({i_name}, NPY_MAXDIMS, + PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS, NULL); if (py_min_val == NULL) {{ {fail}; @@ -2184,7 +2184,7 @@ def c_code(self, node, name, input_names, output_names, sub): if (min_val == -1 && PyErr_Occurred()) {{ {fail}; }} - PyObject* py_max_val = PyArray_Max({i_name}, NPY_MAXDIMS, + PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS, NULL); if (py_max_val == NULL) {{ {fail}; @@ -2243,7 +2243,7 @@ def c_code(self, node, name, input_names, output_names, sub): """ def c_code_cache_version(self): - return (0, 1, 2) + return (0, 1, 2, 3) advanced_subtensor1 = AdvancedSubtensor1() diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 8555a1d29f..45a7f53c2c 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -18,6 +18,7 @@ from pytensor.graph.replace import vectorize_node from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker +from pytensor.npy_2_compat import numpy_maxdims from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -121,7 +122,8 @@ def test_infer_shape(self): def test_too_big_rank(self): x = self.type(self.dtype, shape=())() - y = x.dimshuffle(("x",) * (np.MAXDIMS + 1)) + y = x.dimshuffle(("x",) * (numpy_maxdims + 1)) + with pytest.raises(ValueError): y.eval({x: 0}) From b349a9a763f2a054016b152a05bca2c8bfc1cc77 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 27 Aug 2024 13:22:14 +0100 Subject: [PATCH 35/43] Fixed failed test due to uint8 overflow In numpy 2.0, -1 as uint8 is out of bounds, whereas previously it would be converted to 255. This affected the test helper function `reduced_bitwise_and`. The helper function was changed to use 255 instead of -1 if the dtype was uint8, since this is what is needed to match the behavior of the "bitwise and" op. `reduced_bitwise_and` was only used by `TestCAReduce` in `tests/tensor/test_elemwise.py`, so it was moved there from `tests/tensor/test_math.py` --- tests/compile/function/test_function.py | 9 +++++---- tests/compile/function/test_pfunc.py | 17 ++++++++++------- tests/tensor/test_elemwise.py | 22 +++++++++++++++++++++- tests/tensor/test_math.py | 16 ---------------- 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/tests/compile/function/test_function.py b/tests/compile/function/test_function.py index f835953b19..9f75ef15d8 100644 --- a/tests/compile/function/test_function.py +++ b/tests/compile/function/test_function.py @@ -11,6 +11,7 @@ from pytensor.compile.function import function, function_dump from pytensor.compile.io import In from pytensor.configdefaults import config +from pytensor.npy_2_compat import UintOverflowError from pytensor.tensor.type import ( bscalar, bvector, @@ -166,12 +167,12 @@ def test_in_allow_downcast_int(self): # Value too big for a, silently ignored assert np.array_equal(f([2**20], np.ones(1, dtype="int8"), 1), [2]) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError (in numpy >= 2.0... TypeError in numpy < 2.0) + with pytest.raises(UintOverflowError): f([3], [312], 1) - # Value too big for c, raises TypeError - with pytest.raises(TypeError): + # Value too big for c, raises OverflowError + with pytest.raises(UintOverflowError): f([3], [6], 806) def test_in_allow_downcast_floatX(self): diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index 0a9bda9846..249f230d81 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -9,6 +9,7 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.graph.utils import MissingInputError +from pytensor.npy_2_compat import UintOverflowError from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.type import ( bscalar, @@ -237,12 +238,12 @@ def test_param_allow_downcast_int(self): # Value too big for a, silently ignored assert np.all(f([2**20], np.ones(1, dtype="int8"), 1) == 2) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): f([3], [312], 1) - # Value too big for c, raises TypeError - with pytest.raises(TypeError): + # Value too big for c, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): f([3], [6], 806) def test_param_allow_downcast_floatX(self): @@ -327,8 +328,8 @@ def test_allow_input_downcast_int(self): with pytest.raises(TypeError): g([3], np.array([6], dtype="int16"), 0) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): g([3], [312], 0) h = pfunc([a, b, c], (a + b + c)) # Default: allow_input_downcast=None @@ -336,7 +337,9 @@ def test_allow_input_downcast_int(self): assert np.all(h([3], [6], 0) == 9) with pytest.raises(TypeError): h([3], np.array([6], dtype="int16"), 0) - with pytest.raises(TypeError): + + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): h([3], [312], 0) def test_allow_downcast_floatX(self): diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 45a7f53c2c..5ce533d3a3 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -40,7 +40,27 @@ ) from tests import unittest_tools from tests.link.test_link import make_function -from tests.tensor.test_math import reduce_bitwise_and + + +def reduce_bitwise_and(x, axis=-1, dtype="int8"): + """Helper function for TestCAReduce""" + if dtype == "uint8": + # in numpy version >= 2.0, out of bounds uint8 values are not converted + identity = np.array((255,), dtype=dtype)[0] + else: + identity = np.array((-1,), dtype=dtype)[0] + + shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis) + if 0 in shape_without_axis: + return np.empty(shape=shape_without_axis, dtype=x.dtype) + + def custom_reduce(a): + out = identity + for i in range(a.size): + out = np.bitwise_and(a[i], out) + return out + + return np.apply_along_axis(custom_reduce, axis, x) class TestDimShuffle(unittest_tools.InferShapeTester): diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 40c505b7b4..64af7057a5 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3444,22 +3444,6 @@ def test_var_axes(self): x.var(a) -def reduce_bitwise_and(x, axis=-1, dtype="int8"): - identity = np.array((-1,), dtype=dtype)[0] - - shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis) - if 0 in shape_without_axis: - return np.empty(shape=shape_without_axis, dtype=x.dtype) - - def custom_reduce(a): - out = identity - for i in range(a.size): - out = np.bitwise_and(a[i], out) - return out - - return np.apply_along_axis(custom_reduce, axis, x) - - def test_clip_grad(): # test the gradient of clip def func(x, y, z): From 9e919c74d65d9b27445765b7211ec18d20be246b Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Fri, 30 Aug 2024 12:01:19 +0100 Subject: [PATCH 36/43] Changes due to new numpy scalar promotion rules 1. Changed autocaster due to new promotion rules With "weak promotion" of python types in Numpy 2.0, the statement `1.1 == np.asarray(1.1).astype('float32')` is True, whereas in Numpy 1.26, it was false. However, in numpy 1.26, `1.1 == np.asarray([1.1]).astype('float32')` was true, so the scalar behavior and array behavior are the same in Numpy 2.0, while they were different in numpy 1.26. Essentially, in Numpy 2.0, if python floats are used in operations with numpy floats or arrays, then the type of the numpy object will be used (i.e. the python value will be treated as the type of the numpy objects). To preserve the behavior of `NumpyAutocaster` from numpy <= 1.26, I've added an explicit conversion of the value to be converted to a numpy type using `np.asarray` during the check that decides what dtype to cast to. 2. Updates due to new numpy conversion rules for out-of-bounds python ints In numpy 2.0, out of bounds python ints will not be automatically converted, and will raise an `OverflowError` instead. For instance, converting 255 to int8 will raise an error, instead of returning -1. To explicitly force conversion, we must use `np.asarray(value).astype(dtype)`, rather than `np.asarray(value, dtype=dtype)`. The code in `TensorType.filter` has been changed to the new recommended way to downcast, and the error type caught by some tests has been changed to OverflowError from TypeError --- pytensor/scalar/basic.py | 4 +++- pytensor/tensor/type.py | 2 +- tests/compile/function/test_pfunc.py | 1 + tests/tensor/test_basic.py | 1 - 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index d7d719e2f4..f8ecabd7b2 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -183,7 +183,9 @@ def __call__(self, x): for dtype in try_dtypes: x_ = np.asarray(x).astype(dtype=dtype) - if np.all(x == x_): + if np.all( + np.asarray(x) == x_ + ): # use np.asarray(x) to match TensorType.filter break # returns either an exact x_==x, or the last cast x_ return x_ diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index d48a7a6f08..b96113c8e3 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -178,7 +178,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: else: if allow_downcast: # Convert to self.dtype, regardless of the type of data - data = np.asarray(data, dtype=self.dtype) + data = np.asarray(data).astype(self.dtype) # TODO: consider to pad shape with ones to make it consistent # with self.broadcastable... like vector->row type thing else: diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index 249f230d81..3e23b12f74 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -335,6 +335,7 @@ def test_allow_input_downcast_int(self): h = pfunc([a, b, c], (a + b + c)) # Default: allow_input_downcast=None # Everything here should behave like with False assert np.all(h([3], [6], 0) == 9) + with pytest.raises(TypeError): h([3], np.array([6], dtype="int16"), 0) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 6b5ec48112..467dc66407 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3198,7 +3198,6 @@ def test_autocast_custom(): assert (dvector() + 1.1).dtype == "float64" assert (fvector() + np.float32(1.1)).dtype == "float32" assert (fvector() + np.float64(1.1)).dtype == "float64" - assert (fvector() + 1.1).dtype == config.floatX assert (lvector() + np.int64(1)).dtype == "int64" assert (lvector() + np.int32(1)).dtype == "int64" assert (lvector() + np.int16(1)).dtype == "int64" From bce361323bddf4be19280c4b364184b748eca372 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Sun, 10 Nov 2024 14:57:16 +0000 Subject: [PATCH 37/43] Fix for NameError in test I was getting a NameError from the list comprehensions saying that e.g. `pytensor_scalar` was not defined. I'm not sure why, but this is another (more verbose) way to do the same thing. --- tests/tensor/test_math.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 64af7057a5..374a22ab5d 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -2492,11 +2492,22 @@ def pytensor_i_scalar(dtype): def numpy_i_scalar(dtype): return numpy_scalar(dtype) + pytensor_funcs = { + "scalar": pytensor_scalar, + "array": pytensor_array, + "i_scalar": pytensor_i_scalar, + } + numpy_funcs = { + "scalar": numpy_scalar, + "array": numpy_array, + "i_scalar": numpy_i_scalar, + } + with config.change_flags(cast_policy="numpy+floatX"): # We will test all meaningful combinations of # scalar and array operations. - pytensor_args = [eval(f"pytensor_{c}") for c in combo] - numpy_args = [eval(f"numpy_{c}") for c in combo] + pytensor_args = [pytensor_funcs[c] for c in combo] + numpy_args = [numpy_funcs[c] for c in combo] pytensor_arg_1 = pytensor_args[0](a_type) pytensor_arg_2 = pytensor_args[1](b_type) pytensor_dtype = op( From 45c3a0182cc394eda1f3ac05b9d545e50d6e8a4b Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Fri, 24 Jan 2025 15:43:23 +0000 Subject: [PATCH 38/43] Updated doctests From numpy PR https://github.com/numpy/numpy/pull/22449, the repr of scalar values has changed, e.g. from "1" to "np.int64(1)", which caused two doctests to fail. --- pytensor/tensor/einsum.py | 2 +- pytensor/tensor/subtensor.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 88a6257c9c..660c16d387 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -256,7 +256,7 @@ def _general_dot( .. testoutput:: - (3, 4, 2) + (np.int64(3), np.int64(4), np.int64(2)) """ # Shortcut for non batched case if not batch_axes[0] and not batch_axes[1]: diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index c1fdb463b6..3a2304eb7b 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -757,13 +757,15 @@ def get_constant_idx( Example usage where `v` and `a` are appropriately typed PyTensor variables : >>> from pytensor.scalar import int64 >>> from pytensor.tensor import matrix + >>> import numpy as np + >>> >>> v = int64("v") >>> a = matrix("a") >>> b = a[v, 1:3] >>> b.owner.op.idx_list (ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None)) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True) - [v, slice(1, 3, None)] + [v, slice(np.int64(1), np.int64(3), None)] >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs) Traceback (most recent call last): pytensor.tensor.exceptions.NotScalarConstantError From 2bfe6dd86be5957c5469d3ef112bbd18db3edc01 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Thu, 30 Jan 2025 14:02:45 +0000 Subject: [PATCH 39/43] Preserve numpy < 2.0 Unique inverse output shape In numpy 2.0, if axis=None, then np.unique does not flatten the inverse indices returned if return_inverse=True A helper function has been added to npy_2_compat.py to mimic the output of `np.unique` from version of numpy before 2.0 --- pytensor/npy_2_compat.py | 22 ++++++++++++++++++++++ pytensor/tensor/extra_ops.py | 19 ++++++++++++++++--- tests/tensor/test_extra_ops.py | 17 +++++++++-------- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/pytensor/npy_2_compat.py b/pytensor/npy_2_compat.py index facc3b8865..667a5c074e 100644 --- a/pytensor/npy_2_compat.py +++ b/pytensor/npy_2_compat.py @@ -63,6 +63,28 @@ numpy_maxdims = 64 if using_numpy_2 else 32 +# function that replicates np.unique from numpy < 2.0 +def old_np_unique( + arr, return_index=False, return_inverse=False, return_counts=False, axis=None +): + """Replicate np.unique from numpy versions < 2.0""" + if not return_inverse or not using_numpy_2: + return np.unique(arr, return_index, return_inverse, return_counts, axis) + + outs = list(np.unique(arr, return_index, return_inverse, return_counts, axis)) + + inv_idx = 2 if return_index else 1 + + if axis is None: + outs[inv_idx] = np.ravel(outs[inv_idx]) + else: + inv_shape = (arr.shape[axis],) + outs[inv_idx] = outs[inv_idx].reshape(inv_shape) + + return tuple(outs) + + +# compatibility header for C code def npy_2_compat_header() -> str: """Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x""" return dedent(""" diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 7c6dfb9876..7a1bc75b0b 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -20,6 +20,7 @@ normalize_axis_index, npy_2_compat_header, numpy_axis_is_none_flag, + old_np_unique, ) from pytensor.raise_op import Assert from pytensor.scalar import int64 as int_t @@ -1226,6 +1227,9 @@ class Unique(Op): """ Wraps `numpy.unique`. + The indices returned when `return_inverse` is True are ravelled + to match the behavior of `numpy.unique` from before numpy version 2.0. + Examples -------- >>> import numpy as np @@ -1271,17 +1275,21 @@ def make_node(self, x): outputs = [TensorType(dtype=x.dtype, shape=out_shape)()] typ = TensorType(dtype="int64", shape=(None,)) + if self.return_index: outputs.append(typ()) + if self.return_inverse: outputs.append(typ()) + if self.return_counts: outputs.append(typ()) + return Apply(self, [x], outputs) def perform(self, node, inputs, output_storage): [x] = inputs - outs = np.unique( + outs = old_np_unique( x, return_index=self.return_index, return_inverse=self.return_inverse, @@ -1306,9 +1314,14 @@ def infer_shape(self, fgraph, node, i0_shapes): out_shapes[0] = tuple(shape) if self.return_inverse: - shape = prod(x_shape) if self.axis is None else x_shape[axis] return_index_out_idx = 2 if self.return_index else 1 - out_shapes[return_index_out_idx] = (shape,) + + if self.axis is not None: + shape = (x_shape[axis],) + else: + shape = (prod(x_shape),) + + out_shapes[return_index_out_idx] = shape return out_shapes diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 54bb7f4333..6a93f3c7fd 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -9,6 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph.basic import Constant, applys_between, equal_computations +from pytensor.npy_2_compat import old_np_unique from pytensor.raise_op import Assert from pytensor.tensor import alloc from pytensor.tensor.elemwise import DimShuffle @@ -899,14 +900,14 @@ def setup_method(self): ) def test_basic_vector(self, x, inp, axis): list_outs_expected = [ - np.unique(inp, axis=axis), - np.unique(inp, True, axis=axis), - np.unique(inp, False, True, axis=axis), - np.unique(inp, True, True, axis=axis), - np.unique(inp, False, False, True, axis=axis), - np.unique(inp, True, False, True, axis=axis), - np.unique(inp, False, True, True, axis=axis), - np.unique(inp, True, True, True, axis=axis), + old_np_unique(inp, axis=axis), + old_np_unique(inp, True, axis=axis), + old_np_unique(inp, False, True, axis=axis), + old_np_unique(inp, True, True, axis=axis), + old_np_unique(inp, False, False, True, axis=axis), + old_np_unique(inp, True, False, True, axis=axis), + old_np_unique(inp, False, True, True, axis=axis), + old_np_unique(inp, True, True, True, axis=axis), ] for params, outs_expected in zip( self.op_params, list_outs_expected, strict=True From cd75f954fb11b502da41a73bd0046b3ef2019b4d Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Thu, 30 Jan 2025 14:55:29 +0000 Subject: [PATCH 40/43] Fix test for neg on unsigned Due to changes in numpy conversion rules (NEP 50), overflows are not ignored; in particular, negating a unsigned int causes an overflow error. The test for `neg` has been changed to check that this error is raised. --- tests/tensor/test_math.py | 12 +++++++++++- tests/tensor/utils.py | 21 +++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 374a22ab5d..f2331be62e 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -23,6 +23,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import vectorize_node from pytensor.link.c.basic import DualLinker +from pytensor.npy_2_compat import using_numpy_2 from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.tensor import blas, blas_c @@ -391,11 +392,20 @@ def test_maximum_minimum_grad(): grad=_grad_broadcast_unary_normal, ) + +# in numpy >= 2.0, negating a uint raises an error +neg_good = _good_broadcast_unary_normal.copy() +if using_numpy_2: + neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")} +else: + neg_bad = None + TestNegBroadcast = makeBroadcastTester( op=neg, expected=lambda x: -x, - good=_good_broadcast_unary_normal, + good=neg_good, grad=_grad_broadcast_unary_normal, + bad_compile=neg_bad, ) TestSgnBroadcast = makeBroadcastTester( diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index b94750ffe2..1a8b2455ec 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -339,6 +339,7 @@ def makeTester( good=None, bad_build=None, bad_runtime=None, + bad_compile=None, grad=None, mode=None, grad_rtol=None, @@ -373,6 +374,7 @@ def makeTester( _test_memmap = test_memmap _check_name = check_name _grad_eps = grad_eps + _bad_compile = bad_compile or {} class Checker: op = staticmethod(_op) @@ -382,6 +384,7 @@ class Checker: good = _good bad_build = _bad_build bad_runtime = _bad_runtime + bad_compile = _bad_compile grad = _grad mode = _mode skip = skip_ @@ -539,6 +542,24 @@ def test_bad_build(self): # instantiated on the following bad inputs: %s" # % (self.op, testname, node, inputs)) + @config.change_flags(compute_test_value="off") + @pytest.mark.skipif(skip, reason="Skipped") + def test_bad_compile(self): + for testname, inputs in self.bad_compile.items(): + inputrs = [shared(input) for input in inputs] + try: + node = safe_make_node(self.op, *inputrs) + except Exception as exc: + err_msg = ( + f"Test {self.op}::{testname}: Error occurred while trying" + f" to make a node with inputs {inputs}" + ) + exc.args += (err_msg,) + raise + + with pytest.raises(Exception): + inplace_func([], node.outputs, mode=mode, name="test_bad_runtime") + @config.change_flags(compute_test_value="off") @pytest.mark.skipif(skip, reason="Skipped") def test_bad_runtime(self): From 93dd7c8d60a1908192fd00d45699e08e12ecb4d5 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 4 Feb 2025 13:56:51 +0000 Subject: [PATCH 41/43] Split up TestMinMax::test_uint I split this test up to test uint64 separately, since this is the case discussed in Issue #770. I also added a test for the exact example used in that issue. The uint dtypes with lower precision should pass. The uint64 case started passing for me locally on Mac OSX, but still fails on CI. I'm not sure why this is, but at least the test will be more specific now if it fails in the future. --- tests/tensor/test_math.py | 41 ++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index f2331be62e..9ab4fd104d 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -1403,18 +1403,37 @@ def _grad_list(self): # check_grad_max(data, eval_outputs(grad(max_and_argmax(n, # axis=1)[0], n)),axis=1) + @pytest.mark.parametrize( + "dtype", + ( + "uint8", + "uint16", + "uint32", + pytest.param("uint64", marks=pytest.mark.xfail(reason="Fails due to #770")), + ), + ) + def test_uint(self, dtype): + itype = np.iinfo(dtype) + data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) + n = as_tensor_variable(data) + + assert min(n).dtype == dtype + i_min = eval_outputs(min(n)) + assert i_min == itype.min + + assert max(n).dtype == dtype + i_max = eval_outputs(max(n)) + assert i_max == itype.max + @pytest.mark.xfail(reason="Fails due to #770") - def test_uint(self): - for dtype in ("uint8", "uint16", "uint32", "uint64"): - itype = np.iinfo(dtype) - data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) - n = as_tensor_variable(data) - assert min(n).dtype == dtype - i = eval_outputs(min(n)) - assert i == itype.min - assert max(n).dtype == dtype - i = eval_outputs(max(n)) - assert i == itype.max + def test_uint64_special_value(self): + """Example from issue #770""" + dtype = "uint64" + data = np.array([0, 9223372036854775], dtype=dtype) + n = as_tensor_variable(data) + + i_max = eval_outputs(max(n)) + assert i_max == data.max() def test_bool(self): data = np.array([True, False], "bool") From 720568cfbd936ca4d0436281c9c477fe30ab2dd8 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 27 Aug 2024 11:23:43 +0100 Subject: [PATCH 42/43] Unpinned numpy Also added ruff numpy2 transition rule. --- environment-osx-arm64.yml | 2 +- environment.yml | 2 +- pyproject.toml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/environment-osx-arm64.yml b/environment-osx-arm64.yml index 13a68faaaa..c9dc703dcc 100644 --- a/environment-osx-arm64.yml +++ b/environment-osx-arm64.yml @@ -9,7 +9,7 @@ channels: dependencies: - python=>3.10 - compilers - - numpy>=1.17.0,<2 + - numpy>=1.17.0 - scipy>=1,<2 - filelock>=3.15 - etuples diff --git a/environment.yml b/environment.yml index 1571ae0d11..9bdddfb6f6 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ channels: dependencies: - python>=3.10 - compilers - - numpy>=1.17.0,<2 + - numpy>=1.17.0 - scipy>=1,<2 - filelock>=3.15 - etuples diff --git a/pyproject.toml b/pyproject.toml index e82c42753a..e796e35a10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ keywords = [ dependencies = [ "setuptools>=59.0.0", "scipy>=1,<2", - "numpy>=1.17.0,<2", + "numpy>=1.17.0", "filelock>=3.15", "etuples", "logical-unification", @@ -129,7 +129,7 @@ exclude = ["doc/", "pytensor/_version.py"] docstring-code-format = true [tool.ruff.lint] -select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20"] +select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] unfixable = [ # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead From b633bcacd6c3b68d56b482bccd91cbcb65df84d5 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 4 Feb 2025 15:29:05 +0000 Subject: [PATCH 43/43] Added numpy 1.26.* to CI Remaining tests now run on latest numpy, except for Numba jobs, which need numpy 2.1.0 --- .github/workflows/test.yml | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 53f1e16606..5bb416f893 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -65,7 +65,7 @@ jobs: - uses: pre-commit/action@v3.0.1 test: - name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}" + name: "${{ matrix.os }} test py${{ matrix.python-version }} numpy${{ matrix.numpy-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}" needs: - changes - style @@ -76,6 +76,7 @@ jobs: matrix: os: ["ubuntu-latest"] python-version: ["3.10", "3.12"] + numpy-version: ["~=1.26.0", ">=2.0"] fast-compile: [0, 1] float32: [0, 1] install-numba: [0] @@ -105,45 +106,68 @@ jobs: float32: 1 - part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" fast-compile: 1 + - numpy-version: "~=1.26.0" + fast-compile: 1 + - numpy-version: "~=1.26.0" + float32: 1 + - numpy-version: "~=1.26.0" + python-version: "3.12" + - numpy-version: "~=1.26.0" + part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" include: - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: "~=2.1.0" fast-compile: 0 float32: 0 part: "tests/link/numba" - install-numba: 1 os: "ubuntu-latest" python-version: "3.12" + numpy-version: "~=2.1.0" fast-compile: 0 float32: 0 part: "tests/link/numba" - install-jax: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/jax" - install-jax: 1 os: "ubuntu-latest" python-version: "3.12" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/jax" - install-torch: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/pytorch" - os: macos-15 python-version: "3.12" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 install-numba: 0 install-jax: 0 install-torch: 0 part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" + - os: "ubuntu-latest" + python-version: "3.10" + numpy-version: "~=1.26.0" + fast-compile: 0 + float32: 0 + install-numba: 0 + install-jax: 0 + install-torch: 0 + part: "tests/tensor/test_math.py" steps: - uses: actions/checkout@v4 @@ -174,9 +198,9 @@ jobs: run: | if [[ $OS == "macos-15" ]]; then - micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" numpy scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate; + micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate; else - micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; + micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi @@ -193,6 +217,7 @@ jobs: fi env: PYTHON_VERSION: ${{ matrix.python-version }} + NUMPY_VERSION: ${{ matrix.numpy-version }} INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}}