22factorization score
33"""
44
5- import anndata
65import numpy as np
7- import pandas as pd
8- import scanpy as sc
9- import seaborn as sns
10- from matplotlib .axes import Axes
11- from tensorly .cp_tensor import CPTensor
12- from tlviz .factor_tools import factor_match_score as fms
136
14- from ..factorization import pf2
157from ..imports import import_Heiser
168from .common import getSetup , subplotLabel
9+ from .commonFuncs .plotGeneral import plot_fms_diff_ranks , plot_fms_percent_drop
1710
1811
1912def makeFigure ():
@@ -22,125 +15,9 @@ def makeFigure():
2215
2316 X = import_Heiser ()
2417 percentList = np .arange (0.0 , 55.0 , 5.0 )
25- # plot_fms_percent_drop(X, ax[0], percentList=percentList, runs=2)
18+ plot_fms_percent_drop (X , ax [0 ], percentList = percentList , runs = 2 )
2619
27- ranks = list ( range ( 30 , 51 ) )
20+ ranks = np . arange ( 10 , 101 , 10 )
2821 plot_fms_diff_ranks (X , ax [1 ], ranksList = ranks , runs = 2 )
2922
3023 return f
31-
32-
33- def calculateFMS (A : anndata .AnnData , B : anndata .AnnData ):
34- """Calculates FMS between 2 factors"""
35- factors = [A .uns ["Pf2_A" ], A .uns ["Pf2_B" ], A .varm ["Pf2_C" ]]
36- A_CP = CPTensor (
37- (
38- A .uns ["Pf2_weights" ],
39- factors ,
40- )
41- )
42-
43- factors = [B .uns ["Pf2_A" ], B .uns ["Pf2_B" ], B .varm ["Pf2_C" ]]
44- B_CP = CPTensor (
45- (
46- B .uns ["Pf2_weights" ],
47- factors ,
48- )
49- )
50-
51- return fms (A_CP , B_CP , consider_weights = False , skip_mode = 1 ) # type: ignore
52-
53-
54- def plot_fms_percent_drop (
55- X : anndata .AnnData ,
56- ax : Axes ,
57- percentList : np .ndarray ,
58- runs : int ,
59- rank : int = 30 ,
60- ):
61- # Plots FMS score when percentage is removed from data
62- dataX = pf2 (X , rank , doEmbedding = False )
63-
64- fmsLists = []
65-
66- for j in range (0 , runs , 1 ):
67- scores = [1.0 ]
68-
69- for i in percentList [1 :]:
70- sampled_data : anndata .AnnData = sc .pp .subsample (
71- X , fraction = 1 - (i / 100 ), random_state = j , copy = True
72- ) # type: ignore
73- sampledX = pf2 (sampled_data , rank , random_state = j + 2 , doEmbedding = False )
74-
75- fmsScore = calculateFMS (dataX , sampledX )
76- scores .append (fmsScore )
77-
78- fmsLists .append (scores )
79-
80- runsList_df = []
81- for i in range (0 , runs ):
82- for j in range (0 , len (percentList )):
83- runsList_df .append (i )
84- percentList_df = []
85- for i in range (0 , runs ):
86- for j in range (0 , len (percentList )):
87- percentList_df .append (percentList [j ])
88- fmsList_df = []
89- for sublist in fmsLists :
90- fmsList_df += sublist
91- df = pd .DataFrame (
92- {
93- "Run" : runsList_df ,
94- "Percentage of Data Dropped" : percentList_df ,
95- "FMS" : fmsList_df ,
96- }
97- )
98-
99- sns .lineplot (data = df , x = "Percentage of Data Dropped" , y = "FMS" , ax = ax )
100- ax .set_ylim (0 , 1 )
101-
102-
103- def resample (data : anndata .AnnData ) -> anndata .AnnData :
104- """Bootstrapping dataset"""
105- indices = np .random .randint (0 , data .shape [0 ], size = (data .shape [0 ],))
106- data = data [indices ].copy ()
107- return data
108-
109-
110- def plot_fms_diff_ranks (
111- X : anndata .AnnData ,
112- ax : Axes ,
113- ranksList : list [int ],
114- runs : int ,
115- ):
116- # Plots FMS when using different Pf2 components
117- fmsLists = []
118-
119- for j in range (0 , runs , 1 ):
120- scores = []
121- for i in ranksList :
122- dataX = pf2 (X , rank = i , random_state = j , doEmbedding = False )
123-
124- sampledX = pf2 (resample (X ), rank = i , random_state = j , doEmbedding = False )
125-
126- fmsScore = calculateFMS (dataX , sampledX )
127- scores .append (fmsScore )
128- fmsLists .append (scores )
129-
130- runsList_df = []
131- for i in range (0 , runs ):
132- for j in range (0 , len (ranksList )):
133- runsList_df .append (i )
134- ranksList_df = []
135- for i in range (0 , runs ):
136- for j in range (0 , len (ranksList )):
137- ranksList_df .append (ranksList [j ])
138- fmsList_df = []
139- for sublist in fmsLists :
140- fmsList_df += sublist
141- df = pd .DataFrame (
142- {"Run" : runsList_df , "Component" : ranksList_df , "FMS" : fmsList_df }
143- )
144-
145- sns .lineplot (data = df , x = "Component" , y = "FMS" , ax = ax )
146- ax .set_ylim (0 , 1 )
0 commit comments