@@ -195,6 +195,8 @@ class COO(SparseArray, NDArrayOperatorsMixin): # lgtm [py/missing-equals]
195
195
196
196
__array_priority__ = 12
197
197
198
+ __array_members__ = ("data" , "coords" , "fill_value" )
199
+
198
200
def __init__ (
199
201
self ,
200
202
coords ,
@@ -207,6 +209,8 @@ def __init__(
207
209
fill_value = None ,
208
210
idx_dtype = None ,
209
211
):
212
+ from .._common import _coerce_to_supported_dense
213
+
210
214
if isinstance (coords , COO ):
211
215
self ._make_shallow_copy_of (coords )
212
216
if data is not None or shape is not None :
@@ -226,8 +230,8 @@ def __init__(
226
230
self .enable_caching ()
227
231
return
228
232
229
- self .data = np . asarray (data )
230
- self .coords = np . asarray (coords )
233
+ self .data = _coerce_to_supported_dense (data )
234
+ self .coords = _coerce_to_supported_dense (coords )
231
235
232
236
if self .coords .ndim == 1 :
233
237
if self .coords .size == 0 and shape is not None :
@@ -236,7 +240,7 @@ def __init__(
236
240
self .coords = self .coords [None , :]
237
241
238
242
if self .data .ndim == 0 :
239
- self .data = np .broadcast_to (self .data , self .coords .shape [1 ])
243
+ self .data = self . _component_namespace .broadcast_to (self .data , self .coords .shape [1 ])
240
244
241
245
if self .data .ndim != 1 :
242
246
raise ValueError ("`data` must be a scalar or 1-dimensional." )
@@ -251,7 +255,9 @@ def __init__(
251
255
shape = tuple (shape )
252
256
253
257
if shape and not self .coords .size :
254
- self .coords = np .zeros ((len (shape ) if isinstance (shape , Iterable ) else 1 , 0 ), dtype = np .intp )
258
+ self .coords = self ._component_namespace .zeros (
259
+ (len (shape ) if isinstance (shape , Iterable ) else 1 , 0 ), dtype = np .intp
260
+ )
255
261
super ().__init__ (shape , fill_value = fill_value )
256
262
if idx_dtype :
257
263
if not can_store (idx_dtype , max (shape )):
@@ -369,7 +375,7 @@ def from_numpy(cls, x, fill_value=None, idx_dtype=None):
369
375
x = np .asanyarray (x ).view (type = np .ndarray )
370
376
371
377
if fill_value is None :
372
- fill_value = _zero_of_dtype (x .dtype ) if x .shape else x
378
+ fill_value = _zero_of_dtype (x .dtype , x . device ) if x .shape else x
373
379
374
380
coords = np .atleast_2d (np .flatnonzero (~ equivalent (x , fill_value )))
375
381
data = x .ravel ()[tuple (coords )]
@@ -407,7 +413,9 @@ def todense(self):
407
413
>>> np.array_equal(x, x2)
408
414
True
409
415
"""
410
- x = np .full (self .shape , self .fill_value , self .dtype )
416
+ x = self ._component_namespace .full (
417
+ self .shape , fill_value = self .fill_value , dtype = self .dtype , device = self .data .device
418
+ )
411
419
412
420
coords = tuple ([self .coords [i , :] for i in range (self .ndim )])
413
421
data = self .data
@@ -446,14 +454,16 @@ def from_scipy_sparse(cls, x, /, *, fill_value=None):
446
454
>>> np.array_equal(x.todense(), s.todense())
447
455
True
448
456
"""
457
+ import array_api_compat
458
+
449
459
x = x .asformat ("coo" )
450
460
if not x .has_canonical_format :
451
461
x .eliminate_zeros ()
452
462
x .sum_duplicates ()
453
463
454
- coords = np . empty (( 2 , x . nnz ), dtype = x . row . dtype )
455
- coords [ 0 , :] = x . row
456
- coords [ 1 , :] = x . col
464
+ xp = array_api_compat . array_namespace ( x . data )
465
+
466
+ coords = xp . stack (( x . row , x . col ))
457
467
return COO (
458
468
coords ,
459
469
x .data ,
@@ -1184,14 +1194,19 @@ def to_scipy_sparse(self, /, *, accept_fv=None):
1184
1194
- [`sparse.COO.tocsr`][] : Convert to a [`scipy.sparse.csr_matrix`][].
1185
1195
- [`sparse.COO.tocsc`][] : Convert to a [`scipy.sparse.csc_matrix`][].
1186
1196
"""
1187
- import scipy .sparse
1197
+ from .._settings import NUMPY_DEVICE
1198
+
1199
+ if self .device == NUMPY_DEVICE :
1200
+ import scipy .sparse as sps
1201
+ else :
1202
+ import cupyx .scipy .sparse as sps
1188
1203
1189
1204
check_fill_value (self , accept_fv = accept_fv )
1190
1205
1191
1206
if self .ndim != 2 :
1192
1207
raise ValueError ("Can only convert a 2-dimensional array to a Scipy sparse matrix." )
1193
1208
1194
- result = scipy . sparse .coo_matrix ((self .data , (self .coords [0 ], self .coords [1 ])), shape = self .shape )
1209
+ result = sps .coo_matrix ((self .data , (self .coords [0 ], self .coords [1 ])), shape = self .shape )
1195
1210
result .has_canonical_format = True
1196
1211
return result
1197
1212
@@ -1307,10 +1322,10 @@ def _sort_indices(self):
1307
1322
"""
1308
1323
linear = self .linear_loc ()
1309
1324
1310
- if (np .diff (linear ) >= 0 ).all (): # already sorted
1325
+ if (self . _component_namespace .diff (linear ) >= 0 ).all (): # already sorted
1311
1326
return
1312
1327
1313
- order = np .argsort (linear , kind = "mergesort" )
1328
+ order = self . _component_namespace .argsort (linear , kind = "mergesort" )
1314
1329
self .coords = self .coords [:, order ]
1315
1330
self .data = self .data [order ]
1316
1331
@@ -1336,16 +1351,16 @@ def _sum_duplicates(self):
1336
1351
# Inspired by scipy/sparse/coo.py::sum_duplicates
1337
1352
# See https://github.com/scipy/scipy/blob/main/LICENSE.txt
1338
1353
linear = self .linear_loc ()
1339
- unique_mask = np .diff (linear ) != 0
1354
+ unique_mask = self . _component_namespace .diff (linear ) != 0
1340
1355
1341
1356
if unique_mask .sum () == len (unique_mask ): # already unique
1342
1357
return
1343
1358
1344
- unique_mask = np .append (True , unique_mask )
1359
+ unique_mask = self . _component_namespace .append (True , unique_mask )
1345
1360
1346
1361
coords = self .coords [:, unique_mask ]
1347
- (unique_inds ,) = np .nonzero (unique_mask )
1348
- data = np .add .reduceat (self .data , unique_inds , dtype = self .data .dtype )
1362
+ (unique_inds ,) = self . _component_namespace .nonzero (unique_mask )
1363
+ data = self . _component_namespace .add .reduceat (self .data , unique_inds , dtype = self .data .dtype )
1349
1364
1350
1365
self .data = data
1351
1366
self .coords = coords
0 commit comments