Skip to content

Commit e5161b1

Browse files
committed
Fixed some changes inlcuding putting FMS calculations out of figure files.
1 parent 758439a commit e5161b1

File tree

9 files changed

+196
-316
lines changed

9 files changed

+196
-316
lines changed

pf2rnaseq/figures/commonFuncs/plotGeneral.py

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
import anndata
12
import numpy as np
23
import pandas as pd
3-
import seaborn as sns
44
import scanpy as sc
5-
from scipy.stats import ranksums
6-
import anndata
7-
from matplotlib.axes import Axes
8-
from ...factorization import pf2_pca_r2x
9-
import matplotlib.pyplot as plt
105
import scipy.sparse
6+
import seaborn as sns
7+
from matplotlib.axes import Axes
8+
from tensorly.cp_tensor import CPTensor
9+
from tlviz.factor_tools import factor_match_score as fms
10+
11+
from ...factorization import pf2, pf2_pca_r2x
1112

1213

1314
def plot_r2x(data, rank_vec, ax: Axes):
1415
"""Creates R2X plot for parafac2 tensor decomposition and pca"""
16+
1517
r2xError = pf2_pca_r2x(data, rank_vec)
1618
labelNames = ["Fit: Pf2", "Fit: PCA"]
1719
colorDecomp = ["r", "b"]
@@ -440,3 +442,119 @@ def plot_boxplot_gene_celltype(
440442
ax.set(title=gene)
441443
ax.set_xticks(ax.get_xticks())
442444
ax.set_xticklabels(labels=ax.get_xticklabels(), rotation=45)
445+
446+
447+
def calculateFMS(A: anndata.AnnData, B: anndata.AnnData):
448+
"""Calculates FMS between 2 factors"""
449+
factors = [A.uns["Pf2_A"], A.uns["Pf2_B"], A.varm["Pf2_C"]]
450+
A_CP = CPTensor(
451+
(
452+
A.uns["Pf2_weights"],
453+
factors,
454+
)
455+
)
456+
457+
factors = [B.uns["Pf2_A"], B.uns["Pf2_B"], B.varm["Pf2_C"]]
458+
B_CP = CPTensor(
459+
(
460+
B.uns["Pf2_weights"],
461+
factors,
462+
)
463+
)
464+
465+
return fms(A_CP, B_CP, consider_weights=False, skip_mode=1) # type: ignore
466+
467+
468+
def plot_fms_percent_drop(
469+
X: anndata.AnnData,
470+
ax: Axes,
471+
percentList: np.ndarray,
472+
runs: int,
473+
rank: int = 30,
474+
):
475+
# Plots FMS score when percentage is removed from data
476+
dataX = pf2(X, rank, doEmbedding=False)
477+
478+
fmsLists = []
479+
480+
for j in range(0, runs, 1):
481+
scores = [1.0]
482+
483+
for i in percentList[1:]:
484+
sampled_data: anndata.AnnData = sc.pp.subsample(
485+
X, fraction=1 - (i / 100), random_state=j, copy=True
486+
) # type: ignore
487+
sampledX = pf2(sampled_data, rank, random_state=j + 2, doEmbedding=False)
488+
489+
fmsScore = calculateFMS(dataX, sampledX)
490+
scores.append(fmsScore)
491+
492+
fmsLists.append(scores)
493+
494+
runsList_df = []
495+
for i in range(0, runs):
496+
for j in range(0, len(percentList)):
497+
runsList_df.append(i)
498+
percentList_df = []
499+
for i in range(0, runs):
500+
for j in range(0, len(percentList)):
501+
percentList_df.append(percentList[j])
502+
fmsList_df = []
503+
for sublist in fmsLists:
504+
fmsList_df += sublist
505+
df = pd.DataFrame(
506+
{
507+
"Run": runsList_df,
508+
"Percentage of Data Dropped": percentList_df,
509+
"FMS": fmsList_df,
510+
}
511+
)
512+
513+
sns.lineplot(data=df, x="Percentage of Data Dropped", y="FMS", ax=ax)
514+
ax.set_ylim(0, 1)
515+
516+
517+
def resample(data: anndata.AnnData) -> anndata.AnnData:
518+
"""Bootstrapping dataset"""
519+
indices = np.random.randint(0, data.shape[0], size=(data.shape[0],))
520+
data = data[indices].copy()
521+
return data
522+
523+
524+
def plot_fms_diff_ranks(
525+
X: anndata.AnnData,
526+
ax: Axes,
527+
ranksList: list[int],
528+
runs: int,
529+
):
530+
# Plots FMS when using different Pf2 components
531+
fmsLists = []
532+
533+
for j in range(0, runs, 1):
534+
scores = []
535+
for i in ranksList:
536+
dataX = pf2(X, rank=i, random_state=j, doEmbedding=False)
537+
538+
sampledX = pf2(resample(X), rank=i, random_state=j, doEmbedding=False)
539+
540+
fmsScore = calculateFMS(dataX, sampledX)
541+
scores.append(fmsScore)
542+
fmsLists.append(scores)
543+
544+
runsList_df = []
545+
for i in range(0, runs):
546+
for j in range(0, len(ranksList)):
547+
runsList_df.append(i)
548+
ranksList_df = []
549+
for i in range(0, runs):
550+
for j in range(0, len(ranksList)):
551+
ranksList_df.append(ranksList[j])
552+
fmsList_df = []
553+
for sublist in fmsLists:
554+
fmsList_df += sublist
555+
df = pd.DataFrame(
556+
{"Run": runsList_df, "Component": ranksList_df, "FMS": fmsList_df}
557+
)
558+
559+
sns.lineplot(data=df, x="Component", y="FMS", ax=ax)
560+
ax.set_ylim(0, 1)

pf2rnaseq/figures/figureHeiserFMS.py

Lines changed: 3 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,11 @@
22
factorization score
33
"""
44

5-
import anndata
65
import numpy as np
7-
import pandas as pd
8-
import scanpy as sc
9-
import seaborn as sns
10-
from matplotlib.axes import Axes
11-
from tensorly.cp_tensor import CPTensor
12-
from tlviz.factor_tools import factor_match_score as fms
136

14-
from ..factorization import pf2
157
from ..imports import import_Heiser
168
from .common import getSetup, subplotLabel
9+
from .commonFuncs.plotGeneral import plot_fms_diff_ranks, plot_fms_percent_drop
1710

1811

1912
def makeFigure():
@@ -22,125 +15,9 @@ def makeFigure():
2215

2316
X = import_Heiser()
2417
percentList = np.arange(0.0, 55.0, 5.0)
25-
# plot_fms_percent_drop(X, ax[0], percentList=percentList, runs=2)
18+
plot_fms_percent_drop(X, ax[0], percentList=percentList, runs=2)
2619

27-
ranks = list(range(30, 51))
20+
ranks = np.arange(10, 101, 10)
2821
plot_fms_diff_ranks(X, ax[1], ranksList=ranks, runs=2)
2922

3023
return f
31-
32-
33-
def calculateFMS(A: anndata.AnnData, B: anndata.AnnData):
34-
"""Calculates FMS between 2 factors"""
35-
factors = [A.uns["Pf2_A"], A.uns["Pf2_B"], A.varm["Pf2_C"]]
36-
A_CP = CPTensor(
37-
(
38-
A.uns["Pf2_weights"],
39-
factors,
40-
)
41-
)
42-
43-
factors = [B.uns["Pf2_A"], B.uns["Pf2_B"], B.varm["Pf2_C"]]
44-
B_CP = CPTensor(
45-
(
46-
B.uns["Pf2_weights"],
47-
factors,
48-
)
49-
)
50-
51-
return fms(A_CP, B_CP, consider_weights=False, skip_mode=1) # type: ignore
52-
53-
54-
def plot_fms_percent_drop(
55-
X: anndata.AnnData,
56-
ax: Axes,
57-
percentList: np.ndarray,
58-
runs: int,
59-
rank: int = 30,
60-
):
61-
# Plots FMS score when percentage is removed from data
62-
dataX = pf2(X, rank, doEmbedding=False)
63-
64-
fmsLists = []
65-
66-
for j in range(0, runs, 1):
67-
scores = [1.0]
68-
69-
for i in percentList[1:]:
70-
sampled_data: anndata.AnnData = sc.pp.subsample(
71-
X, fraction=1 - (i / 100), random_state=j, copy=True
72-
) # type: ignore
73-
sampledX = pf2(sampled_data, rank, random_state=j + 2, doEmbedding=False)
74-
75-
fmsScore = calculateFMS(dataX, sampledX)
76-
scores.append(fmsScore)
77-
78-
fmsLists.append(scores)
79-
80-
runsList_df = []
81-
for i in range(0, runs):
82-
for j in range(0, len(percentList)):
83-
runsList_df.append(i)
84-
percentList_df = []
85-
for i in range(0, runs):
86-
for j in range(0, len(percentList)):
87-
percentList_df.append(percentList[j])
88-
fmsList_df = []
89-
for sublist in fmsLists:
90-
fmsList_df += sublist
91-
df = pd.DataFrame(
92-
{
93-
"Run": runsList_df,
94-
"Percentage of Data Dropped": percentList_df,
95-
"FMS": fmsList_df,
96-
}
97-
)
98-
99-
sns.lineplot(data=df, x="Percentage of Data Dropped", y="FMS", ax=ax)
100-
ax.set_ylim(0, 1)
101-
102-
103-
def resample(data: anndata.AnnData) -> anndata.AnnData:
104-
"""Bootstrapping dataset"""
105-
indices = np.random.randint(0, data.shape[0], size=(data.shape[0],))
106-
data = data[indices].copy()
107-
return data
108-
109-
110-
def plot_fms_diff_ranks(
111-
X: anndata.AnnData,
112-
ax: Axes,
113-
ranksList: list[int],
114-
runs: int,
115-
):
116-
# Plots FMS when using different Pf2 components
117-
fmsLists = []
118-
119-
for j in range(0, runs, 1):
120-
scores = []
121-
for i in ranksList:
122-
dataX = pf2(X, rank=i, random_state=j, doEmbedding=False)
123-
124-
sampledX = pf2(resample(X), rank=i, random_state=j, doEmbedding=False)
125-
126-
fmsScore = calculateFMS(dataX, sampledX)
127-
scores.append(fmsScore)
128-
fmsLists.append(scores)
129-
130-
runsList_df = []
131-
for i in range(0, runs):
132-
for j in range(0, len(ranksList)):
133-
runsList_df.append(i)
134-
ranksList_df = []
135-
for i in range(0, runs):
136-
for j in range(0, len(ranksList)):
137-
ranksList_df.append(ranksList[j])
138-
fmsList_df = []
139-
for sublist in fmsLists:
140-
fmsList_df += sublist
141-
df = pd.DataFrame(
142-
{"Run": runsList_df, "Component": ranksList_df, "FMS": fmsList_df}
143-
)
144-
145-
sns.lineplot(data=df, x="Component", y="FMS", ax=ax)
146-
ax.set_ylim(0, 1)

0 commit comments

Comments
 (0)