Skip to content

Commit 45b5ee1

Browse files
authored
Merge pull request #404 from monarch-initiative/add-violinplot-preset
Update plot colors, allow plotting violin plot
2 parents ed99d7a + 7fc254a commit 45b5ee1

File tree

11 files changed

+369
-90
lines changed

11 files changed

+369
-90
lines changed

docs/user-guide/analyses/partitioning/genotype/allele_count.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ based on presence of zero or one *EGFR* mutation allele:
6161
... target=affects_egfr,
6262
... )
6363
>>> gt_clf.class_labels
64-
('0', '1')
64+
('0 alleles', '1 allele')
6565

6666
The ``allele_count`` needs two inputs.
6767
The ``counts`` takes a tuple of the target allele counts,
@@ -102,9 +102,9 @@ and we will compare the individuals with one allele with those with two alleles:
102102
... target=affects_lmna,
103103
... )
104104
>>> gt_clf.class_labels
105-
('1', '2')
105+
('1 allele', '2 alleles')
106106

107107

108108
The classifier assigns the individuals into one of two classes:
109109
those with one *LMNA* variant allele and those with two *LMNA* variant alleles.
110-
Any cohort member with other allele counts (e.g. `0` or `3`) is ignored.
110+
Any cohort member with other allele counts (e.g. `0 allele` or `3 alleles`) is ignored.

docs/user-guide/analyses/survival.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ We can plot Kaplan-Meier curves:
151151
... )
152152
>>> _ = ax.set(
153153
... xlabel=endpoint.name + " [years]",
154-
... ylabel="Empirical survival",
154+
... ylabel="Event-free proportion",
155155
... )
156156
>>> _ = ax.grid(axis="y")
157157

src/gpsea/analysis/_base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,26 @@ def statistic(self) -> Statistic:
154154
"""
155155
return self._statistic
156156

157+
@staticmethod
158+
def _choose_palette_idxs(
159+
n_categories: int,
160+
n_colors: int,
161+
) -> typing.Sequence[int]:
162+
"""
163+
Choose the color indices for coloring `n_categories` using a palette with `n_colors`.
164+
"""
165+
if n_colors < 2:
166+
raise ValueError(
167+
f"Expected a palette with at least 2 colors but got {n_colors}"
168+
)
169+
if n_colors < n_categories:
170+
raise ValueError(
171+
f"The predicate produces {n_categories} categories but the palette includes only {n_colors} colors!"
172+
)
173+
174+
a = np.linspace(start=1, stop=n_colors, num=n_categories, dtype=int)
175+
return tuple(a - 1)
176+
157177
def __eq__(self, value: object) -> bool:
158178
return (
159179
isinstance(value, AnalysisResult)
@@ -399,7 +419,7 @@ class MonoPhenotypeAnalysisResult(AnalysisResult, metaclass=abc.ABCMeta):
399419
"""
400420
Name of the data index.
401421
"""
402-
422+
403423
GT_COL = "genotype"
404424
"""
405425
Name of column for storing genotype data.

src/gpsea/analysis/clf/_gt_classifiers.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _build_ac_to_cat(
338338

339339
ac2cat = {}
340340
for i, partition in enumerate(partitions):
341-
name = " OR ".join(str(j) for j in partition)
341+
name = " OR ".join(_pluralize(count=j, base="allele") for j in partition)
342342
description = " OR ".join(labels[j] for j in partition)
343343
cat = Categorization(
344344
PatientCategory(cat_id=i, name=name, description=description),
@@ -348,6 +348,14 @@ def _build_ac_to_cat(
348348

349349
return ac2cat
350350

351+
def _pluralize(
352+
count: int,
353+
base: str,
354+
) -> str:
355+
if count == 1:
356+
return f"{count} {base}"
357+
else:
358+
return f"{count} {base}s"
351359

352360
def allele_count(
353361
counts: typing.Collection[typing.Union[int, typing.Collection[int]]],
@@ -372,20 +380,20 @@ def allele_count(
372380
>>> from gpsea.analysis.clf import allele_count
373381
>>> zero_vs_one = allele_count(counts=(0, 1))
374382
>>> zero_vs_one.summarize_classes()
375-
'Allele count: 0, 1'
383+
'Allele count: 0 alleles, 1 allele'
376384
377385
These counts will create three classes for individuals with zero, one or two alleles:
378386
379387
>>> zero_vs_one_vs_two = allele_count(counts=(0, 1, 2))
380388
>>> zero_vs_one_vs_two.summarize_classes()
381-
'Allele count: 0, 1, 2'
389+
'Allele count: 0 alleles, 1 allele, 2 alleles'
382390
383391
Last, the counts below will create two groups, one for the individuals with zero target variant type alleles,
384392
and one for the individuals with one or two alleles:
385393
386394
>>> zero_vs_one_vs_two = allele_count(counts=(0, {1, 2}))
387395
>>> zero_vs_one_vs_two.summarize_classes()
388-
'Allele count: 0, 1 OR 2'
396+
'Allele count: 0 alleles, 1 allele OR 2 alleles'
389397
390398
Note that we wrap the last two allele counts in a set.
391399

src/gpsea/analysis/clf/_test__gt_classifiers.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,32 +61,32 @@ def test_build_count_to_cat(
6161
(
6262
((0,), (1,), (2,)),
6363
{
64-
0: "0",
65-
1: "1",
66-
2: "2",
64+
0: "0 alleles",
65+
1: "1 allele",
66+
2: "2 alleles",
6767
},
6868
),
6969
(
7070
((0, 1), (2,)),
7171
{
72-
0: "0 OR 1",
73-
1: "0 OR 1",
74-
2: "2",
72+
0: "0 alleles OR 1 allele",
73+
1: "0 alleles OR 1 allele",
74+
2: "2 alleles",
7575
},
7676
),
7777
(
7878
((0,), (1, 2)),
7979
{
80-
0: "0",
81-
1: "1 OR 2",
82-
2: "1 OR 2",
80+
0: "0 alleles",
81+
1: "1 allele OR 2 alleles",
82+
2: "1 allele OR 2 alleles",
8383
},
8484
),
8585
(
8686
((1,), (2,)),
8787
{
88-
1: "1",
89-
2: "2",
88+
1: "1 allele",
89+
2: "2 alleles",
9090
},
9191
),
9292
],

src/gpsea/analysis/pscore/_api.py

Lines changed: 130 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import abc
2+
import math
23
import typing
34

5+
import numpy as np
46
import pandas as pd
57

68
from gpsea.model import Patient
9+
from gpsea.config import PALETTE_DATA, PALETTE_SPECIAL
710
from ..clf import GenotypeClassifier
811
from .stats import PhenotypeScoreStatistic
912

@@ -128,6 +131,16 @@ def __init__(
128131
super().__init__(gt_clf, phenotype, statistic, data, statistic_result)
129132
assert isinstance(phenotype, PhenotypeScorer)
130133

134+
# Check that the provided genotype predicate defines the same categories
135+
# as those found in `data.`
136+
actual = set(
137+
int(val)
138+
for val in data[MonoPhenotypeAnalysisResult.GT_COL].unique()
139+
if val is not None and not math.isnan(val)
140+
)
141+
expected = set(c.cat_id for c in self._gt_clf.get_categories())
142+
assert actual == expected, "Mismatch in the genotype classes"
143+
131144
def phenotype_scorer(self) -> PhenotypeScorer:
132145
"""
133146
Get the scorer that computed the phenotype score.
@@ -137,31 +150,21 @@ def phenotype_scorer(self) -> PhenotypeScorer:
137150
# being a subclass of `Partitioning`.
138151
return self._phenotype # type: ignore
139152

140-
def plot_boxplots(
153+
def _make_data_df(
141154
self,
142-
ax,
143-
colors=("darksalmon", "honeydew"),
144-
median_color: str = "black",
145-
):
146-
"""
147-
Draw box plot with distributions of phenotype scores for the genotype groups.
148-
149-
:param gt_predicate: the genotype predicate used to produce the genotype groups.
150-
:param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
151-
:param colors: a sequence with colors to use for coloring the box patches of the box plot.
152-
:param median_color: a `str` with the color for the boxplot median line.
153-
"""
155+
) -> pd.DataFrame:
154156
# skip the patients with unassigned genotype group
155-
bla = self._data.notna()
156-
not_na_gts = bla.all(axis="columns")
157-
data = self._data.loc[not_na_gts]
158-
159-
# Check that the provided genotype predicate defines the same categories
160-
# as those found in `data.`
161-
actual = set(data[MonoPhenotypeAnalysisResult.GT_COL].unique())
162-
expected = set(c.cat_id for c in self._gt_clf.get_categories())
163-
assert actual == expected, "Mismatch in the genotype classes"
157+
not_na = self._data.notna()
158+
not_na_gts = not_na.all(axis="columns")
159+
return self._data.loc[not_na_gts]
164160

161+
def _make_x_and_tick_labels(
162+
self,
163+
data: pd.DataFrame,
164+
) -> typing.Tuple[
165+
typing.Sequence[typing.Sequence[float]],
166+
typing.Sequence[str],
167+
]:
165168
x = [
166169
data.loc[
167170
data[MonoPhenotypeAnalysisResult.GT_COL] == c.category.cat_id,
@@ -171,19 +174,116 @@ def plot_boxplots(
171174
]
172175

173176
gt_cat_names = [c.category.name for c in self._gt_clf.get_categorizations()]
177+
178+
return x, gt_cat_names
179+
180+
def plot_boxplots(
181+
self,
182+
ax,
183+
colors: typing.Sequence[str] = PALETTE_DATA,
184+
median_color: str = PALETTE_SPECIAL,
185+
**boxplot_kwargs,
186+
):
187+
"""
188+
Draw box plot with distributions of phenotype scores for the genotype groups.
189+
190+
:param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
191+
:param colors: a sequence with color palette for the box plot patches.
192+
:param median_color: a `str` with the color for the boxplot median line.
193+
:param boxplot_kwargs: arguments to pass into :func:`matplotlib.axes.Axes.boxplot` function.
194+
"""
195+
data = self._make_data_df()
196+
197+
x, gt_cat_names = self._make_x_and_tick_labels(data)
198+
patch_artist = boxplot_kwargs.pop("patch_artist", True)
199+
tick_labels = boxplot_kwargs.pop("tick_labels", gt_cat_names)
200+
174201
bplot = ax.boxplot(
175202
x=x,
176-
patch_artist=True,
177-
tick_labels=gt_cat_names,
203+
patch_artist=patch_artist,
204+
tick_labels=tick_labels,
205+
**boxplot_kwargs,
178206
)
179207

180208
# Set face colors of the boxes
181-
for patch, color in zip(bplot["boxes"], colors):
182-
patch.set_facecolor(color)
209+
col_idxs = self._choose_palette_idxs(
210+
n_categories=self._gt_clf.n_categorizations(), n_colors=len(colors)
211+
)
212+
for patch, col_idx in zip(bplot["boxes"], col_idxs):
213+
patch.set_facecolor(colors[col_idx])
183214

184-
for median in bplot['medians']:
215+
for median in bplot["medians"]:
185216
median.set_color(median_color)
186217

218+
def plot_violins(
219+
self,
220+
ax,
221+
colors: typing.Sequence[str] = PALETTE_DATA,
222+
**violinplot_kwargs,
223+
):
224+
"""
225+
Draw a violin plot with distributions of phenotype scores for the genotype groups.
226+
227+
:param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
228+
:param colors: a sequence with color palette for the violin patches.
229+
:param violinplot_kwargs: arguments to pass into :func:`matplotlib.axes.Axes.violinplot` function.
230+
"""
231+
data = self._make_data_df()
232+
233+
x, gt_cat_names = self._make_x_and_tick_labels(data)
234+
235+
showmeans = violinplot_kwargs.pop("showmeans", False)
236+
showextrema = violinplot_kwargs.pop("showextrema", False)
237+
238+
parts = ax.violinplot(
239+
dataset=x,
240+
showmeans=showmeans,
241+
showextrema=showextrema,
242+
**violinplot_kwargs,
243+
)
244+
245+
# quartile1, medians, quartile3 = np.percentile(x, [25, 50, 75], axis=1)
246+
quartile1 = [np.percentile(v, 25) for v in x]
247+
medians = [np.median(v) for v in x]
248+
quartile3 = [np.percentile(v, 75) for v in x]
249+
x = [sorted(val) for val in x]
250+
whiskers = np.array(
251+
[
252+
PhenotypeScoreAnalysisResult._adjacent_values(sorted_array, q1, q3)
253+
for sorted_array, q1, q3 in zip(x, quartile1, quartile3)
254+
]
255+
)
256+
whiskers_min, whiskers_max = whiskers[:, 0], whiskers[:, 1]
257+
258+
inds = np.arange(1, len(medians) + 1)
259+
ax.scatter(inds, medians, marker="o", color="white", s=30, zorder=3)
260+
ax.vlines(inds, quartile1, quartile3, color="k", linestyle="-", lw=5)
261+
ax.vlines(inds, whiskers_min, whiskers_max, color="k", linestyle="-", lw=1)
262+
263+
ax.xaxis.set(
264+
ticks=np.arange(1, len(gt_cat_names) + 1),
265+
ticklabels=gt_cat_names,
266+
)
267+
268+
col_idxs = self._choose_palette_idxs(
269+
n_categories=self._gt_clf.n_categorizations(), n_colors=len(colors)
270+
)
271+
for pc, color_idx in zip(parts["bodies"], col_idxs):
272+
pc.set(
273+
facecolor=colors[color_idx],
274+
edgecolor=None,
275+
alpha=1,
276+
)
277+
278+
@staticmethod
279+
def _adjacent_values(vals, q1, q3):
280+
upper_adjacent_value = q3 + (q3 - q1) * 1.5
281+
upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])
282+
283+
lower_adjacent_value = q1 - (q3 - q1) * 1.5
284+
lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
285+
return lower_adjacent_value, upper_adjacent_value
286+
187287
def __eq__(self, value: object) -> bool:
188288
return isinstance(value, PhenotypeScoreAnalysisResult) and super(
189289
MonoPhenotypeAnalysisResult, self
@@ -254,7 +354,9 @@ def compare_genotype_vs_phenotype_score(
254354
for individual in cohort:
255355
gt_cat = gt_clf.test(individual)
256356
if gt_cat is None:
257-
data.loc[individual.patient_id, MonoPhenotypeAnalysisResult.GT_COL] = None
357+
data.loc[individual.patient_id, MonoPhenotypeAnalysisResult.GT_COL] = (
358+
None
359+
)
258360
else:
259361
data.loc[individual.patient_id, MonoPhenotypeAnalysisResult.GT_COL] = (
260362
gt_cat.category.cat_id

0 commit comments

Comments
 (0)