Skip to content

Commit 18c4ebd

Browse files
authored
ENH: sparse: make two sputils public for easier index array casting (scipy#22043)
* make index tools public * update release notes * Add examples to docstring
1 parent a56b885 commit 18c4ebd

File tree

4 files changed

+55
-10
lines changed

4 files changed

+55
-10
lines changed

doc/source/reference/sparse.migration_to_sparray.rst

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,19 +449,23 @@ The function signatures are::
449449
def get_index_dtype(arrays=(), maxval=None, check_contents=False):
450450
def safely_cast_index_arrays(A, idx_dtype=np.int32, msg=""):
451451

452-
Example idioms include the following::
452+
Example idioms include the following for ``get_index_dtype``::
453453

454454
.. code-block:: python
455455

456456
# select index dtype before construction based on shape
457457
shape = (3, 3)
458-
idx_dtype = scipy.sparse._sputils.get_index_dtype(maxval=max(shape))
458+
idx_dtype = scipy.sparse.get_index_dtype(maxval=max(shape))
459459
indices = np.array([0, 1, 0], dtype=idx_dtype)
460460
indptr = np.arange(3, dtype=idx_dtype)
461461
A = csr_array((data, indices, indptr), shape=shape)
462462

463-
# rescast after construction, raising exception before overflow
464-
indices, indptr = scipy.sparse._sputils.safely_cast_index_arrays(B, np.int32)
463+
and for ``safely_cast_index_arrays``::
464+
465+
.. code-block:: python
466+
467+
# rescast after construction, raising exception if shape too big
468+
indices, indptr = scipy.sparse.safely_cast_index_arrays(B, np.int32)
465469
B.indices, B.indptr = indices, indptr
466470

467471
Other

doc/source/release/1.15.0-notes.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ and support several Array API compatible array libraries in addition to NumPy
194194
incompatible data types, such as ``float16``.
195195
- ``min``, ``max``, ``argmin``, and ``argmax`` now support computation
196196
over nonzero elements only via the new ``explicit`` argument.
197-
- New function ``safely_cast_index_arrays`` has been added
198-
to facilitate casting challenges in ``sparse``.
197+
- New functions ``get_index_dtype`` and ``safely_cast_index_arrays`` are
198+
available to facilitate index array casting in ``sparse``.
199199

200200

201201
`scipy.spatial` improvements

scipy/sparse/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
save_npz - Save a sparse array to a file using ``.npz`` format.
103103
load_npz - Load a sparse array from a file using ``.npz`` format.
104104
find - Return the indices and values of the nonzero elements
105+
get_index_dtype - determine a good dtype for index arrays.
106+
safely_cast_index_arrays - cast index array dtype or raise if shape too big
105107
106108
Identifying sparse arrays
107109
-------------------------
@@ -307,6 +309,7 @@
307309
from ._extract import *
308310
from ._matrix import spmatrix
309311
from ._matrix_io import *
312+
from ._sputils import get_index_dtype, safely_cast_index_arrays
310313

311314
# For backward compatibility with v0.19.
312315
from . import csgraph

scipy/sparse/_sputils.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def upcast(*args):
3232
--------
3333
>>> from scipy.sparse._sputils import upcast
3434
>>> upcast('int32')
35-
<type 'numpy.int32'>
35+
<class 'numpy.int32'>
3636
>>> upcast('bool')
37-
<type 'numpy.bool_'>
37+
<class 'numpy.bool'>
3838
>>> upcast('int32','float32')
39-
<type 'numpy.float64'>
39+
<class 'numpy.float64'>
4040
>>> upcast('bool',complex,float)
41-
<type 'numpy.complex128'>
41+
<class 'numpy.complex128'>
4242
4343
"""
4444

@@ -172,6 +172,7 @@ def safely_cast_index_arrays(A, idx_dtype=np.int32, msg=""):
172172
The array for which index arrays should be downcast.
173173
idx_dtype : dtype
174174
Desired dtype. Should be an integer dtype (default: ``np.int32``).
175+
Most of scipy.sparse uses either int64 or int32.
175176
msg : string, optional
176177
A string to be added to the end of the ValueError message
177178
if the array shape is too big to fit in `idx_dtype`.
@@ -193,6 +194,24 @@ def safely_cast_index_arrays(A, idx_dtype=np.int32, msg=""):
193194
ValueError
194195
If the array has shape that would not fit in the new dtype, or if
195196
the sparse format does not use index arrays.
197+
198+
Examples
199+
--------
200+
>>> import numpy as np
201+
>>> from scipy import sparse
202+
>>> data = [3]
203+
>>> coords = (np.array([3]), np.array([1])) # Note: int64 arrays
204+
>>> A = sparse.coo_array((data, coords))
205+
>>> A.coords[0].dtype
206+
dtype('int64')
207+
208+
>>> # rescast after construction, raising exception if shape too big
209+
>>> coords = sparse.safely_cast_index_arrays(A, np.int32)
210+
>>> A.coords[0] is coords[0] # False if casting is needed
211+
False
212+
>>> A.coords = coords # set the index dtype of A
213+
>>> A.coords[0].dtype
214+
dtype('int32')
196215
"""
197216
if not msg:
198217
msg = f"dtype {idx_dtype}"
@@ -262,6 +281,25 @@ def get_index_dtype(arrays=(), maxval=None, check_contents=False):
262281
dtype : dtype
263282
Suitable index data type (int32 or int64)
264283
284+
Examples
285+
--------
286+
>>> import numpy as np
287+
>>> from scipy import sparse
288+
>>> # select index dtype based on shape
289+
>>> shape = (3, 3)
290+
>>> idx_dtype = sparse.get_index_dtype(maxval=max(shape))
291+
>>> data = [1.1, 3.0, 1.5]
292+
>>> indices = np.array([0, 1, 0], dtype=idx_dtype)
293+
>>> indptr = np.array([0, 2, 3, 3], dtype=idx_dtype)
294+
>>> A = sparse.csr_array((data, indices, indptr), shape=shape)
295+
>>> A.indptr.dtype
296+
dtype('int32')
297+
298+
>>> # select based on larger of existing arrays and shape
299+
>>> shape = (3, 3)
300+
>>> idx_dtype = sparse.get_index_dtype(A.indptr, maxval=max(shape))
301+
>>> idx_dtype
302+
<class 'numpy.int32'>
265303
"""
266304
# not using intc directly due to misinteractions with pythran
267305
if np.intc().itemsize != 4:

0 commit comments

Comments
 (0)