Skip to content

Commit 7a8aaf6

Browse files
committed
Made snp_transcript polymorphic to avoid breaking the API
1 parent 911010b commit 7a8aaf6

File tree

4 files changed

+91
-31
lines changed

4 files changed

+91
-31
lines changed

malariagen_data/anoph/dipclust.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple, Sequence
1+
from typing import Optional, Tuple
22

33
import allel # type: ignore
44
import numpy as np
@@ -549,11 +549,52 @@ def _dipclust_concat_subplots(
549549

550550
return fig
551551

552+
def _insert_dipclust_snp_trace(
553+
self,
554+
*,
555+
figures,
556+
subplot_heights,
557+
snp_row_height: plotly_params.height = 25,
558+
transcript: base_params.transcript,
559+
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
560+
sample_sets: Optional[base_params.sample_sets],
561+
sample_query: Optional[base_params.sample_query],
562+
sample_query_options: Optional[base_params.sample_query_options],
563+
site_mask: Optional[base_params.site_mask],
564+
dendro_sample_id_order: np.ndarray,
565+
snp_filter_min_maf: float,
566+
snp_colorscale: Optional[plotly_params.color_continuous_scale],
567+
chunks: base_params.chunks = base_params.native_chunks,
568+
inline_array: base_params.inline_array = base_params.inline_array_default,
569+
):
570+
snp_trace, n_snps_transcript = self._dipclust_snp_trace(
571+
transcript=transcript,
572+
sample_sets=sample_sets,
573+
sample_query=sample_query,
574+
sample_query_options=sample_query_options,
575+
snp_query=snp_query,
576+
site_mask=site_mask,
577+
dendro_sample_id_order=dendro_sample_id_order,
578+
snp_filter_min_maf=snp_filter_min_maf,
579+
snp_colorscale=snp_colorscale,
580+
chunks=chunks,
581+
inline_array=inline_array,
582+
)
583+
584+
if snp_trace:
585+
figures.append(snp_trace)
586+
subplot_heights.append(snp_row_height * n_snps_transcript)
587+
else:
588+
print(
589+
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
590+
)
591+
return figures, subplot_heights
592+
552593
@doc(
553594
summary="Perform diplotype clustering, annotated with heterozygosity, gene copy number and amino acid variants.",
554595
parameters=dict(
555596
heterozygosity="Plot heterozygosity track.",
556-
snp_transcripts="Plot amino acid variants for these transcripts.",
597+
snp_transcript="Plot amino acid variants for these transcripts.",
557598
cnv_region="Plot gene CNV calls for this region.",
558599
snp_filter_min_maf="Filter amino acid variants with alternate allele frequency below this threshold.",
559600
),
@@ -563,7 +604,7 @@ def plot_diplotype_clustering_advanced(
563604
region: base_params.regions,
564605
heterozygosity: bool = True,
565606
heterozygosity_colorscale: plotly_params.color_continuous_scale = "Greys",
566-
snp_transcripts: Sequence[base_params.transcript] = [],
607+
snp_transcript: dipclust_params.snp_transcript = None,
567608
snp_colorscale: plotly_params.color_continuous_scale = "Greys",
568609
snp_filter_min_maf: float = 0.05,
569610
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
@@ -603,7 +644,7 @@ def plot_diplotype_clustering_advanced(
603644
chunks: base_params.chunks = base_params.native_chunks,
604645
inline_array: base_params.inline_array = base_params.inline_array_default,
605646
):
606-
if cohort_size and snp_transcripts:
647+
if cohort_size and snp_transcript:
607648
cohort_size = None
608649
print(
609650
"Cohort size is not supported with amino acid heatmap. Overriding cohort size to None."
@@ -684,9 +725,11 @@ def plot_diplotype_clustering_advanced(
684725
figures.append(cnv_trace)
685726
subplot_heights.append(cnv_row_height * n_cnv_genes)
686727

687-
for snp_transcript in snp_transcripts:
688-
snp_trace, n_snps_transcript = self._dipclust_snp_trace(
728+
if isinstance(snp_transcript, str):
729+
figures, subplot_heights = self._insert_dipclust_snp_trace(
689730
transcript=snp_transcript,
731+
figures=figures,
732+
subplot_heights=subplot_heights,
690733
sample_sets=sample_sets,
691734
sample_query=sample_query,
692735
sample_query_options=sample_query_options,
@@ -698,13 +741,22 @@ def plot_diplotype_clustering_advanced(
698741
chunks=chunks,
699742
inline_array=inline_array,
700743
)
701-
702-
if snp_trace:
703-
figures.append(snp_trace)
704-
subplot_heights.append(snp_row_height * n_snps_transcript)
705-
else:
706-
print(
707-
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
744+
elif isinstance(snp_transcript, list):
745+
for st in snp_transcript:
746+
figures, subplot_heights = self._insert_dipclust_snp_trace(
747+
transcript=st,
748+
figures=figures,
749+
subplot_heights=subplot_heights,
750+
sample_sets=sample_sets,
751+
sample_query=sample_query,
752+
sample_query_options=sample_query_options,
753+
snp_query=snp_query,
754+
site_mask=site_mask,
755+
dendro_sample_id_order=dendro_sample_id_order,
756+
snp_filter_min_maf=snp_filter_min_maf,
757+
snp_colorscale=snp_colorscale,
758+
chunks=chunks,
759+
inline_array=inline_array,
708760
)
709761

710762
# Calculate total height based on subplot heights, plus a fixed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
"""Parameters for diplotype clustering functions."""
22

3+
from typing_extensions import Annotated, TypeAlias, Union, Sequence
4+
35
from .distance_params import distance_metric
46
from .clustering_params import linkage_method
7+
from .base_params import transcript
58

69

710
linkage_method_default: linkage_method = "complete"
811

912
distance_metric_default: distance_metric = "cityblock"
13+
14+
snp_transcript: TypeAlias = Annotated[
15+
Union[None, transcript, Sequence[transcript]],
16+
"A transcript or a list of transcripts",
17+
]

notebooks/plot_diplotype_clustering.ipynb

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@
3636
{
3737
"cell_type": "code",
3838
"execution_count": null,
39-
"id": "92a9bdfd-d808-4c07-8bdd-f2eb3d1e5614",
39+
"id": "15d0bf3e-bbbe-4b67-a39d-c7ed5622cc3b",
4040
"metadata": {},
4141
"outputs": [],
4242
"source": [
43-
"ag3.plot_diplotype_clustering_advanced(\n",
44-
" region='2L:2,350,000-2,680,000',\n",
45-
" snp_transcripts=['AGAP004707-RD', 'AGAP004717-RA'],\n",
46-
" snp_query=\"effect == 'NON_SYNONYMOUS_CODING'\",\n",
43+
"fig = ag3.plot_diplotype_clustering_advanced(\n",
44+
" region=\"2R:28,480,000-28,500,000\",\n",
45+
" cnv_region=\"2R:28,480,000-28,500,000\",\n",
46+
" snp_transcript='AGAP002862-RA',\n",
4747
" snp_filter_min_maf=0.05,\n",
4848
" sample_sets=\"AG1000G-GH\",\n",
4949
" site_mask=\"gamb_colu\",\n",
@@ -52,31 +52,31 @@
5252
" linkage_method=\"complete\",\n",
5353
" count_sort=True,\n",
5454
" distance_sort=False,\n",
55-
")"
55+
" show=False,\n",
56+
")\n",
57+
"fig"
5658
]
5759
},
5860
{
5961
"cell_type": "code",
6062
"execution_count": null,
61-
"id": "15d0bf3e-bbbe-4b67-a39d-c7ed5622cc3b",
63+
"id": "92a9bdfd-d808-4c07-8bdd-f2eb3d1e5614",
6264
"metadata": {},
6365
"outputs": [],
6466
"source": [
65-
"fig = ag3.plot_diplotype_clustering_advanced(\n",
66-
" region=\"2R:28,480,000-28,500,000\",\n",
67-
" cnv_region=\"2R:28,480,000-28,500,000\",\n",
68-
" snp_transcripts=['AGAP002862-RA'],\n",
69-
" snp_filter_min_maf=0.05,\n",
67+
"ag3.plot_diplotype_clustering_advanced(\n",
68+
" region='2R:28,480,000-28,490,000',\n",
69+
" snp_transcript=['AGAP002862-RA', 'AGAP002864-RA'],\n",
70+
" snp_query=\"effect == 'NON_SYNONYMOUS_CODING'\",\n",
71+
" snp_filter_min_maf=0.1,\n",
7072
" sample_sets=\"AG1000G-GH\",\n",
7173
" site_mask=\"gamb_colu\",\n",
7274
" color=\"taxon\",\n",
7375
" symbol=\"country\",\n",
7476
" linkage_method=\"complete\",\n",
7577
" count_sort=True,\n",
7678
" distance_sort=False,\n",
77-
" show=False,\n",
78-
")\n",
79-
"fig"
79+
")"
8080
]
8181
},
8282
{
@@ -89,7 +89,7 @@
8989
"ag3.plot_diplotype_clustering_advanced(\n",
9090
" region=\"2R:28,480,000-28,500,000\",\n",
9191
" cnv_region = \"2R:28,480,000-28,500,000\",\n",
92-
" snp_transcripts=['AGAP002862-RA'],\n",
92+
" snp_transcript=None,\n",
9393
" sample_sets=[\"AG1000G-GH\", 'AG1000G-BF-A'],\n",
9494
" snp_filter_min_maf=0.05,\n",
9595
" site_mask=\"gamb_colu\",\n",
@@ -412,7 +412,7 @@
412412
"source": [
413413
"af1.plot_diplotype_clustering_advanced(\n",
414414
" region = \"X:8,438,477-8,460,887\",\n",
415-
" snp_transcripts=[\"LOC125764232_t1\"],\n",
415+
" snp_transcript=[\"LOC125764232_t1\"],\n",
416416
" cnv_region=\"X:8,418,477-8,480,887\",\n",
417417
" sample_sets=[\"1232-VO-KE-OCHOMO-VMF00044\", \"1231-VO-MULTI-WONDJI-VMF00043\", \"1236-VO-TZ-OKUMU-VMF00090\"],\n",
418418
" sample_query=\"country in ['Kenya', 'Uganda', 'Tanzania'] and taxon == 'funestus'\"\n",

tests/anoph/test_dipclust.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_plot_diplotype_clustering_advanced_with_transcript(
159159
sample_queries = (None, "sex_call == 'F'")
160160
dipclust_params = dict(
161161
region=contig,
162-
snp_transcripts=transcripts,
162+
snp_transcript=transcripts,
163163
sample_sets=[random.choice(all_sample_sets)],
164164
linkage_method=random.choice(linkage_methods),
165165
distance_metric="cityblock",

0 commit comments

Comments
 (0)