1313from pytensor .graph .fg import FunctionGraph
1414from pytensor .graph .type import Type
1515from pytensor .link .numba .cache import compile_numba_function_src , hash_from_pickle_dump
16- from pytensor .link .numba .dispatch .sparse import CSCMatrixType , CSRMatrixType
1716from pytensor .link .utils import (
1817 fgraph_to_python ,
1918)
2019from pytensor .scalar .basic import ScalarType
2120from pytensor .sparse import SparseTensorType
22- from pytensor .tensor .type import TensorType
21+ from pytensor .tensor .type import DenseTensorType
2322from pytensor .tensor .utils import hash_from_ndarray
2423
2524
@@ -80,7 +79,7 @@ def get_numba_type(
8079 Return Numba scalars for zero dimensional :class:`TensorType`\s.
8180 """
8281
83- if isinstance (pytensor_type , TensorType ):
82+ if isinstance (pytensor_type , DenseTensorType ):
8483 dtype = pytensor_type .numpy_dtype
8584 numba_dtype = numba .from_dtype (dtype )
8685 if force_scalar or (
@@ -93,12 +92,20 @@ def get_numba_type(
9392 numba_dtype = numba .from_dtype (dtype )
9493 return numba_dtype
9594 elif isinstance (pytensor_type , SparseTensorType ):
96- dtype = pytensor_type .numpy_dtype
97- numba_dtype = numba .from_dtype (dtype )
95+ from pytensor .link .numba .dispatch .sparse .basic import (
96+ CSCMatrixType ,
97+ CSRMatrixType ,
98+ )
99+
100+ data_array = numba .types .Array (
101+ numba .from_dtype (pytensor_type .numpy_dtype ), 1 , layout
102+ )
103+ indices_array = numba .types .Array (numba .from_dtype (np .int32 ), 1 , layout )
104+ indptr_array = numba .types .Array (numba .from_dtype (np .int32 ), 1 , layout )
98105 if pytensor_type .format == "csr" :
99- return CSRMatrixType (numba_dtype )
106+ return CSRMatrixType (data_array , indices_array , indptr_array )
100107 if pytensor_type .format == "csc" :
101- return CSCMatrixType (numba_dtype )
108+ return CSCMatrixType (data_array , indices_array , indptr_array )
102109
103110 raise NotImplementedError ()
104111 else :
0 commit comments