1
-
1
+ import numpy as np
2
2
import click .testing as ct
3
3
import pytest
4
4
import tskit
7
7
from sc2ts import __main__ as main
8
8
from sc2ts import cli
9
9
10
- class TestInitialise :
11
10
11
+ class TestInitialise :
12
12
def test_defaults (self , tmp_path ):
13
13
ts_path = tmp_path / "trees.ts"
14
14
match_db_path = tmp_path / "match.db"
@@ -24,3 +24,23 @@ def test_defaults(self, tmp_path):
24
24
other_ts .tables .assert_equals (ts .tables )
25
25
match_db = sc2ts .MatchDb (match_db_path )
26
26
assert len (match_db ) == 0
27
+
28
+ @pytest .mark .parametrize ("additional" , [[100 ], [100 , 200 ]])
29
+ def test_additional_problematic_sites (self , tmp_path , additional ):
30
+ ts_path = tmp_path / "trees.ts"
31
+ match_db_path = tmp_path / "match.db"
32
+ problematic_path = tmp_path / "additional_problematic.txt"
33
+ np .savetxt (problematic_path , additional )
34
+ runner = ct .CliRunner (mix_stderr = False )
35
+ result = runner .invoke (
36
+ cli .cli ,
37
+ f"initialise { ts_path } { match_db_path } "
38
+ f"--additional-problematic-sites { problematic_path } " ,
39
+ catch_exceptions = False ,
40
+ )
41
+ assert result .exit_code == 0
42
+ ts = tskit .load (ts_path )
43
+ other_ts = sc2ts .initial_ts (additional_problematic_sites = additional )
44
+ other_ts .tables .assert_equals (ts .tables )
45
+ match_db = sc2ts .MatchDb (match_db_path )
46
+ assert len (match_db ) == 0
0 commit comments