Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,21 @@ lines-after-imports = 2
# TODO: Get rid of these:
"**/__init__.py" = ["F401", "E402", "F403"]
"pytensor/tensor/linalg.py" = ["F403"]
# Modules that use print-statements, skip "T201"
"pytensor/link/c/cmodule.py" = ["PTH", "T201"]
"pytensor/misc/elemwise_time_test.py" = ["T201"]
"pytensor/misc/elemwise_openmp_speedup.py" = ["T201"]
"pytensor/misc/check_duplicate_key.py" = ["T201"]
"pytensor/misc/check_blas.py" = ["T201"]
"pytensor/bin/pytensor_cache.py" = ["T201"]
# For the tests we skip `E402` because `pytest.importorskip` is used:
"tests/link/jax/test_scalar.py" = ["E402"]
"tests/link/jax/test_tensor_basic.py" = ["E402"]
"tests/link/numba/test_basic.py" = ["E402"]
"tests/link/numba/test_cython_support.py" = ["E402"]
"tests/link/numba/test_performance.py" = ["E402"]
"tests/link/numba/test_sparse.py" = ["E402"]
"tests/link/numba/test_tensor_basic.py" = ["E402"]
"tests/tensor/test_math_scipy.py" = ["E402"]
"tests/sparse/test_basic.py" = ["E402"]
"tests/sparse/test_sp2.py" = ["E402"]
"tests/sparse/test_utils.py" = ["E402"]
"tests/sparse/sandbox/test_sp.py" = ["E402", "F401"]
"tests/compile/test_monitormode.py" = ["T201"]
"scripts/run_mypy.py" = ["T201"]
# Test modules of optional backends that use `pytest.importorskip`, skip "E402"
"tests/link/jax/**/test_*.py" = ["E402"]
"tests/link/numba/**/test_*.py" = ["E402"]
"tests/link/pytorch/**/test_*.py" = ["E402"]
"tests/link/mlx/**/test_*.py" = ["E402"]



[tool.mypy]
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class PerformLinker(LocalLinker):
"""

required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
incompatible_rewrites: tuple[str, ...] = ("cxx",)
incompatible_rewrites: tuple[str, ...] = ("cxx_only",)

def __init__(
self, allow_gc: bool | None = None, schedule: Callable | None = None
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class JAXLinker(JITLinker):
"jax",
) # TODO: Distinguish between optional "jax" and "minimum_compile_jax"
incompatible_rewrites = (
"cxx",
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
Expand Down
25 changes: 16 additions & 9 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
fgraph_to_python,
)
from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType
from pytensor.tensor.random.type import RandomGeneratorType
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import DenseTensorType
from pytensor.tensor.utils import hash_from_ndarray
from pytensor.typed_list import TypedListType

Expand Down Expand Up @@ -112,7 +111,7 @@ def get_numba_type(
Return Numba scalars for zero dimensional :class:`TensorType`\s.
"""

if isinstance(pytensor_type, TensorType):
if isinstance(pytensor_type, DenseTensorType):
dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if force_scalar or (
Expand All @@ -125,18 +124,26 @@ def get_numba_type(
numba_dtype = numba.from_dtype(dtype)
return numba_dtype
elif isinstance(pytensor_type, SparseTensorType):
dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
from pytensor.link.numba.dispatch.sparse.variable import (
CSCMatrixType,
CSRMatrixType,
)

data_array = numba.types.Array(
numba.from_dtype(pytensor_type.numpy_dtype), 1, layout
)
indices_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout)
indptr_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout)
if pytensor_type.format == "csr":
return CSRMatrixType(numba_dtype)
return CSRMatrixType(data_array, indices_array, indptr_array)
if pytensor_type.format == "csc":
return CSCMatrixType(numba_dtype)
return CSCMatrixType(data_array, indices_array, indptr_array)
elif isinstance(pytensor_type, RandomGeneratorType):
return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType")
elif isinstance(pytensor_type, TypedListType):
return numba.types.List(get_numba_type(pytensor_type.ttype))
else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")

raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")


def create_numba_signature(
Expand Down
206 changes: 0 additions & 206 deletions pytensor/link/numba/dispatch/sparse.py

This file was deleted.

1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytensor.link.numba.dispatch.sparse import basic, math, variable
68 changes: 68 additions & 0 deletions pytensor/link/numba/dispatch/sparse/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import scipy as sp
from numba.extending import overload

from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType
from pytensor.sparse import CSM, Cast, CSMProperties


@overload(numba_deepcopy)
def numba_deepcopy_sparse(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's deep about this?

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparse_matrix.copy() does a deepcopy just like array.copy(). But for other types like list or rng there's a difference between copy and deepcopy hence the more explicit name

if isinstance(x, CSMatrixType):

def sparse_deepcopy(x):
return x.copy()

return sparse_deepcopy


@register_funcify_default_op_cache_key(CSMProperties)
def numba_funcify_CSMProperties(op, node, **kwargs):
@numba_basic.numba_njit
def csm_properties(x):
# Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we able to just go to int64 ourselves, or do we need to wait for upstream to change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would need to change stuff in the pre-existing Ops so that fallback to obj mode is compatible. Would leave that for a later PR if we decide

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would need to change stuff in the pre-existing Ops so that fallback to obj mode is compatible. Would leave that for a later PR if we decide

# But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there.
return x.data, x.indices, x.indptr, np.asarray(x.shape, dtype="int32")

return csm_properties


@register_funcify_default_op_cache_key(CSM)
def numba_funcify_CSM(op, node, **kwargs):
format = op.format

@numba_basic.numba_njit
def csm_constructor(data, indices, indptr, shape):
constructor_arg = (data, indices, indptr)
shape_arg = (shape[0], shape[1])
if format == "csr":
return sp.sparse.csr_matrix(constructor_arg, shape=shape_arg)
else:
return sp.sparse.csc_matrix(constructor_arg, shape=shape_arg)

return csm_constructor


@register_funcify_default_op_cache_key(Cast)
def numba_funcify_Cast(op, node, **kwargs):
inp_dtype = node.inputs[0].type.dtype
out_dtype = np.dtype(op.out_type)
if not np.can_cast(inp_dtype, out_dtype):
if config.compiler_verbose:
print( # noqa: T201
f"Sparse Cast fallback to obj mode due to unsafe casting from {inp_dtype} to {out_dtype}"
)
return generate_fallback_impl(op, node, **kwargs)

@numba_basic.numba_njit
def cast(x):
return x.astype(out_dtype)

return cast
Loading
Loading