Skip to content

Commit 35b15d8

Browse files
committed
Cleanup imports, remove reg param from Pf2, remove doublets when reading in mouse data
1 parent bdf1714 commit 35b15d8

File tree

2 files changed

+64
-70
lines changed

2 files changed

+64
-70
lines changed

pf2rnaseq/factorization.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,11 @@ def pf2(
3737
random_state=1,
3838
doEmbedding: bool = True,
3939
tolerance=1e-9,
40-
regParam=0.0,
4140
r2x=False,
4241
):
4342
cupy.cuda.Device(1).use()
4443
pf_out, R2X = parafac2_nd(
45-
X,
46-
rank=rank,
47-
random_state=random_state,
48-
tol=tolerance,
49-
n_iter_max=500,
50-
l2=regParam,
44+
X, rank=rank, random_state=random_state, tol=tolerance, n_iter_max=500
5145
)
5246

5347
X = store_pf2(X, pf_out)

pf2rnaseq/imports.py

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,53 @@
1-
import glob
21
from concurrent.futures import ProcessPoolExecutor
3-
from pathlib import Path
42

53
import anndata
64
import numpy as np
75
import pandas as pd
86
import scanpy as sc
9-
from scipy.sparse import csr_matrix, spmatrix
7+
from scipy.sparse import csr_array, spmatrix
8+
from sklearn.preprocessing import scale
109
from sklearn.utils.sparsefuncs import inplace_column_scale, mean_variance_axis
1110

1211

12+
def prepare_dataset_deviance(
13+
X: anndata.AnnData, condition_name, geneThreshold
14+
) -> anndata.AnnData:
15+
X.X = csr_array(X.X) # type: ignore
16+
assert np.amin(X.X.data) >= 0.0
17+
# Remove cells and genes with fewer than 10 reads
18+
X = X[X.X.sum(axis=1) > 10, X.X.sum(axis=0) > 10]
19+
readmean, _ = mean_variance_axis(X.X, axis=0) # type: ignore
20+
X = X[:, readmean > geneThreshold]
21+
# Copy so that the subsetting is preserved
22+
X._init_as_actual(X.copy())
23+
# deviance transform
24+
y_ij = X.X.toarray() # type: ignore
25+
# counts per cell
26+
n_i = y_ij.sum(axis=1)
27+
# MLE of gene expression
28+
pi_j = y_ij.sum(axis=0) / np.sum(n_i)
29+
30+
non_y_ij = n_i[:, None] - y_ij
31+
mu_ij = n_i[:, None] * pi_j[None, :]
32+
signs = np.sign(y_ij - mu_ij)
33+
34+
first_term = 2 * y_ij * np.log(np.maximum(y_ij, 1.0) / mu_ij)
35+
second_term = 2 * non_y_ij * np.log(non_y_ij / (n_i[:, None] - mu_ij))
36+
37+
X.X = signs * np.sqrt(np.maximum(first_term + second_term, 0.0))
38+
39+
X.X = scale(X.X)
40+
41+
_, sgIndex = np.unique(X.obs_vector(condition_name), return_inverse=True)
42+
X.obs["condition_unique_idxs"] = sgIndex
43+
X.obs["condition_unique_idxs"] = X.obs["condition_unique_idxs"].astype("category")
44+
# Pre-calculate gene means
45+
X.var["means"] = np.zeros(X.shape[1])
46+
47+
assert np.all(np.isfinite(X.X)) # type: ignore
48+
return X
49+
50+
1351
def prepare_dataset(
1452
X: anndata.AnnData, condition_name: str, geneThreshold: float
1553
) -> anndata.AnnData:
@@ -65,58 +103,6 @@ def import_citeseq() -> anndata.AnnData:
65103
return prepare_dataset(X, "Condition", geneThreshold=0.1)
66104

67105

68-
def import_HTAN() -> anndata.AnnData:
69-
"""Imports Vanderbilt's HTAN 10X data."""
70-
files = glob.glob("/opt/extra-storage/HTAN/*.mtx.gz")
71-
futures = []
72-
data = {}
73-
74-
with ProcessPoolExecutor(max_workers=10) as executor:
75-
for filename in files:
76-
future = executor.submit(
77-
sc.read_10x_mtx,
78-
"/opt/extra-storage/HTAN/",
79-
gex_only=False,
80-
make_unique=True,
81-
prefix=filename.split("/")[-1].split("matrix.")[0],
82-
)
83-
futures.append(future)
84-
85-
for i, k in enumerate(files):
86-
result = futures[i].result()
87-
data[k.split("/")[-1].split("_matrix.")[0]] = result
88-
89-
X = anndata.concat(data, merge="same", label="Condition")
90-
91-
return prepare_dataset(X, "Condition", geneThreshold=0.1)
92-
93-
94-
def import_CCLE() -> anndata.AnnData:
95-
"""Imports barcoded cell data."""
96-
# TODO: Still need to add gene names and barcodes.
97-
folder = "/opt/extra-storage/asm/Heiser-barcode/CCLE/"
98-
99-
adatas = {
100-
"HCT116_1": anndata.read_text(
101-
Path(folder + "HCT116_tracing_T1.count_mtx.tsv")
102-
).T,
103-
"HCT116_2": anndata.read_text(
104-
Path(folder + "HCT116_tracing_T2.count_mtx.tsv")
105-
).T,
106-
"MDA-MB-231_1": anndata.read_text(
107-
Path(folder + "MDA-MB-231_tracing_T1.count_mtx.tsv")
108-
).T,
109-
"MDA-MB-231_2": anndata.read_text(
110-
Path(folder + "MDA-MB-231_tracing_T2.count_mtx.tsv")
111-
).T,
112-
}
113-
114-
X = anndata.concat(adatas, label="sample")
115-
X.X = csr_matrix(X.X)
116-
117-
return prepare_dataset(X, "sample", geneThreshold=0.1)
118-
119-
120106
def import_cytokine() -> anndata.AnnData:
121107
"""Import Meyer Cytokine PBMC dataset.
122108
-- columns from observation data:
@@ -140,25 +126,39 @@ def import_pf2Cytokine30() -> anndata.AnnData:
140126
return X
141127

142128

143-
def import_Heiser() -> anndata.AnnData:
129+
def import_Heiser(deviance=False) -> anndata.AnnData:
144130
"""Import Heiser C3TAg dataset.
145131
anndata.X is the raw counts
146132
147133
"""
148134
data = anndata.read_h5ad("/home/nicoleb/C3TAg.h5ad")
135+
if deviance:
136+
# Apply deviance transformation
137+
data = prepare_dataset_deviance(data, "sample_id", geneThreshold=0.1)
138+
else:
139+
# Apply standard normalization and scaling
140+
data = prepare_dataset(data, "sample_id", geneThreshold=0.1)
149141

150-
return prepare_dataset(data, "sample_id", geneThreshold=0.01)
142+
return prepare_dataset(data, "sample_id", geneThreshold=0.1)
151143

152144

153145
def import_MouseImmune() -> anndata.AnnData:
154-
"""Import cytokine data including gene expression and hashtag information.
155-
Processes files with naming patterns like:
156-
- GSM6102842_cytokine-samples07-barcodes.tsv.gz
157-
- GSM6102885_cytokine-hashtags06-matrix.mtx.gz
158-
"""
146+
"""Import Mouse Immune Dictionary cytokine data.
147+
-- columns from observation data:
148+
{'biosample_id': cytokine and replicate info,
149+
'rep': replicate,
150+
'species': mouse species,
151+
'cytokine_family': cytokine family label,
152+
'cyt': cytokine mouse was treated with,
153+
'sex': sex of mouse,
154+
'celltype': cell type label,
155+
'organ__ontology_label': organ label,
156+
...}"""
159157
X = anndata.read_h5ad("/home/nicoleb/MouseCytok.h5ad")
158+
# Filter out doublets
159+
X = X[X.obs["celltype"] != "doublet", :]
160160

161-
return prepare_dataset(X, "cyt", geneThreshold=0.1) # 0.01
161+
return prepare_dataset(X, "biosample_id", geneThreshold=0.1) # 0.01
162162

163163

164164
def pseudobulk_lupus(X, cellType="Cell Type"):

0 commit comments

Comments
 (0)