Skip to content

Commit e808922

Browse files
committed
Moved functions to records
1 parent 95cc5e6 commit e808922

File tree

4 files changed

+196
-195
lines changed

4 files changed

+196
-195
lines changed

semhash/records.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,192 @@
1+
from collections import defaultdict
12
from collections.abc import Sequence
3+
from typing import Any
4+
5+
from frozendict import frozendict
26

37
from semhash.datamodels import DeduplicationResult, DuplicateRecord
8+
from semhash.utils import DatasetLike, Record, coerce_value, to_frozendict
9+
10+
11+
def group_records_by_key(
12+
records: Sequence[dict[str, Any]],
13+
columns: Sequence[str],
14+
) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]]]:
15+
"""
16+
Group records by exact match on columns, preserving first-occurrence order.
17+
18+
:param records: Records to group.
19+
:param columns: Columns to use as grouping key.
20+
:return: Tuple of (deduplicated_records, items) where:
21+
- deduplicated_records: first record from each unique group
22+
- items: list of groups, each group is a list of exact duplicates
23+
"""
24+
buckets: dict[frozendict[str, Any], list[dict[str, Any]]] = {}
25+
order: list[frozendict[str, Any]] = []
26+
27+
for r in records:
28+
key = to_frozendict(r, columns)
29+
bucket = buckets.get(key)
30+
if bucket is None:
31+
buckets[key] = [r]
32+
order.append(key)
33+
else:
34+
bucket.append(r)
35+
36+
items = [buckets[k] for k in order]
37+
deduplicated_records = [bucket[0] for bucket in items]
38+
return deduplicated_records, items
39+
40+
41+
def remove_exact_duplicates(
42+
records: Sequence[dict[str, Any]],
43+
columns: Sequence[str],
44+
reference_records: list[list[dict[str, Any]]] | None = None,
45+
) -> tuple[list[dict[str, Any]], list[tuple[dict[str, Any], list[dict[str, Any]]]]]:
46+
"""
47+
Remove exact duplicates based on the hashable representation of each record.
48+
49+
If reference_records is None, the function will only check for duplicates within the records list.
50+
51+
:param records: A list of records to check for exact duplicates.
52+
:param columns: Columns to unpack.
53+
:param reference_records: A list of records to compare against. These are already unpacked
54+
:return: A list of deduplicated records and a list of duplicates.
55+
"""
56+
deduplicated: list[dict[str, Any]] = []
57+
duplicates: list[tuple[dict[str, Any], list[dict[str, Any]]]] = []
58+
59+
column_set = set(columns)
60+
# Build a seen set from reference_records if provided
61+
seen: defaultdict[frozendict[str, Any], list[dict[str, Any]]] = defaultdict(list)
62+
if reference_records is not None:
63+
for record_set in reference_records:
64+
key = to_frozendict(record_set[0], column_set)
65+
seen[key] = list(record_set)
66+
in_one_set = reference_records is None
67+
68+
for record in records:
69+
frozen_record = to_frozendict(record, column_set)
70+
if duplicated_records := seen.get(frozen_record):
71+
duplicates.append((record, duplicated_records))
72+
else:
73+
deduplicated.append(record)
74+
# Only add current documents to seen if no reference set is used
75+
if in_one_set:
76+
seen[frozen_record].append(record)
77+
78+
return deduplicated, duplicates
79+
80+
81+
def prepare_records(
82+
records: Sequence[Record], columns: Sequence[str] | None
83+
) -> tuple[list[dict[str, Any]], Sequence[str], bool]:
84+
"""
85+
Validate and prepare records for processing.
86+
87+
:param records: A list of records (strings or dictionaries).
88+
:param columns: Columns to use if records are dictionaries.
89+
:return: Tuple of (dict_records, columns, was_string).
90+
:raises ValueError: If records are empty.
91+
:raises ValueError: If columns are not provided for dictionary records.
92+
:raises ValueError: If dict record contains None values.
93+
:raises ValueError: If records are not homogeneous (mixed strings and dicts).
94+
"""
95+
if len(records) == 0:
96+
raise ValueError("records must not be empty")
97+
98+
if columns is None and isinstance(records[0], dict):
99+
raise ValueError("Columns must be specified when passing dictionaries.")
100+
101+
if isinstance(records[0], str):
102+
# Validate all records are strings
103+
if not all(isinstance(r, str) for r in records):
104+
raise ValueError("All records must be strings when the first record is a string.")
105+
columns = ["text"]
106+
dict_records: list[dict[str, Any]] = [{"text": record} for record in records]
107+
was_string = True
108+
else:
109+
# Validate all records are dicts
110+
if not all(isinstance(r, dict) for r in records):
111+
raise ValueError("All records must be dicts when the first record is a dict.")
112+
assert columns is not None
113+
# Coerce values: stringify primitives, keep complex types raw (for images, etc.)
114+
dict_records_typed: list[dict[str, Any]] = list(records) # type: ignore[arg-type]
115+
dict_records = []
116+
for r in dict_records_typed:
117+
coerced: dict[str, Any] = {}
118+
for c in columns:
119+
val = r.get(c)
120+
if val is None:
121+
raise ValueError(f"Column '{c}' has None value in record {r}")
122+
coerced[c] = coerce_value(val)
123+
dict_records.append(coerced)
124+
was_string = False
125+
126+
return dict_records, columns, was_string
127+
128+
129+
def _validate_dataset(dataset: DatasetLike, columns: Sequence[str]) -> tuple[dict[str, Sequence[Any]], int]:
130+
"""Validate dataset structure and extract columns."""
131+
try:
132+
column_names = dataset.column_names
133+
except AttributeError as e:
134+
raise TypeError("dataset must satisfy DatasetLike (column_names, __len__, __getitem__)") from e
135+
136+
missing = set(columns) - set(column_names)
137+
if missing:
138+
raise ValueError(f"Columns {missing} not found in dataset")
139+
140+
n = len(dataset)
141+
if n == 0:
142+
raise ValueError("dataset must not be empty")
143+
144+
cols = {c: dataset[c] for c in columns}
145+
for c in columns:
146+
if len(cols[c]) != n:
147+
raise ValueError(f"Column '{c}' length ({len(cols[c])}) does not match dataset length ({n})")
148+
149+
return cols, n
150+
151+
152+
def prepare_dataset_records(
153+
dataset: DatasetLike,
154+
columns: Sequence[str],
155+
) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]], bool]:
156+
"""
157+
Extract, validate, and exact-deduplicate dataset rows using columnar access.
158+
159+
:param dataset: A dataset-like object with columnar access.
160+
:param columns: Columns to use for deduplication.
161+
:return: Tuple of (deduplicated_records, items, was_string) where:
162+
- deduplicated_records: representative record per exact-duplicate bucket
163+
- items: buckets of exact duplicates (each bucket is list[record])
164+
- was_string: True iff columns == ["text"] and ALL raw values were strings
165+
"""
166+
cols, n = _validate_dataset(dataset, columns)
167+
168+
# was_string controls whether deduplicate() returns strings or dicts.
169+
# We only return strings if: (1) single column named "text", AND (2) all raw
170+
# values in the dataset are actual strings (not integers/floats coerced to str).
171+
was_string = len(columns) == 1 and columns[0] == "text"
172+
173+
def validate_and_coerce(raw: Any, *, col: str, idx: int) -> Any:
174+
"""Validate value is not None, then coerce for encoding."""
175+
if raw is None:
176+
raise ValueError(f"Column '{col}' has None at index {idx}")
177+
return coerce_value(raw)
178+
179+
# Build all records while tracking was_string
180+
records: list[dict[str, Any]] = []
181+
for i in range(n):
182+
if was_string and not isinstance(cols["text"][i], str):
183+
was_string = False
184+
records.append({c: validate_and_coerce(cols[c][i], col=c, idx=i) for c in columns})
185+
186+
# Group by exact match, preserving first-occurrence order
187+
deduplicated_records, items = group_records_by_key(records, columns)
188+
189+
return deduplicated_records, items, was_string
4190

5191

6192
def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str:

semhash/semhash.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,21 @@
1212

1313
from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult
1414
from semhash.index import Index
15-
from semhash.records import add_scores_to_records, map_deduplication_result_to_strings
15+
from semhash.records import (
16+
add_scores_to_records,
17+
group_records_by_key,
18+
map_deduplication_result_to_strings,
19+
prepare_dataset_records,
20+
prepare_records,
21+
remove_exact_duplicates,
22+
)
1623
from semhash.utils import (
1724
DatasetLike,
1825
Encoder,
1926
Record,
2027
coerce_value,
2128
compute_candidate_limit,
2229
featurize,
23-
group_records_by_key,
24-
prepare_dataset_records,
25-
prepare_records,
26-
remove_exact_duplicates,
2730
to_frozendict,
2831
)
2932

0 commit comments

Comments
 (0)