Skip to content

Commit 7706a08

Browse files
committed
added inplace CSC <-> CSR conversion version bumped to 0.1.2-alpha.0
1 parent 6dc27d2 commit 7706a08

File tree

4 files changed

+164
-11
lines changed

4 files changed

+164
-11
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "anndata-memory"
3-
version = "0.1.1-alpha.2"
3+
version = "0.1.2-alpha.0"
44
edition = "2021"
55
readme = "README.md"
66
repository = "https://github.com/SingleRust/Anndata-Memory"
@@ -9,7 +9,6 @@ include = [
99
"**/*.rs",
1010
"Cargo.toml",
1111
]
12-
license = "BSD-3-Clause"
1312
license-file = "LICENSE.md"
1413

1514
[dependencies]

src/ad/helpers.rs

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@ use std::{
22
collections::HashMap,
33
fmt,
44
ops::{Deref, DerefMut},
5+
sync::Arc,
56
};
67

78
use anndata::{
89
backend::DataType,
910
container::{Axis, Dim},
10-
data::{DataFrameIndex, SelectInfoElem, Shape},
11+
data::{DataFrameIndex, DynArray, DynCscMatrix, DynCsrMatrix, SelectInfoElem, Shape},
1112
ArrayData, ArrayOp, Data, HasShape, WriteData,
1213
};
14+
use anyhow::{anyhow, bail};
15+
16+
use ndarray::Array2;
1317
use polars::{
1418
frame::DataFrame,
1519
prelude::{IdxCa, NamedFrom},
@@ -61,6 +65,65 @@ impl IMArrayElement {
6165
Ok(())
6266
}
6367

68+
pub fn convert_matrix_format(&self) -> anyhow::Result<()> {
69+
let mut write_guard = self.0.write_inner();
70+
let d = write_guard.deref_mut();
71+
72+
// Create a placeholder that we can swap with - use an empty dense array as it's likely the smallest
73+
let ddata: Array2<f64> = Array2::zeros((0, 0));
74+
let placeholder = ArrayData::Array(DynArray::from(ddata));
75+
76+
// Take ownership using replace
77+
let matrix_data = std::mem::replace(d, placeholder);
78+
79+
let converted = match matrix_data {
80+
ArrayData::CsrMatrix(dyn_csr_matrix) => {
81+
let csc = match dyn_csr_matrix {
82+
DynCsrMatrix::F64(m) => DynCscMatrix::F64(m.transpose_as_csc()),
83+
DynCsrMatrix::F32(m) => DynCscMatrix::F32(m.transpose_as_csc()),
84+
DynCsrMatrix::I64(m) => DynCscMatrix::I64(m.transpose_as_csc()),
85+
DynCsrMatrix::I32(m) => DynCscMatrix::I32(m.transpose_as_csc()),
86+
DynCsrMatrix::I16(m) => DynCscMatrix::I16(m.transpose_as_csc()),
87+
DynCsrMatrix::I8(m) => DynCscMatrix::I8(m.transpose_as_csc()),
88+
DynCsrMatrix::U64(m) => DynCscMatrix::U64(m.transpose_as_csc()),
89+
DynCsrMatrix::U32(m) => DynCscMatrix::U32(m.transpose_as_csc()),
90+
DynCsrMatrix::U16(m) => DynCscMatrix::U16(m.transpose_as_csc()),
91+
DynCsrMatrix::U8(m) => DynCscMatrix::U8(m.transpose_as_csc()),
92+
DynCsrMatrix::Bool(m) => DynCscMatrix::Bool(m.transpose_as_csc()),
93+
DynCsrMatrix::String(m) => DynCscMatrix::String(m.transpose_as_csc()),
94+
DynCsrMatrix::Usize(m) => DynCscMatrix::Usize(m.transpose_as_csc()),
95+
};
96+
ArrayData::CscMatrix(csc)
97+
}
98+
ArrayData::CscMatrix(dyn_csc_matrix) => {
99+
let csr = match dyn_csc_matrix {
100+
DynCscMatrix::F64(m) => DynCsrMatrix::F64(m.transpose_as_csr()),
101+
DynCscMatrix::F32(m) => DynCsrMatrix::F32(m.transpose_as_csr()),
102+
DynCscMatrix::I64(m) => DynCsrMatrix::I64(m.transpose_as_csr()),
103+
DynCscMatrix::I32(m) => DynCsrMatrix::I32(m.transpose_as_csr()),
104+
DynCscMatrix::I16(m) => DynCsrMatrix::I16(m.transpose_as_csr()),
105+
DynCscMatrix::I8(m) => DynCsrMatrix::I8(m.transpose_as_csr()),
106+
DynCscMatrix::U64(m) => DynCsrMatrix::U64(m.transpose_as_csr()),
107+
DynCscMatrix::U32(m) => DynCsrMatrix::U32(m.transpose_as_csr()),
108+
DynCscMatrix::U16(m) => DynCsrMatrix::U16(m.transpose_as_csr()),
109+
DynCscMatrix::U8(m) => DynCsrMatrix::U8(m.transpose_as_csr()),
110+
DynCscMatrix::Bool(m) => DynCsrMatrix::Bool(m.transpose_as_csr()),
111+
DynCscMatrix::String(m) => DynCsrMatrix::String(m.transpose_as_csr()),
112+
DynCscMatrix::Usize(m) => DynCsrMatrix::Usize(m.transpose_as_csr()),
113+
};
114+
ArrayData::CsrMatrix(csr)
115+
}
116+
_ => {
117+
// Put back the original value since we're erroring
118+
*d = matrix_data;
119+
bail!("This datatype is not supported, only CSC and CSR matrices are supported.")
120+
}
121+
};
122+
123+
*d = converted;
124+
Ok(())
125+
}
126+
64127
pub fn subset(&self, s: &[&SelectInfoElem]) -> anyhow::Result<Self> {
65128
let read_guard = self.0.read_inner();
66129
let d = read_guard.deref();

tests/test_basic.rs

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use std::ops::Deref;
2+
13
use anndata::{container::Axis, data::DynCsrMatrix, ArrayData};
2-
use nalgebra_sparse::{CooMatrix, CsrMatrix};
34
use anndata_memory::{IMAnnData, IMArrayElement};
5+
use nalgebra_sparse::{CooMatrix, CsrMatrix};
46

57
fn create_test_data() -> (ArrayData, Vec<String>, Vec<String>) {
68
let nrows = 3;
@@ -10,20 +12,109 @@ fn create_test_data() -> (ArrayData, Vec<String>, Vec<String>) {
1012
let mut coo_matrix = CooMatrix::new(nrows, ncols);
1113

1214
// Add some non-zero elements (row, col, value)
13-
coo_matrix.push(0, 0, 1.0); // element at (0, 0) = 1.0
14-
coo_matrix.push(1, 2, 2.0); // element at (1, 2) = 2.0
15-
coo_matrix.push(2, 1, 3.0); // element at (2, 1) = 3.0
16-
coo_matrix.push(2, 2, 4.0); // element at (2, 2) = 4.0
15+
coo_matrix.push(0, 0, 1.0); // element at (0, 0) = 1.0
16+
coo_matrix.push(1, 2, 2.0); // element at (1, 2) = 2.0
17+
coo_matrix.push(2, 1, 3.0); // element at (2, 1) = 3.0
18+
coo_matrix.push(2, 2, 4.0); // element at (2, 2) = 4.0
1719

1820
// Optionally, you can convert the COO matrix to a more efficient CSR format
1921
let csr_matrix: CsrMatrix<f64> = CsrMatrix::from(&coo_matrix);
20-
22+
2123
let matrix = DynCsrMatrix::from(csr_matrix);
2224
let obs_names = vec!["obs1".to_string(), "obs2".to_string(), "obs3".to_string()];
2325
let var_names = vec!["var1".to_string(), "var2".to_string(), "var3".to_string()];
2426
(ArrayData::CsrMatrix(matrix), obs_names, var_names)
2527
}
2628

29+
#[test]
30+
fn test_convert_matrix_format() {
31+
// Create test data using CooMatrix
32+
let coo = CooMatrix::try_from_triplets(
33+
5,
34+
4, // 5x4 matrix
35+
vec![0, 1, 1, 2, 3, 4], // row indices
36+
vec![0, 1, 2, 3, 1, 3], // column indices
37+
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], // values
38+
)
39+
.unwrap();
40+
41+
// Convert to CSR format
42+
let csr = CsrMatrix::from(&coo);
43+
let array_data = ArrayData::CsrMatrix(DynCsrMatrix::F64(csr));
44+
let mut matrix = IMArrayElement::new(array_data);
45+
46+
// Convert CSR to CSC
47+
matrix.convert_matrix_format().unwrap();
48+
49+
// Verify it's now CSC
50+
{
51+
let read_guard = matrix.0.read_inner();
52+
match read_guard.deref() {
53+
ArrayData::CscMatrix(_) => (),
54+
_ => panic!("Matrix should be in CSC format"),
55+
}
56+
} // read_guard is dropped here
57+
58+
// Convert CSC back to CSR
59+
matrix.convert_matrix_format().unwrap();
60+
61+
// Verify it's back to CSR and check content
62+
{
63+
let read_guard = matrix.0.read_inner();
64+
match read_guard.deref() {
65+
ArrayData::CsrMatrix(csr) => {
66+
if let DynCsrMatrix::F64(m) = csr {
67+
// Verify the matrix content is preserved
68+
assert_eq!(m.nrows(), 5);
69+
assert_eq!(m.ncols(), 4);
70+
assert_eq!(m.nnz(), 6);
71+
72+
// Check specific values
73+
assert_eq!(
74+
m.triplet_iter()
75+
.find(|&(i, j, &v)| i == 0 && j == 0)
76+
.map(|(_, _, &v)| v),
77+
Some(1.0)
78+
);
79+
assert_eq!(
80+
m.triplet_iter()
81+
.find(|&(i, j, &v)| i == 1 && j == 1)
82+
.map(|(_, _, &v)| v),
83+
Some(2.0)
84+
);
85+
assert_eq!(
86+
m.triplet_iter()
87+
.find(|&(i, j, &v)| i == 1 && j == 2)
88+
.map(|(_, _, &v)| v),
89+
Some(3.0)
90+
);
91+
assert_eq!(
92+
m.triplet_iter()
93+
.find(|&(i, j, &v)| i == 2 && j == 3)
94+
.map(|(_, _, &v)| v),
95+
Some(4.0)
96+
);
97+
assert_eq!(
98+
m.triplet_iter()
99+
.find(|&(i, j, &v)| i == 3 && j == 1)
100+
.map(|(_, _, &v)| v),
101+
Some(5.0)
102+
);
103+
assert_eq!(
104+
m.triplet_iter()
105+
.find(|&(i, j, &v)| i == 4 && j == 3)
106+
.map(|(_, _, &v)| v),
107+
Some(6.0)
108+
);
109+
} else {
110+
panic!("Expected F64 matrix");
111+
}
112+
}
113+
_ => panic!("Matrix should be in CSR format"),
114+
}
115+
} // read_guard is dropped here
116+
}
117+
27118
#[test]
28119
fn test_new_basic() {
29120
let (matrix, obs_names, var_names) = create_test_data();
@@ -119,4 +210,4 @@ fn test_uns() {
119210

120211
let uns = adata.uns();
121212
assert!(uns.get_data("test_key").is_err());
122-
}
213+
}

0 commit comments

Comments
 (0)