Skip to content

Commit d2bf5f1

Browse files
authored
Merge pull request #17 from earmingol/dev
Implemented memory efficient normalization
2 parents 8dd0b48 + 83fb4c9 commit d2bf5f1

File tree

3 files changed

+101
-25
lines changed

3 files changed

+101
-25
lines changed

sccellfie/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
from .reaction_activity import (compute_reaction_activity)
1717
from .sccellfie_pipeline import (run_sccellfie_pipeline)
1818

19-
__version__ = "0.4.5"
19+
__version__ = "0.4.6"

sccellfie/preprocessing/adata_utils.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,15 @@ def stratified_subsample_adata(adata, group_column, target_fraction=0.20, random
9393
return adata_subsampled
9494

9595

96-
def normalize_adata(adata, target_sum=10_000, n_counts_key='n_counts', copy=False):
96+
from scipy import sparse
97+
from scipy.sparse import issparse, csr_matrix, hstack
98+
99+
100+
def normalize_adata(adata, target_sum=10_000, n_counts_key='n_counts',
101+
chunk_size=None, copy=False):
97102
"""
98-
Preprocesses an AnnData object by normalizing the data to a target sum.
99-
Original adata object is updated in place.
103+
Memory-efficient normalization of AnnData object.
104+
Works directly on sparse matrices without converting to dense.
100105
101106
Parameters
102107
----------
@@ -109,6 +114,11 @@ def normalize_adata(adata, target_sum=10_000, n_counts_key='n_counts', copy=Fals
109114
n_counts_key : str, optional (default: 'n_counts')
110115
The key in adata.obs containing the total counts for each cell.
111116
117+
chunk_size : int or None, optional (default: None)
118+
If None, process entire matrix at once (faster, more memory).
119+
If int, process matrix in chunks of this size (slower, less memory).
120+
Recommended for very large datasets (>1M cells).
121+
112122
copy : bool, optional (default: False)
113123
If True, returns a copy of adata with the normalized data.
114124
"""
@@ -118,37 +128,65 @@ def normalize_adata(adata, target_sum=10_000, n_counts_key='n_counts', copy=Fals
118128
# Check if total counts are already calculated
119129
if n_counts_key not in adata.obs.columns:
120130
warnings.warn(f"{n_counts_key} not found in adata.obs. Calculating total counts.", UserWarning)
121-
n_counts_key = 'total_counts' # scanpy uses 'total_counts' as the key
122-
# Calculate total counts from the raw expression matrix
123-
adata.obs[n_counts_key] = adata.X.sum(axis=1)
124-
125-
# Input data
126-
X_view = adata.X
131+
n_counts_key = 'total_counts'
127132

128-
warnings.warn("Normalizing data.", UserWarning)
133+
if sparse.issparse(adata.X):
134+
if chunk_size is not None:
135+
# Chunked calculation for very large matrices
136+
n_cells = adata.X.shape[0]
137+
counts = np.zeros(n_cells)
129138

130-
# Check if matrix is sparse
131-
is_sparse = sparse.issparse(X_view)
139+
for start in range(0, n_cells, chunk_size):
140+
end = min(start + chunk_size, n_cells)
141+
counts[start:end] = np.array(adata.X[start:end].sum(axis=1)).flatten()
132142

133-
# Convert to dense if sparse
134-
if is_sparse:
135-
X_view = X_view.toarray()
143+
adata.obs[n_counts_key] = counts
144+
else:
145+
# Standard calculation
146+
adata.obs[n_counts_key] = np.array(adata.X.sum(axis=1)).flatten()
147+
else:
148+
# Dense matrix
149+
adata.obs[n_counts_key] = np.array(adata.X.sum(axis=1)).flatten()
136150

137-
# Normalize
138-
n_counts = adata.obs[n_counts_key].values[:, None]
139-
X_norm = X_view / n_counts * target_sum
151+
warnings.warn("Normalizing data.", UserWarning)
140152

141-
# Convert back to sparse if original was sparse
142-
if is_sparse:
143-
X_norm = sparse.csr_matrix(X_norm)
153+
# Get counts and calculate scaling factors
154+
n_counts = adata.obs[n_counts_key].values
155+
scaling_factors = target_sum / n_counts
156+
157+
# Perform normalization
158+
if sparse.issparse(adata.X):
159+
if chunk_size is not None:
160+
# Chunked processing for very large sparse matrices
161+
n_cells = adata.X.shape[0]
162+
normalized_chunks = []
163+
164+
for start in range(0, n_cells, chunk_size):
165+
end = min(start + chunk_size, n_cells)
166+
chunk_scaling = sparse.diags(scaling_factors[start:end], 0, format='csr')
167+
normalized_chunk = chunk_scaling @ adata.X[start:end]
168+
normalized_chunks.append(normalized_chunk)
169+
170+
# Combine chunks
171+
adata.X = sparse.vstack(normalized_chunks, format='csr')
172+
else:
173+
# Standard sparse matrix normalization (most efficient)
174+
scaling_matrix = sparse.diags(scaling_factors, 0, format='csr')
175+
adata.X = scaling_matrix @ adata.X
176+
else:
177+
# Dense matrix normalization
178+
adata.X = adata.X / n_counts[:, None] * target_sum
144179

145-
# Update adata
146-
adata.X = X_norm
180+
# Update metadata
147181
adata.uns['normalization'] = {
148182
'method': 'total_counts',
149183
'target_sum': target_sum,
150-
'n_counts_key': n_counts_key
184+
'n_counts_key': n_counts_key,
185+
'chunked': chunk_size is not None
151186
}
187+
if chunk_size is not None:
188+
adata.uns['normalization']['chunk_size'] = chunk_size
189+
152190
if copy:
153191
return adata
154192

sccellfie/preprocessing/tests/test_adata_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,44 @@ def test_normalize_adata_dense():
9898
normalize_adata(adata, target_sum=1000)
9999

100100

101+
def test_normalize_adata_chunked():
102+
"""Test chunked normalization when counts already exist"""
103+
# Create test data
104+
adata = create_controlled_adata()
105+
106+
# Pre-calculate counts
107+
adata.obs['n_counts'] = np.array([3, 9, 21, 21])
108+
109+
# Normalize without chunks
110+
adata_no_chunk = adata.copy()
111+
normalize_adata(adata_no_chunk, target_sum=1000, n_counts_key='n_counts', copy=False)
112+
113+
# Normalize with chunks
114+
adata_chunked = adata.copy()
115+
normalize_adata(adata_chunked, target_sum=1000, n_counts_key='n_counts',
116+
chunk_size=2, copy=False) # Small chunk size for 4 cells
117+
118+
# Results should be identical
119+
np.testing.assert_array_almost_equal(
120+
adata_no_chunk.X.toarray(),
121+
adata_chunked.X.toarray(),
122+
decimal=10
123+
)
124+
125+
# Check expected values
126+
expected_normalized_X = np.array([
127+
[333.33, 666.67, 0],
128+
[333.33, 444.44, 222.22],
129+
[238.10, 285.71, 476.19],
130+
[333.33, 380.95, 285.71]
131+
])
132+
np.testing.assert_array_almost_equal(
133+
adata_chunked.X.toarray(),
134+
expected_normalized_X,
135+
decimal=2
136+
)
137+
138+
101139
# Transform gene names tests
102140
# Mock data for testing
103141
MOCK_ENSEMBL2SYMBOL = {

0 commit comments

Comments
 (0)