@@ -1469,6 +1469,7 @@ def plot_subgraph(
1469
1469
ts_id_labels = None ,
1470
1470
node_metadata_labels = None ,
1471
1471
sample_metadata_labels = None ,
1472
+ show_descendant_samples = None ,
1472
1473
edge_labels = None ,
1473
1474
edge_font_size = None ,
1474
1475
node_font_size = None ,
@@ -1508,8 +1509,12 @@ def plot_subgraph(
1508
1509
do not plot any all-node metadata.
1509
1510
:param str sample_metadata_labels: Should we additionally label sample nodes with a
1510
1511
value from their metadata: Default: ``None``, treated as ``"gisaid_epi_isl"``.
1511
- Notes representing multiple samples will have a label saying "XXX samples".
1512
- If ``""``, do not plot any sample node metadata.
1512
+ :param str show_descendant_samples: Should we label nodes with the maximum number
1513
+ of samples descending from them in any tree (in the format "+XXX samples").
1514
+ If ``"samples"``, only label sample nodes. If "tips", label all tip nodes.
1515
+ If ``"sample_tips"` label all tips that are also samples. If ``"all"``, label
1516
+ all nodes. If ``""`` or False, do not show labels. Default: ``None``, treated
1517
+ as ``"sample_tips"``. If a node has no descendant samples, a label is not placed.
1513
1518
:param dict edge_labels: a mapping of {(parent_id, child_id): "label")} with which
1514
1519
to label the edges. If ``None``, label with mutations or (if above a
1515
1520
recombination node) with the edge interval. If ``{}``, do not plot
@@ -1547,6 +1552,9 @@ def sort_mutation_label(s):
1547
1552
try :
1548
1553
return float (s )
1549
1554
except ValueError :
1555
+ if s [0 ] == "$" :
1556
+ # matplotlib mathtext - remove the $ and the formatting
1557
+ s = s .replace ("$" , "" ).replace (r"\bf" , "" ).replace ("\it" , "" ).replace ("{" , "" ).replace ("}" , "" )
1550
1558
try :
1551
1559
return float (s [1 :- 1 ])
1552
1560
except ValueError :
@@ -1564,11 +1572,18 @@ def sort_mutation_label(s):
1564
1572
node_metadata_labels = "Imputed_GISAID_lineage"
1565
1573
if sample_metadata_labels is None :
1566
1574
sample_metadata_labels = "gisaid_epi_isl"
1575
+ if show_descendant_samples is None :
1576
+ show_descendant_samples = "sample_tips"
1567
1577
if colour_metadata_key is None :
1568
1578
colour_metadata_key = "strain"
1569
1579
if exterior_edge_len is None :
1570
1580
exterior_edge_len = 0.4
1571
1581
1582
+ if show_descendant_samples not in {"samples" , "tips" , "sample_tips" , "all" , "" , False }:
1583
+ raise ValueError (
1584
+ "show_descendant_samples must be one of 'samples', 'tips', 'sample_tips', 'all', or '' / False"
1585
+ )
1586
+
1572
1587
# Read in characteristic mutations info
1573
1588
linmuts_dict = None
1574
1589
if mutations_json_filepath is not None :
@@ -1585,7 +1600,7 @@ def sort_mutation_label(s):
1585
1600
G = to_nx_subgraph (ts , nodes )
1586
1601
1587
1602
nodelabels = collections .defaultdict (list )
1588
- tip_samples = {}
1603
+ shown_tips = []
1589
1604
for u , out_deg in G .out_degree ():
1590
1605
node = ts .node (u )
1591
1606
if node_metadata_labels :
@@ -1595,16 +1610,22 @@ def sort_mutation_label(s):
1595
1610
if node .is_sample ():
1596
1611
if sample_metadata_labels :
1597
1612
nodelabels [u ].append (node .metadata [sample_metadata_labels ])
1598
- if out_deg == 0 : # Only show num descendants for tip samples
1599
- tip_samples [u ] = 0
1600
- for tree in ts .trees ():
1601
- for u in tip_samples .keys ():
1602
- # This is't quite right - it shows the max num samples per tree,
1603
- # not the total number of samples. But it's close enough.
1604
- tip_samples [u ] = max (tip_samples [u ], tree .num_samples (u ))
1605
- for u , s in tip_samples .items ():
1606
- if s > 1 :
1607
- nodelabels [u ].append (f"+{ s - 1 } { 'samples' if s > 2 else 'sample' } " )
1613
+ if show_descendant_samples :
1614
+ show = True if show_descendant_samples == "all" else False
1615
+ is_tip = out_deg == 0
1616
+ if show_descendant_samples == "tips" and is_tip :
1617
+ show = True
1618
+ elif node .is_sample ():
1619
+ if show_descendant_samples == "samples" :
1620
+ show = True
1621
+ elif show_descendant_samples == "sample_tips" and is_tip :
1622
+ show = True
1623
+ if show :
1624
+ s = ti .nodes_max_descendant_samples [u ]
1625
+ if node .is_sample ():
1626
+ s -= 1 # don't count self
1627
+ if s > 0 :
1628
+ nodelabels [u ].append (f"+{ s } { 'samples' if s > 1 else 'sample' } " )
1608
1629
1609
1630
nodelabels = {k : "\n " .join (v ) for k , v in nodelabels .items ()}
1610
1631
@@ -1628,7 +1649,9 @@ def sort_mutation_label(s):
1628
1649
inherited_state = ts .mutation (m .parent ).derived_state
1629
1650
1630
1651
if ti .mutations_is_reversion [m .id ]:
1631
- mutstr = f"{ inherited_state .lower ()} { pos } { m .derived_state .lower ()} "
1652
+ mutstr = f"$\\ bf{{{ inherited_state .lower ()} { pos } { m .derived_state .lower ()} }}$"
1653
+ elif ts .mutations_parent [m .id ] != tskit .NULL :
1654
+ mutstr = f"$\\ bf{{{ inherited_state .upper ()} { pos } { m .derived_state .upper ()} }}$"
1632
1655
else :
1633
1656
mutstr = f"{ inherited_state .upper ()} { pos } { m .derived_state .upper ()} "
1634
1657
if linmuts_dict is None or pos in linmuts_dict .all_positions :
@@ -1659,7 +1682,13 @@ def sort_mutation_label(s):
1659
1682
lpos = "lft"
1660
1683
elif edge .left > 0 and edge .right == ts .sequence_length :
1661
1684
lpos = "rgt"
1685
+ if interval_labels [lpos ][pc ]:
1686
+ interval_labels [lpos ][pc ] += " " # multiple same-side intervals for an edge
1687
+ if lpos == "rgt" and interval_labels ["lft" ][pc ]:
1688
+ interval_labels [lpos ][pc ] = " " + interval_labels [lpos ][pc ]
1662
1689
interval_labels [lpos ][pc ] = f"{ int (edge .left )} …{ int (edge .right )} "
1690
+ if lpos == "lft" and interval_labels ["rgt" ][pc ]:
1691
+ interval_labels [lpos ][pc ] += " "
1663
1692
1664
1693
if label_replace is not None :
1665
1694
for search , replace in label_replace .items ():
0 commit comments