1+ use std:: ops:: Deref ;
2+
13use anndata:: { container:: Axis , data:: DynCsrMatrix , ArrayData } ;
2- use nalgebra_sparse:: { CooMatrix , CsrMatrix } ;
34use anndata_memory:: { IMAnnData , IMArrayElement } ;
5+ use nalgebra_sparse:: { CooMatrix , CsrMatrix } ;
46
57fn create_test_data ( ) -> ( ArrayData , Vec < String > , Vec < String > ) {
68 let nrows = 3 ;
@@ -10,20 +12,109 @@ fn create_test_data() -> (ArrayData, Vec<String>, Vec<String>) {
1012 let mut coo_matrix = CooMatrix :: new ( nrows, ncols) ;
1113
1214 // Add some non-zero elements (row, col, value)
13- coo_matrix. push ( 0 , 0 , 1.0 ) ; // element at (0, 0) = 1.0
14- coo_matrix. push ( 1 , 2 , 2.0 ) ; // element at (1, 2) = 2.0
15- coo_matrix. push ( 2 , 1 , 3.0 ) ; // element at (2, 1) = 3.0
16- coo_matrix. push ( 2 , 2 , 4.0 ) ; // element at (2, 2) = 4.0
15+ coo_matrix. push ( 0 , 0 , 1.0 ) ; // element at (0, 0) = 1.0
16+ coo_matrix. push ( 1 , 2 , 2.0 ) ; // element at (1, 2) = 2.0
17+ coo_matrix. push ( 2 , 1 , 3.0 ) ; // element at (2, 1) = 3.0
18+ coo_matrix. push ( 2 , 2 , 4.0 ) ; // element at (2, 2) = 4.0
1719
1820 // Optionally, you can convert the COO matrix to a more efficient CSR format
1921 let csr_matrix: CsrMatrix < f64 > = CsrMatrix :: from ( & coo_matrix) ;
20-
22+
2123 let matrix = DynCsrMatrix :: from ( csr_matrix) ;
2224 let obs_names = vec ! [ "obs1" . to_string( ) , "obs2" . to_string( ) , "obs3" . to_string( ) ] ;
2325 let var_names = vec ! [ "var1" . to_string( ) , "var2" . to_string( ) , "var3" . to_string( ) ] ;
2426 ( ArrayData :: CsrMatrix ( matrix) , obs_names, var_names)
2527}
2628
29+ #[ test]
30+ fn test_convert_matrix_format ( ) {
31+ // Create test data using CooMatrix
32+ let coo = CooMatrix :: try_from_triplets (
33+ 5 ,
34+ 4 , // 5x4 matrix
35+ vec ! [ 0 , 1 , 1 , 2 , 3 , 4 ] , // row indices
36+ vec ! [ 0 , 1 , 2 , 3 , 1 , 3 ] , // column indices
37+ vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] , // values
38+ )
39+ . unwrap ( ) ;
40+
41+ // Convert to CSR format
42+ let csr = CsrMatrix :: from ( & coo) ;
43+ let array_data = ArrayData :: CsrMatrix ( DynCsrMatrix :: F64 ( csr) ) ;
44+ let mut matrix = IMArrayElement :: new ( array_data) ;
45+
46+ // Convert CSR to CSC
47+ matrix. convert_matrix_format ( ) . unwrap ( ) ;
48+
49+ // Verify it's now CSC
50+ {
51+ let read_guard = matrix. 0 . read_inner ( ) ;
52+ match read_guard. deref ( ) {
53+ ArrayData :: CscMatrix ( _) => ( ) ,
54+ _ => panic ! ( "Matrix should be in CSC format" ) ,
55+ }
56+ } // read_guard is dropped here
57+
58+ // Convert CSC back to CSR
59+ matrix. convert_matrix_format ( ) . unwrap ( ) ;
60+
61+ // Verify it's back to CSR and check content
62+ {
63+ let read_guard = matrix. 0 . read_inner ( ) ;
64+ match read_guard. deref ( ) {
65+ ArrayData :: CsrMatrix ( csr) => {
66+ if let DynCsrMatrix :: F64 ( m) = csr {
67+ // Verify the matrix content is preserved
68+ assert_eq ! ( m. nrows( ) , 5 ) ;
69+ assert_eq ! ( m. ncols( ) , 4 ) ;
70+ assert_eq ! ( m. nnz( ) , 6 ) ;
71+
72+ // Check specific values
73+ assert_eq ! (
74+ m. triplet_iter( )
75+ . find( |& ( i, j, & v) | i == 0 && j == 0 )
76+ . map( |( _, _, & v) | v) ,
77+ Some ( 1.0 )
78+ ) ;
79+ assert_eq ! (
80+ m. triplet_iter( )
81+ . find( |& ( i, j, & v) | i == 1 && j == 1 )
82+ . map( |( _, _, & v) | v) ,
83+ Some ( 2.0 )
84+ ) ;
85+ assert_eq ! (
86+ m. triplet_iter( )
87+ . find( |& ( i, j, & v) | i == 1 && j == 2 )
88+ . map( |( _, _, & v) | v) ,
89+ Some ( 3.0 )
90+ ) ;
91+ assert_eq ! (
92+ m. triplet_iter( )
93+ . find( |& ( i, j, & v) | i == 2 && j == 3 )
94+ . map( |( _, _, & v) | v) ,
95+ Some ( 4.0 )
96+ ) ;
97+ assert_eq ! (
98+ m. triplet_iter( )
99+ . find( |& ( i, j, & v) | i == 3 && j == 1 )
100+ . map( |( _, _, & v) | v) ,
101+ Some ( 5.0 )
102+ ) ;
103+ assert_eq ! (
104+ m. triplet_iter( )
105+ . find( |& ( i, j, & v) | i == 4 && j == 3 )
106+ . map( |( _, _, & v) | v) ,
107+ Some ( 6.0 )
108+ ) ;
109+ } else {
110+ panic ! ( "Expected F64 matrix" ) ;
111+ }
112+ }
113+ _ => panic ! ( "Matrix should be in CSR format" ) ,
114+ }
115+ } // read_guard is dropped here
116+ }
117+
27118#[ test]
28119fn test_new_basic ( ) {
29120 let ( matrix, obs_names, var_names) = create_test_data ( ) ;
@@ -119,4 +210,4 @@ fn test_uns() {
119210
120211 let uns = adata. uns ( ) ;
121212 assert ! ( uns. get_data( "test_key" ) . is_err( ) ) ;
122- }
213+ }
0 commit comments