@@ -300,37 +300,50 @@ def test_copy():
300300
301301
302302@parametrize_dtypes
303- def test_reshape (rng , dtype ):
303+ @pytest .mark .parametrize (
304+ "format" ,
305+ [
306+ "csr" ,
307+ pytest .param ("csc" , marks = pytest .mark .xfail (reason = "https://github.com/llvm/llvm-project/pull/109135" )),
308+ pytest .param ("coo" , marks = pytest .mark .xfail (reason = "https://github.com/llvm/llvm-project/pull/109641" )),
309+ ],
310+ )
311+ @pytest .mark .parametrize (
312+ ("shape" , "new_shape" ),
313+ [
314+ ((100 , 50 ), (25 , 200 )),
315+ ((100 , 50 ), (10 , 500 , 1 )),
316+ ((80 , 1 ), (8 , 10 )),
317+ ((80 , 1 ), (80 ,)),
318+ ],
319+ )
320+ def test_reshape (rng , dtype , format , shape , new_shape ):
304321 DENSITY = 0.5
305322 sampler = generate_sampler (dtype , rng )
306323
307- # CSR, CSC, COO
308- for shape , new_shape in [
309- ((100 , 50 ), (25 , 200 )),
310- # ((100, 50), (10, 500, 1)),
311- ((80 , 1 ), (8 , 10 )),
312- # ((80, 1), (80,)),
313- ]:
314- for format in ["csr" , "csc" , "coo" ]:
315- if format == "coo" :
316- # NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
317- continue
318- if format == "csc" :
319- # NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
320- continue
321-
322- arr = sps .random_array (
323- shape , density = DENSITY , format = format , dtype = dtype , random_state = rng , data_sampler = sampler
324- )
325- arr .eliminate_zeros ()
326- arr .sum_duplicates ()
327- tensor = sparse .asarray (arr )
328-
329- actual = sparse .to_scipy (sparse .reshape (tensor , shape = new_shape ))
330- expected = arr .todense ().reshape (new_shape )
331-
332- np .testing .assert_array_equal (actual .todense (), expected )
324+ arr_sps = sps .random_array (
325+ shape , density = DENSITY , format = format , dtype = dtype , random_state = rng , data_sampler = sampler
326+ )
327+ arr_sps .eliminate_zeros ()
328+ arr_sps .sum_duplicates ()
329+ arr = sparse .asarray (arr_sps )
330+
331+ actual = sparse .reshape (arr , shape = new_shape )
332+ assert actual .shape == new_shape
333+
334+ try :
335+ scipy_format = sparse .to_scipy (actual ).format
336+ except RuntimeError :
337+ pytest .xfail ("No library to compare to." )
338+
339+ expected = sparse .asarray (arr_sps .reshape (new_shape ).asformat (scipy_format )) if scipy_format is not None else arr
333340
341+ for x , y in zip (expected .get_constituent_arrays (), actual .get_constituent_arrays (), strict = True ):
342+ np .testing .assert_array_equal (x , y )
343+
344+
345+ @parametrize_dtypes
346+ def test_reshape_csf (dtype ):
334347 # CSF
335348 csf_shape = (2 , 2 , 4 )
336349 csf_format = sparse .levels .get_storage_format (
@@ -372,7 +385,6 @@ def test_reshape(rng, dtype):
372385 csf_tensor = sparse .from_constituent_arrays (format = csf_format , arrays = arrs , shape = shape )
373386
374387 result = sparse .reshape (csf_tensor , shape = new_shape )
375-
376388 for actual , expected in zip (result .get_constituent_arrays (), expected_arrs , strict = True ):
377389 np .testing .assert_array_equal (actual , expected )
378390
0 commit comments