Skip to content

Commit bbf7b8e

Browse files
committed
Implement basic numba sparse Ops
1 parent 17c0864 commit bbf7b8e

File tree

12 files changed

+611
-320
lines changed

12 files changed

+611
-320
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.type import Type
1515
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
16-
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
1716
from pytensor.link.utils import (
1817
fgraph_to_python,
1918
)
2019
from pytensor.scalar.basic import ScalarType
2120
from pytensor.sparse import SparseTensorType
22-
from pytensor.tensor.type import TensorType
21+
from pytensor.tensor.type import DenseTensorType
2322
from pytensor.tensor.utils import hash_from_ndarray
2423

2524

@@ -80,7 +79,7 @@ def get_numba_type(
8079
Return Numba scalars for zero dimensional :class:`TensorType`\s.
8180
"""
8281

83-
if isinstance(pytensor_type, TensorType):
82+
if isinstance(pytensor_type, DenseTensorType):
8483
dtype = pytensor_type.numpy_dtype
8584
numba_dtype = numba.from_dtype(dtype)
8685
if force_scalar or (
@@ -93,12 +92,20 @@ def get_numba_type(
9392
numba_dtype = numba.from_dtype(dtype)
9493
return numba_dtype
9594
elif isinstance(pytensor_type, SparseTensorType):
96-
dtype = pytensor_type.numpy_dtype
97-
numba_dtype = numba.from_dtype(dtype)
95+
from pytensor.link.numba.dispatch.sparse.basic import (
96+
CSCMatrixType,
97+
CSRMatrixType,
98+
)
99+
100+
data_array = numba.types.Array(
101+
numba.from_dtype(pytensor_type.numpy_dtype), 1, layout
102+
)
103+
indices_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout)
104+
indptr_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout)
98105
if pytensor_type.format == "csr":
99-
return CSRMatrixType(numba_dtype)
106+
return CSRMatrixType(data_array, indices_array, indptr_array)
100107
if pytensor_type.format == "csc":
101-
return CSCMatrixType(numba_dtype)
108+
return CSCMatrixType(data_array, indices_array, indptr_array)
102109

103110
raise NotImplementedError()
104111
else:

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def identity(x):
6464

6565
@register_funcify_default_op_cache_key(DeepCopyOp)
6666
def numba_funcify_DeepCopyOp(op, node, **kwargs):
67+
# FIXME: SparseTensorType will match on this condition, but `np.copy` doesn't work with them
6768
if isinstance(node.inputs[0].type, TensorType):
6869

6970
@numba_basic.numba_njit

pytensor/link/numba/dispatch/sparse.py

Lines changed: 0 additions & 206 deletions
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pytensor.link.numba.dispatch.sparse import basic, math

0 commit comments

Comments
 (0)