Skip to content

Commit 183ca27

Browse files
committed
Used the prep_ratio_data function to generalise how volcano plots are plotted (i.e. include the custom range function from the ratio plotting).
1 parent 94cc2e9 commit 183ca27

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

autoprot/visualization/basic.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def boxplot(
492492
figsize: tuple = (15, 5),
493493
ax: Union[plt.axis, None] = None,
494494
**kwargs: object,
495-
) -> plt.figure:
495+
) -> plt.Figure | None:
496496
# noinspection PyUnresolvedReferences
497497
r"""
498498
Plot intensity boxplots.
@@ -699,6 +699,9 @@ def intensity_rank(
699699
If list, must be the same length as highlight.
700700
ascending : bool, optional
701701
Whether to sort the data in ascending order.
702+
annotate_density: int, optional
703+
Number of points to consider for density-based annotation.
704+
The default is 100.
702705
**kwargs :
703706
Passed to seaborn.scatterplot.
704707
@@ -1031,7 +1034,7 @@ def _init_scatter(
10311034
figsize: tuple[float, float],
10321035
pointsize_colname: str,
10331036
pointsize_scaler: float,
1034-
) -> tuple[plt.figure, plt.Axes, pd.DataFrame]:
1037+
) -> tuple[plt.Figure, plt.Axes, pd.DataFrame]:
10351038
"""
10361039
Initialize a scatter plot.
10371040
@@ -1350,9 +1353,6 @@ def _prep_volcano_data(
13501353
If neither a p-score nor a p value is provided by the user.
13511354
13521355
"""
1353-
# Work with a copy of the dataframe
1354-
df = df.copy()
1355-
13561356
if score_colname is None and p_colname is None:
13571357
raise ValueError("You have to provide either a score or a (adjusted) p value.")
13581358
elif score_colname is None:
@@ -1367,41 +1367,40 @@ def _prep_volcano_data(
13671367
# four groups of points are present in a volcano plot:
13681368
# (1) non-significant
13691369
df["SigCat"] = "not significant"
1370+
p_sig_idx = pd.Index([])
1371+
logfc_sig_idx = pd.Index([])
1372+
both_sig_idx = pd.Index([])
1373+
13701374
if p_thresh is not None:
13711375
# (2) significant by score
1372-
df.loc[df[p_colname] < p_thresh, "SigCat"] = "p-value"
1376+
df, _, _, _, p_sig_idx = _prep_ratio_data(
1377+
df, p_colname, None, p_thresh, signficance_label="p-value"
1378+
)
13731379

13741380
if log_fc_thresh is not None:
13751381
# (3) significant above or below fc-thresh
1376-
df.loc[
1377-
(df["SigCat"] == "not significant")
1378-
& (abs(df[log_fc_colname]) > log_fc_thresh),
1379-
"SigCat",
1380-
] = "log2FC"
1382+
df, _, _, _, logfc_sig_idx = _prep_ratio_data(
1383+
df, log_fc_colname, None, log_fc_thresh, signficance_label="log2FC"
1384+
)
13811385

13821386
if p_thresh is not None and log_fc_thresh is not None:
13831387
# (4) significant by both
1384-
df.loc[
1385-
(df["SigCat"] == "p-value") & (abs(df[log_fc_colname]) > log_fc_thresh),
1386-
"SigCat",
1387-
] = "p-value and log2FC"
1388+
both_sig_idx = p_sig_idx.intersection(logfc_sig_idx)
1389+
df.loc[both_sig_idx, "SigCat"] = "p-value and log2FC"
13881390

13891391
unsig = df[df["SigCat"] == "not significant"].index
1390-
sig_fc = df[df["SigCat"] == "log2FC"].index
1391-
sig_p = df[df["SigCat"] == "p-value"].index
1392-
sig_both = df[df["SigCat"] == "p-value and log2FC"].index
13931392

1394-
return df, score_colname, unsig, sig_fc, sig_p, sig_both
1393+
return df, score_colname, unsig, logfc_sig_idx, p_sig_idx, both_sig_idx
13951394

13961395

13971396
def volcano(
13981397
df: pd.DataFrame,
13991398
log_fc_colname: str,
14001399
p_colname: str = None,
14011400
score_colname: str = None,
1402-
p_thresh: float or None = 0.05,
1403-
log_fc_thresh: float or None = np.log2(2),
1404-
pointsize_colname: str or float = None,
1401+
p_thresh: float | None = 0.05,
1402+
log_fc_thresh: float | None = np.log2(2),
1403+
pointsize_colname: str | float = None,
14051404
pointsize_scaler: float = 1,
14061405
highlight: Union[pd.Index, list[pd.Index], None] = None,
14071406
title: str = None,
@@ -1756,10 +1755,11 @@ def volcano(
17561755

17571756
if show_thresh:
17581757
if log_fc_thresh is not None:
1759-
ax.axvline(x=log_fc_thresh, color="black", linestyle="--")
1760-
ax.axvline(x=-log_fc_thresh, color="black", linestyle="--")
1758+
_ratio_plot_style_axes(
1759+
ax, ratio_thresh_x=log_fc_thresh, ratio_thresh_y=None
1760+
)
17611761
if p_thresh is not None:
1762-
ax.axhline(y=-np.log10(p_thresh), color="black", linestyle="--")
1762+
_ratio_plot_style_axes(ax, ratio_thresh_x=None, ratio_thresh_y=p_thresh)
17631763

17641764
if ret_fig:
17651765
return fig
@@ -1771,10 +1771,10 @@ def ivolcano(
17711771
log_fc_colname: str,
17721772
p_colname: str = None,
17731773
score_colname: str = None,
1774-
p_thresh: float or None = 0.05,
1775-
log_fc_thresh: float or None = None,
1774+
p_thresh: float | None = 0.05,
1775+
log_fc_thresh: float | None = None,
17761776
annotate_colname: str = None,
1777-
pointsize_colname: str or float = None,
1777+
pointsize_colname: str | float = None,
17781778
highlight: pd.Index = None,
17791779
title: str = "Volcano Plot",
17801780
show_legend: bool = True,
@@ -1935,29 +1935,32 @@ def _prep_ratio_data(
19351935
col_name1: str,
19361936
col_name2: str | None,
19371937
ratio_thresh: tuple[float | None, float | None] | float | None,
1938+
signficance_label: str = "ratio_thresh",
19381939
) -> tuple[pd.DataFrame, str, str, pd.Index, pd.Index]:
19391940
"""
19401941
Prepare ratio data for analysis.
19411942
1942-
This function takes a DataFrame and two column names representing ratios,
1943-
and filters the DataFrame based on a given ratio threshold. It returns a
1944-
new DataFrame containing only the rows where the absolute value of the
1945-
ratio between the two specified columns exceeds the threshold.
1943+
This function takes a DataFrame and two column names
1944+
and labels the DataFrame based on a given threshold. It returns a
1945+
new DataFrame with an additional 'SigCat' column indicating whether
1946+
the value exceeds the threshold in any of the columns.
19461947
19471948
Parameters
19481949
----------
19491950
df : pd.DataFrame
19501951
The input DataFrame containing the data.
19511952
col_name1 : str
19521953
The name of the first column to be used in the ratio calculation.
1953-
col_name2 : str
1954+
col_name2 : str or None
19541955
The name of the second column to be used in the ratio calculation.
19551956
ratio_thresh : float or None or tuple of float
19561957
The threshold value for filtering the ratios. If None, no filtering is applied.
19571958
If a tuple is provided, it should contain two float values representing the
19581959
lower and upper bounds for the ratio threshold. If a single float is provided,
19591960
it is treated as both the lower and upper bound. If None is included in the tuple,
19601961
that bound is ignored.
1962+
signficance_label : str
1963+
The label to assign to significant ratios in the 'SigCat' column.
19611964
19621965
Returns
19631966
-------
@@ -1973,9 +1976,6 @@ def _prep_ratio_data(
19731976
pd.Index
19741977
The indices of the rows where the ratio is above the threshold.
19751978
"""
1976-
# Work with a copy of the dataframe
1977-
df: pd.DataFrame = df.copy() # noqa
1978-
19791979
# check that ratio_thresh is a number or a tuple of numbers
19801980
if not isinstance(ratio_thresh, (int, float, type(None))):
19811981
if isinstance(ratio_thresh, tuple) and len(ratio_thresh) == 2:
@@ -2022,7 +2022,7 @@ def _prep_ratio_data(
20222022

20232023
if col_name2 is None:
20242024
# significantly up or down
2025-
df.loc[sig_up | sig_down, "SigCat"] = "ratio_thresh"
2025+
df.loc[sig_up | sig_down, "SigCat"] = signficance_label
20262026
else:
20272027

20282028
if ratio_thresh[1] is None:
@@ -2039,10 +2039,10 @@ def _prep_ratio_data(
20392039
df.loc[
20402040
(sig_up & sig_up_2) | (sig_down & sig_down_2),
20412041
"SigCat",
2042-
] = "ratio_thresh"
2042+
] = signficance_label
20432043

20442044
non_sig_idx = df[df["SigCat"] == "not significant"].index
2045-
sig_idx = df[df["SigCat"] == "ratio_thresh"].index
2045+
sig_idx = df[df["SigCat"] == signficance_label].index
20462046

20472047
return df, col_name1, col_name2, non_sig_idx, sig_idx
20482048

@@ -2055,15 +2055,15 @@ def _ratio_plot_style_axes(
20552055

20562056
if ratio_thresh_x is not None:
20572057
if not isinstance(ratio_thresh_x, tuple):
2058-
ratio_thresh_x = (ratio_thresh_x, ratio_thresh_x)
2058+
ratio_thresh_x = (-ratio_thresh_x, ratio_thresh_x)
20592059
for thresh in ratio_thresh_x:
20602060
if thresh is None: # Skip boundaries that should not be plotted
20612061
continue
20622062
ax.axvline(x=thresh, color="grey", linestyle="--", alpha=0.8)
20632063

20642064
if ratio_thresh_y is not None:
20652065
if not isinstance(ratio_thresh_y, tuple):
2066-
ratio_thresh_y = (ratio_thresh_y, ratio_thresh_y)
2066+
ratio_thresh_y = (-ratio_thresh_y, ratio_thresh_y)
20672067
for thresh in ratio_thresh_y:
20682068
if thresh is None:
20692069
continue
@@ -2274,7 +2274,7 @@ def iratio_plot(
22742274
ratio_thresh: float = None,
22752275
xlabel: str = "Ratio col1",
22762276
ylabel: str = "Ratio col2",
2277-
pointsize_colname: str or float = None,
2277+
pointsize_colname: str | float = None,
22782278
highlight: pd.Index = None,
22792279
title: str = None,
22802280
show_legend: bool = True,

0 commit comments

Comments
 (0)