|
| 1 | +from itertools import combinations |
1 | 2 | from pathlib import Path |
| 3 | + |
2 | 4 | import pandas as pd |
3 | 5 |
|
4 | 6 | import bib_dedupe.cluster |
5 | | -from bib_dedupe.bib_dedupe import block, match, merge, prep |
| 7 | +from bib_dedupe.bib_dedupe import block |
| 8 | +from bib_dedupe.bib_dedupe import match |
| 9 | +from bib_dedupe.bib_dedupe import prep |
6 | 10 |
|
7 | 11 |
|
8 | 12 | def _normalize_groups(groups) -> set[frozenset[str]]: |
@@ -35,31 +39,64 @@ def _load_expected_groups_from_csv(csv_path: Path) -> set[frozenset[str]]: |
35 | 39 | return _normalize_groups(expected_groups) |
36 | 40 |
|
37 | 41 |
|
38 | | -def test_full() -> None: |
39 | | - benchmark_path = Path("tests/data") |
40 | | - print(f"Dataset: {benchmark_path}") |
| 42 | +def _pairs_from_groups(groups: set[frozenset[str]]) -> set[frozenset[str]]: |
| 43 | + """Expand groups into unordered 2-item pairs.""" |
| 44 | + pairs: set[frozenset[str]] = set() |
| 45 | + for g in groups: |
| 46 | + for a, b in combinations(g, 2): |
| 47 | + pairs.add(frozenset((a, b))) |
| 48 | + return pairs |
| 49 | + |
| 50 | + |
| 51 | +# TODO : replace by core once ready? |
| 52 | +def test_full_benchmark() -> None: |
| 53 | + benchmark_dir = Path("tests/ldd-full-benchmark") |
| 54 | + for benchmark_path in sorted( |
| 55 | + p for p in benchmark_dir.iterdir() if p.is_dir() and p.name != ".git" |
| 56 | + ): |
| 57 | + # if benchmark_path.name not in ["srsrs", "depression"]: |
| 58 | + # continue |
| 59 | + |
| 60 | + print(f"Dataset: {benchmark_path}") |
| 61 | + if benchmark_path.name == "depression": |
| 62 | + # Load all matching parts (e.g., records_pre_merged_part1.csv, part2.csv) |
| 63 | + part_paths = sorted(benchmark_path.glob("records_pre_merged_part*.csv")) |
| 64 | + if not part_paths: |
| 65 | + raise FileNotFoundError( |
| 66 | + "No part files found for 'depression' in " f"{benchmark_path}" |
| 67 | + ) |
| 68 | + records_df = pd.concat( |
| 69 | + (pd.read_csv(p) for p in part_paths), ignore_index=True |
| 70 | + ) |
41 | 71 |
|
42 | | - records_df = pd.read_csv(benchmark_path / "records_pre_merged.csv") |
43 | | - records_df = prep(records_df) |
44 | | - actual_blocked_df = block(records_df=records_df) |
45 | | - matched_df = match(actual_blocked_df) |
| 72 | + else: |
| 73 | + records_df = pd.read_csv(benchmark_path / "records_pre_merged.csv") |
| 74 | + records_df = prep(records_df) |
| 75 | + actual_blocked_df = block(records_df=records_df) |
| 76 | + matched_df = match(actual_blocked_df) |
46 | 77 |
|
47 | | - # Get connected components and keep only true duplicate groups (>1) |
48 | | - duplicate_id_sets = bib_dedupe.cluster.get_connected_components(matched_df) |
49 | | - print("Detected duplicate groups:", duplicate_id_sets) |
| 78 | + # Get connected components and keep only duplicate groups (>1) |
| 79 | + duplicate_id_sets = bib_dedupe.cluster.get_connected_components(matched_df) |
| 80 | + detected_groups = _normalize_groups(g for g in duplicate_id_sets if len(g) > 1) |
50 | 81 |
|
51 | | - detected_groups = _normalize_groups( |
52 | | - g for g in duplicate_id_sets if len(g) > 1 |
53 | | - ) |
| 82 | + expected_groups = _load_expected_groups_from_csv( |
| 83 | + benchmark_path / "merged_record_ids.csv" |
| 84 | + ) |
54 | 85 |
|
55 | | - expected_groups = _load_expected_groups_from_csv( |
56 | | - benchmark_path / "merged_record_ids.csv" |
57 | | - ) |
| 86 | + # Compare on pairs to catch only false positives |
| 87 | + detected_pairs = _pairs_from_groups(detected_groups) |
| 88 | + expected_pairs = _pairs_from_groups(expected_groups) |
58 | 89 |
|
59 | | - # Direct set equality: order within groups and across groups does not matter |
60 | | - assert detected_groups == expected_groups, ( |
61 | | - "Mismatch in duplicate groups.\n" |
62 | | - f"Only-in-expected: {expected_groups - detected_groups}\n" |
63 | | - f"Only-in-detected: {detected_groups - expected_groups}" |
64 | | - ) |
| 90 | + false_positives = detected_pairs - expected_pairs |
| 91 | + if false_positives: |
| 92 | + # Optional: show which detected groups contain the offending pairs |
| 93 | + fp_examples = sorted([tuple(sorted(p)) for p in false_positives])[:20] |
| 94 | + raise AssertionError( |
| 95 | + "False positives: merged IDs not in the same expected group.\n" |
| 96 | + f"Count: {len(false_positives)}\n" |
| 97 | + f"Examples (up to 20): {fp_examples}" |
| 98 | + ) |
65 | 99 |
|
| 100 | + # If you still want visibility into misses without failing: |
| 101 | + # false_negatives = expected_pairs - detected_pairs |
| 102 | + # print(f"Missed pairs (not failing): {len(false_negatives)}") |
0 commit comments