Skip to content

Commit c2d576e

Browse files
committed
Updating FMS-moving out of figure file and plotting functions
1 parent bdf1714 commit c2d576e

File tree

3 files changed

+145
-129
lines changed

3 files changed

+145
-129
lines changed

pf2rnaseq/factorization.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import anndata
22
import cupy
33
import numpy as np
4+
import pandas as pd
5+
import scanpy as sc
46
import scipy.sparse as sps
57
from pacmap import PaCMAP
68
from parafac2.parafac2 import parafac2_nd, store_pf2
79
from scipy.stats import gmean
810
from sklearn.decomposition import PCA
911
from sklearn.linear_model import LinearRegression
12+
from tensorly.cp_tensor import CPTensor
13+
from tlviz.factor_tools import factor_match_score as fms
1014
from tqdm import tqdm
1115

1216

@@ -37,7 +41,6 @@ def pf2(
3741
random_state=1,
3842
doEmbedding: bool = True,
3943
tolerance=1e-9,
40-
regParam=0.0,
4144
r2x=False,
4245
):
4346
cupy.cuda.Device(1).use()
@@ -47,7 +50,6 @@ def pf2(
4750
random_state=random_state,
4851
tol=tolerance,
4952
n_iter_max=500,
50-
l2=regParam,
5153
)
5254

5355
X = store_pf2(X, pf_out)
@@ -76,3 +78,115 @@ def pf2_pca_r2x(X: anndata.AnnData, ranks):
7678
r2x_pca = np.cumsum(pca.explained_variance_ratio_)
7779

7880
return r2x_pf2, r2x_pca[np.array(ranks) - 1]
81+
82+
83+
def calculateFMS(A: anndata.AnnData, B: anndata.AnnData):
84+
"""Calculates FMS between 2 factors"""
85+
factors = [A.uns["Pf2_A"], A.uns["Pf2_B"], A.varm["Pf2_C"]]
86+
A_CP = CPTensor(
87+
(
88+
A.uns["Pf2_weights"],
89+
factors,
90+
)
91+
)
92+
93+
factors = [B.uns["Pf2_A"], B.uns["Pf2_B"], B.varm["Pf2_C"]]
94+
B_CP = CPTensor(
95+
(
96+
B.uns["Pf2_weights"],
97+
factors,
98+
)
99+
)
100+
101+
return fms(A_CP, B_CP, consider_weights=False, skip_mode=1) # type: ignore
102+
103+
104+
def fms_percent_drop(
105+
X: anndata.AnnData,
106+
percentList: np.ndarray,
107+
runs: int,
108+
rank: int = 30,
109+
):
110+
# Plots FMS score when percentage is removed from data
111+
dataX = pf2(X, rank, doEmbedding=False)
112+
113+
fmsLists = []
114+
115+
for j in range(0, runs, 1):
116+
scores = [1.0]
117+
118+
for i in percentList[1:]:
119+
sampled_data: anndata.AnnData = sc.pp.subsample(
120+
X, fraction=1 - (i / 100), random_state=j, copy=True
121+
) # type: ignore
122+
sampledX = pf2(sampled_data, rank, random_state=j + 2, doEmbedding=False)
123+
124+
fmsScore = calculateFMS(dataX, sampledX)
125+
scores.append(fmsScore)
126+
127+
fmsLists.append(scores)
128+
129+
runsList_df = []
130+
for i in range(0, runs):
131+
for _j in range(0, len(percentList)):
132+
runsList_df.append(i)
133+
percentList_df = []
134+
for _i in range(0, runs):
135+
for j in range(0, len(percentList)):
136+
percentList_df.append(percentList[j])
137+
fmsList_df = []
138+
for sublist in fmsLists:
139+
fmsList_df += sublist
140+
df = pd.DataFrame(
141+
{
142+
"Run": runsList_df,
143+
"Percentage of Data Dropped": percentList_df,
144+
"FMS": fmsList_df,
145+
}
146+
)
147+
148+
return df
149+
150+
151+
def resample(data: anndata.AnnData) -> anndata.AnnData:
152+
"""Bootstrapping dataset"""
153+
indices = np.random.randint(0, data.shape[0], size=(data.shape[0],))
154+
data = data[indices].copy()
155+
return data
156+
157+
158+
def fms_diff_ranks(
159+
X: anndata.AnnData,
160+
ranksList: list[int],
161+
runs: int,
162+
):
163+
# Plots FMS when using different Pf2 components
164+
fmsLists = []
165+
166+
for j in range(0, runs, 1):
167+
scores = []
168+
for i in ranksList:
169+
dataX = pf2(X, rank=i, random_state=j, doEmbedding=False)
170+
171+
sampledX = pf2(resample(X), rank=i, random_state=j, doEmbedding=False)
172+
173+
fmsScore = calculateFMS(dataX, sampledX)
174+
scores.append(fmsScore)
175+
fmsLists.append(scores)
176+
177+
runsList_df = []
178+
for i in range(0, runs):
179+
for _j in range(0, len(ranksList)):
180+
runsList_df.append(i)
181+
ranksList_df = []
182+
for _i in range(0, runs):
183+
for j in range(0, len(ranksList)):
184+
ranksList_df.append(ranksList[j])
185+
fmsList_df = []
186+
for sublist in fmsLists:
187+
fmsList_df += sublist
188+
df = pd.DataFrame(
189+
{"Run": runsList_df, "Component": ranksList_df, "FMS": fmsList_df}
190+
)
191+
192+
return df

pf2rnaseq/figures/commonFuncs/plotGeneral.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import seaborn as sns
77
from matplotlib.axes import Axes
88

9-
from ...factorization import pf2_pca_r2x
9+
from ...factorization import fms_percent_drop, pf2_pca_r2x, fms_diff_ranks
1010

1111

1212
def plot_r2x(data, rank_vec, ax: Axes):
@@ -439,3 +439,24 @@ def plot_boxplot_gene_celltype(
439439
ax.set(title=gene)
440440
ax.set_xticks(ax.get_xticks())
441441
ax.set_xticklabels(labels=ax.get_xticklabels(), rotation=45)
442+
443+
444+
def plot_fms_diff_ranks(
445+
X: anndata.AnnData,
446+
ax: Axes,
447+
ranksList: list[int],
448+
runs=3,
449+
):
450+
"""Plots FMS when using different Pf2 components"""
451+
df = fms_diff_ranks(X, ranksList, runs)
452+
sns.lineplot(data=df, x="Component", y="FMS", ax=ax)
453+
ax.set_ylim(0, 1)
454+
455+
456+
def plot_fms_percent_drop(
457+
X: anndata.AnnData, ax: Axes, percentList: np.ndarray, runs=3, rank: int = 30
458+
):
459+
"""Plots FMS when dropping different percentages of data"""
460+
df = fms_percent_drop(X, percentList, runs, rank)
461+
sns.lineplot(data=df, x="Percentage of Data Dropped", y="FMS", ax=ax)
462+
ax.set_ylim(0, 1)
Lines changed: 7 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
"""
22
factorization score
3+
34
"""
45

5-
import anndata
66
import numpy as np
7-
import pandas as pd
8-
import scanpy as sc
9-
import seaborn as sns
10-
from matplotlib.axes import Axes
11-
from tensorly.cp_tensor import CPTensor
12-
from tlviz.factor_tools import factor_match_score as fms
137

14-
from ..factorization import pf2
158
from ..imports import import_Heiser
169
from .common import getSetup, subplotLabel
10+
from .commonFuncs.plotGeneral import (
11+
plot_fms_diff_ranks,
12+
plot_fms_percent_drop,
13+
)
1714

1815

1916
def makeFigure():
@@ -22,125 +19,9 @@ def makeFigure():
2219

2320
X = import_Heiser()
2421
percentList = np.arange(0.0, 55.0, 5.0)
25-
# plot_fms_percent_drop(X, ax[0], percentList=percentList, runs=2)
22+
plot_fms_percent_drop(X, ax[0], percentList=percentList, runs=2, rank=30)
2623

27-
ranks = list(range(30, 51))
24+
ranks = list(range(1, 31))
2825
plot_fms_diff_ranks(X, ax[1], ranksList=ranks, runs=2)
2926

3027
return f
31-
32-
33-
def calculateFMS(A: anndata.AnnData, B: anndata.AnnData):
34-
"""Calculates FMS between 2 factors"""
35-
factors = [A.uns["Pf2_A"], A.uns["Pf2_B"], A.varm["Pf2_C"]]
36-
A_CP = CPTensor(
37-
(
38-
A.uns["Pf2_weights"],
39-
factors,
40-
)
41-
)
42-
43-
factors = [B.uns["Pf2_A"], B.uns["Pf2_B"], B.varm["Pf2_C"]]
44-
B_CP = CPTensor(
45-
(
46-
B.uns["Pf2_weights"],
47-
factors,
48-
)
49-
)
50-
51-
return fms(A_CP, B_CP, consider_weights=False, skip_mode=1) # type: ignore
52-
53-
54-
def plot_fms_percent_drop(
55-
X: anndata.AnnData,
56-
ax: Axes,
57-
percentList: np.ndarray,
58-
runs: int,
59-
rank: int = 30,
60-
):
61-
# Plots FMS score when percentage is removed from data
62-
dataX = pf2(X, rank, doEmbedding=False)
63-
64-
fmsLists = []
65-
66-
for j in range(0, runs, 1):
67-
scores = [1.0]
68-
69-
for i in percentList[1:]:
70-
sampled_data: anndata.AnnData = sc.pp.subsample(
71-
X, fraction=1 - (i / 100), random_state=j, copy=True
72-
) # type: ignore
73-
sampledX = pf2(sampled_data, rank, random_state=j + 2, doEmbedding=False)
74-
75-
fmsScore = calculateFMS(dataX, sampledX)
76-
scores.append(fmsScore)
77-
78-
fmsLists.append(scores)
79-
80-
runsList_df = []
81-
for i in range(0, runs):
82-
for j in range(0, len(percentList)):
83-
runsList_df.append(i)
84-
percentList_df = []
85-
for i in range(0, runs):
86-
for j in range(0, len(percentList)):
87-
percentList_df.append(percentList[j])
88-
fmsList_df = []
89-
for sublist in fmsLists:
90-
fmsList_df += sublist
91-
df = pd.DataFrame(
92-
{
93-
"Run": runsList_df,
94-
"Percentage of Data Dropped": percentList_df,
95-
"FMS": fmsList_df,
96-
}
97-
)
98-
99-
sns.lineplot(data=df, x="Percentage of Data Dropped", y="FMS", ax=ax)
100-
ax.set_ylim(0, 1)
101-
102-
103-
def resample(data: anndata.AnnData) -> anndata.AnnData:
104-
"""Bootstrapping dataset"""
105-
indices = np.random.randint(0, data.shape[0], size=(data.shape[0],))
106-
data = data[indices].copy()
107-
return data
108-
109-
110-
def plot_fms_diff_ranks(
111-
X: anndata.AnnData,
112-
ax: Axes,
113-
ranksList: list[int],
114-
runs: int,
115-
):
116-
# Plots FMS when using different Pf2 components
117-
fmsLists = []
118-
119-
for j in range(0, runs, 1):
120-
scores = []
121-
for i in ranksList:
122-
dataX = pf2(X, rank=i, random_state=j, doEmbedding=False)
123-
124-
sampledX = pf2(resample(X), rank=i, random_state=j, doEmbedding=False)
125-
126-
fmsScore = calculateFMS(dataX, sampledX)
127-
scores.append(fmsScore)
128-
fmsLists.append(scores)
129-
130-
runsList_df = []
131-
for i in range(0, runs):
132-
for j in range(0, len(ranksList)):
133-
runsList_df.append(i)
134-
ranksList_df = []
135-
for i in range(0, runs):
136-
for j in range(0, len(ranksList)):
137-
ranksList_df.append(ranksList[j])
138-
fmsList_df = []
139-
for sublist in fmsLists:
140-
fmsList_df += sublist
141-
df = pd.DataFrame(
142-
{"Run": runsList_df, "Component": ranksList_df, "FMS": fmsList_df}
143-
)
144-
145-
sns.lineplot(data=df, x="Component", y="FMS", ax=ax)
146-
ax.set_ylim(0, 1)

0 commit comments

Comments
 (0)