Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 69 additions & 15 deletions malariagen_data/anoph/dipclust.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@
cnv_params,
)
from .snp_frq import AnophelesSnpFrequencyAnalysis
from .cnv_data import AnophelesCnvData
from .cnv_frq import AnophelesCnvFrequencyAnalysis

AA_CHANGE_QUERY = (
"effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']"
)


class AnophelesDipClustAnalysis(AnophelesSnpFrequencyAnalysis, AnophelesCnvData):
class AnophelesDipClustAnalysis(
AnophelesCnvFrequencyAnalysis, AnophelesSnpFrequencyAnalysis
):
def __init__(
self,
**kwargs,
Expand Down Expand Up @@ -190,7 +192,7 @@ def plot_diplotype_clustering(
else:
return {
"figure": fig,
"dendro_sample_id_order": leaf_data["sample_id"].to_list(),
"dendro_sample_id_order": np.asarray(leaf_data["sample_id"].to_list()),
"n_snps": n_snps_used,
}

Expand Down Expand Up @@ -319,7 +321,7 @@ def _dipclust_het_bar_trace(
sample_sets: Optional[base_params.sample_sets],
sample_query: Optional[base_params.sample_query],
sample_query_options: Optional[base_params.sample_query_options],
site_mask: base_params.site_mask,
site_mask: Optional[base_params.site_mask],
cohort_size: Optional[base_params.cohort_size],
random_seed: base_params.random_seed,
color_continuous_scale: Optional[plotly_params.color_continuous_scale],
Expand Down Expand Up @@ -547,11 +549,52 @@ def _dipclust_concat_subplots(

return fig

def _insert_dipclust_snp_trace(
self,
*,
figures,
subplot_heights,
snp_row_height: plotly_params.height = 25,
transcript: base_params.transcript,
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
sample_sets: Optional[base_params.sample_sets],
sample_query: Optional[base_params.sample_query],
sample_query_options: Optional[base_params.sample_query_options],
site_mask: Optional[base_params.site_mask],
dendro_sample_id_order: np.ndarray,
snp_filter_min_maf: float,
snp_colorscale: Optional[plotly_params.color_continuous_scale],
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
):
snp_trace, n_snps_transcript = self._dipclust_snp_trace(
transcript=transcript,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
snp_query=snp_query,
site_mask=site_mask,
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
chunks=chunks,
inline_array=inline_array,
)

if snp_trace:
figures.append(snp_trace)
subplot_heights.append(snp_row_height * n_snps_transcript)
else:
print(
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
)
return figures, subplot_heights

@doc(
summary="Perform diplotype clustering, annotated with heterozygosity, gene copy number and amino acid variants.",
parameters=dict(
heterozygosity="Plot heterozygosity track.",
snp_transcript="Plot amino acid variants for this transcript.",
snp_transcript="Plot amino acid variants for these transcripts.",
cnv_region="Plot gene CNV calls for this region.",
snp_filter_min_maf="Filter amino acid variants with alternate allele frequency below this threshold.",
),
Expand All @@ -561,7 +604,7 @@ def plot_diplotype_clustering_advanced(
region: base_params.regions,
heterozygosity: bool = True,
heterozygosity_colorscale: plotly_params.color_continuous_scale = "Greys",
snp_transcript: Optional[base_params.transcript] = None,
snp_transcript: Optional[dipclust_params.snp_transcript] = None,
snp_colorscale: plotly_params.color_continuous_scale = "Greys",
snp_filter_min_maf: float = 0.05,
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
Expand Down Expand Up @@ -682,9 +725,11 @@ def plot_diplotype_clustering_advanced(
figures.append(cnv_trace)
subplot_heights.append(cnv_row_height * n_cnv_genes)

if snp_transcript:
snp_trace, n_snps_transcript = self._dipclust_snp_trace(
if isinstance(snp_transcript, str):
figures, subplot_heights = self._insert_dipclust_snp_trace(
transcript=snp_transcript,
figures=figures,
subplot_heights=subplot_heights,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
Expand All @@ -696,13 +741,22 @@ def plot_diplotype_clustering_advanced(
chunks=chunks,
inline_array=inline_array,
)

if snp_trace:
figures.append(snp_trace)
subplot_heights.append(snp_row_height * n_snps_transcript)
else:
print(
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
elif isinstance(snp_transcript, list):
for st in snp_transcript:
figures, subplot_heights = self._insert_dipclust_snp_trace(
transcript=st,
figures=figures,
subplot_heights=subplot_heights,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
snp_query=snp_query,
site_mask=site_mask,
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
chunks=chunks,
inline_array=inline_array,
)

# Calculate total height based on subplot heights, plus a fixed
Expand Down
8 changes: 8 additions & 0 deletions malariagen_data/anoph/dipclust_params.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Parameters for diplotype clustering functions."""

from typing_extensions import Annotated, TypeAlias, Union, Sequence

from .distance_params import distance_metric
from .clustering_params import linkage_method
from .base_params import transcript


linkage_method_default: linkage_method = "complete"

distance_metric_default: distance_metric = "cityblock"

snp_transcript: TypeAlias = Annotated[
Union[transcript, Sequence[transcript]],
"A transcript or a list of transcripts",
]
9 changes: 0 additions & 9 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
import plotly.graph_objects as go # type: ignore
from numpydoc_decorator import doc # type: ignore

from malariagen_data.anoph.snp_frq import (
AnophelesSnpFrequencyAnalysis,
)

from .anoph.cnv_frq import AnophelesCnvFrequencyAnalysis

from .anoph import (
aim_params,
Expand All @@ -32,7 +27,6 @@
from .anoph.karyotype import AnophelesKaryotypeAnalysis
from .anoph.aim_data import AnophelesAimData
from .anoph.base import AnophelesBase
from .anoph.cnv_data import AnophelesCnvData
from .anoph.genome_features import AnophelesGenomeFeaturesData
from .anoph.genome_sequence import AnophelesGenomeSequenceData
from .anoph.hap_data import AnophelesHapData, hap_params
Expand Down Expand Up @@ -88,8 +82,6 @@ class AnophelesDataResource(
AnophelesH12Analysis,
AnophelesG123Analysis,
AnophelesFstAnalysis,
AnophelesCnvFrequencyAnalysis,
AnophelesSnpFrequencyAnalysis,
AnophelesHapFrequencyAnalysis,
AnophelesDistanceAnalysis,
AnophelesPca,
Expand All @@ -99,7 +91,6 @@ class AnophelesDataResource(
AnophelesAimData,
AnophelesHapData,
AnophelesSnpData,
AnophelesCnvData,
AnophelesSampleMetadata,
AnophelesGenomeFeaturesData,
AnophelesGenomeSequenceData,
Expand Down
Loading
Loading