diff --git a/malariagen_data/anoph/dipclust.py b/malariagen_data/anoph/dipclust.py index b96ee428..825c9286 100644 --- a/malariagen_data/anoph/dipclust.py +++ b/malariagen_data/anoph/dipclust.py @@ -13,7 +13,7 @@ multiallelic_diplotype_mean_sqeuclidean, multiallelic_diplotype_mean_cityblock, ) -from ..plotly_dendrogram import plot_dendrogram +from ..plotly_dendrogram import plot_dendrogram, concat_clustering_subplots from . import ( base_params, plotly_params, @@ -22,13 +22,9 @@ dipclust_params, cnv_params, ) -from .snp_frq import AnophelesSnpFrequencyAnalysis +from .snp_frq import AnophelesSnpFrequencyAnalysis, AA_CHANGE_QUERY from .cnv_frq import AnophelesCnvFrequencyAnalysis -AA_CHANGE_QUERY = ( - "effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']" -) - class AnophelesDipClustAnalysis( AnophelesCnvFrequencyAnalysis, AnophelesSnpFrequencyAnalysis @@ -503,56 +499,6 @@ def _dipclust_snp_trace( return snp_trace, n_snps - def _dipclust_concat_subplots( - self, - figures, - width, - height, - row_heights, - region: base_params.regions, - n_snps: int, - sample_sets: Optional[base_params.sample_sets], - sample_query: Optional[base_params.sample_query], - ): - from plotly.subplots import make_subplots # type: ignore - - title_lines = [] - if sample_sets is not None: - title_lines.append(f"sample sets: {sample_sets}") - if sample_query is not None: - title_lines.append(f"sample query: {sample_query}") - title_lines.append(f"genomic region: {region} ({n_snps} SNPs)") - title = "
".join(title_lines) - - # make subplots - fig = make_subplots( - rows=len(figures), - cols=1, - shared_xaxes=True, - vertical_spacing=0.02, - row_heights=row_heights, - ) - - for i, figure in enumerate(figures): - if isinstance(figure, go.Figure): - # This is a figure, access the traces within it. - for trace in range(len(figure["data"])): - fig.append_trace(figure["data"][trace], row=i + 1, col=1) - else: - # Assume this is a trace, add directly. - fig.append_trace(figure, row=i + 1, col=1) - - fig.update_xaxes(visible=False) - fig.update_layout( - title=title, - width=width, - height=height, - hovermode="closest", - plot_bgcolor="white", - ) - - return fig - def _insert_dipclust_snp_trace( self, *, @@ -592,7 +538,7 @@ def _insert_dipclust_snp_trace( print( f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot." ) - return figures, subplot_heights + return figures, subplot_heights, n_snps_transcript @doc( summary="""" @@ -733,8 +679,13 @@ def plot_diplotype_clustering_advanced( figures.append(cnv_trace) subplot_heights.append(cnv_row_height * n_cnv_genes) + n_snps_transcripts = [] if isinstance(snp_transcript, str): - figures, subplot_heights = self._insert_dipclust_snp_trace( + ( + figures, + subplot_heights, + n_snps_transcript, + ) = self._insert_dipclust_snp_trace( transcript=snp_transcript, figures=figures, subplot_heights=subplot_heights, @@ -746,12 +697,18 @@ def plot_diplotype_clustering_advanced( dendro_sample_id_order=dendro_sample_id_order, snp_filter_min_maf=snp_filter_min_maf, snp_colorscale=snp_colorscale, + snp_row_height=snp_row_height, chunks=chunks, inline_array=inline_array, ) + n_snps_transcripts.append(n_snps_transcript) elif isinstance(snp_transcript, list): for st in snp_transcript: - figures, subplot_heights = self._insert_dipclust_snp_trace( + ( + figures, + subplot_heights, + n_snps_transcript, + ) = self._insert_dipclust_snp_trace( transcript=st, figures=figures, subplot_heights=subplot_heights, @@ -763,14 +720,16 @@ def plot_diplotype_clustering_advanced( dendro_sample_id_order=dendro_sample_id_order, snp_filter_min_maf=snp_filter_min_maf, snp_colorscale=snp_colorscale, + snp_row_height=snp_row_height, chunks=chunks, inline_array=inline_array, ) + n_snps_transcripts.append(n_snps_transcript) # Calculate total height based on subplot heights, plus a fixed # additional component to allow for title, axes etc. height = sum(subplot_heights) + 50 - fig = self._dipclust_concat_subplots( + fig = concat_clustering_subplots( figures=figures, width=width, height=height, @@ -791,6 +750,44 @@ def plot_diplotype_clustering_advanced( legend=dict(itemsizing=legend_sizing, tracegroupgap=0), ) + # add lines to aa plot - looks neater + if snp_transcript: + n_transcripts = ( + len(snp_transcript) if isinstance(snp_transcript, list) else 1 + ) + for i in range(n_transcripts): + tx_idx = len(figures) - n_transcripts + i + 1 + if n_snps_transcripts[i] > 0: + fig.add_hline( + y=-0.5, line_width=1, line_color="grey", row=tx_idx, col=1 + ) + for j in range(n_snps_transcripts[i]): + fig.add_hline( + y=j + 0.5, + line_width=1, + line_color="grey", + row=tx_idx, + col=1, + ) + + fig.update_xaxes( + showline=True, + linecolor="grey", + linewidth=1, + row=tx_idx, + col=1, + mirror=True, + ) + + fig.update_yaxes( + showline=True, + linecolor="grey", + linewidth=1, + row=tx_idx, + col=1, + mirror=True, + ) + if show: fig.show(renderer=renderer) return None diff --git a/malariagen_data/anoph/hapclust.py b/malariagen_data/anoph/hapclust.py index 05eb6acb..792b3350 100644 --- a/malariagen_data/anoph/hapclust.py +++ b/malariagen_data/anoph/hapclust.py @@ -5,8 +5,8 @@ import pandas as pd from numpydoc_decorator import doc # type: ignore -from ..util import CacheMiss, check_types, pdist_abs_hamming -from ..plotly_dendrogram import plot_dendrogram +from ..util import CacheMiss, check_types, pdist_abs_hamming, pandas_apply +from ..plotly_dendrogram import plot_dendrogram, concat_clustering_subplots from . import ( base_params, plotly_params, @@ -14,12 +14,16 @@ hap_params, clustering_params, hapclust_params, + dipclust_params, ) from .snp_data import AnophelesSnpData +from .snp_frq import AnophelesFrequencyAnalysis, AA_CHANGE_QUERY, _make_snp_label_effect from .hap_data import AnophelesHapData -class AnophelesHapClustAnalysis(AnophelesHapData, AnophelesSnpData): +class AnophelesHapClustAnalysis( + AnophelesHapData, AnophelesSnpData, AnophelesFrequencyAnalysis +): def __init__( self, **kwargs, @@ -47,6 +51,7 @@ def plot_haplotype_clustering( color: plotly_params.color = None, symbol: plotly_params.symbol = None, linkage_method: hapclust_params.linkage_method = hapclust_params.linkage_method_default, + distance_metric: hapclust_params.distance_metric = hapclust_params.distance_metric_default, count_sort: Optional[tree_params.count_sort] = None, distance_sort: Optional[tree_params.distance_sort] = None, title: plotly_params.title = True, @@ -66,7 +71,7 @@ def plot_haplotype_clustering( legend_sizing: plotly_params.legend_sizing = "constant", chunks: base_params.chunks = base_params.native_chunks, inline_array: base_params.inline_array = base_params.inline_array_default, - ) -> plotly_params.figure: + ) -> Optional[dict]: import sys # Normalise params. @@ -89,6 +94,7 @@ def plot_haplotype_clustering( dist, phased_samples, n_snps_used = self.haplotype_pairwise_distances( region=region, analysis=analysis, + distance_metric=distance_metric, sample_sets=sample_sets, sample_query=sample_query, sample_query_options=sample_query_options, @@ -127,6 +133,7 @@ def plot_haplotype_clustering( # Repeat the dataframe so there is one row of metadata for each haplotype. df_haps = pd.DataFrame(np.repeat(df_samples_phased.values, 2, axis=0)) df_haps.columns = df_samples_phased.columns + leaf_data = df_haps.assign(sample_id=_make_unique(df_haps.sample_id)) # Configure hover data. hover_data = self._setup_sample_hover_data_plotly( @@ -145,7 +152,7 @@ def plot_haplotype_clustering( # Create the plot. with self._spinner("Plot dendrogram"): - fig, _ = plot_dendrogram( + fig, leaf_data = plot_dendrogram( dist=dist, linkage_method=linkage_method, count_sort=count_sort, @@ -157,7 +164,7 @@ def plot_haplotype_clustering( line_width=line_width, line_color=line_color, marker_size=marker_size, - leaf_data=df_haps, + leaf_data=leaf_data, leaf_hover_name="sample_id", leaf_hover_data=hover_data, leaf_color=color_prepped, @@ -166,7 +173,7 @@ def plot_haplotype_clustering( leaf_color_discrete_map=color_discrete_map_prepped, leaf_category_orders=category_orders_prepped, template="simple_white", - y_axis_title="Distance (no. SNPs)", + y_axis_title=f"Distance ({distance_metric})", y_axis_buffer=1, ) @@ -182,7 +189,13 @@ def plot_haplotype_clustering( fig.show(renderer=renderer) return None else: - return fig + return { + "figure": fig, + "n_snps": n_snps_used, + "dist": dist, + "dist_samples": phased_samples, + "leaf_data": leaf_data, + } @doc( summary=""" @@ -197,6 +210,7 @@ def plot_haplotype_clustering( def haplotype_pairwise_distances( self, region: base_params.regions, + distance_metric: hapclust_params.distance_metric = hapclust_params.distance_metric_default, analysis: hap_params.analysis = base_params.DEFAULT, sample_sets: Optional[base_params.sample_sets] = None, sample_query: Optional[base_params.sample_query] = None, @@ -215,6 +229,7 @@ def haplotype_pairwise_distances( region_prepped = self._prep_region_cache_param(region=region) params = dict( region=region_prepped, + distance_metric=distance_metric, analysis=analysis, sample_sets=sample_sets_prepped, sample_query=sample_query, @@ -244,6 +259,7 @@ def _haplotype_pairwise_distances( self, *, region, + distance_metric, analysis, sample_sets, sample_query, @@ -287,8 +303,494 @@ def _haplotype_pairwise_distances( # to allow these to be saved to the results cache. phased_samples = ds_haps["sample_id"].values.astype("U") + # Number of sites + n_total_sites = region["end"] - region["start"] + 1 + + # Adjust distances if dxy requested + if distance_metric == "dxy": + # Normalize by total sites (common definition of dxy) + dist = dist / n_total_sites + elif distance_metric == "hamming": + # Leave as raw SNP differences + pass + else: + raise ValueError( + f"Unsupported distance_metric: {distance_metric}. " + "Choose from {'hamming', 'dxy'}." + ) + return dict( dist=dist, phased_samples=phased_samples, - n_snps=np.array(ht.shape[0]), + n_snps=np.array(n_total_sites), + n_seg_sites=np.array(ht.shape[0]), + ) + + @check_types + @doc( + summary=""" + Hierarchically cluster haplotypes in region, and produce an interactive plot + with optional SNP haplotype heatmap and/or cluster assignments. + """, + returns=""" + If `show` is False, returns a tuple (fig, leaf_data, df_haps) where + `fig` is a plotly Figure object, `leaf_data` is a DataFrame with + metadata for each haplotype in the dendrogram, and `df_haps` is a DataFrame + of haplotype calls for each sample at each SNP in the specified transcript. + If `show` is True, displays the figure and returns None. + """, + parameters=dict( + snp_transcript="Plot amino acid variants for these transcripts.", + snp_filter_min_maf="Filter amino acid variants with alternate allele frequency below this threshold.", + snp_query="Query to filter SNPs for amino acid heatmap. Default is to include all non-synonymous SNPs.", + cluster_threshold="Height at which to cut the dendrogram to form clusters. If not provided, no clusters assignment is not performed.", + min_cluster_size="Minimum number of haplotypes required in a cluster to be included when cutting the dendrogram. Default is 5.", + cluster_criterion="The cluster_criterion to use in forming flat clusters. One of 'inconsistent', 'distance', 'maxclust', 'maxclust_monochronic', 'monocrit'. See scipy.cluster.hierarchy.fcluster for details.", + ), + ) + def plot_haplotype_clustering_advanced( + self, + region: base_params.regions, + analysis: hap_params.analysis = base_params.DEFAULT, + snp_transcript: Optional[dipclust_params.snp_transcript] = None, + snp_colorscale: Optional[plotly_params.color_continuous_scale] = "Greys", + snp_filter_min_maf: float = 0.05, + snp_query=AA_CHANGE_QUERY, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + random_seed: base_params.random_seed = 42, + cohort_size: Optional[base_params.cohort_size] = None, + distance_metric: hapclust_params.distance_metric = hapclust_params.distance_metric_default, + cluster_threshold: Optional[float] = None, + min_cluster_size: Optional[int] = 5, + cluster_criterion="distance", + color: plotly_params.color = None, + symbol: plotly_params.symbol = None, + linkage_method: dipclust_params.linkage_method = "complete", + count_sort: Optional[tree_params.count_sort] = None, + distance_sort: Optional[tree_params.distance_sort] = None, + title: plotly_params.title = True, + title_font_size: plotly_params.title_font_size = 14, + width: plotly_params.fig_width = None, + dendrogram_height: plotly_params.height = 300, + snp_row_height: plotly_params.height = 25, + show: plotly_params.show = True, + renderer: plotly_params.renderer = None, + render_mode: plotly_params.render_mode = "svg", + leaf_y: clustering_params.leaf_y = 0, + marker_size: plotly_params.marker_size = 5, + line_width: plotly_params.line_width = 0.5, + line_color: plotly_params.line_color = "black", + color_discrete_sequence: plotly_params.color_discrete_sequence = None, + color_discrete_map: plotly_params.color_discrete_map = None, + category_orders: plotly_params.category_order = None, + legend_sizing: plotly_params.legend_sizing = "constant", + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + ): + import plotly.express as px + import plotly.graph_objects as go + + if cohort_size and snp_transcript: + cohort_size = None + print( + "Cohort size is not supported with amino acid heatmap. Overriding cohort size to None." + ) + + res = self.plot_haplotype_clustering( + region=region, + analysis=analysis, + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + count_sort=count_sort, + cohort_size=cohort_size, + distance_sort=distance_sort, + distance_metric=distance_metric, + linkage_method=linkage_method, + color=color, + symbol=symbol, + title=title, + title_font_size=title_font_size, + width=width, + height=dendrogram_height, + show=False, + renderer=renderer, + render_mode=render_mode, + leaf_y=leaf_y, + marker_size=marker_size, + line_width=line_width, + line_color=line_color, + color_discrete_sequence=color_discrete_sequence, + color_discrete_map=color_discrete_map, + category_orders=category_orders, + legend_sizing=legend_sizing, + random_seed=random_seed, + chunks=chunks, + inline_array=inline_array, + ) + + fig_dendro = res["figure"] + n_snps_cluster = res["n_snps"] + leaf_data = res["leaf_data"] + dendro_sample_id_order = np.asarray(leaf_data["sample_id"].to_list()) + + figures = [fig_dendro] + subplot_heights = [dendrogram_height] + + if cluster_threshold and min_cluster_size: + df_clusters = self.cut_dist_tree( + dist=res["dist"], + dist_samples=_make_unique(np.repeat(res["dist_samples"], 2)), + dendro_sample_id_order=dendro_sample_id_order, + linkage_method=linkage_method, + cluster_threshold=cluster_threshold, + min_cluster_size=min_cluster_size, + cluster_criterion=cluster_criterion, + ) + + leaf_data = leaf_data.merge(df_clusters.T.reset_index()) + + # if more than 8 clusters, use px.colors.qualitative.Alphabet + if df_clusters.max().max() > 8: + cluster_col_list = px.colors.qualitative.Alphabet.copy() + cluster_col_list.insert(0, "white") + else: + cluster_col_list = px.colors.qualitative.Dark2.copy() + cluster_col_list.insert(0, "white") + + snp_trace = go.Heatmap( + z=df_clusters.values, + y=df_clusters.index.to_list(), + colorscale=cluster_col_list, + showlegend=False, + showscale=False, + ) + figures.append(snp_trace) + subplot_heights.append(25) + + n_snps_transcripts = [] + if isinstance(snp_transcript, str): + ( + figures, + subplot_heights, + n_snps_transcript, + ) = self._insert_hapclust_snp_trace( + transcript=snp_transcript, + snp_query=snp_query, + figures=figures, + subplot_heights=subplot_heights, + sample_sets=sample_sets, + sample_query=sample_query, + analysis=analysis, + dendro_sample_id_order=dendro_sample_id_order, + snp_filter_min_maf=snp_filter_min_maf, + snp_colorscale=snp_colorscale, + snp_row_height=snp_row_height, + chunks=chunks, + inline_array=inline_array, + ) + n_snps_transcripts.append(n_snps_transcript) + elif isinstance(snp_transcript, list): + for st in snp_transcript: + ( + figures, + subplot_heights, + n_snps_transcript, + ) = self._insert_hapclust_snp_trace( + transcript=st, + snp_query=snp_query, + figures=figures, + subplot_heights=subplot_heights, + sample_sets=sample_sets, + sample_query=sample_query, + analysis=analysis, + dendro_sample_id_order=dendro_sample_id_order, + snp_filter_min_maf=snp_filter_min_maf, + snp_colorscale=snp_colorscale, + snp_row_height=snp_row_height, + chunks=chunks, + inline_array=inline_array, + ) + n_snps_transcripts.append(n_snps_transcript) + + # Calculate total height based on subplot heights, plus a fixed + # additional component to allow for title, axes etc. + height = sum(subplot_heights) + 50 + fig = concat_clustering_subplots( + figures=figures, + width=width, + height=height, + row_heights=subplot_heights, + sample_sets=sample_sets, + sample_query=sample_query, # Only uses query for title. + region=region, + n_snps=n_snps_cluster, + ) + + fig["layout"]["yaxis"]["title"] = f"Distance ({distance_metric})" + fig.update_layout( + title_font=dict( + size=title_font_size, + ), + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), ) + + # add lines to aa plot - looks neater + if snp_transcript: + n_transcripts = ( + len(snp_transcript) if isinstance(snp_transcript, list) else 1 + ) + for i in range(n_transcripts): + tx_idx = len(figures) - n_transcripts + i + 1 + if n_snps_transcripts[i] > 0: + fig.add_hline( + y=-0.5, line_width=1, line_color="grey", row=tx_idx, col=1 + ) + for j in range(n_snps_transcripts[i]): + fig.add_hline( + y=j + 0.5, + line_width=1, + line_color="grey", + row=tx_idx, + col=1, + ) + + fig.update_xaxes( + showline=True, + linecolor="grey", + linewidth=1, + row=tx_idx, + col=1, + mirror=True, + ) + + fig.update_yaxes( + showline=True, + linecolor="grey", + linewidth=1, + row=tx_idx, + col=1, + mirror=True, + ) + + if show: + fig.show(renderer=renderer) + return None + else: + return fig, leaf_data + + def transcript_haplotypes( + self, + transcript, + sample_sets, + sample_query, + analysis, + snp_query, + chunks, + inline_array, + ): + """ + Extract haplotype calls for a given transcript. + """ + + # Get SNP genotype allele counts for the transcript, applying snp_query + df_eff = ( + self.snp_effects( + transcript=transcript, + ) + .query(snp_query) + .reset_index(drop=True) + ) + + df_eff["label"] = pandas_apply( + _make_snp_label_effect, + df_eff, + columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"], + ) + + # Add a unique variant identifier: "position-alt_allele" + df_eff = df_eff.assign( + pos_alt=lambda x: x.position.astype(str) + "-" + x.alt_allele + ) + + # Get haplotypes for the transcript + ds_haps = self.haplotypes( + region=transcript, + sample_sets=sample_sets, + sample_query=sample_query, + analysis=analysis, + chunks=chunks, + inline_array=inline_array, + ) + + # Convert genotype calls to haplotypes + haps = allel.GenotypeArray(ds_haps["call_genotype"].values).to_haplotypes() + h_pos = allel.SortedIndex(ds_haps["variant_position"].values) + h_alts = ds_haps["variant_allele"].values.astype(str)[:, 1] + h_pos_alts = np.array([f"{pos}-{h_alts[i]}" for i, pos in enumerate(h_pos)]) + + # Filter df_eff to haplotypes, and filter haplotypes to SNPs present in df_eff + df_eff = df_eff.query("pos_alt in @h_pos_alts") + label = df_eff.label.values + haps_bool = np.isin(h_pos_alts, df_eff.pos_alt) + haps = haps.compress(haps_bool) + + # Build haplotype DataFrame + df_haps = pd.DataFrame( + haps, + index=label, + columns=_make_unique( + np.repeat(ds_haps["sample_id"].values, 2) + ), # two haplotypes per sample + ) + + return df_haps + + def _insert_hapclust_snp_trace( + self, + figures, + subplot_heights, + transcript, + analysis, + dendro_sample_id_order: np.ndarray, + snp_filter_min_maf: float, + snp_colorscale, + snp_query, + snp_row_height, + sample_sets, + sample_query, + chunks, + inline_array, + ): + from plotly import graph_objects as go + + # load genotype allele counts at SNP variants for each sample + df_haps = self.transcript_haplotypes( + transcript=transcript, + snp_query=snp_query, + sample_query=sample_query, + sample_sets=sample_sets, + analysis=analysis, + chunks=chunks, + inline_array=inline_array, + ) + + # set to diplotype cluster order + df_haps = df_haps.loc[:, dendro_sample_id_order] + + if snp_filter_min_maf: + df_haps = df_haps.assign(af=lambda x: x.sum(axis=1) / x.shape[1]) + df_haps = df_haps.query("af > @snp_filter_min_maf").drop(columns="af") + + n_snps_transcript = df_haps.shape[0] + + if not df_haps.empty: + snp_trace = go.Heatmap( + z=df_haps.values, + y=df_haps.index.to_list(), + colorscale=snp_colorscale, + showlegend=False, + showscale=False, + ) + else: + snp_trace = None + + if snp_trace: + figures.append(snp_trace) + subplot_heights.append(snp_row_height * df_haps.shape[0]) + else: + print( + f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot." + ) + return figures, subplot_heights, n_snps_transcript + + def cut_dist_tree( + self, + dist, + dist_samples, + dendro_sample_id_order, + linkage_method, + cluster_threshold, + cluster_criterion, + min_cluster_size, + ): + """ + Create a one-row DataFrame with haplotype_ids as columns and cluster assignments as values + + Parameters: + ----------- + dist : ndarray + distance array + dist_samples : array-like + List/array of individual identifiers (haplotype_ids) + linkage_method : str + Method used to calculate the linkage matrix + cluster_threshold : float + Height at which to cut the dendrogram + min_cluster_size : int, default=1 + Minimum number of individuals required in a cluster to be included + dendro_sample_id_order : array-like + List/array of individual identifiers (haplotype_ids) in the order they appear in + the dendrogram + cluster_criterion : str, default='distance' + The cluster_criterion to use in forming flat clusters. One of + 'inconsistent', 'distance', 'maxclust', 'maxclust_monochronic', 'monocrit' + See scipy.cluster.hierarchy.fcluster for details. + + Returns: + -------- + pd.DataFrame + One-row DataFrame with haplotype_ids as columns and assigned cluster numbers (1...n) as values + """ + from scipy.cluster.hierarchy import linkage, fcluster + + Z = linkage(dist, method=linkage_method) + + # Get cluster assignments for each individual + cluster_assignments = fcluster( + Z, t=cluster_threshold, criterion=cluster_criterion + ) + + # Create initial DataFrame + df = ( + pd.DataFrame( + { + "sample_id": dist_samples, + "Cluster ID": _filter_and_remap( + cluster_assignments, x=min_cluster_size + ), + } + ) + .set_index("sample_id") + .T.loc[:, dendro_sample_id_order] + ) + + return df + + +def _filter_and_remap(arr, x): + from collections import Counter + + # Get unique values that appear >= x times + valid_values = [val for val, count in Counter(arr).items() if count >= x] + # Create mapping to 1, 2, 3, ..., n + mapping = {val: i + 1 for i, val in enumerate(sorted(valid_values))} + # Apply transformation + return np.array([mapping.get(val, 0) for val in arr]) + + +def _make_unique(values): + value_counts = {} + unique_values = [] + + for value in values: + if value in value_counts: + value_counts[value] += 1 + unique_values.append(f"{value}_{value_counts[value]}") + else: + value_counts[value] = 0 + unique_values.append(f"{value}_{value_counts[value]}") + + return np.array(unique_values) diff --git a/malariagen_data/anoph/hapclust_params.py b/malariagen_data/anoph/hapclust_params.py index 4f53fc82..71c5c2b4 100644 --- a/malariagen_data/anoph/hapclust_params.py +++ b/malariagen_data/anoph/hapclust_params.py @@ -1,5 +1,17 @@ """Parameters for haplotype clustering functions.""" from .clustering_params import linkage_method +from typing_extensions import Annotated, TypeAlias, Literal linkage_method_default: linkage_method = "single" + +distance_metric: TypeAlias = Annotated[ + Literal["hamming", "dxy"], + """ + The distance metric to use for calculating pairwise distances between haplotypes. + 'hamming' computes the Hamming distance (number of differing SNPs) between haplotypes. + 'dxy' computes the average number of nucleotide differences per site between haplotypes. + """, +] + +distance_metric_default: Literal["hamming", "dxy"] = "hamming" diff --git a/malariagen_data/plotly_dendrogram.py b/malariagen_data/plotly_dendrogram.py index a172c469..8a94907d 100644 --- a/malariagen_data/plotly_dendrogram.py +++ b/malariagen_data/plotly_dendrogram.py @@ -130,3 +130,54 @@ def plot_dendrogram( ) return fig, leaf_data + + +def concat_clustering_subplots( + figures, + width, + height, + row_heights, + region, + n_snps, + sample_sets, + sample_query, +): + from plotly.subplots import make_subplots # type: ignore + import plotly.graph_objects as go # type: ignore + + title_lines = [] + if sample_sets is not None: + title_lines.append(f"sample sets: {sample_sets}") + if sample_query is not None: + title_lines.append(f"sample query: {sample_query}") + title_lines.append(f"genomic region: {region} ({n_snps} SNPs)") + title = "
".join(title_lines) + + # make subplots + fig = make_subplots( + rows=len(figures), + cols=1, + shared_xaxes=True, + vertical_spacing=0.02, + row_heights=row_heights, + ) + + for i, figure in enumerate(figures): + if isinstance(figure, go.Figure): + # This is a figure, access the traces within it. + for trace in range(len(figure["data"])): + fig.append_trace(figure["data"][trace], row=i + 1, col=1) + else: + # Assume this is a trace, add directly. + fig.append_trace(figure, row=i + 1, col=1) + + fig.update_xaxes(visible=False) + fig.update_layout( + title=title, + width=width, + height=height, + hovermode="closest", + plot_bgcolor="white", + ) + + return fig diff --git a/notebooks/plot_haplotype_clustering.ipynb b/notebooks/plot_haplotype_clustering.ipynb index c6dbdf4f..f035681b 100644 --- a/notebooks/plot_haplotype_clustering.ipynb +++ b/notebooks/plot_haplotype_clustering.ipynb @@ -314,7 +314,7 @@ "metadata": {}, "outputs": [], "source": [ - "ag3.plot_haplotype_clustering(\n", + "ag3.plot_haplotype_clustering_advanced(\n", " region=\"2R:28,480,000-28,490,000\",\n", " sample_sets=[\"3.0\"],\n", " sample_query=\"taxon == 'arabiensis'\",\n", @@ -322,8 +322,32 @@ " symbol=new_cohorts,\n", " color=\"year\",\n", " cohort_size=None,\n", - " width=1000,\n", - " height=400,\n", + " snp_transcript=\"AGAP002863-RA\",\n", + " distance_metric='hamming',\n", + " snp_filter_min_maf=0.05\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51846e18", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_haplotype_clustering_advanced(\n", + " region=\"2L:2,360,000-2,431,000\",\n", + " sample_sets=[\"AG1000G-GH\"],\n", + " analysis=\"gamb_colu\",\n", + " color=\"taxon\",\n", + " cohort_size=None,\n", + " snp_transcript=\"AGAP004707-RD\",\n", + " distance_metric='dxy',\n", + " cluster_criterion='distance',\n", + " cluster_threshold=0.0007,\n", + " min_cluster_size=5,\n", + " snp_filter_min_maf=0.05,\n", + " show=True\n", ")" ] }, @@ -435,12 +459,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "vscode": { - "interpreter": { - "hash": "3b9ddb1005cd06989fd869b9e3d566470f1be01faa610bb17d64e58e32302e8b" - } + "version": "3.12.8" }, "widgets": { "application/vnd.jupyter.widget-state+json": {