Skip to content

Commit 4ffd540

Browse files
authored
External full tests based on ldd-full-benchmark (#48)
* full-tests with ext ldd-full-benchmark * test workflow: clone submodule * split unit tests in CI * fix * rename * improve diagnostics for full-tets
1 parent 1120346 commit 4ffd540

File tree

4 files changed

+209
-5
lines changed

4 files changed

+209
-5
lines changed

.github/workflows/tests.yml

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,57 @@
11
name: Run Tests
22

33
on:
4-
- push
5-
- pull_request
4+
push:
5+
pull_request:
66

77
jobs:
8-
test-full-deps:
8+
test:
9+
name: Quick tests (${{ matrix.platform }}, py${{ matrix.python-version }})
910
strategy:
1011
matrix:
1112
platform: [ubuntu-latest, windows-latest]
1213
python-version: ['3.10', '3.11', '3.12', '3.13']
1314
runs-on: ${{ matrix.platform }}
1415
steps:
1516
- uses: actions/checkout@v4
17+
with:
18+
submodules: recursive
19+
fetch-depth: 0
20+
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v4
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
26+
- name: Install uv and dependencies
27+
run: |
28+
pip install uv
29+
uv venv
30+
uv pip install -e .[dev] || echo "No dev extra"
31+
echo "Dependencies installed successfully"
32+
33+
- name: Setup git
34+
run: |
35+
git config --global user.name "CoLRev update"
36+
git config --global user.email "actions@users.noreply.github.com"
37+
git config --global url.https://github.com/.insteadOf git://github.com/
38+
39+
- name: Run tests (excluding full_test.py)
40+
run: uv run pytest --ignore=tests/full_test.py
41+
42+
full-test:
43+
name: Full test (${{ matrix.platform }}, py${{ matrix.python-version }})
44+
needs: test
45+
strategy:
46+
matrix:
47+
platform: [ubuntu-latest, windows-latest]
48+
python-version: ['3.10', '3.11', '3.12', '3.13']
49+
runs-on: ${{ matrix.platform }}
50+
steps:
51+
- uses: actions/checkout@v4
52+
with:
53+
submodules: recursive
54+
fetch-depth: 0
1655

1756
- name: Set up Python ${{ matrix.python-version }}
1857
uses: actions/setup-python@v4
@@ -32,5 +71,5 @@ jobs:
3271
git config --global user.email "actions@users.noreply.github.com"
3372
git config --global url.https://github.com/.insteadOf git://github.com/
3473
35-
- name: Run tests
36-
run: uv run pytest
74+
- name: Run full_test.py
75+
run: uv run pytest -q tests/full_test.py

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "tests/ldd-full-benchmark"]
2+
path = tests/ldd-full-benchmark
3+
url = git@github.com:CoLRev-Environment/ldd-full-benchmark.git

tests/full_test.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from itertools import combinations
2+
from pathlib import Path
3+
4+
import pandas as pd
5+
import pytest
6+
7+
import bib_dedupe.cluster
8+
from bib_dedupe.bib_dedupe import block
9+
from bib_dedupe.bib_dedupe import match
10+
from bib_dedupe.bib_dedupe import prep
11+
12+
13+
ID_COL = "ID"
14+
15+
BENCHMARK_DIR = Path("tests/ldd-full-benchmark")
16+
17+
MAX_FP_CASES_TO_PRINT = 50
18+
MAX_FP_DIAGNOSTICS = 20
19+
20+
def _normalize_groups(groups) -> set[frozenset[str]]:
21+
"""Convert an iterable of iterables to a set of frozensets of trimmed strings."""
22+
norm = set()
23+
for g in groups:
24+
parts = {str(x).strip() for x in g if str(x).strip()}
25+
if parts:
26+
norm.add(frozenset(parts))
27+
return norm
28+
29+
30+
def _load_expected_groups_from_csv(csv_path: Path) -> set[frozenset[str]]:
31+
"""
32+
Read merged_record_ids.csv with a single column 'merged_ids',
33+
where each row is a semicolon-delimited group (e.g., 'id_a;id_b;id_c').
34+
Only groups with length > 1 are considered.
35+
"""
36+
df = pd.read_csv(csv_path)
37+
if "merged_ids" not in df.columns:
38+
raise ValueError("CSV must contain a 'merged_ids' column.")
39+
40+
expected_groups = []
41+
for s in df["merged_ids"].astype(str):
42+
parts = [p.strip() for p in s.split(";") if p.strip()]
43+
if len(parts) > 1:
44+
expected_groups.append(parts)
45+
46+
return _normalize_groups(expected_groups)
47+
48+
49+
def _pairs_from_groups(groups: set[frozenset[str]]) -> set[frozenset[str]]:
50+
"""Expand groups into unordered 2-item pairs."""
51+
pairs: set[frozenset[str]] = set()
52+
for g in groups:
53+
for a, b in combinations(g, 2):
54+
pairs.add(frozenset((a, b)))
55+
return pairs
56+
57+
58+
def _benchmark_paths() -> list[Path]:
59+
return sorted(p for p in BENCHMARK_DIR.iterdir() if p.is_dir() and p.name != ".git")
60+
61+
62+
def _load_records_df(benchmark_path: Path) -> pd.DataFrame:
63+
if benchmark_path.name == "depression":
64+
part_paths = sorted(benchmark_path.glob("records_pre_merged_part*.csv"))
65+
if not part_paths:
66+
raise FileNotFoundError(f"No part files found for 'depression' in {benchmark_path}")
67+
return pd.concat((pd.read_csv(p) for p in part_paths), ignore_index=True)
68+
69+
return pd.read_csv(benchmark_path / "records_pre_merged.csv")
70+
71+
72+
def _assert_id_column(records_df: pd.DataFrame) -> None:
73+
if ID_COL not in records_df.columns:
74+
raise KeyError(
75+
f"Expected an ID column '{ID_COL}' in records_df. "
76+
f"Available columns: {sorted(records_df.columns)}"
77+
)
78+
79+
80+
def _rerun_pair_with_diagnostics(records_df: pd.DataFrame, a: str, b: str) -> None:
81+
"""Run the pipeline on just the two records with verbosity_level=2."""
82+
_assert_id_column(records_df)
83+
84+
subset = records_df[records_df[ID_COL].astype(str).isin([a, b])].copy()
85+
print("\n" + "=" * 80)
86+
print(f"DIAGNOSTICS for false-positive pair: ({a}, {b})")
87+
print(f"Selected rows: {len(subset)}")
88+
89+
missing = {a, b} - set(subset[ID_COL].astype(str))
90+
if missing:
91+
print(f"WARNING: could not find IDs in records_df: {sorted(missing)}")
92+
print("=" * 80 + "\n")
93+
return
94+
95+
# Rerun with high verbosity
96+
subset = prep(subset, verbosity_level=2)
97+
blocked = block(records_df=subset, verbosity_level=2)
98+
matched = match(blocked, verbosity_level=2)
99+
100+
print("Matched rows (top):")
101+
try:
102+
print(matched.head(50).to_string(index=False))
103+
except Exception:
104+
print(matched.head(50))
105+
print("=" * 80 + "\n")
106+
107+
108+
@pytest.mark.parametrize("benchmark_path", _benchmark_paths(), ids=lambda p: p.name)
109+
def test_full_benchmark(benchmark_path: Path) -> None:
110+
print(f"Dataset: {benchmark_path}")
111+
112+
113+
try:
114+
records_df = _load_records_df(benchmark_path)
115+
_assert_id_column(records_df)
116+
117+
records_df_prepped = prep(records_df, verbosity_level=0)
118+
actual_blocked_df = block(records_df=records_df_prepped, verbosity_level=0)
119+
matched_df = match(actual_blocked_df, verbosity_level=0)
120+
121+
duplicate_id_sets = bib_dedupe.cluster.get_connected_components(matched_df)
122+
detected_groups = _normalize_groups(g for g in duplicate_id_sets if len(g) > 1)
123+
124+
expected_groups = _load_expected_groups_from_csv(benchmark_path / "merged_record_ids.csv")
125+
126+
detected_pairs = _pairs_from_groups(detected_groups)
127+
expected_pairs = _pairs_from_groups(expected_groups)
128+
129+
false_positives = detected_pairs - expected_pairs
130+
if false_positives:
131+
fp_examples = sorted([tuple(sorted(p)) for p in false_positives])[:20]
132+
133+
# rerun diagnostics for (up to) the first 20 false-positive pairs
134+
for a, b in fp_examples:
135+
_rerun_pair_with_diagnostics(records_df, a, b)
136+
137+
raise AssertionError(
138+
"False positives: merged IDs not in the same expected group.\n"
139+
f"Count: {len(false_positives)}\n"
140+
f"Examples (up to 20): {fp_examples}\n"
141+
"Diagnostics printed above for the shown examples."
142+
)
143+
144+
finally:
145+
# ---- always print summary stats ----
146+
fp_cases = sorted([tuple(sorted(p)) for p in false_positives])[:MAX_FP_CASES_TO_PRINT]
147+
148+
print("\n" + "-" * 80)
149+
print(f"SUMMARY for dataset: {benchmark_path.name}")
150+
print(f"Detected pairs: {len(detected_pairs)}")
151+
print(f"Expected pairs: {len(expected_pairs)}")
152+
print(f"False positives: {len(false_positives)}")
153+
154+
if fp_cases:
155+
print(f"False-positive cases (up to {MAX_FP_CASES_TO_PRINT}):")
156+
for a, b in fp_cases:
157+
print(f" - {a} <> {b}")
158+
else:
159+
print("False-positive cases: none")
160+
161+
print("-" * 80 + "\n")

tests/ldd-full-benchmark

Submodule ldd-full-benchmark added at 48e8cba

0 commit comments

Comments
 (0)