-
Notifications
You must be signed in to change notification settings - Fork 159
Basic Sparse functionality in Numba #1676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5075dfa
90c7195
db84708
fe8b80f
447af64
63c8171
0037f84
b6a01cc
feecbfd
38e27e4
89b4f7a
5333b3e
0e078a7
df200e0
b4cccc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from pytensor.link.numba.dispatch.sparse import basic, math, variable |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| import numpy as np | ||
| import scipy as sp | ||
| from numba.extending import overload | ||
|
|
||
| from pytensor import config | ||
| from pytensor.link.numba.dispatch import basic as numba_basic | ||
| from pytensor.link.numba.dispatch.basic import ( | ||
| generate_fallback_impl, | ||
| register_funcify_default_op_cache_key, | ||
| ) | ||
| from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy | ||
| from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType | ||
| from pytensor.sparse import CSM, Cast, CSMProperties | ||
|
|
||
|
|
||
| @overload(numba_deepcopy) | ||
| def numba_deepcopy_sparse(x): | ||
| if isinstance(x, CSMatrixType): | ||
|
|
||
| def sparse_deepcopy(x): | ||
| return x.copy() | ||
|
|
||
| return sparse_deepcopy | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(CSMProperties) | ||
| def numba_funcify_CSMProperties(op, node, **kwargs): | ||
| @numba_basic.numba_njit | ||
| def csm_properties(x): | ||
| # Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we able to just go to int64 ourselves, or do we need to wait for upstream to change?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would need to change stuff in the pre-existing Ops so that fallback to obj mode is compatible. Would leave that for a later PR if we decide
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would need to change stuff in the pre-existing Ops so that fallback to obj mode is compatible. Would leave that for a later PR if we decide |
||
| # But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there. | ||
| return x.data, x.indices, x.indptr, np.asarray(x.shape, dtype="int32") | ||
|
|
||
| return csm_properties | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(CSM) | ||
| def numba_funcify_CSM(op, node, **kwargs): | ||
| format = op.format | ||
|
|
||
| @numba_basic.numba_njit | ||
| def csm_constructor(data, indices, indptr, shape): | ||
| constructor_arg = (data, indices, indptr) | ||
| shape_arg = (shape[0], shape[1]) | ||
| if format == "csr": | ||
| return sp.sparse.csr_matrix(constructor_arg, shape=shape_arg) | ||
| else: | ||
| return sp.sparse.csc_matrix(constructor_arg, shape=shape_arg) | ||
|
|
||
| return csm_constructor | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(Cast) | ||
| def numba_funcify_Cast(op, node, **kwargs): | ||
| inp_dtype = node.inputs[0].type.dtype | ||
| out_dtype = np.dtype(op.out_type) | ||
| if not np.can_cast(inp_dtype, out_dtype): | ||
| if config.compiler_verbose: | ||
| print( # noqa: T201 | ||
| f"Sparse Cast fallback to obj mode due to unsafe casting from {inp_dtype} to {out_dtype}" | ||
| ) | ||
| return generate_fallback_impl(op, node, **kwargs) | ||
|
|
||
| @numba_basic.numba_njit | ||
| def cast(x): | ||
| return x.astype(out_dtype) | ||
|
|
||
| return cast | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's deep about this?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sparse_matrix.copy() does a deepcopy just like array.copy(). But for other types like list or rng there's a difference between copy and deepcopy hence the more explicit name