Skip to content

Commit a49aeaf

Browse files
Merge pull request #177 from hyanwong/plot-labelling
Allow numbers of descendant samples to be shown even on nonsample nodes
2 parents 916e9f6 + a5a614b commit a49aeaf

File tree

1 file changed

+43
-14
lines changed

1 file changed

+43
-14
lines changed

sc2ts/utils.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,7 @@ def plot_subgraph(
14691469
ts_id_labels=None,
14701470
node_metadata_labels=None,
14711471
sample_metadata_labels=None,
1472+
show_descendant_samples=None,
14721473
edge_labels=None,
14731474
edge_font_size=None,
14741475
node_font_size=None,
@@ -1508,8 +1509,12 @@ def plot_subgraph(
15081509
do not plot any all-node metadata.
15091510
:param str sample_metadata_labels: Should we additionally label sample nodes with a
15101511
value from their metadata: Default: ``None``, treated as ``"gisaid_epi_isl"``.
1511-
Notes representing multiple samples will have a label saying "XXX samples".
1512-
If ``""``, do not plot any sample node metadata.
1512+
:param str show_descendant_samples: Should we label nodes with the maximum number
1513+
of samples descending from them in any tree (in the format "+XXX samples").
1514+
If ``"samples"``, only label sample nodes. If "tips", label all tip nodes.
1515+
If ``"sample_tips"` label all tips that are also samples. If ``"all"``, label
1516+
all nodes. If ``""`` or False, do not show labels. Default: ``None``, treated
1517+
as ``"sample_tips"``. If a node has no descendant samples, a label is not placed.
15131518
:param dict edge_labels: a mapping of {(parent_id, child_id): "label")} with which
15141519
to label the edges. If ``None``, label with mutations or (if above a
15151520
recombination node) with the edge interval. If ``{}``, do not plot
@@ -1547,6 +1552,9 @@ def sort_mutation_label(s):
15471552
try:
15481553
return float(s)
15491554
except ValueError:
1555+
if s[0] == "$":
1556+
# matplotlib mathtext - remove the $ and the formatting
1557+
s = s.replace("$", "").replace(r"\bf", "").replace("\it", "").replace("{", "").replace("}", "")
15501558
try:
15511559
return float(s[1:-1])
15521560
except ValueError:
@@ -1564,11 +1572,18 @@ def sort_mutation_label(s):
15641572
node_metadata_labels = "Imputed_GISAID_lineage"
15651573
if sample_metadata_labels is None:
15661574
sample_metadata_labels = "gisaid_epi_isl"
1575+
if show_descendant_samples is None:
1576+
show_descendant_samples = "sample_tips"
15671577
if colour_metadata_key is None:
15681578
colour_metadata_key = "strain"
15691579
if exterior_edge_len is None:
15701580
exterior_edge_len = 0.4
15711581

1582+
if show_descendant_samples not in {"samples", "tips", "sample_tips", "all", "", False}:
1583+
raise ValueError(
1584+
"show_descendant_samples must be one of 'samples', 'tips', 'sample_tips', 'all', or '' / False"
1585+
)
1586+
15721587
# Read in characteristic mutations info
15731588
linmuts_dict = None
15741589
if mutations_json_filepath is not None:
@@ -1585,7 +1600,7 @@ def sort_mutation_label(s):
15851600
G = to_nx_subgraph(ts, nodes)
15861601

15871602
nodelabels = collections.defaultdict(list)
1588-
tip_samples = {}
1603+
shown_tips = []
15891604
for u, out_deg in G.out_degree():
15901605
node = ts.node(u)
15911606
if node_metadata_labels:
@@ -1595,16 +1610,22 @@ def sort_mutation_label(s):
15951610
if node.is_sample():
15961611
if sample_metadata_labels:
15971612
nodelabels[u].append(node.metadata[sample_metadata_labels])
1598-
if out_deg == 0: # Only show num descendants for tip samples
1599-
tip_samples[u] = 0
1600-
for tree in ts.trees():
1601-
for u in tip_samples.keys():
1602-
# This is't quite right - it shows the max num samples per tree,
1603-
# not the total number of samples. But it's close enough.
1604-
tip_samples[u] = max(tip_samples[u], tree.num_samples(u))
1605-
for u, s in tip_samples.items():
1606-
if s > 1:
1607-
nodelabels[u].append(f"+{s-1} {'samples' if s > 2 else 'sample'}")
1613+
if show_descendant_samples:
1614+
show = True if show_descendant_samples == "all" else False
1615+
is_tip = out_deg == 0
1616+
if show_descendant_samples == "tips" and is_tip:
1617+
show = True
1618+
elif node.is_sample():
1619+
if show_descendant_samples == "samples":
1620+
show = True
1621+
elif show_descendant_samples == "sample_tips" and is_tip:
1622+
show = True
1623+
if show:
1624+
s = ti.nodes_max_descendant_samples[u]
1625+
if node.is_sample():
1626+
s -= 1 # don't count self
1627+
if s > 0:
1628+
nodelabels[u].append(f"+{s} {'samples' if s > 1 else 'sample'}")
16081629

16091630
nodelabels = {k: "\n".join(v) for k, v in nodelabels.items()}
16101631

@@ -1628,7 +1649,9 @@ def sort_mutation_label(s):
16281649
inherited_state = ts.mutation(m.parent).derived_state
16291650

16301651
if ti.mutations_is_reversion[m.id]:
1631-
mutstr = f"{inherited_state.lower()}{pos}{m.derived_state.lower()}"
1652+
mutstr = f"$\\bf{{{inherited_state.lower()}{pos}{m.derived_state.lower()}}}$"
1653+
elif ts.mutations_parent[m.id] != tskit.NULL:
1654+
mutstr = f"$\\bf{{{inherited_state.upper()}{pos}{m.derived_state.upper()}}}$"
16321655
else:
16331656
mutstr = f"{inherited_state.upper()}{pos}{m.derived_state.upper()}"
16341657
if linmuts_dict is None or pos in linmuts_dict.all_positions:
@@ -1659,7 +1682,13 @@ def sort_mutation_label(s):
16591682
lpos = "lft"
16601683
elif edge.left > 0 and edge.right == ts.sequence_length:
16611684
lpos = "rgt"
1685+
if interval_labels[lpos][pc]:
1686+
interval_labels[lpos][pc] += " " # multiple same-side intervals for an edge
1687+
if lpos == "rgt" and interval_labels["lft"][pc]:
1688+
interval_labels[lpos][pc] = " " + interval_labels[lpos][pc]
16621689
interval_labels[lpos][pc] = f"{int(edge.left)}{int(edge.right)}"
1690+
if lpos == "lft" and interval_labels["rgt"][pc]:
1691+
interval_labels[lpos][pc] += " "
16631692

16641693
if label_replace is not None:
16651694
for search, replace in label_replace.items():

0 commit comments

Comments
 (0)