Skip to content

Commit ee81cdc

Browse files
authored
Merge pull request #657 from pydata/update-finch-backend
Update Finch backend
2 parents c532a35 + d93a6cf commit ee81cdc

File tree

2 files changed

+97
-16
lines changed

2 files changed

+97
-16
lines changed

sparse/finch_backend/__init__.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,68 @@
33
except ModuleNotFoundError as e:
44
raise ImportError("Finch not installed. Run `pip install sparse[finch]` to enable Finch backend") from e
55

6-
from finch import Tensor, astype, permute_dims
6+
from finch import (
7+
add,
8+
astype,
9+
bool,
10+
compiled,
11+
complex64,
12+
complex128,
13+
compute,
14+
divide,
15+
float16,
16+
float32,
17+
float64,
18+
int8,
19+
int16,
20+
int32,
21+
int64,
22+
int_,
23+
lazy,
24+
multiply,
25+
negative,
26+
permute_dims,
27+
positive,
28+
prod,
29+
random,
30+
subtract,
31+
sum,
32+
uint,
33+
uint8,
34+
uint16,
35+
uint32,
36+
uint64,
37+
)
738

8-
__all__ = ["Tensor", "astype", "permute_dims"]
9-
10-
11-
class COO:
12-
def from_numpy(self):
13-
raise NotImplementedError
39+
__all__ = [
40+
"add",
41+
"astype",
42+
"bool",
43+
"compiled",
44+
"complex64",
45+
"complex128",
46+
"compute",
47+
"divide",
48+
"float16",
49+
"float32",
50+
"float64",
51+
"int8",
52+
"int16",
53+
"int32",
54+
"int64",
55+
"int_",
56+
"lazy",
57+
"multiply",
58+
"negative",
59+
"permute_dims",
60+
"positive",
61+
"prod",
62+
"random",
63+
"subtract",
64+
"sum",
65+
"uint",
66+
"uint8",
67+
"uint16",
68+
"uint32",
69+
"uint64",
70+
]

sparse/tests/test_backends.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,56 @@
11
import sparse
22

3-
import pytest
4-
53
import numpy as np
64
import scipy.sparse as sp
5+
from numpy.testing import assert_equal
76

87

98
def test_backend_contex_manager(backend):
9+
rng = np.random.default_rng(0)
10+
x = sparse.random((100, 10, 100), density=0.01, random_state=rng)
11+
y = sparse.random((100, 10, 100), density=0.01, random_state=rng)
12+
1013
if backend == sparse.BackendType.Finch:
11-
with pytest.raises(NotImplementedError):
12-
sparse.COO.from_numpy(np.eye(5))
14+
import finch
15+
16+
def storage():
17+
return finch.Storage(finch.Dense(finch.SparseList(finch.SparseList(finch.Element(0.0)))), order="C")
18+
19+
x = x.to_device(storage())
20+
y = y.to_device(storage())
1321
else:
14-
sparse.COO.from_numpy(np.eye(5))
22+
x.asformat("gcxs")
23+
y.asformat("gcxs")
24+
25+
z = x + y
26+
result = sparse.sum(z)
27+
assert result.shape == ()
1528

1629

1730
def test_finch_backend():
1831
np_eye = np.eye(5)
1932
sp_arr = sp.csr_matrix(np_eye)
2033

2134
with sparse.Backend(backend=sparse.BackendType.Finch):
22-
finch_dense = sparse.Tensor(np_eye)
35+
import finch
36+
37+
finch_dense = finch.Tensor(np_eye)
2338

2439
assert np.shares_memory(finch_dense.todense(), np_eye)
2540

26-
finch_arr = sparse.Tensor(sp_arr)
41+
finch_arr = finch.Tensor(sp_arr)
2742

28-
np.testing.assert_equal(finch_arr.todense(), np_eye)
43+
assert_equal(finch_arr.todense(), np_eye)
2944

3045
transposed = sparse.permute_dims(finch_arr, (1, 0))
3146

32-
np.testing.assert_equal(transposed.todense(), np_eye.T)
47+
assert_equal(transposed.todense(), np_eye.T)
48+
49+
@sparse.compiled
50+
def my_fun(tns1, tns2):
51+
tmp = sparse.add(tns1, tns2)
52+
return sparse.sum(tmp, axis=0)
53+
54+
result = my_fun(finch_dense, finch_arr)
55+
56+
assert_equal(result.todense(), np.sum(2 * np_eye, axis=0))

0 commit comments

Comments
 (0)