Skip to content

Commit f96bd69

Browse files
Mike LeeMike Lee
authored andcommitted
more tests
1 parent a9541d6 commit f96bd69

File tree

4 files changed

+303
-2
lines changed

4 files changed

+303
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
### Added
1818
- more test coverage of `bit-ez-screen`
19-
- unit tests for `bit-gen-kraken2-tax-plots`
19+
- unit tests for `bit-gen-kraken2-tax-plots` and `bit-kraken2-to-taxon-summaries`
2020
- integration test for `bit-cov-analyzer`
2121

2222
### Changed

bit/tests/test_gen_kraken2_tax_plots.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import io
21
from pathlib import Path
32
import pandas as pd
43
import pytest

bit/tests/test_input_parsing.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import pytest
2+
from bit.modules import input_parsing as m
3+
4+
5+
class PrematureExit(Exception):
6+
"""Raised by our monkeypatched notify_premature_exit to test failure paths."""
7+
8+
@pytest.fixture(autouse=True)
9+
def patch_failure_hooks(monkeypatch):
10+
# capture messages sent to report_message; raise on premature exit
11+
messages = []
12+
13+
def fake_report_message(msg, initial_indent="", subsequent_indent=""):
14+
messages.append(("report", msg, initial_indent, subsequent_indent))
15+
16+
def fake_notify_premature_exit():
17+
raise PrematureExit("premature-exit")
18+
19+
monkeypatch.setattr(m, "report_message", fake_report_message)
20+
monkeypatch.setattr(m, "notify_premature_exit", fake_notify_premature_exit)
21+
22+
return messages
23+
24+
25+
@pytest.mark.parametrize("fname,expected", [
26+
("sample_R1_.fastq.gz", ("sample", "R1")),
27+
("sample_R2_.fastq.gz", ("sample", "R2")),
28+
("sample-R1-.fq.gz", ("sample", "R1")),
29+
("sample-R2-.fq.gz", ("sample", "R2")),
30+
("sample.R1..fastq.gz", ("sample", "R1")),
31+
("sample.R2..fastq.gz", ("sample", "R2")),
32+
("sample_1..fastq.gz", ("sample", "R1")),
33+
("sample_2..fastq.gz", ("sample", "R2")),
34+
("no_tag.fastq.gz", (None, None)),
35+
])
36+
def test_parse_read_filename_variants(fname, expected):
37+
assert m.parse_read_filename(fname) == expected
38+
39+
40+
def test_validate_extension_accepts_known_extensions(tmp_path):
41+
for ext in m.accepted_read_extensions:
42+
p = tmp_path / f"x{ext}"
43+
p.write_text("data")
44+
m.validate_extension(p)
45+
46+
def test_validate_extension_rejects_unknown_extension(tmp_path):
47+
p = tmp_path / "x.fq" # not gzipped → reject
48+
p.write_text("data")
49+
with pytest.raises(PrematureExit):
50+
m.validate_extension(p)
51+
52+
53+
def test_get_input_reads_dict_from_paths_happy_path(tmp_path):
54+
r1 = tmp_path / "samp_R1_.fastq.gz"
55+
r2 = tmp_path / "samp_R2_.fastq.gz"
56+
r1.write_text("r1")
57+
r2.write_text("r2")
58+
59+
out = m.get_input_reads_dict_from_paths(r1, r2)
60+
assert list(out.keys()) == ["samp"]
61+
assert set(out["samp"].keys()) == {"R1", "R2"}
62+
assert out["samp"]["R1"] == str(r1.resolve())
63+
assert out["samp"]["R2"] == str(r2.resolve())
64+
65+
def test_get_input_reads_dict_from_paths_missing_designation_calls_exit(tmp_path):
66+
bad = tmp_path / "samp.fastq.gz" # no R1/R2 tag
67+
bad.write_text("x")
68+
with pytest.raises(PrematureExit):
69+
m.get_input_reads_dict_from_paths(bad, None)
70+
71+
72+
def test_get_input_reads_dict_from_paths_wrong_slot_calls_exit(tmp_path):
73+
# file says R2 but provided as R1 argument
74+
wrong = tmp_path / "samp_R2_.fastq.gz"
75+
wrong.write_text("x")
76+
with pytest.raises(PrematureExit):
77+
m.get_input_reads_dict_from_paths(wrong, None)
78+
79+
80+
def test_get_input_reads_dict_from_dir_pairs_samples_and_ignores_noise(tmp_path):
81+
# A complete pair
82+
(tmp_path / "A_R1_.fastq.gz").write_text("a1")
83+
(tmp_path / "A_R2_.fastq.gz").write_text("a2")
84+
# B has extra unrelated files that should be ignored
85+
(tmp_path / "notes.txt").write_text("ignore me")
86+
(tmp_path / "weird.fq").write_text("ignore me") # bad ext
87+
(tmp_path / "junk.fastq.gz").write_text("no R tag") # will be parsed as None,None and skipped
88+
89+
out = m.get_input_reads_dict_from_dir(tmp_path)
90+
assert list(out.keys()) == ["A"]
91+
assert set(out["A"].keys()) == {"R1", "R2"}
92+
93+
def test_get_input_reads_dict_from_dir_detects_incomplete_pairs_and_exits(tmp_path, patch_failure_hooks):
94+
# Only R1 present for B → should trigger error
95+
(tmp_path / "B_R1_.fastq.gz").write_text("b1")
96+
with pytest.raises(PrematureExit):
97+
m.get_input_reads_dict_from_dir(tmp_path)
98+
99+
# Ensure a diagnostic report_message was sent
100+
reports = [x for x in patch_failure_hooks if x[0] == "report"]
101+
assert reports, "expected report_message() to be called"
102+
# The message mentions the input directory
103+
assert str(tmp_path) in reports[-1][1]
104+
105+
def test_get_input_reads_dict_from_dir_handles_multiple_samples(tmp_path):
106+
# C paired
107+
(tmp_path / "C_R1_.fastq.gz").write_text("c1")
108+
(tmp_path / "C_R2_.fastq.gz").write_text("c2")
109+
# D paired with different accepted tags
110+
(tmp_path / "D-R1-.fq.gz").write_text("d1")
111+
(tmp_path / "D-R2-.fq.gz").write_text("d2")
112+
113+
out = m.get_input_reads_dict_from_dir(tmp_path)
114+
assert set(out.keys()) == {"C", "D"}
115+
assert set(out["C"].keys()) == {"R1", "R2"}
116+
assert set(out["D"].keys()) == {"R1", "R2"}
117+
118+
119+
def test_get_input_reads_dict_from_dir_skips_files_with_no_designation(tmp_path):
120+
(tmp_path / "E.fastq.gz").write_text("no tag")
121+
out = m.get_input_reads_dict_from_dir(tmp_path)
122+
assert out == {}
123+
124+
125+
def test_parse_read_filename_uses_basename_not_path(tmp_path):
126+
# Ensure directories in the path don’t confuse parsing
127+
p = tmp_path / "subdir"
128+
p.mkdir()
129+
f = p / "sample.R1..fastq.gz"
130+
f.write_text("x")
131+
# give full path; function uses Path(...).name internally
132+
samp, which = m.parse_read_filename(str(f))
133+
assert samp == "sample" and which == "R1"
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import pandas as pd
2+
import numpy as np
3+
import pytest
4+
from bit.modules import kraken2_to_taxon_summaries as k
5+
6+
7+
@pytest.fixture
8+
def report_min(tmp_path):
9+
"""
10+
Tiny kraken2-like report covering:
11+
- unclassified (U)
12+
- root (R)
13+
- domain given as R1 with Bacteria (to test normalize_rank_code)
14+
- genus lines with spaces around names (to test stripping)
15+
- one non-standard rank '-' which should produce a row but not change lineage
16+
"""
17+
text = "\n".join([
18+
# percent clade_reads taxon_reads rank taxid name(with indent)
19+
"10.00 100 100 U 0 unclassified",
20+
"90.00 900 50 R 1 root",
21+
"80.00 800 0 R1 2 Bacteria",
22+
"25.00 250 250 G 123 GenX", # leading spaces
23+
"15.00 150 150 G 124 GenY ", # trailing spaces
24+
" 0.50 5 5 - 9999 12345_like_strain", # non-standard rank
25+
"", # blank line should be ignored
26+
]) + "\n"
27+
p = tmp_path / "k.report"
28+
p.write_text(text)
29+
return p
30+
31+
32+
def test_parse_report_line_basic_space_split():
33+
line = "12.3 120 12 G 1234 Some Genus"
34+
rec = k.parse_report_line(line)
35+
assert rec == {
36+
"clade_reads": 120,
37+
"taxon_reads": 12,
38+
"rank": "G",
39+
"taxid": 1234,
40+
"name": "Some Genus",
41+
}
42+
43+
def test_parse_report_line_tab_fallback():
44+
line = "12.3\t120\t12\tG\t1234\t Some\tGenus "
45+
rec = k.parse_report_line(line)
46+
assert rec["taxid"] == 1234
47+
assert rec["rank"] == "G"
48+
assert rec["name"] == "Some\tGenus"
49+
50+
def test_normalize_rank_code_domain_R1():
51+
assert k.normalize_rank_code("R1", "Bacteria") == "D"
52+
assert k.normalize_rank_code("R1", "Viruses") == "D"
53+
assert k.normalize_rank_code("R1", "WeirdDomain") == "R1"
54+
assert k.normalize_rank_code("G", "Genus") == "G"
55+
56+
57+
def test_parse_report_builds_lineages(report_min):
58+
df = k.parse_report(str(report_min))
59+
60+
un = df[df["taxid"] == 0].iloc[0]
61+
assert all(un[r] == "Unclassified" for r in k.STD_RANKS)
62+
assert un["read_counts"] == 100
63+
64+
gx = df[df["taxid"] == 123].iloc[0]
65+
gy = df[df["taxid"] == 124].iloc[0]
66+
assert gx["genus"] == "GenX"
67+
assert gy["genus"] == "GenY"
68+
assert gx["domain"] == "Bacteria"
69+
assert gy["domain"] == "Bacteria"
70+
71+
row_dash = df[df["taxid"] == 9999].iloc[0]
72+
assert row_dash["domain"] == "Bacteria"
73+
assert row_dash["domain"] != "Unclassified"
74+
75+
76+
def test_refine_df_aggregates_duplicates_and_computes_percents():
77+
rows = [
78+
{"taxid": 5, "domain":"Bacteria","phylum":"Firmicutes","class":"Bacilli","order":"NA","family":"NA","genus":"GenZ","species":"NA","read_counts":30},
79+
{"taxid": 5, "domain":"Bacteria","phylum":"Firmicutes","class":"Bacilli","order":"NA","family":"NA","genus":"GenZ","species":"NA","read_counts":70},
80+
{"taxid": 0, **{r:"Unclassified" for r in k.STD_RANKS}, "read_counts":100},
81+
]
82+
df = pd.DataFrame(rows)
83+
out = k.refine_df(df.copy())
84+
85+
agg = out[out["taxid"] == 5].iloc[0]
86+
assert agg["read_counts"] == 100
87+
88+
z = out.set_index("taxid")["percent_of_reads"].to_dict()
89+
assert np.isclose(z[5], 50.0, atol=1e-6)
90+
assert np.isclose(z[0], 50.0, atol=1e-6)
91+
92+
def test_refine_df_zero_total_reads():
93+
rows = [
94+
{"taxid": 1, **{r:"NA" for r in k.STD_RANKS}, "read_counts":0},
95+
{"taxid": 2, **{r:"NA" for r in k.STD_RANKS}, "read_counts":0},
96+
]
97+
df = pd.DataFrame(rows)
98+
out = k.refine_df(df.copy())
99+
# zero-read rows are dropped
100+
assert out.empty
101+
102+
103+
def test_sort_df_custom_groups_and_stability():
104+
# building rows to exercise sort groups:
105+
# - Unclassified (taxid=0) should come first
106+
# - All-NA lineage (group 1) second
107+
# - The rest (group 2) sorted by lineage then taxid (mergesort: stable)
108+
rows = [
109+
{"taxid": 3, "domain":"Bacteria","phylum":"Actino","class":"C1","order":"O1","family":"F1","genus":"G1","species":"S1","read_counts":10},
110+
{"taxid": 0, **{r:"Unclassified" for r in k.STD_RANKS}, "read_counts":10},
111+
{"taxid": 2, **{r:"NA" for r in k.STD_RANKS}, "read_counts":10},
112+
{"taxid": 4, "domain":"Archaea","phylum":"Eury","class":"C1","order":"O1","family":"F1","genus":"G1","species":"S1","read_counts":10},
113+
{"taxid": 5, "domain":"Archaea","phylum":"Eury","class":"C1","order":"O1","family":"F1","genus":"G1","species":"S1","read_counts":10},
114+
]
115+
df = pd.DataFrame(rows)
116+
out = k.sort_df(df.copy())
117+
118+
# expected order by sort_group then lineage then taxid:
119+
# 0 first, then 2, then archaea 4,5 (taxid ascending), then bacteria 3
120+
assert list(out["taxid"]) == [0, 2, 4, 5, 3]
121+
122+
123+
def test_kraken2_to_taxon_summaries_writes_tsv(tmp_path, report_min, monkeypatch):
124+
# avoid touching the filesystem checker in unit tests
125+
monkeypatch.setattr(k, "check_files_are_found", lambda paths: None)
126+
127+
out = tmp_path / "summary.tsv"
128+
k.kraken2_to_taxon_summaries(str(report_min), str(out))
129+
130+
assert out.exists()
131+
132+
df = pd.read_csv(out, sep="\t")
133+
expected_cols = ["taxid"] + k.STD_RANKS + ["read_counts", "percent_of_reads"]
134+
assert list(df.columns) == expected_cols
135+
assert (df["taxid"] == 0).any()
136+
assert df["percent_of_reads"].map(lambda x: isinstance(x, float)).all()
137+
138+
139+
def test_parse_report_ignores_blank_and_handles_U_and_R1(tmp_path):
140+
text = "\n".join([
141+
" ", # blank
142+
"5.0 50 50 U 0 unclassified",
143+
"95.0 950 0 R1 2 Bacteria",
144+
"20.0 200 200 G 9 GenA",
145+
])
146+
p = tmp_path / "mini.report"
147+
p.write_text(text)
148+
149+
df = k.parse_report(str(p))
150+
assert (df["taxid"] == 0).any()
151+
gen = df[df["taxid"] == 9].iloc[0]
152+
assert gen["domain"] == "Bacteria"
153+
154+
155+
def test_preflight_checks_calls_validator(monkeypatch):
156+
called = {}
157+
monkeypatch.setattr(k, "check_files_are_found", lambda paths: called.setdefault("ok", paths))
158+
k.preflight_checks("abc.txt")
159+
assert "ok" in called and called["ok"] == ["abc.txt"]
160+
161+
162+
def test_parse_report_line_bad_line_uses_report_failure(monkeypatch):
163+
msgs = {}
164+
def fake_report_failure(msg):
165+
raise ValueError(msg)
166+
monkeypatch.setattr(k, "report_failure", fake_report_failure)
167+
168+
with pytest.raises(ValueError):
169+
k.parse_report_line("not enough fields")

0 commit comments

Comments
 (0)