Skip to content

Commit 2fdab71

Browse files
committed
Put added-back samples at the correct times
1 parent ab7c55c commit 2fdab71

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

sc2ts/inference.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,7 +1424,17 @@ def trim_branches(ts):
14241424
return tables.tree_sequence()
14251425

14261426

1427-
def attach_tree(parent_ts, parent_tables, attach_path, reversions, child_ts, date):
1427+
def attach_tree(
1428+
parent_ts,
1429+
parent_tables,
1430+
attach_path,
1431+
reversions,
1432+
child_ts,
1433+
date,
1434+
epsilon=None,
1435+
):
1436+
if epsilon is None:
1437+
epsilon = 1e-6 # In time units of days ago
14281438

14291439
root_time = min(parent_ts.nodes_time[seg.parent] for seg in attach_path)
14301440
if root_time == 0:
@@ -1444,19 +1454,37 @@ def attach_tree(parent_ts, parent_tables, attach_path, reversions, child_ts, dat
14441454
child_ts = add_root_edge(child_ts)
14451455
tree = child_ts.first()
14461456

1457+
# Add sample node times
1458+
current_date = parse_date(date)
1459+
node_time = {} # In time units of days ago
1460+
for u in tree.postorder():
1461+
if tree.is_sample(u):
1462+
node = child_ts.node(u)
1463+
sample_date = parse_date(node.metadata['date'])
1464+
node_time[u] = (current_date - sample_date).days
1465+
assert node_time[u] >= 0.0
1466+
max_sample_time = max(node_time.values())
1467+
14471468
node_id_map = {}
14481469
if child_ts.nodes_time[tree.root] != 1.0:
14491470
raise ValueError("Time must be scaled from 0 to 1.")
1450-
node_time = {}
1471+
1472+
num_internal_nodes_visited = 0
14511473
for u in tree.postorder()[:-1]:
14521474
node = child_ts.node(u)
1453-
# Tree branch length is scaled from 0 to 1.
1454-
time = node.time * root_time
1455-
node_time[u] = time
1475+
if tree.is_sample(u):
1476+
# All sample nodes are terminal
1477+
time = node_time[u]
1478+
else:
1479+
num_internal_nodes_visited += 1
1480+
time = max_sample_time + num_internal_nodes_visited * epsilon
1481+
node_time[u] = time
14561482
metadata = node.metadata
14571483
if tree.is_internal(u):
14581484
metadata = {"date_added": date}
1459-
new_id = parent_tables.nodes.append(node.replace(time=time, metadata=metadata))
1485+
new_id = parent_tables.nodes.append(
1486+
node.replace(time=time, metadata=metadata)
1487+
)
14601488
node_id_map[node.id] = new_id
14611489
for v in tree.children(u):
14621490
parent_tables.edges.add_row(

0 commit comments

Comments
 (0)