3131)
3232
3333
34- def assert_csx_equal (
35- expected : sps .csr_array | sps .csc_array ,
36- actual : sps .csr_array | sps .csc_array ,
34+ def assert_sps_equal (
35+ expected : sps .csr_array | sps .csc_array | sps . coo_array ,
36+ actual : sps .csr_array | sps .csc_array | sps . coo_array ,
3737) -> None :
3838 assert expected .format == actual .format
3939 expected .eliminate_zeros ()
@@ -42,8 +42,13 @@ def assert_csx_equal(
4242 actual .eliminate_zeros ()
4343 actual .sum_duplicates ()
4444
45- np .testing .assert_array_equal (expected .indptr , actual .indptr )
46- np .testing .assert_array_equal (expected .indices , actual .indices )
45+ if expected .format != "coo" :
46+ np .testing .assert_array_equal (expected .indptr , actual .indptr )
47+ np .testing .assert_array_equal (expected .indices , actual .indices )
48+ else :
49+ np .testing .assert_array_equal (expected .row , actual .col )
50+ np .testing .assert_array_equal (expected .row , actual .col )
51+
4752 np .testing .assert_array_equal (expected .data , actual .data )
4853
4954
@@ -121,10 +126,10 @@ def test_2d_constructors(rng, dtype):
121126 dense_2_tensor = sparse .asarray (np .arange (100 , dtype = dtype ).reshape ((25 , 4 )) + 10 )
122127
123128 csr_retured = sparse .to_scipy (csr_tensor )
124- assert_csx_equal (csr_retured , csr )
129+ assert_sps_equal (csr_retured , csr )
125130
126131 csc_retured = sparse .to_scipy (csc_tensor )
127- assert_csx_equal (csc_retured , csc )
132+ assert_sps_equal (csc_retured , csc )
128133
129134 dense_returned = sparse .to_numpy (dense_tensor )
130135 np .testing .assert_equal (dense_returned , dense )
@@ -157,15 +162,15 @@ def test_add(rng, dtype):
157162
158163 actual = sparse .to_scipy (sparse .add (csr_tensor , csr_2_tensor ))
159164 expected = csr + csr_2
160- assert_csx_equal (expected , actual )
165+ assert_sps_equal (expected , actual )
161166
162167 actual = sparse .to_scipy (sparse .add (csc_tensor , csc_tensor ))
163168 expected = csc + csc
164- assert_csx_equal (expected , actual )
169+ assert_sps_equal (expected , actual )
165170
166171 actual = sparse .to_scipy (sparse .add (csc_tensor , csr_tensor ))
167172 expected = (csc + csr ).asformat ("csr" )
168- assert_csx_equal (expected , actual )
173+ assert_sps_equal (expected , actual )
169174
170175 actual = sparse .to_numpy (sparse .add (csr_tensor , dense_tensor ))
171176 expected = csr + dense
@@ -183,7 +188,7 @@ def test_add(rng, dtype):
183188
184189 actual = sparse .to_scipy (sparse .add (csr_2_tensor , coo_tensor ))
185190 expected = csr_2 + coo
186- assert_csx_equal (expected , actual )
191+ assert_sps_equal (expected , actual )
187192
188193 # This ends up being DCSR, not COO
189194 actual_tensor = sparse .add (coo_tensor , coo_tensor )
@@ -307,7 +312,7 @@ def test_copy():
307312 [
308313 "csr" ,
309314 pytest .param ("csc" , marks = pytest .mark .xfail (reason = "https://github.com/llvm/llvm-project/pull/109641" )),
310- pytest . param ( "coo" , marks = pytest . mark . xfail ( reason = "https://github.com/llvm/llvm-project/pull/109135" )) ,
315+ "coo" ,
311316 ],
312317)
313318@pytest .mark .parametrize (
@@ -390,6 +395,45 @@ def test_reshape_csf(dtype):
390395 for actual , expected in zip (result .get_constituent_arrays (), expected_arrs , strict = True ):
391396 np .testing .assert_array_equal (actual , expected )
392397
393- # DENSE
394- # NOTE: dense reshape is probably broken in MLIR in 19.x branch
395- # dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
398+
399+ @parametrize_dtypes
400+ def test_reshape_dense (dtype ):
401+ SHAPE = (2 , 2 , 4 )
402+
403+ np_arr = np .arange (math .prod (SHAPE ), dtype = dtype ).reshape (SHAPE )
404+ sp_arr = sparse .asarray (np_arr )
405+
406+ for new_shape in [
407+ (4 , 4 , 1 ),
408+ (2 , 1 , 8 ),
409+ ]:
410+ expected = np_arr .reshape (new_shape )
411+ actual = sparse .reshape (sp_arr , new_shape )
412+
413+ actual_np = sparse .to_numpy (actual )
414+
415+ assert actual_np .dtype == expected .dtype
416+ np .testing .assert_equal (actual_np , expected )
417+
418+
419+ @pytest .mark .skip (reason = "Segfault" )
420+ @pytest .mark .parametrize ("src_fmt" , ["csr" , "csc" ])
421+ @pytest .mark .parametrize ("dst_fmt" , ["csr" , "csc" ])
422+ def test_asformat (rng , src_fmt , dst_fmt ):
423+ SHAPE = (100 , 50 )
424+ DENSITY = 0.5
425+ sampler = generate_sampler (np .float64 , rng )
426+
427+ sps_arr = sps .random_array (
428+ SHAPE , density = DENSITY , format = src_fmt , dtype = np .float64 , random_state = rng , data_sampler = sampler
429+ )
430+ sp_arr = sparse .asarray (sps_arr )
431+
432+ expected = sps_arr .asformat (dst_fmt )
433+
434+ actual_fmt = sparse .asarray (expected , copy = False ).format
435+ actual = sp_arr .asformat (actual_fmt )
436+ actual_sps = sparse .to_scipy (actual )
437+
438+ assert actual_sps .format == dst_fmt
439+ assert_sps_equal (expected , actual_sps )
0 commit comments