@@ -93,10 +93,15 @@ def stratified_subsample_adata(adata, group_column, target_fraction=0.20, random
9393 return adata_subsampled
9494
9595
96- def normalize_adata (adata , target_sum = 10_000 , n_counts_key = 'n_counts' , copy = False ):
96+ from scipy import sparse
97+ from scipy .sparse import issparse , csr_matrix , hstack
98+
99+
100+ def normalize_adata (adata , target_sum = 10_000 , n_counts_key = 'n_counts' ,
101+ chunk_size = None , copy = False ):
97102 """
98- Preprocesses an AnnData object by normalizing the data to a target sum .
99- Original adata object is updated in place .
103+ Memory-efficient normalization of AnnData object.
104+ Works directly on sparse matrices without converting to dense .
100105
101106 Parameters
102107 ----------
@@ -109,6 +114,11 @@ def normalize_adata(adata, target_sum=10_000, n_counts_key='n_counts', copy=Fals
109114 n_counts_key : str, optional (default: 'n_counts')
110115 The key in adata.obs containing the total counts for each cell.
111116
117+ chunk_size : int or None, optional (default: None)
118+ If None, process entire matrix at once (faster, more memory).
119+ If int, process matrix in chunks of this size (slower, less memory).
120+ Recommended for very large datasets (>1M cells).
121+
112122 copy : bool, optional (default: False)
113123 If True, returns a copy of adata with the normalized data.
114124 """
@@ -118,37 +128,65 @@ def normalize_adata(adata, target_sum=10_000, n_counts_key='n_counts', copy=Fals
118128 # Check if total counts are already calculated
119129 if n_counts_key not in adata .obs .columns :
120130 warnings .warn (f"{ n_counts_key } not found in adata.obs. Calculating total counts." , UserWarning )
121- n_counts_key = 'total_counts' # scanpy uses 'total_counts' as the key
122- # Calculate total counts from the raw expression matrix
123- adata .obs [n_counts_key ] = adata .X .sum (axis = 1 )
124-
125- # Input data
126- X_view = adata .X
131+ n_counts_key = 'total_counts'
127132
128- warnings .warn ("Normalizing data." , UserWarning )
133+ if sparse .issparse (adata .X ):
134+ if chunk_size is not None :
135+ # Chunked calculation for very large matrices
136+ n_cells = adata .X .shape [0 ]
137+ counts = np .zeros (n_cells )
129138
130- # Check if matrix is sparse
131- is_sparse = sparse .issparse (X_view )
139+ for start in range (0 , n_cells , chunk_size ):
140+ end = min (start + chunk_size , n_cells )
141+ counts [start :end ] = np .array (adata .X [start :end ].sum (axis = 1 )).flatten ()
132142
133- # Convert to dense if sparse
134- if is_sparse :
135- X_view = X_view .toarray ()
143+ adata .obs [n_counts_key ] = counts
144+ else :
145+ # Standard calculation
146+ adata .obs [n_counts_key ] = np .array (adata .X .sum (axis = 1 )).flatten ()
147+ else :
148+ # Dense matrix
149+ adata .obs [n_counts_key ] = np .array (adata .X .sum (axis = 1 )).flatten ()
136150
137- # Normalize
138- n_counts = adata .obs [n_counts_key ].values [:, None ]
139- X_norm = X_view / n_counts * target_sum
151+ warnings .warn ("Normalizing data." , UserWarning )
140152
141- # Convert back to sparse if original was sparse
142- if is_sparse :
143- X_norm = sparse .csr_matrix (X_norm )
153+ # Get counts and calculate scaling factors
154+ n_counts = adata .obs [n_counts_key ].values
155+ scaling_factors = target_sum / n_counts
156+
157+ # Perform normalization
158+ if sparse .issparse (adata .X ):
159+ if chunk_size is not None :
160+ # Chunked processing for very large sparse matrices
161+ n_cells = adata .X .shape [0 ]
162+ normalized_chunks = []
163+
164+ for start in range (0 , n_cells , chunk_size ):
165+ end = min (start + chunk_size , n_cells )
166+ chunk_scaling = sparse .diags (scaling_factors [start :end ], 0 , format = 'csr' )
167+ normalized_chunk = chunk_scaling @ adata .X [start :end ]
168+ normalized_chunks .append (normalized_chunk )
169+
170+ # Combine chunks
171+ adata .X = sparse .vstack (normalized_chunks , format = 'csr' )
172+ else :
173+ # Standard sparse matrix normalization (most efficient)
174+ scaling_matrix = sparse .diags (scaling_factors , 0 , format = 'csr' )
175+ adata .X = scaling_matrix @ adata .X
176+ else :
177+ # Dense matrix normalization
178+ adata .X = adata .X / n_counts [:, None ] * target_sum
144179
145- # Update adata
146- adata .X = X_norm
180+ # Update metadata
147181 adata .uns ['normalization' ] = {
148182 'method' : 'total_counts' ,
149183 'target_sum' : target_sum ,
150- 'n_counts_key' : n_counts_key
184+ 'n_counts_key' : n_counts_key ,
185+ 'chunked' : chunk_size is not None
151186 }
187+ if chunk_size is not None :
188+ adata .uns ['normalization' ]['chunk_size' ] = chunk_size
189+
152190 if copy :
153191 return adata
154192
0 commit comments