Skip to content

Commit da37fea

Browse files
Initial first tests on real data
1 parent 841c377 commit da37fea

File tree

2 files changed

+58
-104
lines changed

2 files changed

+58
-104
lines changed

sc2ts/inference.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,17 @@ def extend(
563563
show_progress=False,
564564
max_submission_delay=None,
565565
max_daily_samples=None,
566-
num_threads=None,
566+
num_threads=0,
567567
precision=None,
568568
rng=None,
569569
):
570+
if num_mismatches is None:
571+
num_mismatches = 3
572+
if max_hmm_cost is None:
573+
max_hmm_cost = 5
574+
if min_group_size is None:
575+
min_group_size = 10
576+
570577
check_base_ts(base_ts)
571578
logger.info(
572579
f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};"

tests/test_inference.py

Lines changed: 50 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -485,106 +485,53 @@ def test_high_recomb_mutation(self):
485485
self.check_double_mirror(ts)
486486

487487

488-
# # @pytest.fixture
489-
# def small_sd_fixture():
490-
# reference = core.get_reference_sequence()
491-
# print(reference)
492-
# fasta = {"REF": reference}
493-
# rows = [{"strain": "REF"}]
494-
# sd = convert.convert_alignments(rows, fasta)
495-
496-
# return sd
497-
498-
# ref = core.get_reference_sequence()
499-
# with tsinfer.SampleData(sequence_length=len(ref)) as sd:
500-
# sd.add_individual(
501-
# metadata={
502-
# "strain": "A",
503-
# "date": "2019-12-30",
504-
# "date_submitted": "2020-01-02",
505-
# }
506-
# )
507-
# sd.add_individual(
508-
# metadata={
509-
# "strain": "B",
510-
# "date": "2020-01-01",
511-
# "date_submitted": "2020-02-02",
512-
# }
513-
# )
514-
# sd.add_individual(
515-
# metadata={
516-
# "strain": "C",
517-
# "date": "2020-01-01",
518-
# "date_submitted": "2020-02-02",
519-
# }
520-
# )
521-
# sd.add_individual(
522-
# metadata={
523-
# "strain": "D",
524-
# "date": "2020-01-02",
525-
# "date_submitted": "2022-02-02",
526-
# }
527-
# )
528-
# sd.add_individual(
529-
# metadata={
530-
# "strain": "E",
531-
# "date": "2020-01-06",
532-
# "date_submitted": "2020-02-02",
533-
# }
534-
# )
535-
# for
536-
# return sd
537-
538-
# class TestInitialTables:
539-
# def test_site_schema(self):
540-
# sd = small_sd_fixture()
541-
# pass
542-
543-
544-
@pytest.mark.skip()
545-
class TestInference:
546-
def test_small_sd_times(self, small_sd_fixture):
547-
ts = sc2ts.infer(small_sd_fixture)
548-
inference.validate(small_sd_fixture, ts)
549-
# Day 0 is Jan 6, and ultimate ancestor is one day older than the
550-
# real root (reference)
551-
np.testing.assert_array_equal(ts.nodes_time, [9, 8, 7, 5, 5, 4, 0])
552-
553-
def test_small_sd_submission_delay(self, small_sd_fixture):
554-
ts = sc2ts.infer(small_sd_fixture, max_submission_delay=100)
555-
strains = [ts.node(u).metadata["strain"] for u in ts.samples()]
556-
# Strain D should be filtered out.
557-
assert strains == ["A", "B", "C", "E"]
558-
with pytest.raises(ValueError):
559-
inference.validate(small_sd_fixture, ts)
560-
inference.validate(small_sd_fixture, ts, max_submission_delay=100)
561-
562-
def test_daily_prefix(self, tmp_path, sd_fixture):
563-
prefix = str(tmp_path) + "/x"
564-
ts = sc2ts.infer(sd_fixture, daily_prefix=prefix)
565-
paths = sorted(list(tmp_path.glob("x*")))
566-
dailies = [tskit.load(x) for x in paths]
567-
assert len(dailies) > 0
568-
ts.tables.assert_equals(dailies[-1].tables)
569-
570-
@pytest.mark.parametrize("num_mismatches", [1, 2, 4, 1000])
571-
def test_integrity(self, sd_fixture, num_mismatches):
572-
ts = sc2ts.infer(sd_fixture, num_mismatches=num_mismatches)
573-
assert ts.sequence_length == 29904
574-
inference.validate(sd_fixture, ts)
575-
576-
577-
@pytest.mark.skip()
578-
class TestSubsetInferenceDefaults:
579-
def test_metadata(self, ts_fixture):
580-
for node in ts_fixture.nodes():
581-
if node.flags == 0:
582-
assert node.metadata == {}
583-
elif node.flags == tskit.NODE_IS_SAMPLE:
584-
assert "strain" in node.metadata
585-
else:
586-
assert node.flags == tsinfer.NODE_IS_IDENTICAL_SAMPLE_ANCESTOR
587-
assert node.metadata == {}
588-
589-
def test_integrity(self, sd_fixture, ts_fixture):
590-
inference.validate(sd_fixture, ts_fixture)
488+
class TestRealData:
489+
def test_first_day(self, tmp_path, alignments_store, metadata_db):
490+
ts = sc2ts.extend(
491+
alignment_store=alignments_store,
492+
metadata_db=metadata_db,
493+
base_ts=sc2ts.initial_ts(),
494+
date="2020-01-19",
495+
match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db"),
496+
)
497+
# 25.00┊ 0 ┊
498+
# ┊ ┃ ┊
499+
# 24.00┊ 1 ┊
500+
# ┊ ┃ ┊
501+
# 0.00 ┊ 2 ┊
502+
# 0 29904
503+
assert ts.num_trees == 1
504+
assert ts.num_nodes == 3
505+
assert ts.num_samples == 2
506+
assert ts.num_mutations == 3
507+
assert list(ts.nodes_time) == [25, 24, 0]
508+
assert ts.metadata["sc2ts"]["date"] == "2020-01-19"
509+
assert ts.metadata["sc2ts"]["samples_strain"] == [
510+
"Wuhan/Hu-1/2019",
511+
"SRR11772659",
512+
]
513+
assert list(ts.samples()) == [1, 2]
514+
assert ts.node(1).metadata["strain"] == "Wuhan/Hu-1/2019"
515+
assert ts.node(2).metadata["strain"] == "SRR11772659"
516+
assert list(ts.mutations_node) == [2, 2, 2]
517+
assert list(ts.mutations_time) == [0, 0, 0]
518+
assert list(ts.mutations_site) == [8632, 17816, 27786]
519+
sc2ts_md = ts.node(2).metadata["sc2ts"]
520+
assert len(sc2ts_md["mutations"]) == 3
521+
for mut_md, mut in zip(sc2ts_md["mutations"], ts.mutations()):
522+
assert mut_md["derived_state"] == mut.derived_state
523+
assert mut_md["site_id"] == mut.site
524+
assert mut_md["site_position"] == ts.sites_position[mut.site]
525+
assert mut_md["inherited_state"] == ts.site(mut.site).ancestral_state
526+
assert sc2ts_md["path"] == [{"left": 0, "parent": 1, "right": 29904}]
527+
assert sc2ts_md["qc"] == {
528+
"num_masked_sites": 133,
529+
"original_base_composition": {
530+
"A": 8894,
531+
"C": 5472,
532+
"G": 5850,
533+
"N": 121,
534+
"T": 9566,
535+
},
536+
"original_md5": "e96feaa72c4f4baba73c2e147ede7502",
537+
}

0 commit comments

Comments
 (0)