Skip to content

Commit 3f9fb48

Browse files
fyrestone刘宝
andauthored
Fix error message when sparse data format not supported (#3046)
Co-authored-by: 刘宝 <[email protected]>
1 parent 4dabb97 commit 3f9fb48

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

mars/lib/sparse/array.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,13 @@ def __new__(cls, *args, **kwargs):
4646
return object.__new__(SparseMatrix)
4747

4848
else:
49-
from .coo import COONDArray
50-
51-
return object.__new__(COONDArray)
49+
if cls is not SparseNDArray:
50+
return object.__new__(cls)
51+
else:
52+
raise ValueError(
53+
f"The construct params of {cls.__name__} are invalid: "
54+
f"args={args}, kwargs={kwargs}"
55+
)
5256

5357
@property
5458
def raw(self):
@@ -229,6 +233,12 @@ def __getattr__(self, attr):
229233

230234
return super().__getattribute__(attr)
231235

236+
def __getstate__(self):
237+
return self.spmatrix
238+
239+
def __setstate__(self, state):
240+
self.spmatrix = state
241+
232242
def astype(self, dtype, **_):
233243
dtype = np.dtype(dtype)
234244
if self.dtype == dtype:

mars/lib/sparse/tests/test_sparse.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import pickle
18+
1719
import numpy as np
20+
import pytest
1821
import scipy.sparse as sps
1922

2023
from ... import sparse as mls
@@ -47,12 +50,20 @@ def assert_array_equal(a, b, almost=False):
4750

4851

4952
def test_sparse_creation():
53+
with pytest.raises(ValueError):
54+
SparseNDArray()
55+
5056
s = SparseNDArray(s1_data)
5157
assert s.ndim == 2
5258
assert isinstance(s, SparseMatrix)
5359
assert_array_equal(s.toarray(), s1_data.A)
5460
assert_array_equal(s.todense(), s1_data.A)
5561

62+
ss = pickle.loads(pickle.dumps(s))
63+
assert s == ss
64+
assert_array_equal(ss.toarray(), s1_data.A)
65+
assert_array_equal(ss.todense(), s1_data.A)
66+
5667
v = SparseNDArray(v1, shape=(3,))
5768
assert s.ndim
5869
assert isinstance(v, SparseVector)
@@ -61,6 +72,12 @@ def test_sparse_creation():
6172
assert_array_equal(v.toarray(), v1_data)
6273
assert_array_equal(v, v1_data)
6374

75+
vv = pickle.loads(pickle.dumps(v))
76+
assert v == vv
77+
assert_array_equal(vv.todense(), v1_data)
78+
assert_array_equal(vv.toarray(), v1_data)
79+
assert_array_equal(vv, v1_data)
80+
6481

6582
def test_sparse_add():
6683
s1 = SparseNDArray(s1_data)

mars/tensor/base/tests/test_base_execution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def test_copyto_execution(setup):
101101
assert res.flags["C_CONTIGUOUS"] is False
102102

103103

104+
@pytest.mark.ray_dag
104105
def test_astype_execution(setup):
105106
raw = np.random.random((10, 5))
106107
arr = tensor(raw, chunk_size=3)

0 commit comments

Comments
 (0)