1+ from itertools import combinations
2+ from pathlib import Path
3+
4+ import pandas as pd
5+ import pytest
6+
7+ import bib_dedupe .cluster
8+ from bib_dedupe .bib_dedupe import block
9+ from bib_dedupe .bib_dedupe import match
10+ from bib_dedupe .bib_dedupe import prep
11+
12+
13+ ID_COL = "ID"
14+
15+ BENCHMARK_DIR = Path ("tests/ldd-full-benchmark" )
16+
17+ MAX_FP_CASES_TO_PRINT = 50
18+ MAX_FP_DIAGNOSTICS = 20
19+
20+ def _normalize_groups (groups ) -> set [frozenset [str ]]:
21+ """Convert an iterable of iterables to a set of frozensets of trimmed strings."""
22+ norm = set ()
23+ for g in groups :
24+ parts = {str (x ).strip () for x in g if str (x ).strip ()}
25+ if parts :
26+ norm .add (frozenset (parts ))
27+ return norm
28+
29+
30+ def _load_expected_groups_from_csv (csv_path : Path ) -> set [frozenset [str ]]:
31+ """
32+ Read merged_record_ids.csv with a single column 'merged_ids',
33+ where each row is a semicolon-delimited group (e.g., 'id_a;id_b;id_c').
34+ Only groups with length > 1 are considered.
35+ """
36+ df = pd .read_csv (csv_path )
37+ if "merged_ids" not in df .columns :
38+ raise ValueError ("CSV must contain a 'merged_ids' column." )
39+
40+ expected_groups = []
41+ for s in df ["merged_ids" ].astype (str ):
42+ parts = [p .strip () for p in s .split (";" ) if p .strip ()]
43+ if len (parts ) > 1 :
44+ expected_groups .append (parts )
45+
46+ return _normalize_groups (expected_groups )
47+
48+
49+ def _pairs_from_groups (groups : set [frozenset [str ]]) -> set [frozenset [str ]]:
50+ """Expand groups into unordered 2-item pairs."""
51+ pairs : set [frozenset [str ]] = set ()
52+ for g in groups :
53+ for a , b in combinations (g , 2 ):
54+ pairs .add (frozenset ((a , b )))
55+ return pairs
56+
57+
58+ def _benchmark_paths () -> list [Path ]:
59+ return sorted (p for p in BENCHMARK_DIR .iterdir () if p .is_dir () and p .name != ".git" )
60+
61+
62+ def _load_records_df (benchmark_path : Path ) -> pd .DataFrame :
63+ if benchmark_path .name == "depression" :
64+ part_paths = sorted (benchmark_path .glob ("records_pre_merged_part*.csv" ))
65+ if not part_paths :
66+ raise FileNotFoundError (f"No part files found for 'depression' in { benchmark_path } " )
67+ return pd .concat ((pd .read_csv (p ) for p in part_paths ), ignore_index = True )
68+
69+ return pd .read_csv (benchmark_path / "records_pre_merged.csv" )
70+
71+
72+ def _assert_id_column (records_df : pd .DataFrame ) -> None :
73+ if ID_COL not in records_df .columns :
74+ raise KeyError (
75+ f"Expected an ID column '{ ID_COL } ' in records_df. "
76+ f"Available columns: { sorted (records_df .columns )} "
77+ )
78+
79+
80+ def _rerun_pair_with_diagnostics (records_df : pd .DataFrame , a : str , b : str ) -> None :
81+ """Run the pipeline on just the two records with verbosity_level=2."""
82+ _assert_id_column (records_df )
83+
84+ subset = records_df [records_df [ID_COL ].astype (str ).isin ([a , b ])].copy ()
85+ print ("\n " + "=" * 80 )
86+ print (f"DIAGNOSTICS for false-positive pair: ({ a } , { b } )" )
87+ print (f"Selected rows: { len (subset )} " )
88+
89+ missing = {a , b } - set (subset [ID_COL ].astype (str ))
90+ if missing :
91+ print (f"WARNING: could not find IDs in records_df: { sorted (missing )} " )
92+ print ("=" * 80 + "\n " )
93+ return
94+
95+ # Rerun with high verbosity
96+ subset = prep (subset , verbosity_level = 2 )
97+ blocked = block (records_df = subset , verbosity_level = 2 )
98+ matched = match (blocked , verbosity_level = 2 )
99+
100+ print ("Matched rows (top):" )
101+ try :
102+ print (matched .head (50 ).to_string (index = False ))
103+ except Exception :
104+ print (matched .head (50 ))
105+ print ("=" * 80 + "\n " )
106+
107+
108+ @pytest .mark .parametrize ("benchmark_path" , _benchmark_paths (), ids = lambda p : p .name )
109+ def test_full_benchmark (benchmark_path : Path ) -> None :
110+ print (f"Dataset: { benchmark_path } " )
111+
112+
113+ try :
114+ records_df = _load_records_df (benchmark_path )
115+ _assert_id_column (records_df )
116+
117+ records_df_prepped = prep (records_df , verbosity_level = 0 )
118+ actual_blocked_df = block (records_df = records_df_prepped , verbosity_level = 0 )
119+ matched_df = match (actual_blocked_df , verbosity_level = 0 )
120+
121+ duplicate_id_sets = bib_dedupe .cluster .get_connected_components (matched_df )
122+ detected_groups = _normalize_groups (g for g in duplicate_id_sets if len (g ) > 1 )
123+
124+ expected_groups = _load_expected_groups_from_csv (benchmark_path / "merged_record_ids.csv" )
125+
126+ detected_pairs = _pairs_from_groups (detected_groups )
127+ expected_pairs = _pairs_from_groups (expected_groups )
128+
129+ false_positives = detected_pairs - expected_pairs
130+ if false_positives :
131+ fp_examples = sorted ([tuple (sorted (p )) for p in false_positives ])[:20 ]
132+
133+ # rerun diagnostics for (up to) the first 20 false-positive pairs
134+ for a , b in fp_examples :
135+ _rerun_pair_with_diagnostics (records_df , a , b )
136+
137+ raise AssertionError (
138+ "False positives: merged IDs not in the same expected group.\n "
139+ f"Count: { len (false_positives )} \n "
140+ f"Examples (up to 20): { fp_examples } \n "
141+ "Diagnostics printed above for the shown examples."
142+ )
143+
144+ finally :
145+ # ---- always print summary stats ----
146+ fp_cases = sorted ([tuple (sorted (p )) for p in false_positives ])[:MAX_FP_CASES_TO_PRINT ]
147+
148+ print ("\n " + "-" * 80 )
149+ print (f"SUMMARY for dataset: { benchmark_path .name } " )
150+ print (f"Detected pairs: { len (detected_pairs )} " )
151+ print (f"Expected pairs: { len (expected_pairs )} " )
152+ print (f"False positives: { len (false_positives )} " )
153+
154+ if fp_cases :
155+ print (f"False-positive cases (up to { MAX_FP_CASES_TO_PRINT } ):" )
156+ for a , b in fp_cases :
157+ print (f" - { a } <> { b } " )
158+ else :
159+ print ("False-positive cases: none" )
160+
161+ print ("-" * 80 + "\n " )
0 commit comments