1
- from typing import Optional , Tuple , Sequence
1
+ from typing import Optional , Tuple
2
2
3
3
import allel # type: ignore
4
4
import numpy as np
@@ -549,11 +549,52 @@ def _dipclust_concat_subplots(
549
549
550
550
return fig
551
551
552
+ def _insert_dipclust_snp_trace (
553
+ self ,
554
+ * ,
555
+ figures ,
556
+ subplot_heights ,
557
+ snp_row_height : plotly_params .height = 25 ,
558
+ transcript : base_params .transcript ,
559
+ snp_query : Optional [base_params .snp_query ] = AA_CHANGE_QUERY ,
560
+ sample_sets : Optional [base_params .sample_sets ],
561
+ sample_query : Optional [base_params .sample_query ],
562
+ sample_query_options : Optional [base_params .sample_query_options ],
563
+ site_mask : Optional [base_params .site_mask ],
564
+ dendro_sample_id_order : np .ndarray ,
565
+ snp_filter_min_maf : float ,
566
+ snp_colorscale : Optional [plotly_params .color_continuous_scale ],
567
+ chunks : base_params .chunks = base_params .native_chunks ,
568
+ inline_array : base_params .inline_array = base_params .inline_array_default ,
569
+ ):
570
+ snp_trace , n_snps_transcript = self ._dipclust_snp_trace (
571
+ transcript = transcript ,
572
+ sample_sets = sample_sets ,
573
+ sample_query = sample_query ,
574
+ sample_query_options = sample_query_options ,
575
+ snp_query = snp_query ,
576
+ site_mask = site_mask ,
577
+ dendro_sample_id_order = dendro_sample_id_order ,
578
+ snp_filter_min_maf = snp_filter_min_maf ,
579
+ snp_colorscale = snp_colorscale ,
580
+ chunks = chunks ,
581
+ inline_array = inline_array ,
582
+ )
583
+
584
+ if snp_trace :
585
+ figures .append (snp_trace )
586
+ subplot_heights .append (snp_row_height * n_snps_transcript )
587
+ else :
588
+ print (
589
+ f"No SNPs were found below { snp_filter_min_maf } allele frequency. Omitting SNP genotype plot."
590
+ )
591
+ return figures , subplot_heights
592
+
552
593
@doc (
553
594
summary = "Perform diplotype clustering, annotated with heterozygosity, gene copy number and amino acid variants." ,
554
595
parameters = dict (
555
596
heterozygosity = "Plot heterozygosity track." ,
556
- snp_transcripts = "Plot amino acid variants for these transcripts." ,
597
+ snp_transcript = "Plot amino acid variants for these transcripts." ,
557
598
cnv_region = "Plot gene CNV calls for this region." ,
558
599
snp_filter_min_maf = "Filter amino acid variants with alternate allele frequency below this threshold." ,
559
600
),
@@ -563,7 +604,7 @@ def plot_diplotype_clustering_advanced(
563
604
region : base_params .regions ,
564
605
heterozygosity : bool = True ,
565
606
heterozygosity_colorscale : plotly_params .color_continuous_scale = "Greys" ,
566
- snp_transcripts : Sequence [ base_params . transcript ] = [] ,
607
+ snp_transcript : dipclust_params . snp_transcript = None ,
567
608
snp_colorscale : plotly_params .color_continuous_scale = "Greys" ,
568
609
snp_filter_min_maf : float = 0.05 ,
569
610
snp_query : Optional [base_params .snp_query ] = AA_CHANGE_QUERY ,
@@ -603,7 +644,7 @@ def plot_diplotype_clustering_advanced(
603
644
chunks : base_params .chunks = base_params .native_chunks ,
604
645
inline_array : base_params .inline_array = base_params .inline_array_default ,
605
646
):
606
- if cohort_size and snp_transcripts :
647
+ if cohort_size and snp_transcript :
607
648
cohort_size = None
608
649
print (
609
650
"Cohort size is not supported with amino acid heatmap. Overriding cohort size to None."
@@ -684,9 +725,11 @@ def plot_diplotype_clustering_advanced(
684
725
figures .append (cnv_trace )
685
726
subplot_heights .append (cnv_row_height * n_cnv_genes )
686
727
687
- for snp_transcript in snp_transcripts :
688
- snp_trace , n_snps_transcript = self ._dipclust_snp_trace (
728
+ if isinstance ( snp_transcript , str ) :
729
+ figures , subplot_heights = self ._insert_dipclust_snp_trace (
689
730
transcript = snp_transcript ,
731
+ figures = figures ,
732
+ subplot_heights = subplot_heights ,
690
733
sample_sets = sample_sets ,
691
734
sample_query = sample_query ,
692
735
sample_query_options = sample_query_options ,
@@ -698,13 +741,22 @@ def plot_diplotype_clustering_advanced(
698
741
chunks = chunks ,
699
742
inline_array = inline_array ,
700
743
)
701
-
702
- if snp_trace :
703
- figures .append (snp_trace )
704
- subplot_heights .append (snp_row_height * n_snps_transcript )
705
- else :
706
- print (
707
- f"No SNPs were found below { snp_filter_min_maf } allele frequency. Omitting SNP genotype plot."
744
+ elif isinstance (snp_transcript , list ):
745
+ for st in snp_transcript :
746
+ figures , subplot_heights = self ._insert_dipclust_snp_trace (
747
+ transcript = st ,
748
+ figures = figures ,
749
+ subplot_heights = subplot_heights ,
750
+ sample_sets = sample_sets ,
751
+ sample_query = sample_query ,
752
+ sample_query_options = sample_query_options ,
753
+ snp_query = snp_query ,
754
+ site_mask = site_mask ,
755
+ dendro_sample_id_order = dendro_sample_id_order ,
756
+ snp_filter_min_maf = snp_filter_min_maf ,
757
+ snp_colorscale = snp_colorscale ,
758
+ chunks = chunks ,
759
+ inline_array = inline_array ,
708
760
)
709
761
710
762
# Calculate total height based on subplot heights, plus a fixed
0 commit comments