Skip to content

Commit 6b2c42a

Browse files
Add special case for inserting exact matches
1 parent 23ef308 commit 6b2c42a

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

sc2ts/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
NODE_IS_MUTATION_OVERLAP = 1 << 21
1818
NODE_IS_REVERSION_PUSH = 1 << 22
1919
NODE_IS_RECOMBINANT = 1 << 23
20+
NODE_IS_EXACT_MATCH = 1 << 24
2021

2122

2223
__version__ = "undefined"

sc2ts/inference.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,9 +593,11 @@ def extend(
593593
match_db.create_mask_table(base_ts)
594594
ts = increment_time(date, base_ts)
595595

596+
ts = add_exact_matches(ts=ts, match_db=match_db, date=date)
597+
596598
logger.info(f"Update ARG with low-cost samples for {date}")
597599
ts = add_matching_results(
598-
f"match_date=='{date}' and hmm_cost<={max_hmm_cost}",
600+
f"match_date=='{date}' and hmm_cost>0 and hmm_cost<={max_hmm_cost}",
599601
ts=ts,
600602
match_db=match_db,
601603
date=date,
@@ -674,6 +676,31 @@ def match_path_ts(samples, ts, path, reversions):
674676
# print(tables)
675677

676678

679+
def add_exact_matches(match_db, ts, date):
680+
where_clause = f"match_date=='{date}' AND hmm_cost==0"
681+
logger.info(f"Querying match DB WHERE: {where_clause}")
682+
samples = list(match_db.get(where_clause))
683+
if len(samples) == 0:
684+
logger.info(f"No exact matches on {date}")
685+
return ts
686+
logger.info(f"Update ARG with {len(samples)} exact matches for {date}")
687+
tables = ts.dump_tables()
688+
for sample in samples:
689+
assert len(sample.path) == 1
690+
assert len(sample.mutations) == 0
691+
node_id = tables.nodes.add_row(
692+
flags=tskit.NODE_IS_SAMPLE | core.NODE_IS_EXACT_MATCH,
693+
time=0,
694+
metadata=sample.metadata,
695+
)
696+
parent = sample.path[0].parent
697+
logger.debug(f"ARG add exact match {sample.strain}:{node_id}->{parent}")
698+
tables.edges.add_row(0, ts.sequence_length, parent=parent, child=node_id)
699+
tables.sort()
700+
tables.build_index()
701+
return tables.tree_sequence()
702+
703+
677704
def add_matching_results(
678705
where_clause,
679706
match_db,

0 commit comments

Comments
 (0)