Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 57 additions & 60 deletions malariagen_data/anoph/dipclust.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
multiallelic_diplotype_mean_sqeuclidean,
multiallelic_diplotype_mean_cityblock,
)
from ..plotly_dendrogram import plot_dendrogram
from ..plotly_dendrogram import plot_dendrogram, concat_clustering_subplots
from . import (
base_params,
plotly_params,
Expand All @@ -22,13 +22,9 @@
dipclust_params,
cnv_params,
)
from .snp_frq import AnophelesSnpFrequencyAnalysis
from .snp_frq import AnophelesSnpFrequencyAnalysis, AA_CHANGE_QUERY
from .cnv_frq import AnophelesCnvFrequencyAnalysis

AA_CHANGE_QUERY = (
"effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']"
)


class AnophelesDipClustAnalysis(
AnophelesCnvFrequencyAnalysis, AnophelesSnpFrequencyAnalysis
Expand Down Expand Up @@ -503,56 +499,6 @@ def _dipclust_snp_trace(

return snp_trace, n_snps

def _dipclust_concat_subplots(
self,
figures,
width,
height,
row_heights,
region: base_params.regions,
n_snps: int,
sample_sets: Optional[base_params.sample_sets],
sample_query: Optional[base_params.sample_query],
):
from plotly.subplots import make_subplots # type: ignore

title_lines = []
if sample_sets is not None:
title_lines.append(f"sample sets: {sample_sets}")
if sample_query is not None:
title_lines.append(f"sample query: {sample_query}")
title_lines.append(f"genomic region: {region} ({n_snps} SNPs)")
title = "<br>".join(title_lines)

# make subplots
fig = make_subplots(
rows=len(figures),
cols=1,
shared_xaxes=True,
vertical_spacing=0.02,
row_heights=row_heights,
)

for i, figure in enumerate(figures):
if isinstance(figure, go.Figure):
# This is a figure, access the traces within it.
for trace in range(len(figure["data"])):
fig.append_trace(figure["data"][trace], row=i + 1, col=1)
else:
# Assume this is a trace, add directly.
fig.append_trace(figure, row=i + 1, col=1)

fig.update_xaxes(visible=False)
fig.update_layout(
title=title,
width=width,
height=height,
hovermode="closest",
plot_bgcolor="white",
)

return fig

def _insert_dipclust_snp_trace(
self,
*,
Expand Down Expand Up @@ -592,7 +538,7 @@ def _insert_dipclust_snp_trace(
print(
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
)
return figures, subplot_heights
return figures, subplot_heights, n_snps_transcript

@doc(
summary=""""
Expand Down Expand Up @@ -733,8 +679,13 @@ def plot_diplotype_clustering_advanced(
figures.append(cnv_trace)
subplot_heights.append(cnv_row_height * n_cnv_genes)

n_snps_transcripts = []
if isinstance(snp_transcript, str):
figures, subplot_heights = self._insert_dipclust_snp_trace(
(
figures,
subplot_heights,
n_snps_transcript,
) = self._insert_dipclust_snp_trace(
transcript=snp_transcript,
figures=figures,
subplot_heights=subplot_heights,
Expand All @@ -746,12 +697,18 @@ def plot_diplotype_clustering_advanced(
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
snp_row_height=snp_row_height,
chunks=chunks,
inline_array=inline_array,
)
n_snps_transcripts.append(n_snps_transcript)
elif isinstance(snp_transcript, list):
for st in snp_transcript:
figures, subplot_heights = self._insert_dipclust_snp_trace(
(
figures,
subplot_heights,
n_snps_transcript,
) = self._insert_dipclust_snp_trace(
transcript=st,
figures=figures,
subplot_heights=subplot_heights,
Expand All @@ -763,14 +720,16 @@ def plot_diplotype_clustering_advanced(
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
snp_row_height=snp_row_height,
chunks=chunks,
inline_array=inline_array,
)
n_snps_transcripts.append(n_snps_transcript)

# Calculate total height based on subplot heights, plus a fixed
# additional component to allow for title, axes etc.
height = sum(subplot_heights) + 50
fig = self._dipclust_concat_subplots(
fig = concat_clustering_subplots(
figures=figures,
width=width,
height=height,
Expand All @@ -791,6 +750,44 @@ def plot_diplotype_clustering_advanced(
legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
)

# add lines to aa plot - looks neater
if snp_transcript:
n_transcripts = (
len(snp_transcript) if isinstance(snp_transcript, list) else 1
)
for i in range(n_transcripts):
tx_idx = len(figures) - n_transcripts + i + 1
if n_snps_transcripts[i] > 0:
fig.add_hline(
y=-0.5, line_width=1, line_color="grey", row=tx_idx, col=1
)
for j in range(n_snps_transcripts[i]):
fig.add_hline(
y=j + 0.5,
line_width=1,
line_color="grey",
row=tx_idx,
col=1,
)

fig.update_xaxes(
showline=True,
linecolor="grey",
linewidth=1,
row=tx_idx,
col=1,
mirror=True,
)

fig.update_yaxes(
showline=True,
linecolor="grey",
linewidth=1,
row=tx_idx,
col=1,
mirror=True,
)

if show:
fig.show(renderer=renderer)
return None
Expand Down
Loading
Loading