Skip to content

Commit ffe08e5

Browse files
committed
Implement basic numba sparse Ops
1 parent 29697f5 commit ffe08e5

File tree

12 files changed

+602
-317
lines changed

12 files changed

+602
-317
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.type import Type
1515
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
16-
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
1716
from pytensor.link.utils import (
1817
fgraph_to_python,
1918
)
@@ -93,12 +92,14 @@ def get_numba_type(
9392
numba_dtype = numba.from_dtype(dtype)
9493
return numba_dtype
9594
elif isinstance(pytensor_type, SparseTensorType):
95+
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
96+
9697
dtype = pytensor_type.numpy_dtype
97-
numba_dtype = numba.from_dtype(dtype)
98+
# numba_dtype = numba.from_dtype(dtype)
9899
if pytensor_type.format == "csr":
99-
return CSRMatrixType(numba_dtype)
100+
return CSRMatrixType()
100101
if pytensor_type.format == "csc":
101-
return CSCMatrixType(numba_dtype)
102+
return CSCMatrixType()
102103

103104
raise NotImplementedError()
104105
else:

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def identity(x):
6464

6565
@register_funcify_default_op_cache_key(DeepCopyOp)
6666
def numba_funcify_DeepCopyOp(op, node, **kwargs):
67+
# FIXME: SparseTensorType will match on this condition, but `np.copy` doesn't work with them
6768
if isinstance(node.inputs[0].type, TensorType):
6869

6970
@numba_basic.numba_njit

pytensor/link/numba/dispatch/sparse.py

Lines changed: 0 additions & 206 deletions
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pytensor.link.numba.dispatch.sparse import basic, math

0 commit comments

Comments
 (0)