Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,21 @@ lines-after-imports = 2
# TODO: Get rid of these:
"**/__init__.py" = ["F401", "E402", "F403"]
"pytensor/tensor/linalg.py" = ["F403"]
# Modules that use print-statements, skip "T201"
"pytensor/link/c/cmodule.py" = ["PTH", "T201"]
"pytensor/misc/elemwise_time_test.py" = ["T201"]
"pytensor/misc/elemwise_openmp_speedup.py" = ["T201"]
"pytensor/misc/check_duplicate_key.py" = ["T201"]
"pytensor/misc/check_blas.py" = ["T201"]
"pytensor/bin/pytensor_cache.py" = ["T201"]
# For the tests we skip `E402` because `pytest.importorskip` is used:
"tests/link/jax/test_scalar.py" = ["E402"]
"tests/link/jax/test_tensor_basic.py" = ["E402"]
"tests/link/numba/test_basic.py" = ["E402"]
"tests/link/numba/test_cython_support.py" = ["E402"]
"tests/link/numba/test_performance.py" = ["E402"]
"tests/link/numba/test_sparse.py" = ["E402"]
"tests/link/numba/test_tensor_basic.py" = ["E402"]
"tests/tensor/test_math_scipy.py" = ["E402"]
"tests/sparse/test_basic.py" = ["E402"]
"tests/sparse/test_sp2.py" = ["E402"]
"tests/sparse/test_utils.py" = ["E402"]
"tests/sparse/sandbox/test_sp.py" = ["E402", "F401"]
"tests/compile/test_monitormode.py" = ["T201"]
"scripts/run_mypy.py" = ["T201"]
# Test modules of optional backends that use `pytest.importorskip`, skip "E402"
"tests/link/jax/**/test_*.py" = ["E402"]
"tests/link/numba/**/test_*.py" = ["E402"]
"tests/link/pytorch/**/test_*.py" = ["E402"]
"tests/link/mlx/**/test_*.py" = ["E402"]



[tool.mypy]
Expand Down
21 changes: 14 additions & 7 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
fgraph_to_python,
)
from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import DenseTensorType
from pytensor.tensor.utils import hash_from_ndarray


Expand Down Expand Up @@ -80,7 +79,7 @@ def get_numba_type(
Return Numba scalars for zero dimensional :class:`TensorType`\s.
"""

if isinstance(pytensor_type, TensorType):
if isinstance(pytensor_type, DenseTensorType):
dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if force_scalar or (
Expand All @@ -93,12 +92,20 @@ def get_numba_type(
numba_dtype = numba.from_dtype(dtype)
return numba_dtype
elif isinstance(pytensor_type, SparseTensorType):
dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
from pytensor.link.numba.dispatch.sparse.basic import (
CSCMatrixType,
CSRMatrixType,
)

data_array = numba.types.Array(
numba.from_dtype(pytensor_type.numpy_dtype), 1, layout
)
indices_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout)
indptr_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout)
if pytensor_type.format == "csr":
return CSRMatrixType(numba_dtype)
return CSRMatrixType(data_array, indices_array, indptr_array)
if pytensor_type.format == "csc":
return CSCMatrixType(numba_dtype)
return CSCMatrixType(data_array, indices_array, indptr_array)

raise NotImplementedError()
else:
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def identity(x):

@register_funcify_default_op_cache_key(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
# FIXME: SparseTensorType will match on this condition, but `np.copy` doesn't work with them
if isinstance(node.inputs[0].type, TensorType):

@numba_basic.numba_njit
Expand Down
206 changes: 0 additions & 206 deletions pytensor/link/numba/dispatch/sparse.py

This file was deleted.

1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytensor.link.numba.dispatch.sparse import basic, math
Loading
Loading