Skip to content

Commit 0691436

Browse files
Tests for additional problematic sites
1 parent 0e9f03d commit 0691436

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

sc2ts/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def initialise(ts, match_db, additional_problematic_sites, verbose, log_file):
201201
additional_problematic = []
202202
if additional_problematic_sites is not None:
203203
additional_problematic = (
204-
np.loadtxt(additional_problematic_sites).astype(int).tolist()
204+
np.loadtxt(additional_problematic_sites, ndmin=1).astype(int).tolist()
205205
)
206206
logger.info(
207207
f"Excluding additional {len(additional_problematic)} problematic sites"

tests/test_cli.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
import numpy as np
22
import click.testing as ct
33
import pytest
44
import tskit
@@ -7,8 +7,8 @@
77
from sc2ts import __main__ as main
88
from sc2ts import cli
99

10-
class TestInitialise:
1110

11+
class TestInitialise:
1212
def test_defaults(self, tmp_path):
1313
ts_path = tmp_path / "trees.ts"
1414
match_db_path = tmp_path / "match.db"
@@ -24,3 +24,23 @@ def test_defaults(self, tmp_path):
2424
other_ts.tables.assert_equals(ts.tables)
2525
match_db = sc2ts.MatchDb(match_db_path)
2626
assert len(match_db) == 0
27+
28+
@pytest.mark.parametrize("additional", [[100], [100, 200]])
29+
def test_additional_problematic_sites(self, tmp_path, additional):
30+
ts_path = tmp_path / "trees.ts"
31+
match_db_path = tmp_path / "match.db"
32+
problematic_path = tmp_path / "additional_problematic.txt"
33+
np.savetxt(problematic_path, additional)
34+
runner = ct.CliRunner(mix_stderr=False)
35+
result = runner.invoke(
36+
cli.cli,
37+
f"initialise {ts_path} {match_db_path} "
38+
f"--additional-problematic-sites {problematic_path}",
39+
catch_exceptions=False,
40+
)
41+
assert result.exit_code == 0
42+
ts = tskit.load(ts_path)
43+
other_ts = sc2ts.initial_ts(additional_problematic_sites=additional)
44+
other_ts.tables.assert_equals(ts.tables)
45+
match_db = sc2ts.MatchDb(match_db_path)
46+
assert len(match_db) == 0

0 commit comments

Comments
 (0)