11import abc
2+ import math
23import typing
34
5+ import numpy as np
46import pandas as pd
57
68from gpsea .model import Patient
9+ from gpsea .config import PALETTE_DATA , PALETTE_SPECIAL
710from ..clf import GenotypeClassifier
811from .stats import PhenotypeScoreStatistic
912
@@ -128,6 +131,16 @@ def __init__(
128131 super ().__init__ (gt_clf , phenotype , statistic , data , statistic_result )
129132 assert isinstance (phenotype , PhenotypeScorer )
130133
134+ # Check that the provided genotype predicate defines the same categories
135+ # as those found in `data.`
136+ actual = set (
137+ int (val )
138+ for val in data [MonoPhenotypeAnalysisResult .GT_COL ].unique ()
139+ if val is not None and not math .isnan (val )
140+ )
141+ expected = set (c .cat_id for c in self ._gt_clf .get_categories ())
142+ assert actual == expected , "Mismatch in the genotype classes"
143+
131144 def phenotype_scorer (self ) -> PhenotypeScorer :
132145 """
133146 Get the scorer that computed the phenotype score.
@@ -137,31 +150,21 @@ def phenotype_scorer(self) -> PhenotypeScorer:
137150 # being a subclass of `Partitioning`.
138151 return self ._phenotype # type: ignore
139152
140- def plot_boxplots (
153+ def _make_data_df (
141154 self ,
142- ax ,
143- colors = ("darksalmon" , "honeydew" ),
144- median_color : str = "black" ,
145- ):
146- """
147- Draw box plot with distributions of phenotype scores for the genotype groups.
148-
149- :param gt_predicate: the genotype predicate used to produce the genotype groups.
150- :param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
151- :param colors: a sequence with colors to use for coloring the box patches of the box plot.
152- :param median_color: a `str` with the color for the boxplot median line.
153- """
155+ ) -> pd .DataFrame :
154156 # skip the patients with unassigned genotype group
155- bla = self ._data .notna ()
156- not_na_gts = bla .all (axis = "columns" )
157- data = self ._data .loc [not_na_gts ]
158-
159- # Check that the provided genotype predicate defines the same categories
160- # as those found in `data.`
161- actual = set (data [MonoPhenotypeAnalysisResult .GT_COL ].unique ())
162- expected = set (c .cat_id for c in self ._gt_clf .get_categories ())
163- assert actual == expected , "Mismatch in the genotype classes"
157+ not_na = self ._data .notna ()
158+ not_na_gts = not_na .all (axis = "columns" )
159+ return self ._data .loc [not_na_gts ]
164160
161+ def _make_x_and_tick_labels (
162+ self ,
163+ data : pd .DataFrame ,
164+ ) -> typing .Tuple [
165+ typing .Sequence [typing .Sequence [float ]],
166+ typing .Sequence [str ],
167+ ]:
165168 x = [
166169 data .loc [
167170 data [MonoPhenotypeAnalysisResult .GT_COL ] == c .category .cat_id ,
@@ -171,19 +174,116 @@ def plot_boxplots(
171174 ]
172175
173176 gt_cat_names = [c .category .name for c in self ._gt_clf .get_categorizations ()]
177+
178+ return x , gt_cat_names
179+
180+ def plot_boxplots (
181+ self ,
182+ ax ,
183+ colors : typing .Sequence [str ] = PALETTE_DATA ,
184+ median_color : str = PALETTE_SPECIAL ,
185+ ** boxplot_kwargs ,
186+ ):
187+ """
188+ Draw box plot with distributions of phenotype scores for the genotype groups.
189+
190+ :param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
191+ :param colors: a sequence with color palette for the box plot patches.
192+ :param median_color: a `str` with the color for the boxplot median line.
193+ :param boxplot_kwargs: arguments to pass into :func:`matplotlib.axes.Axes.boxplot` function.
194+ """
195+ data = self ._make_data_df ()
196+
197+ x , gt_cat_names = self ._make_x_and_tick_labels (data )
198+ patch_artist = boxplot_kwargs .pop ("patch_artist" , True )
199+ tick_labels = boxplot_kwargs .pop ("tick_labels" , gt_cat_names )
200+
174201 bplot = ax .boxplot (
175202 x = x ,
176- patch_artist = True ,
177- tick_labels = gt_cat_names ,
203+ patch_artist = patch_artist ,
204+ tick_labels = tick_labels ,
205+ ** boxplot_kwargs ,
178206 )
179207
180208 # Set face colors of the boxes
181- for patch , color in zip (bplot ["boxes" ], colors ):
182- patch .set_facecolor (color )
209+ col_idxs = self ._choose_palette_idxs (
210+ n_categories = self ._gt_clf .n_categorizations (), n_colors = len (colors )
211+ )
212+ for patch , col_idx in zip (bplot ["boxes" ], col_idxs ):
213+ patch .set_facecolor (colors [col_idx ])
183214
184- for median in bplot [' medians' ]:
215+ for median in bplot [" medians" ]:
185216 median .set_color (median_color )
186217
218+ def plot_violins (
219+ self ,
220+ ax ,
221+ colors : typing .Sequence [str ] = PALETTE_DATA ,
222+ ** violinplot_kwargs ,
223+ ):
224+ """
225+ Draw a violin plot with distributions of phenotype scores for the genotype groups.
226+
227+ :param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
228+ :param colors: a sequence with color palette for the violin patches.
229+ :param violinplot_kwargs: arguments to pass into :func:`matplotlib.axes.Axes.violinplot` function.
230+ """
231+ data = self ._make_data_df ()
232+
233+ x , gt_cat_names = self ._make_x_and_tick_labels (data )
234+
235+ showmeans = violinplot_kwargs .pop ("showmeans" , False )
236+ showextrema = violinplot_kwargs .pop ("showextrema" , False )
237+
238+ parts = ax .violinplot (
239+ dataset = x ,
240+ showmeans = showmeans ,
241+ showextrema = showextrema ,
242+ ** violinplot_kwargs ,
243+ )
244+
245+ # quartile1, medians, quartile3 = np.percentile(x, [25, 50, 75], axis=1)
246+ quartile1 = [np .percentile (v , 25 ) for v in x ]
247+ medians = [np .median (v ) for v in x ]
248+ quartile3 = [np .percentile (v , 75 ) for v in x ]
249+ x = [sorted (val ) for val in x ]
250+ whiskers = np .array (
251+ [
252+ PhenotypeScoreAnalysisResult ._adjacent_values (sorted_array , q1 , q3 )
253+ for sorted_array , q1 , q3 in zip (x , quartile1 , quartile3 )
254+ ]
255+ )
256+ whiskers_min , whiskers_max = whiskers [:, 0 ], whiskers [:, 1 ]
257+
258+ inds = np .arange (1 , len (medians ) + 1 )
259+ ax .scatter (inds , medians , marker = "o" , color = "white" , s = 30 , zorder = 3 )
260+ ax .vlines (inds , quartile1 , quartile3 , color = "k" , linestyle = "-" , lw = 5 )
261+ ax .vlines (inds , whiskers_min , whiskers_max , color = "k" , linestyle = "-" , lw = 1 )
262+
263+ ax .xaxis .set (
264+ ticks = np .arange (1 , len (gt_cat_names ) + 1 ),
265+ ticklabels = gt_cat_names ,
266+ )
267+
268+ col_idxs = self ._choose_palette_idxs (
269+ n_categories = self ._gt_clf .n_categorizations (), n_colors = len (colors )
270+ )
271+ for pc , color_idx in zip (parts ["bodies" ], col_idxs ):
272+ pc .set (
273+ facecolor = colors [color_idx ],
274+ edgecolor = None ,
275+ alpha = 1 ,
276+ )
277+
278+ @staticmethod
279+ def _adjacent_values (vals , q1 , q3 ):
280+ upper_adjacent_value = q3 + (q3 - q1 ) * 1.5
281+ upper_adjacent_value = np .clip (upper_adjacent_value , q3 , vals [- 1 ])
282+
283+ lower_adjacent_value = q1 - (q3 - q1 ) * 1.5
284+ lower_adjacent_value = np .clip (lower_adjacent_value , vals [0 ], q1 )
285+ return lower_adjacent_value , upper_adjacent_value
286+
187287 def __eq__ (self , value : object ) -> bool :
188288 return isinstance (value , PhenotypeScoreAnalysisResult ) and super (
189289 MonoPhenotypeAnalysisResult , self
@@ -254,7 +354,9 @@ def compare_genotype_vs_phenotype_score(
254354 for individual in cohort :
255355 gt_cat = gt_clf .test (individual )
256356 if gt_cat is None :
257- data .loc [individual .patient_id , MonoPhenotypeAnalysisResult .GT_COL ] = None
357+ data .loc [individual .patient_id , MonoPhenotypeAnalysisResult .GT_COL ] = (
358+ None
359+ )
258360 else :
259361 data .loc [individual .patient_id , MonoPhenotypeAnalysisResult .GT_COL ] = (
260362 gt_cat .category .cat_id
0 commit comments