11"""Decompositon plots like pca, umap, tsne, etc."""
22
3+ import itertools
34from typing import Optional
45
56import matplotlib
7+ import matplotlib .pyplot as plt
68import pandas as pd
79import sklearn .decomposition
810
@@ -17,3 +19,53 @@ def plot_explained_variance(
1719 exp_var .index .name = "PC"
1820 ax = exp_var .plot (ax = ax )
1921 return ax
22+
23+
24+ def pca_grid (
25+ PCs : pd .DataFrame ,
26+ meta_column : pd .Series ,
27+ n_components : int = 4 ,
28+ meta_col_name : Optional [str ] = None ,
29+ figsize = (6 , 8 ),
30+ ) -> plt .Figure :
31+ """Plot a grid of scatter plots for the first n_components of PCA, per default 4.
32+
33+ Parameters
34+ ----------
35+ PCs : pd.DataFrame
36+ DataFrame with the principal components as columns.
37+ meta_column : pd.Series
38+ Series with categorical data to color the scatter plots.
39+ n_components : int, optional
40+ Number of first n components to plot, by default 4
41+ meta_col_name : Optional[str], optional
42+ If another name than the default series name shoudl be used, by default None
43+
44+ Returns
45+ -------
46+ plt.Figure
47+ Matplotlib figure with the scatter plots.
48+ """
49+ if meta_col_name is None :
50+ meta_col_name = meta_column .name
51+ else :
52+ meta_column = meta_column .rename (meta_col_name )
53+ up_to = min (PCs .shape [- 1 ], n_components )
54+ fig , axes = plt .subplots (up_to - 1 , 2 , figsize = figsize , layout = "constrained" )
55+ PCs = PCs .join (
56+ meta_column .astype ("category" )
57+ ) # ! maybe add a check that it's not continous
58+ for k , (pos , ax ) in enumerate (
59+ zip (itertools .combinations (range (up_to ), 2 ), axes .flatten ())
60+ ):
61+ i , j = pos
62+ plot_heatmap = bool (k % 2 )
63+ PCs .plot .scatter (
64+ i ,
65+ j ,
66+ c = meta_col_name ,
67+ cmap = "Paired" ,
68+ ax = ax ,
69+ colorbar = plot_heatmap ,
70+ )
71+ return fig
0 commit comments