|
| 1 | +from collections import defaultdict |
1 | 2 | from collections.abc import Sequence |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +from frozendict import frozendict |
2 | 6 |
|
3 | 7 | from semhash.datamodels import DeduplicationResult, DuplicateRecord |
| 8 | +from semhash.utils import DatasetLike, Record, coerce_value, to_frozendict |
| 9 | + |
| 10 | + |
| 11 | +def group_records_by_key( |
| 12 | + records: Sequence[dict[str, Any]], |
| 13 | + columns: Sequence[str], |
| 14 | +) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]]]: |
| 15 | + """ |
| 16 | + Group records by exact match on columns, preserving first-occurrence order. |
| 17 | +
|
| 18 | + :param records: Records to group. |
| 19 | + :param columns: Columns to use as grouping key. |
| 20 | + :return: Tuple of (deduplicated_records, items) where: |
| 21 | + - deduplicated_records: first record from each unique group |
| 22 | + - items: list of groups, each group is a list of exact duplicates |
| 23 | + """ |
| 24 | + buckets: dict[frozendict[str, Any], list[dict[str, Any]]] = {} |
| 25 | + order: list[frozendict[str, Any]] = [] |
| 26 | + |
| 27 | + for r in records: |
| 28 | + key = to_frozendict(r, columns) |
| 29 | + bucket = buckets.get(key) |
| 30 | + if bucket is None: |
| 31 | + buckets[key] = [r] |
| 32 | + order.append(key) |
| 33 | + else: |
| 34 | + bucket.append(r) |
| 35 | + |
| 36 | + items = [buckets[k] for k in order] |
| 37 | + deduplicated_records = [bucket[0] for bucket in items] |
| 38 | + return deduplicated_records, items |
| 39 | + |
| 40 | + |
| 41 | +def remove_exact_duplicates( |
| 42 | + records: Sequence[dict[str, Any]], |
| 43 | + columns: Sequence[str], |
| 44 | + reference_records: list[list[dict[str, Any]]] | None = None, |
| 45 | +) -> tuple[list[dict[str, Any]], list[tuple[dict[str, Any], list[dict[str, Any]]]]]: |
| 46 | + """ |
| 47 | + Remove exact duplicates based on the hashable representation of each record. |
| 48 | +
|
| 49 | + If reference_records is None, the function will only check for duplicates within the records list. |
| 50 | +
|
| 51 | + :param records: A list of records to check for exact duplicates. |
| 52 | + :param columns: Columns to unpack. |
| 53 | + :param reference_records: A list of records to compare against. These are already unpacked |
| 54 | + :return: A list of deduplicated records and a list of duplicates. |
| 55 | + """ |
| 56 | + deduplicated: list[dict[str, Any]] = [] |
| 57 | + duplicates: list[tuple[dict[str, Any], list[dict[str, Any]]]] = [] |
| 58 | + |
| 59 | + column_set = set(columns) |
| 60 | + # Build a seen set from reference_records if provided |
| 61 | + seen: defaultdict[frozendict[str, Any], list[dict[str, Any]]] = defaultdict(list) |
| 62 | + if reference_records is not None: |
| 63 | + for record_set in reference_records: |
| 64 | + key = to_frozendict(record_set[0], column_set) |
| 65 | + seen[key] = list(record_set) |
| 66 | + in_one_set = reference_records is None |
| 67 | + |
| 68 | + for record in records: |
| 69 | + frozen_record = to_frozendict(record, column_set) |
| 70 | + if duplicated_records := seen.get(frozen_record): |
| 71 | + duplicates.append((record, duplicated_records)) |
| 72 | + else: |
| 73 | + deduplicated.append(record) |
| 74 | + # Only add current documents to seen if no reference set is used |
| 75 | + if in_one_set: |
| 76 | + seen[frozen_record].append(record) |
| 77 | + |
| 78 | + return deduplicated, duplicates |
| 79 | + |
| 80 | + |
| 81 | +def prepare_records( |
| 82 | + records: Sequence[Record], columns: Sequence[str] | None |
| 83 | +) -> tuple[list[dict[str, Any]], Sequence[str], bool]: |
| 84 | + """ |
| 85 | + Validate and prepare records for processing. |
| 86 | +
|
| 87 | + :param records: A list of records (strings or dictionaries). |
| 88 | + :param columns: Columns to use if records are dictionaries. |
| 89 | + :return: Tuple of (dict_records, columns, was_string). |
| 90 | + :raises ValueError: If records are empty. |
| 91 | + :raises ValueError: If columns are not provided for dictionary records. |
| 92 | + :raises ValueError: If dict record contains None values. |
| 93 | + :raises ValueError: If records are not homogeneous (mixed strings and dicts). |
| 94 | + """ |
| 95 | + if len(records) == 0: |
| 96 | + raise ValueError("records must not be empty") |
| 97 | + |
| 98 | + if columns is None and isinstance(records[0], dict): |
| 99 | + raise ValueError("Columns must be specified when passing dictionaries.") |
| 100 | + |
| 101 | + if isinstance(records[0], str): |
| 102 | + # Validate all records are strings |
| 103 | + if not all(isinstance(r, str) for r in records): |
| 104 | + raise ValueError("All records must be strings when the first record is a string.") |
| 105 | + columns = ["text"] |
| 106 | + dict_records: list[dict[str, Any]] = [{"text": record} for record in records] |
| 107 | + was_string = True |
| 108 | + else: |
| 109 | + # Validate all records are dicts |
| 110 | + if not all(isinstance(r, dict) for r in records): |
| 111 | + raise ValueError("All records must be dicts when the first record is a dict.") |
| 112 | + assert columns is not None |
| 113 | + # Coerce values: stringify primitives, keep complex types raw (for images, etc.) |
| 114 | + dict_records_typed: list[dict[str, Any]] = list(records) # type: ignore[arg-type] |
| 115 | + dict_records = [] |
| 116 | + for r in dict_records_typed: |
| 117 | + coerced: dict[str, Any] = {} |
| 118 | + for c in columns: |
| 119 | + val = r.get(c) |
| 120 | + if val is None: |
| 121 | + raise ValueError(f"Column '{c}' has None value in record {r}") |
| 122 | + coerced[c] = coerce_value(val) |
| 123 | + dict_records.append(coerced) |
| 124 | + was_string = False |
| 125 | + |
| 126 | + return dict_records, columns, was_string |
| 127 | + |
| 128 | + |
| 129 | +def _validate_dataset(dataset: DatasetLike, columns: Sequence[str]) -> tuple[dict[str, Sequence[Any]], int]: |
| 130 | + """Validate dataset structure and extract columns.""" |
| 131 | + try: |
| 132 | + column_names = dataset.column_names |
| 133 | + except AttributeError as e: |
| 134 | + raise TypeError("dataset must satisfy DatasetLike (column_names, __len__, __getitem__)") from e |
| 135 | + |
| 136 | + missing = set(columns) - set(column_names) |
| 137 | + if missing: |
| 138 | + raise ValueError(f"Columns {missing} not found in dataset") |
| 139 | + |
| 140 | + n = len(dataset) |
| 141 | + if n == 0: |
| 142 | + raise ValueError("dataset must not be empty") |
| 143 | + |
| 144 | + cols = {c: dataset[c] for c in columns} |
| 145 | + for c in columns: |
| 146 | + if len(cols[c]) != n: |
| 147 | + raise ValueError(f"Column '{c}' length ({len(cols[c])}) does not match dataset length ({n})") |
| 148 | + |
| 149 | + return cols, n |
| 150 | + |
| 151 | + |
| 152 | +def prepare_dataset_records( |
| 153 | + dataset: DatasetLike, |
| 154 | + columns: Sequence[str], |
| 155 | +) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]], bool]: |
| 156 | + """ |
| 157 | + Extract, validate, and exact-deduplicate dataset rows using columnar access. |
| 158 | +
|
| 159 | + :param dataset: A dataset-like object with columnar access. |
| 160 | + :param columns: Columns to use for deduplication. |
| 161 | + :return: Tuple of (deduplicated_records, items, was_string) where: |
| 162 | + - deduplicated_records: representative record per exact-duplicate bucket |
| 163 | + - items: buckets of exact duplicates (each bucket is list[record]) |
| 164 | + - was_string: True iff columns == ["text"] and ALL raw values were strings |
| 165 | + """ |
| 166 | + cols, n = _validate_dataset(dataset, columns) |
| 167 | + |
| 168 | + # was_string controls whether deduplicate() returns strings or dicts. |
| 169 | + # We only return strings if: (1) single column named "text", AND (2) all raw |
| 170 | + # values in the dataset are actual strings (not integers/floats coerced to str). |
| 171 | + was_string = len(columns) == 1 and columns[0] == "text" |
| 172 | + |
| 173 | + def validate_and_coerce(raw: Any, *, col: str, idx: int) -> Any: |
| 174 | + """Validate value is not None, then coerce for encoding.""" |
| 175 | + if raw is None: |
| 176 | + raise ValueError(f"Column '{col}' has None at index {idx}") |
| 177 | + return coerce_value(raw) |
| 178 | + |
| 179 | + # Build all records while tracking was_string |
| 180 | + records: list[dict[str, Any]] = [] |
| 181 | + for i in range(n): |
| 182 | + if was_string and not isinstance(cols["text"][i], str): |
| 183 | + was_string = False |
| 184 | + records.append({c: validate_and_coerce(cols[c][i], col=c, idx=i) for c in columns}) |
| 185 | + |
| 186 | + # Group by exact match, preserving first-occurrence order |
| 187 | + deduplicated_records, items = group_records_by_key(records, columns) |
| 188 | + |
| 189 | + return deduplicated_records, items, was_string |
4 | 190 |
|
5 | 191 |
|
6 | 192 | def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str: |
|
0 commit comments