@@ -85,7 +85,7 @@ def sampler_complex_floating(size: tuple[int, ...]):
8585 raise NotImplementedError (f"{ dtype = } not yet supported." )
8686
8787
88- def get_exampe_csf_arrays (dtype : np .dtype ) -> tuple :
88+ def get_example_csf_arrays (dtype : np .dtype ) -> tuple :
8989 pos_1 = np .array ([0 , 1 , 3 ], dtype = np .int64 )
9090 crd_1 = np .array ([1 , 0 , 1 ], dtype = np .int64 )
9191 pos_2 = np .array ([0 , 3 , 5 , 7 ], dtype = np .int64 )
@@ -207,7 +207,7 @@ def test_csf_format(dtype):
207207 )
208208
209209 SHAPE = (2 , 2 , 4 )
210- pos_1 , crd_1 , pos_2 , crd_2 , data = get_exampe_csf_arrays (dtype )
210+ pos_1 , crd_1 , pos_2 , crd_2 , data = get_example_csf_arrays (dtype )
211211 constituent_arrays = (pos_1 , crd_1 , pos_2 , crd_2 , data )
212212
213213 csf_array = sparse .from_constituent_arrays (format = format , arrays = constituent_arrays , shape = SHAPE )
@@ -297,3 +297,85 @@ def test_copy():
297297 np .testing .assert_array_equal (sparse .to_numpy (arr_sp1 ), arr_np_orig )
298298 np .testing .assert_array_equal (sparse .to_numpy (arr_sp2 ), arr_np_orig )
299299 np .testing .assert_array_equal (sparse .to_numpy (arr_sp3 ), arr_np_copy )
300+
301+
302+ @parametrize_dtypes
303+ def test_reshape (rng , dtype ):
304+ DENSITY = 0.5
305+ sampler = generate_sampler (dtype , rng )
306+
307+ # CSR, CSC, COO
308+ for shape , new_shape in [
309+ ((100 , 50 ), (25 , 200 )),
310+ # ((100, 50), (10, 500, 1)),
311+ ((80 , 1 ), (8 , 10 )),
312+ # ((80, 1), (80,)),
313+ ]:
314+ for format in ["csr" , "csc" , "coo" ]:
315+ if format == "coo" :
316+ # NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
317+ continue
318+ if format == "csc" :
319+ # NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
320+ continue
321+
322+ arr = sps .random_array (
323+ shape , density = DENSITY , format = format , dtype = dtype , random_state = rng , data_sampler = sampler
324+ )
325+ arr .eliminate_zeros ()
326+ arr .sum_duplicates ()
327+ tensor = sparse .asarray (arr )
328+
329+ actual = sparse .to_scipy (sparse .reshape (tensor , shape = new_shape ))
330+ expected = arr .todense ().reshape (new_shape )
331+
332+ np .testing .assert_array_equal (actual .todense (), expected )
333+
334+ # CSF
335+ csf_shape = (2 , 2 , 4 )
336+ csf_format = sparse .levels .get_storage_format (
337+ levels = (
338+ sparse .levels .Level (sparse .levels .LevelFormat .Dense ),
339+ sparse .levels .Level (sparse .levels .LevelFormat .Compressed ),
340+ sparse .levels .Level (sparse .levels .LevelFormat .Compressed ),
341+ ),
342+ order = "C" ,
343+ pos_width = 64 ,
344+ crd_width = 64 ,
345+ dtype = sparse .asdtype (dtype ),
346+ )
347+ for shape , new_shape , expected_arrs in [
348+ (
349+ csf_shape ,
350+ (4 , 4 , 1 ),
351+ [
352+ np .array ([0 , 0 , 3 , 5 , 7 ]),
353+ np .array ([0 , 1 , 3 , 0 , 3 , 0 , 1 ]),
354+ np .array ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
355+ np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 ]),
356+ np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
357+ ],
358+ ),
359+ (
360+ csf_shape ,
361+ (2 , 1 , 8 ),
362+ [
363+ np .array ([0 , 1 , 2 ]),
364+ np .array ([0 , 0 ]),
365+ np .array ([0 , 3 , 7 ]),
366+ np .array ([4 , 5 , 7 , 0 , 3 , 4 , 5 ]),
367+ np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
368+ ],
369+ ),
370+ ]:
371+ arrs = get_example_csf_arrays (dtype )
372+ csf_tensor = sparse .from_constituent_arrays (format = csf_format , arrays = arrs , shape = shape )
373+
374+ result = sparse .reshape (csf_tensor , shape = new_shape )
375+
376+ for actual , expected in zip (result .get_constituent_arrays (), expected_arrs , strict = True ):
377+ np .testing .assert_array_equal (actual , expected )
378+
379+ # DENSE
380+ # NOTE: dense reshape is probably broken in MLIR in 19.x branch
381+ # dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
0 commit comments