diff --git a/malariagen_data/anoph/dipclust.py b/malariagen_data/anoph/dipclust.py
index b96ee428..825c9286 100644
--- a/malariagen_data/anoph/dipclust.py
+++ b/malariagen_data/anoph/dipclust.py
@@ -13,7 +13,7 @@
multiallelic_diplotype_mean_sqeuclidean,
multiallelic_diplotype_mean_cityblock,
)
-from ..plotly_dendrogram import plot_dendrogram
+from ..plotly_dendrogram import plot_dendrogram, concat_clustering_subplots
from . import (
base_params,
plotly_params,
@@ -22,13 +22,9 @@
dipclust_params,
cnv_params,
)
-from .snp_frq import AnophelesSnpFrequencyAnalysis
+from .snp_frq import AnophelesSnpFrequencyAnalysis, AA_CHANGE_QUERY
from .cnv_frq import AnophelesCnvFrequencyAnalysis
-AA_CHANGE_QUERY = (
- "effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']"
-)
-
class AnophelesDipClustAnalysis(
AnophelesCnvFrequencyAnalysis, AnophelesSnpFrequencyAnalysis
@@ -503,56 +499,6 @@ def _dipclust_snp_trace(
return snp_trace, n_snps
- def _dipclust_concat_subplots(
- self,
- figures,
- width,
- height,
- row_heights,
- region: base_params.regions,
- n_snps: int,
- sample_sets: Optional[base_params.sample_sets],
- sample_query: Optional[base_params.sample_query],
- ):
- from plotly.subplots import make_subplots # type: ignore
-
- title_lines = []
- if sample_sets is not None:
- title_lines.append(f"sample sets: {sample_sets}")
- if sample_query is not None:
- title_lines.append(f"sample query: {sample_query}")
- title_lines.append(f"genomic region: {region} ({n_snps} SNPs)")
- title = "
".join(title_lines)
-
- # make subplots
- fig = make_subplots(
- rows=len(figures),
- cols=1,
- shared_xaxes=True,
- vertical_spacing=0.02,
- row_heights=row_heights,
- )
-
- for i, figure in enumerate(figures):
- if isinstance(figure, go.Figure):
- # This is a figure, access the traces within it.
- for trace in range(len(figure["data"])):
- fig.append_trace(figure["data"][trace], row=i + 1, col=1)
- else:
- # Assume this is a trace, add directly.
- fig.append_trace(figure, row=i + 1, col=1)
-
- fig.update_xaxes(visible=False)
- fig.update_layout(
- title=title,
- width=width,
- height=height,
- hovermode="closest",
- plot_bgcolor="white",
- )
-
- return fig
-
def _insert_dipclust_snp_trace(
self,
*,
@@ -592,7 +538,7 @@ def _insert_dipclust_snp_trace(
print(
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
)
- return figures, subplot_heights
+ return figures, subplot_heights, n_snps_transcript
@doc(
summary=""""
@@ -733,8 +679,13 @@ def plot_diplotype_clustering_advanced(
figures.append(cnv_trace)
subplot_heights.append(cnv_row_height * n_cnv_genes)
+ n_snps_transcripts = []
if isinstance(snp_transcript, str):
- figures, subplot_heights = self._insert_dipclust_snp_trace(
+ (
+ figures,
+ subplot_heights,
+ n_snps_transcript,
+ ) = self._insert_dipclust_snp_trace(
transcript=snp_transcript,
figures=figures,
subplot_heights=subplot_heights,
@@ -746,12 +697,18 @@ def plot_diplotype_clustering_advanced(
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
+ snp_row_height=snp_row_height,
chunks=chunks,
inline_array=inline_array,
)
+ n_snps_transcripts.append(n_snps_transcript)
elif isinstance(snp_transcript, list):
for st in snp_transcript:
- figures, subplot_heights = self._insert_dipclust_snp_trace(
+ (
+ figures,
+ subplot_heights,
+ n_snps_transcript,
+ ) = self._insert_dipclust_snp_trace(
transcript=st,
figures=figures,
subplot_heights=subplot_heights,
@@ -763,14 +720,16 @@ def plot_diplotype_clustering_advanced(
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
+ snp_row_height=snp_row_height,
chunks=chunks,
inline_array=inline_array,
)
+ n_snps_transcripts.append(n_snps_transcript)
# Calculate total height based on subplot heights, plus a fixed
# additional component to allow for title, axes etc.
height = sum(subplot_heights) + 50
- fig = self._dipclust_concat_subplots(
+ fig = concat_clustering_subplots(
figures=figures,
width=width,
height=height,
@@ -791,6 +750,44 @@ def plot_diplotype_clustering_advanced(
legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
)
+ # add lines to aa plot - looks neater
+ if snp_transcript:
+ n_transcripts = (
+ len(snp_transcript) if isinstance(snp_transcript, list) else 1
+ )
+ for i in range(n_transcripts):
+ tx_idx = len(figures) - n_transcripts + i + 1
+ if n_snps_transcripts[i] > 0:
+ fig.add_hline(
+ y=-0.5, line_width=1, line_color="grey", row=tx_idx, col=1
+ )
+ for j in range(n_snps_transcripts[i]):
+ fig.add_hline(
+ y=j + 0.5,
+ line_width=1,
+ line_color="grey",
+ row=tx_idx,
+ col=1,
+ )
+
+ fig.update_xaxes(
+ showline=True,
+ linecolor="grey",
+ linewidth=1,
+ row=tx_idx,
+ col=1,
+ mirror=True,
+ )
+
+ fig.update_yaxes(
+ showline=True,
+ linecolor="grey",
+ linewidth=1,
+ row=tx_idx,
+ col=1,
+ mirror=True,
+ )
+
if show:
fig.show(renderer=renderer)
return None
diff --git a/malariagen_data/anoph/hapclust.py b/malariagen_data/anoph/hapclust.py
index 05eb6acb..792b3350 100644
--- a/malariagen_data/anoph/hapclust.py
+++ b/malariagen_data/anoph/hapclust.py
@@ -5,8 +5,8 @@
import pandas as pd
from numpydoc_decorator import doc # type: ignore
-from ..util import CacheMiss, check_types, pdist_abs_hamming
-from ..plotly_dendrogram import plot_dendrogram
+from ..util import CacheMiss, check_types, pdist_abs_hamming, pandas_apply
+from ..plotly_dendrogram import plot_dendrogram, concat_clustering_subplots
from . import (
base_params,
plotly_params,
@@ -14,12 +14,16 @@
hap_params,
clustering_params,
hapclust_params,
+ dipclust_params,
)
from .snp_data import AnophelesSnpData
+from .snp_frq import AnophelesFrequencyAnalysis, AA_CHANGE_QUERY, _make_snp_label_effect
from .hap_data import AnophelesHapData
-class AnophelesHapClustAnalysis(AnophelesHapData, AnophelesSnpData):
+class AnophelesHapClustAnalysis(
+ AnophelesHapData, AnophelesSnpData, AnophelesFrequencyAnalysis
+):
def __init__(
self,
**kwargs,
@@ -47,6 +51,7 @@ def plot_haplotype_clustering(
color: plotly_params.color = None,
symbol: plotly_params.symbol = None,
linkage_method: hapclust_params.linkage_method = hapclust_params.linkage_method_default,
+ distance_metric: hapclust_params.distance_metric = hapclust_params.distance_metric_default,
count_sort: Optional[tree_params.count_sort] = None,
distance_sort: Optional[tree_params.distance_sort] = None,
title: plotly_params.title = True,
@@ -66,7 +71,7 @@ def plot_haplotype_clustering(
legend_sizing: plotly_params.legend_sizing = "constant",
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
- ) -> plotly_params.figure:
+ ) -> Optional[dict]:
import sys
# Normalise params.
@@ -89,6 +94,7 @@ def plot_haplotype_clustering(
dist, phased_samples, n_snps_used = self.haplotype_pairwise_distances(
region=region,
analysis=analysis,
+ distance_metric=distance_metric,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
@@ -127,6 +133,7 @@ def plot_haplotype_clustering(
# Repeat the dataframe so there is one row of metadata for each haplotype.
df_haps = pd.DataFrame(np.repeat(df_samples_phased.values, 2, axis=0))
df_haps.columns = df_samples_phased.columns
+ leaf_data = df_haps.assign(sample_id=_make_unique(df_haps.sample_id))
# Configure hover data.
hover_data = self._setup_sample_hover_data_plotly(
@@ -145,7 +152,7 @@ def plot_haplotype_clustering(
# Create the plot.
with self._spinner("Plot dendrogram"):
- fig, _ = plot_dendrogram(
+ fig, leaf_data = plot_dendrogram(
dist=dist,
linkage_method=linkage_method,
count_sort=count_sort,
@@ -157,7 +164,7 @@ def plot_haplotype_clustering(
line_width=line_width,
line_color=line_color,
marker_size=marker_size,
- leaf_data=df_haps,
+ leaf_data=leaf_data,
leaf_hover_name="sample_id",
leaf_hover_data=hover_data,
leaf_color=color_prepped,
@@ -166,7 +173,7 @@ def plot_haplotype_clustering(
leaf_color_discrete_map=color_discrete_map_prepped,
leaf_category_orders=category_orders_prepped,
template="simple_white",
- y_axis_title="Distance (no. SNPs)",
+ y_axis_title=f"Distance ({distance_metric})",
y_axis_buffer=1,
)
@@ -182,7 +189,13 @@ def plot_haplotype_clustering(
fig.show(renderer=renderer)
return None
else:
- return fig
+ return {
+ "figure": fig,
+ "n_snps": n_snps_used,
+ "dist": dist,
+ "dist_samples": phased_samples,
+ "leaf_data": leaf_data,
+ }
@doc(
summary="""
@@ -197,6 +210,7 @@ def plot_haplotype_clustering(
def haplotype_pairwise_distances(
self,
region: base_params.regions,
+ distance_metric: hapclust_params.distance_metric = hapclust_params.distance_metric_default,
analysis: hap_params.analysis = base_params.DEFAULT,
sample_sets: Optional[base_params.sample_sets] = None,
sample_query: Optional[base_params.sample_query] = None,
@@ -215,6 +229,7 @@ def haplotype_pairwise_distances(
region_prepped = self._prep_region_cache_param(region=region)
params = dict(
region=region_prepped,
+ distance_metric=distance_metric,
analysis=analysis,
sample_sets=sample_sets_prepped,
sample_query=sample_query,
@@ -244,6 +259,7 @@ def _haplotype_pairwise_distances(
self,
*,
region,
+ distance_metric,
analysis,
sample_sets,
sample_query,
@@ -287,8 +303,494 @@ def _haplotype_pairwise_distances(
# to allow these to be saved to the results cache.
phased_samples = ds_haps["sample_id"].values.astype("U")
+ # Number of sites
+ n_total_sites = region["end"] - region["start"] + 1
+
+ # Adjust distances if dxy requested
+ if distance_metric == "dxy":
+ # Normalize by total sites (common definition of dxy)
+ dist = dist / n_total_sites
+ elif distance_metric == "hamming":
+ # Leave as raw SNP differences
+ pass
+ else:
+ raise ValueError(
+ f"Unsupported distance_metric: {distance_metric}. "
+ "Choose from {'hamming', 'dxy'}."
+ )
+
return dict(
dist=dist,
phased_samples=phased_samples,
- n_snps=np.array(ht.shape[0]),
+ n_snps=np.array(n_total_sites),
+ n_seg_sites=np.array(ht.shape[0]),
+ )
+
+ @check_types
+ @doc(
+ summary="""
+ Hierarchically cluster haplotypes in region, and produce an interactive plot
+ with optional SNP haplotype heatmap and/or cluster assignments.
+ """,
+ returns="""
+ If `show` is False, returns a tuple (fig, leaf_data, df_haps) where
+ `fig` is a plotly Figure object, `leaf_data` is a DataFrame with
+ metadata for each haplotype in the dendrogram, and `df_haps` is a DataFrame
+ of haplotype calls for each sample at each SNP in the specified transcript.
+ If `show` is True, displays the figure and returns None.
+ """,
+ parameters=dict(
+ snp_transcript="Plot amino acid variants for these transcripts.",
+ snp_filter_min_maf="Filter amino acid variants with alternate allele frequency below this threshold.",
+ snp_query="Query to filter SNPs for amino acid heatmap. Default is to include all non-synonymous SNPs.",
+ cluster_threshold="Height at which to cut the dendrogram to form clusters. If not provided, no clusters assignment is not performed.",
+ min_cluster_size="Minimum number of haplotypes required in a cluster to be included when cutting the dendrogram. Default is 5.",
+ cluster_criterion="The cluster_criterion to use in forming flat clusters. One of 'inconsistent', 'distance', 'maxclust', 'maxclust_monochronic', 'monocrit'. See scipy.cluster.hierarchy.fcluster for details.",
+ ),
+ )
+ def plot_haplotype_clustering_advanced(
+ self,
+ region: base_params.regions,
+ analysis: hap_params.analysis = base_params.DEFAULT,
+ snp_transcript: Optional[dipclust_params.snp_transcript] = None,
+ snp_colorscale: Optional[plotly_params.color_continuous_scale] = "Greys",
+ snp_filter_min_maf: float = 0.05,
+ snp_query=AA_CHANGE_QUERY,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ sample_query: Optional[base_params.sample_query] = None,
+ sample_query_options: Optional[base_params.sample_query_options] = None,
+ random_seed: base_params.random_seed = 42,
+ cohort_size: Optional[base_params.cohort_size] = None,
+ distance_metric: hapclust_params.distance_metric = hapclust_params.distance_metric_default,
+ cluster_threshold: Optional[float] = None,
+ min_cluster_size: Optional[int] = 5,
+ cluster_criterion="distance",
+ color: plotly_params.color = None,
+ symbol: plotly_params.symbol = None,
+ linkage_method: dipclust_params.linkage_method = "complete",
+ count_sort: Optional[tree_params.count_sort] = None,
+ distance_sort: Optional[tree_params.distance_sort] = None,
+ title: plotly_params.title = True,
+ title_font_size: plotly_params.title_font_size = 14,
+ width: plotly_params.fig_width = None,
+ dendrogram_height: plotly_params.height = 300,
+ snp_row_height: plotly_params.height = 25,
+ show: plotly_params.show = True,
+ renderer: plotly_params.renderer = None,
+ render_mode: plotly_params.render_mode = "svg",
+ leaf_y: clustering_params.leaf_y = 0,
+ marker_size: plotly_params.marker_size = 5,
+ line_width: plotly_params.line_width = 0.5,
+ line_color: plotly_params.line_color = "black",
+ color_discrete_sequence: plotly_params.color_discrete_sequence = None,
+ color_discrete_map: plotly_params.color_discrete_map = None,
+ category_orders: plotly_params.category_order = None,
+ legend_sizing: plotly_params.legend_sizing = "constant",
+ chunks: base_params.chunks = base_params.native_chunks,
+ inline_array: base_params.inline_array = base_params.inline_array_default,
+ ):
+ import plotly.express as px
+ import plotly.graph_objects as go
+
+ if cohort_size and snp_transcript:
+ cohort_size = None
+ print(
+ "Cohort size is not supported with amino acid heatmap. Overriding cohort size to None."
+ )
+
+ res = self.plot_haplotype_clustering(
+ region=region,
+ analysis=analysis,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ sample_query_options=sample_query_options,
+ count_sort=count_sort,
+ cohort_size=cohort_size,
+ distance_sort=distance_sort,
+ distance_metric=distance_metric,
+ linkage_method=linkage_method,
+ color=color,
+ symbol=symbol,
+ title=title,
+ title_font_size=title_font_size,
+ width=width,
+ height=dendrogram_height,
+ show=False,
+ renderer=renderer,
+ render_mode=render_mode,
+ leaf_y=leaf_y,
+ marker_size=marker_size,
+ line_width=line_width,
+ line_color=line_color,
+ color_discrete_sequence=color_discrete_sequence,
+ color_discrete_map=color_discrete_map,
+ category_orders=category_orders,
+ legend_sizing=legend_sizing,
+ random_seed=random_seed,
+ chunks=chunks,
+ inline_array=inline_array,
+ )
+
+ fig_dendro = res["figure"]
+ n_snps_cluster = res["n_snps"]
+ leaf_data = res["leaf_data"]
+ dendro_sample_id_order = np.asarray(leaf_data["sample_id"].to_list())
+
+ figures = [fig_dendro]
+ subplot_heights = [dendrogram_height]
+
+ if cluster_threshold and min_cluster_size:
+ df_clusters = self.cut_dist_tree(
+ dist=res["dist"],
+ dist_samples=_make_unique(np.repeat(res["dist_samples"], 2)),
+ dendro_sample_id_order=dendro_sample_id_order,
+ linkage_method=linkage_method,
+ cluster_threshold=cluster_threshold,
+ min_cluster_size=min_cluster_size,
+ cluster_criterion=cluster_criterion,
+ )
+
+ leaf_data = leaf_data.merge(df_clusters.T.reset_index())
+
+ # if more than 8 clusters, use px.colors.qualitative.Alphabet
+ if df_clusters.max().max() > 8:
+ cluster_col_list = px.colors.qualitative.Alphabet.copy()
+ cluster_col_list.insert(0, "white")
+ else:
+ cluster_col_list = px.colors.qualitative.Dark2.copy()
+ cluster_col_list.insert(0, "white")
+
+ snp_trace = go.Heatmap(
+ z=df_clusters.values,
+ y=df_clusters.index.to_list(),
+ colorscale=cluster_col_list,
+ showlegend=False,
+ showscale=False,
+ )
+ figures.append(snp_trace)
+ subplot_heights.append(25)
+
+ n_snps_transcripts = []
+ if isinstance(snp_transcript, str):
+ (
+ figures,
+ subplot_heights,
+ n_snps_transcript,
+ ) = self._insert_hapclust_snp_trace(
+ transcript=snp_transcript,
+ snp_query=snp_query,
+ figures=figures,
+ subplot_heights=subplot_heights,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ analysis=analysis,
+ dendro_sample_id_order=dendro_sample_id_order,
+ snp_filter_min_maf=snp_filter_min_maf,
+ snp_colorscale=snp_colorscale,
+ snp_row_height=snp_row_height,
+ chunks=chunks,
+ inline_array=inline_array,
+ )
+ n_snps_transcripts.append(n_snps_transcript)
+ elif isinstance(snp_transcript, list):
+ for st in snp_transcript:
+ (
+ figures,
+ subplot_heights,
+ n_snps_transcript,
+ ) = self._insert_hapclust_snp_trace(
+ transcript=st,
+ snp_query=snp_query,
+ figures=figures,
+ subplot_heights=subplot_heights,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ analysis=analysis,
+ dendro_sample_id_order=dendro_sample_id_order,
+ snp_filter_min_maf=snp_filter_min_maf,
+ snp_colorscale=snp_colorscale,
+ snp_row_height=snp_row_height,
+ chunks=chunks,
+ inline_array=inline_array,
+ )
+ n_snps_transcripts.append(n_snps_transcript)
+
+ # Calculate total height based on subplot heights, plus a fixed
+ # additional component to allow for title, axes etc.
+ height = sum(subplot_heights) + 50
+ fig = concat_clustering_subplots(
+ figures=figures,
+ width=width,
+ height=height,
+ row_heights=subplot_heights,
+ sample_sets=sample_sets,
+ sample_query=sample_query, # Only uses query for title.
+ region=region,
+ n_snps=n_snps_cluster,
+ )
+
+ fig["layout"]["yaxis"]["title"] = f"Distance ({distance_metric})"
+ fig.update_layout(
+ title_font=dict(
+ size=title_font_size,
+ ),
+ legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
)
+
+ # add lines to aa plot - looks neater
+ if snp_transcript:
+ n_transcripts = (
+ len(snp_transcript) if isinstance(snp_transcript, list) else 1
+ )
+ for i in range(n_transcripts):
+ tx_idx = len(figures) - n_transcripts + i + 1
+ if n_snps_transcripts[i] > 0:
+ fig.add_hline(
+ y=-0.5, line_width=1, line_color="grey", row=tx_idx, col=1
+ )
+ for j in range(n_snps_transcripts[i]):
+ fig.add_hline(
+ y=j + 0.5,
+ line_width=1,
+ line_color="grey",
+ row=tx_idx,
+ col=1,
+ )
+
+ fig.update_xaxes(
+ showline=True,
+ linecolor="grey",
+ linewidth=1,
+ row=tx_idx,
+ col=1,
+ mirror=True,
+ )
+
+ fig.update_yaxes(
+ showline=True,
+ linecolor="grey",
+ linewidth=1,
+ row=tx_idx,
+ col=1,
+ mirror=True,
+ )
+
+ if show:
+ fig.show(renderer=renderer)
+ return None
+ else:
+ return fig, leaf_data
+
+ def transcript_haplotypes(
+ self,
+ transcript,
+ sample_sets,
+ sample_query,
+ analysis,
+ snp_query,
+ chunks,
+ inline_array,
+ ):
+ """
+ Extract haplotype calls for a given transcript.
+ """
+
+ # Get SNP genotype allele counts for the transcript, applying snp_query
+ df_eff = (
+ self.snp_effects(
+ transcript=transcript,
+ )
+ .query(snp_query)
+ .reset_index(drop=True)
+ )
+
+ df_eff["label"] = pandas_apply(
+ _make_snp_label_effect,
+ df_eff,
+ columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"],
+ )
+
+ # Add a unique variant identifier: "position-alt_allele"
+ df_eff = df_eff.assign(
+ pos_alt=lambda x: x.position.astype(str) + "-" + x.alt_allele
+ )
+
+ # Get haplotypes for the transcript
+ ds_haps = self.haplotypes(
+ region=transcript,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ analysis=analysis,
+ chunks=chunks,
+ inline_array=inline_array,
+ )
+
+ # Convert genotype calls to haplotypes
+ haps = allel.GenotypeArray(ds_haps["call_genotype"].values).to_haplotypes()
+ h_pos = allel.SortedIndex(ds_haps["variant_position"].values)
+ h_alts = ds_haps["variant_allele"].values.astype(str)[:, 1]
+ h_pos_alts = np.array([f"{pos}-{h_alts[i]}" for i, pos in enumerate(h_pos)])
+
+ # Filter df_eff to haplotypes, and filter haplotypes to SNPs present in df_eff
+ df_eff = df_eff.query("pos_alt in @h_pos_alts")
+ label = df_eff.label.values
+ haps_bool = np.isin(h_pos_alts, df_eff.pos_alt)
+ haps = haps.compress(haps_bool)
+
+ # Build haplotype DataFrame
+ df_haps = pd.DataFrame(
+ haps,
+ index=label,
+ columns=_make_unique(
+ np.repeat(ds_haps["sample_id"].values, 2)
+ ), # two haplotypes per sample
+ )
+
+ return df_haps
+
+ def _insert_hapclust_snp_trace(
+ self,
+ figures,
+ subplot_heights,
+ transcript,
+ analysis,
+ dendro_sample_id_order: np.ndarray,
+ snp_filter_min_maf: float,
+ snp_colorscale,
+ snp_query,
+ snp_row_height,
+ sample_sets,
+ sample_query,
+ chunks,
+ inline_array,
+ ):
+ from plotly import graph_objects as go
+
+ # load genotype allele counts at SNP variants for each sample
+ df_haps = self.transcript_haplotypes(
+ transcript=transcript,
+ snp_query=snp_query,
+ sample_query=sample_query,
+ sample_sets=sample_sets,
+ analysis=analysis,
+ chunks=chunks,
+ inline_array=inline_array,
+ )
+
+ # set to diplotype cluster order
+ df_haps = df_haps.loc[:, dendro_sample_id_order]
+
+ if snp_filter_min_maf:
+ df_haps = df_haps.assign(af=lambda x: x.sum(axis=1) / x.shape[1])
+ df_haps = df_haps.query("af > @snp_filter_min_maf").drop(columns="af")
+
+ n_snps_transcript = df_haps.shape[0]
+
+ if not df_haps.empty:
+ snp_trace = go.Heatmap(
+ z=df_haps.values,
+ y=df_haps.index.to_list(),
+ colorscale=snp_colorscale,
+ showlegend=False,
+ showscale=False,
+ )
+ else:
+ snp_trace = None
+
+ if snp_trace:
+ figures.append(snp_trace)
+ subplot_heights.append(snp_row_height * df_haps.shape[0])
+ else:
+ print(
+ f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
+ )
+ return figures, subplot_heights, n_snps_transcript
+
+ def cut_dist_tree(
+ self,
+ dist,
+ dist_samples,
+ dendro_sample_id_order,
+ linkage_method,
+ cluster_threshold,
+ cluster_criterion,
+ min_cluster_size,
+ ):
+ """
+ Create a one-row DataFrame with haplotype_ids as columns and cluster assignments as values
+
+ Parameters:
+ -----------
+ dist : ndarray
+ distance array
+ dist_samples : array-like
+ List/array of individual identifiers (haplotype_ids)
+ linkage_method : str
+ Method used to calculate the linkage matrix
+ cluster_threshold : float
+ Height at which to cut the dendrogram
+ min_cluster_size : int, default=1
+ Minimum number of individuals required in a cluster to be included
+ dendro_sample_id_order : array-like
+ List/array of individual identifiers (haplotype_ids) in the order they appear in
+ the dendrogram
+ cluster_criterion : str, default='distance'
+ The cluster_criterion to use in forming flat clusters. One of
+ 'inconsistent', 'distance', 'maxclust', 'maxclust_monochronic', 'monocrit'
+ See scipy.cluster.hierarchy.fcluster for details.
+
+ Returns:
+ --------
+ pd.DataFrame
+ One-row DataFrame with haplotype_ids as columns and assigned cluster numbers (1...n) as values
+ """
+ from scipy.cluster.hierarchy import linkage, fcluster
+
+ Z = linkage(dist, method=linkage_method)
+
+ # Get cluster assignments for each individual
+ cluster_assignments = fcluster(
+ Z, t=cluster_threshold, criterion=cluster_criterion
+ )
+
+ # Create initial DataFrame
+ df = (
+ pd.DataFrame(
+ {
+ "sample_id": dist_samples,
+ "Cluster ID": _filter_and_remap(
+ cluster_assignments, x=min_cluster_size
+ ),
+ }
+ )
+ .set_index("sample_id")
+ .T.loc[:, dendro_sample_id_order]
+ )
+
+ return df
+
+
+def _filter_and_remap(arr, x):
+ from collections import Counter
+
+ # Get unique values that appear >= x times
+ valid_values = [val for val, count in Counter(arr).items() if count >= x]
+ # Create mapping to 1, 2, 3, ..., n
+ mapping = {val: i + 1 for i, val in enumerate(sorted(valid_values))}
+ # Apply transformation
+ return np.array([mapping.get(val, 0) for val in arr])
+
+
+def _make_unique(values):
+ value_counts = {}
+ unique_values = []
+
+ for value in values:
+ if value in value_counts:
+ value_counts[value] += 1
+ unique_values.append(f"{value}_{value_counts[value]}")
+ else:
+ value_counts[value] = 0
+ unique_values.append(f"{value}_{value_counts[value]}")
+
+ return np.array(unique_values)
diff --git a/malariagen_data/anoph/hapclust_params.py b/malariagen_data/anoph/hapclust_params.py
index 4f53fc82..71c5c2b4 100644
--- a/malariagen_data/anoph/hapclust_params.py
+++ b/malariagen_data/anoph/hapclust_params.py
@@ -1,5 +1,17 @@
"""Parameters for haplotype clustering functions."""
from .clustering_params import linkage_method
+from typing_extensions import Annotated, TypeAlias, Literal
linkage_method_default: linkage_method = "single"
+
+distance_metric: TypeAlias = Annotated[
+ Literal["hamming", "dxy"],
+ """
+ The distance metric to use for calculating pairwise distances between haplotypes.
+ 'hamming' computes the Hamming distance (number of differing SNPs) between haplotypes.
+ 'dxy' computes the average number of nucleotide differences per site between haplotypes.
+ """,
+]
+
+distance_metric_default: Literal["hamming", "dxy"] = "hamming"
diff --git a/malariagen_data/plotly_dendrogram.py b/malariagen_data/plotly_dendrogram.py
index a172c469..8a94907d 100644
--- a/malariagen_data/plotly_dendrogram.py
+++ b/malariagen_data/plotly_dendrogram.py
@@ -130,3 +130,54 @@ def plot_dendrogram(
)
return fig, leaf_data
+
+
+def concat_clustering_subplots(
+ figures,
+ width,
+ height,
+ row_heights,
+ region,
+ n_snps,
+ sample_sets,
+ sample_query,
+):
+ from plotly.subplots import make_subplots # type: ignore
+ import plotly.graph_objects as go # type: ignore
+
+ title_lines = []
+ if sample_sets is not None:
+ title_lines.append(f"sample sets: {sample_sets}")
+ if sample_query is not None:
+ title_lines.append(f"sample query: {sample_query}")
+ title_lines.append(f"genomic region: {region} ({n_snps} SNPs)")
+ title = "
".join(title_lines)
+
+ # make subplots
+ fig = make_subplots(
+ rows=len(figures),
+ cols=1,
+ shared_xaxes=True,
+ vertical_spacing=0.02,
+ row_heights=row_heights,
+ )
+
+ for i, figure in enumerate(figures):
+ if isinstance(figure, go.Figure):
+ # This is a figure, access the traces within it.
+ for trace in range(len(figure["data"])):
+ fig.append_trace(figure["data"][trace], row=i + 1, col=1)
+ else:
+ # Assume this is a trace, add directly.
+ fig.append_trace(figure, row=i + 1, col=1)
+
+ fig.update_xaxes(visible=False)
+ fig.update_layout(
+ title=title,
+ width=width,
+ height=height,
+ hovermode="closest",
+ plot_bgcolor="white",
+ )
+
+ return fig
diff --git a/notebooks/plot_haplotype_clustering.ipynb b/notebooks/plot_haplotype_clustering.ipynb
index c6dbdf4f..f035681b 100644
--- a/notebooks/plot_haplotype_clustering.ipynb
+++ b/notebooks/plot_haplotype_clustering.ipynb
@@ -314,7 +314,7 @@
"metadata": {},
"outputs": [],
"source": [
- "ag3.plot_haplotype_clustering(\n",
+ "ag3.plot_haplotype_clustering_advanced(\n",
" region=\"2R:28,480,000-28,490,000\",\n",
" sample_sets=[\"3.0\"],\n",
" sample_query=\"taxon == 'arabiensis'\",\n",
@@ -322,8 +322,32 @@
" symbol=new_cohorts,\n",
" color=\"year\",\n",
" cohort_size=None,\n",
- " width=1000,\n",
- " height=400,\n",
+ " snp_transcript=\"AGAP002863-RA\",\n",
+ " distance_metric='hamming',\n",
+ " snp_filter_min_maf=0.05\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "51846e18",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_haplotype_clustering_advanced(\n",
+ " region=\"2L:2,360,000-2,431,000\",\n",
+ " sample_sets=[\"AG1000G-GH\"],\n",
+ " analysis=\"gamb_colu\",\n",
+ " color=\"taxon\",\n",
+ " cohort_size=None,\n",
+ " snp_transcript=\"AGAP004707-RD\",\n",
+ " distance_metric='dxy',\n",
+ " cluster_criterion='distance',\n",
+ " cluster_threshold=0.0007,\n",
+ " min_cluster_size=5,\n",
+ " snp_filter_min_maf=0.05,\n",
+ " show=True\n",
")"
]
},
@@ -435,12 +459,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.12"
- },
- "vscode": {
- "interpreter": {
- "hash": "3b9ddb1005cd06989fd869b9e3d566470f1be01faa610bb17d64e58e32302e8b"
- }
+ "version": "3.12.8"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {