Skip to content

Commit cd4f936

Browse files
authored
Merge pull request #5 from SingleRust/feature-dev-load-speedup
Feature dev load speedup
2 parents cec8577 + dec5a04 commit cd4f936

File tree

7 files changed

+963
-5
lines changed

7 files changed

+963
-5
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "anndata-memory"
3-
version = "1.0.2"
3+
version = "1.0.3"
44
edition = "2021"
55
readme = "README.md"
66
repository = "https://github.com/SingleRust/Anndata-Memory"

src/chunked_loader.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
use anndata::{
2+
backend::{AttributeOp, Backend, DataContainer, DatasetOp, GroupOp, ScalarType},
3+
data::{ArrayData, SelectInfoElem},
4+
ArrayElemOp,
5+
};
6+
use nalgebra_sparse::{pattern::SparsityPattern, CsrMatrix};
7+
use ndarray::Ix1;
8+
9+
use crate::{utils::{read_array_as_usize_optimized, read_array_slice_as_usize}, LoadingConfig};
10+
11+
pub fn load_csr_chunked<B: Backend>(
12+
container: &DataContainer<B>,
13+
config: &LoadingConfig,
14+
) -> anyhow::Result<ArrayData> {
15+
let group = container.as_group()?;
16+
let shape: Vec<u64> = group.get_attr("shape")?;
17+
let (nrows, ncols) = (shape[0] as usize, shape[1] as usize);
18+
19+
let data_ds = group.open_dataset("data")?;
20+
let indices_ds = group.open_dataset("indices")?;
21+
let indptr_ds = group.open_dataset("indptr")?;
22+
23+
let indptr = read_array_as_usize_optimized::<B>(&indptr_ds)?;
24+
let nnz = data_ds.shape()[0];
25+
26+
if config.show_progress && nnz > 10_000_000 {
27+
println!("Loading CSR matrix: {} rows, {} cols, {} non-zeros", nrows, ncols, nnz);
28+
}
29+
30+
use ScalarType::*;
31+
match data_ds.dtype()? {
32+
F64 => load_csr_typed::<B, f64>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
33+
F32 => load_csr_typed::<B, f32>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
34+
I64 => load_csr_typed::<B, i64>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
35+
I32 => load_csr_typed::<B, i32>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
36+
I16 => load_csr_typed::<B, i16>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
37+
I8 => load_csr_typed::<B, i8>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
38+
U64 => load_csr_typed::<B, u64>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
39+
U32 => load_csr_typed::<B, u32>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
40+
U16 => load_csr_typed::<B, u16>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
41+
U8 => load_csr_typed::<B, u8>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
42+
dt => anyhow::bail!("Unsupported data type for CSR matrix: {:?}", dt),
43+
}
44+
}
45+
46+
fn load_csr_typed<B: Backend, T: anndata::backend::BackendData>(
47+
nrows: usize,
48+
ncols: usize,
49+
nnz: usize,
50+
indptr: Vec<usize>,
51+
data_ds: &B::Dataset,
52+
indices_ds: &B::Dataset,
53+
config: &LoadingConfig,
54+
) -> anyhow::Result<ArrayData>
55+
where
56+
ArrayData: From<CsrMatrix<T>>,
57+
{
58+
let chunk_size = ((config.chunk_size_mb << 20) / (std::mem::size_of::<T>() + 8)).max(1000);
59+
60+
let mut data = Vec::with_capacity(nnz);
61+
let mut indices = Vec::with_capacity(nnz);
62+
63+
let show_progress = config.show_progress && nnz > 10_000_000;
64+
let progress_interval = if show_progress { nnz / 10 } else { usize::MAX };
65+
let mut next_progress = progress_interval;
66+
67+
let mut offset = 0;
68+
while offset < nnz {
69+
let chunk_end = (offset + chunk_size).min(nnz);
70+
let range = [SelectInfoElem::from(offset..chunk_end)];
71+
let data_array = data_ds.read_array_slice::<T, _, Ix1>(&range)?;
72+
let (data_vec, data_offset) = data_array.into_raw_vec_and_offset();
73+
if data_offset.is_none() {
74+
data.extend(data_vec);
75+
} else {
76+
data.extend(data_vec);
77+
}
78+
79+
let indices_chunk = read_array_slice_as_usize::<B>(indices_ds, &range)?;
80+
indices.extend(indices_chunk);
81+
82+
offset = chunk_end;
83+
84+
if show_progress && offset >= next_progress {
85+
println!("Loading CSR matrix: {}%", offset * 100 / nnz);
86+
next_progress += progress_interval;
87+
}
88+
}
89+
90+
if show_progress {
91+
println!("Constructing CSR matrix structure...");
92+
}
93+
94+
let pattern = unsafe { SparsityPattern::from_offset_and_indices_unchecked(nrows, ncols, indptr, indices) };
95+
CsrMatrix::try_from_pattern_and_values(pattern, data)
96+
.map(ArrayData::from)
97+
.map_err(|e| anyhow::anyhow!("Failed to construct CSR matrix: {}", e))
98+
}

src/lib.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ mod ad;
22
mod base;
33
mod converter;
44
pub(crate) mod utils;
5+
pub(crate) mod chunked_loader;
6+
pub(crate) mod optimized_loader;
7+
mod loader;
58

69
pub use ad::IMAnnData;
710
pub use ad::helpers::IMArrayElement;
@@ -12,4 +15,32 @@ pub use ad::helpers::IMAxisArrays;
1215
pub use converter::convert_to_in_memory;
1316
pub use converter::convert_to_backed;
1417
pub use converter::convert_to_new_backed_h5;
15-
pub use base::DeepClone;
18+
pub use base::DeepClone;
19+
20+
#[derive(Clone, Debug)]
21+
pub enum LoadingStrategy {
22+
Auto,
23+
ForceComplete,
24+
ForceChunked,
25+
}
26+
27+
#[derive(Clone, Debug)]
28+
pub struct LoadingConfig {
29+
pub loading_strategy: LoadingStrategy,
30+
pub chunk_size_mb: usize,
31+
pub memory_threshold_mb: usize,
32+
pub show_progress: bool,
33+
}
34+
35+
impl Default for LoadingConfig {
36+
fn default() -> Self {
37+
Self {
38+
loading_strategy: LoadingStrategy::Auto,
39+
chunk_size_mb: 100,
40+
memory_threshold_mb: 1024,
41+
show_progress: true,
42+
}
43+
}
44+
}
45+
46+
pub use loader::{load_h5ad, load_h5ad_fast, load_h5ad_conservative, load_h5ad_with_config};

0 commit comments

Comments
 (0)