1+ use anndata:: {
2+ backend:: { AttributeOp , Backend , DataContainer , DatasetOp , GroupOp , ScalarType } ,
3+ data:: { ArrayData , SelectInfoElem } ,
4+ ArrayElemOp ,
5+ } ;
6+ use nalgebra_sparse:: { pattern:: SparsityPattern , CsrMatrix } ;
7+ use ndarray:: Ix1 ;
8+
9+ use crate :: { utils:: { read_array_as_usize_optimized, read_array_slice_as_usize} , LoadingConfig } ;
10+
11+ pub fn load_csr_chunked < B : Backend > (
12+ container : & DataContainer < B > ,
13+ config : & LoadingConfig ,
14+ ) -> anyhow:: Result < ArrayData > {
15+ let group = container. as_group ( ) ?;
16+ let shape: Vec < u64 > = group. get_attr ( "shape" ) ?;
17+ let ( nrows, ncols) = ( shape[ 0 ] as usize , shape[ 1 ] as usize ) ;
18+
19+ let data_ds = group. open_dataset ( "data" ) ?;
20+ let indices_ds = group. open_dataset ( "indices" ) ?;
21+ let indptr_ds = group. open_dataset ( "indptr" ) ?;
22+
23+ let indptr = read_array_as_usize_optimized :: < B > ( & indptr_ds) ?;
24+ let nnz = data_ds. shape ( ) [ 0 ] ;
25+
26+ if config. show_progress && nnz > 10_000_000 {
27+ println ! ( "Loading CSR matrix: {} rows, {} cols, {} non-zeros" , nrows, ncols, nnz) ;
28+ }
29+
30+ use ScalarType :: * ;
31+ match data_ds. dtype ( ) ? {
32+ F64 => load_csr_typed :: < B , f64 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
33+ F32 => load_csr_typed :: < B , f32 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
34+ I64 => load_csr_typed :: < B , i64 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
35+ I32 => load_csr_typed :: < B , i32 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
36+ I16 => load_csr_typed :: < B , i16 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
37+ I8 => load_csr_typed :: < B , i8 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
38+ U64 => load_csr_typed :: < B , u64 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
39+ U32 => load_csr_typed :: < B , u32 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
40+ U16 => load_csr_typed :: < B , u16 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
41+ U8 => load_csr_typed :: < B , u8 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
42+ dt => anyhow:: bail!( "Unsupported data type for CSR matrix: {:?}" , dt) ,
43+ }
44+ }
45+
46+ fn load_csr_typed < B : Backend , T : anndata:: backend:: BackendData > (
47+ nrows : usize ,
48+ ncols : usize ,
49+ nnz : usize ,
50+ indptr : Vec < usize > ,
51+ data_ds : & B :: Dataset ,
52+ indices_ds : & B :: Dataset ,
53+ config : & LoadingConfig ,
54+ ) -> anyhow:: Result < ArrayData >
55+ where
56+ ArrayData : From < CsrMatrix < T > > ,
57+ {
58+ let chunk_size = ( ( config. chunk_size_mb << 20 ) / ( std:: mem:: size_of :: < T > ( ) + 8 ) ) . max ( 1000 ) ;
59+
60+ let mut data = Vec :: with_capacity ( nnz) ;
61+ let mut indices = Vec :: with_capacity ( nnz) ;
62+
63+ let show_progress = config. show_progress && nnz > 10_000_000 ;
64+ let progress_interval = if show_progress { nnz / 10 } else { usize:: MAX } ;
65+ let mut next_progress = progress_interval;
66+
67+ let mut offset = 0 ;
68+ while offset < nnz {
69+ let chunk_end = ( offset + chunk_size) . min ( nnz) ;
70+ let range = [ SelectInfoElem :: from ( offset..chunk_end) ] ;
71+ let data_array = data_ds. read_array_slice :: < T , _ , Ix1 > ( & range) ?;
72+ let ( data_vec, data_offset) = data_array. into_raw_vec_and_offset ( ) ;
73+ if data_offset. is_none ( ) {
74+ data. extend ( data_vec) ;
75+ } else {
76+ data. extend ( data_vec) ;
77+ }
78+
79+ let indices_chunk = read_array_slice_as_usize :: < B > ( indices_ds, & range) ?;
80+ indices. extend ( indices_chunk) ;
81+
82+ offset = chunk_end;
83+
84+ if show_progress && offset >= next_progress {
85+ println ! ( "Loading CSR matrix: {}%" , offset * 100 / nnz) ;
86+ next_progress += progress_interval;
87+ }
88+ }
89+
90+ if show_progress {
91+ println ! ( "Constructing CSR matrix structure..." ) ;
92+ }
93+
94+ let pattern = unsafe { SparsityPattern :: from_offset_and_indices_unchecked ( nrows, ncols, indptr, indices) } ;
95+ CsrMatrix :: try_from_pattern_and_values ( pattern, data)
96+ . map ( ArrayData :: from)
97+ . map_err ( |e| anyhow:: anyhow!( "Failed to construct CSR matrix: {}" , e) )
98+ }
0 commit comments