diff --git a/CITATION.cff b/CITATION.cff
index cd06753..42199b3 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -1,6 +1,6 @@
cff-version: 1.2.0
message: "If you use SemHash in your research, please cite it as below."
-title: "SemHash: Fast Semantic Text Deduplication & Filtering"
+title: "SemHash: Fast Multimodal Semantic Deduplication & Filtering"
authors:
- family-names: "van Dongen"
given-names: "Thomas"
@@ -14,7 +14,7 @@ date-released: "2025-01-05"
preferred-citation:
type: software
- title: "SemHash: Fast Semantic Text Deduplication & Filtering"
+ title: "SemHash: Fast Multimodal Semantic Deduplication & Filtering"
authors:
- family-names: "van Dongen"
given-names: "Thomas"
diff --git a/Makefile b/Makefile
index 2e79ab7..b5af932 100644
--- a/Makefile
+++ b/Makefile
@@ -9,7 +9,7 @@ install: venv
uv run pre-commit install
install-no-pre-commit:
- uv pip install ".[dev]"
+ uv pip install ".[dev,all]"
fix:
uv run pre-commit run --all-files
diff --git a/README.md b/README.md
index fd2ed6a..cffb47d 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@

- Fast Semantic Text Deduplication & Filtering
+ Fast Multimodal Semantic Deduplication & Filtering
@@ -38,9 +38,9 @@
-SemHash is a lightweight and flexible tool for deduplicating datasets, filtering outliers, and finding representative samples using semantic similarity. It combines fast embedding generation from [Model2Vec](https://github.com/MinishLab/model2vec) with efficient ANN-based similarity search through [Vicinity](https://github.com/MinishLab/vicinity).
+SemHash is a lightweight library for semantic deduplication, outlier filtering, and representative sample selection. It's fully multimodal: text works out-of-the-box with fast Model2Vec embeddings, and you can bring your own encoders for images, audio, or custom models.
-SemHash supports both single-dataset deduplication & filtering (e.g., cleaning up a train set by removing duplicates and outliers) and multi-dataset deduplication & filtering (e.g., ensuring no overlap between a test set and a train set). It works with simple datasets, such as text lists, and more complex ones, like multi-column QA datasets. Additionally, it includes functions to inspect deduplication results, making it easier to understand and refine your data cleaning process.
+SemHash supports both single-dataset operations (clean a training set) and cross-dataset operations (deduplicate test against train). It works with simple lists and complex multi-column datasets, and includes inspection tools to help you understand and refine results. All operations use Vicinity for efficient similarity search.
## Quickstart
@@ -49,6 +49,8 @@ Install the package with:
pip install semhash
```
+### Text Deduplication, Filtering & Representative Sampling
+
Deduplicate a single dataset, filter outliers, and find representative samples with the following code (note: the examples assume you have `datasets` installed, which you can install with `pip install datasets`):
```python
@@ -71,7 +73,35 @@ filtered_texts = semhash.self_filter_outliers().selected
representative_texts = semhash.self_find_representative().selected
```
-Or, deduplicate across two datasets, filter outliers, and find representative samples with the following code (e.g., eliminating train/test leakage):
+### Image Deduplication, Filtering & Representative Sampling
+
+Deduplicate an image dataset, filter outliers, and find representative samples using a vision model (requires `pip install sentence-transformers`):
+
+```python
+from datasets import load_dataset
+from sentence_transformers import SentenceTransformer
+from semhash import SemHash
+
+# Load an image dataset and vision model
+model = SentenceTransformer('clip-ViT-B-32')
+dataset = load_dataset("uoft-cs/cifar10", split="test")
+
+# Initialize a SemHash instance with the 'img' column
+semhash = SemHash.from_records(list(dataset), columns=["img"], model=model)
+
+# Deduplicate the images
+deduplicated_images = semhash.self_deduplicate().selected
+
+# Filter outliers
+filtered_images = semhash.self_filter_outliers().selected
+
+# Find representative images
+representative_images = semhash.self_find_representative().selected
+```
+
+### Cross-Dataset Deduplication, Filtering & Representative Sampling
+
+Deduplicate across two datasets, filter outliers, and find representative samples (e.g., eliminating train/test leakage):
```python
from datasets import load_dataset
@@ -93,13 +123,12 @@ filtered_test_texts = semhash.filter_outliers(records=test_texts, outlier_percen
# Find representative texts in the test data against the training data,
# optionally with a specific selection size
-representative_test_texts = semhash.find_representative(
- records=test_texts, selection_size=10).selected
-
-
+representative_test_texts = semhash.find_representative(records=test_texts, selection_size=10).selected
```
-Or, deduplicate multi-column dataset, filter outliers, and find representative samples with the following code (e.g., deduplicating a QA dataset):
+### Multi-Column Deduplication
+
+Deduplicate multi-column datasets (e.g., deduplicating a QA dataset):
```python
from datasets import load_dataset
@@ -116,15 +145,9 @@ semhash = SemHash.from_records(records=records, columns=["question", "context"])
# Deduplicate the records
deduplicated_records = semhash.self_deduplicate().selected
-
-# Filter outliers from the records
-filtered_texts = semhash.self_filter_outliers().selected
-
-# Find representative texts in the records
-representative_texts = semhash.self_find_representative().selected
```
-The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L58). This object stores the deduplicated corpus, a set of duplicate object (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result.
+The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L58). This object stores the deduplicated corpus, a set of duplicate objects (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result.
The `filter_outliers`, `self_filter_outliers`, `find_representative`, and `self_find_representative` functions return a [FilterResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#179). This object stores the found outliers/representative samples.
@@ -212,14 +235,11 @@ The following code snippet shows how to deduplicate across two datasets, filter
from datasets import load_dataset
from semhash import SemHash
-# Initialize a SemHash instance
-semhash = SemHash()
-
# Load two datasets to deduplicate
train_texts = load_dataset("ag_news", split="train")["text"]
test_texts = load_dataset("ag_news", split="test")["text"]
-# Initialize a SemHash instance
+# Initialize a SemHash instance with the training data
semhash = SemHash.from_records(records=train_texts)
# Deduplicate the test data against the training data
@@ -265,6 +285,56 @@ representative_records = semhash.self_find_representative().selected
+
+ Deduplicate, filter outliers, and find representative samples on image datasets
+
+
+You can bring your own encoder for any modality by implementing the Encoder protocol. Here's an example using a vision model from timm for image deduplication:
+
+```python
+from datasets import load_dataset
+import timm
+import torch
+from semhash import SemHash
+
+# Requires: pip install timm torch datasets
+
+# Create a custom image encoder
+class VisionEncoder:
+ """Custom encoder using timm models. Implements the Encoder protocol."""
+
+ def __init__(self, model_name: str = "mobilenetv3_small_100"):
+ self.model = timm.create_model(model_name, pretrained=True, num_classes=0).eval()
+ self.transform = timm.data.create_transform(**timm.data.resolve_model_data_config(self.model))
+
+ def encode(self, inputs):
+ """Encode a batch of PIL images into embeddings."""
+ with torch.no_grad():
+ return self.model(torch.stack([self.transform(img) for img in inputs])).numpy()
+
+# Load image dataset
+dataset = load_dataset("uoft-cs/cifar10", split="test")
+train_data = [{"img": img, "id": i} for i, img in enumerate(dataset["img"][:100])]
+test_data = [{"img": img, "id": i} for i, img in enumerate(dataset["img"][100:150])]
+
+# Initialize SemHash with the custom vision encoder
+semhash = SemHash.from_records(train_data, columns=["img"], model=VisionEncoder())
+
+# Single-dataset operations
+deduplicated = semhash.self_deduplicate().selected
+outliers = semhash.self_filter_outliers().selected
+representatives = semhash.self_find_representative().selected
+
+# Cross-dataset operations
+test_deduplicated = semhash.deduplicate(test_data).selected
+test_outliers = semhash.filter_outliers(test_data).selected
+test_representatives = semhash.find_representative(test_data, selection_size=10).selected
+```
+
+The Encoder protocol requires only an `encode(inputs, **kwargs)` method that returns a numpy array. This makes it easy to integrate any embedding model for any modality.
+
+
+
Using custom encoders
@@ -400,6 +470,44 @@ representative_texts = semhash.self_find_representative().selected
```
+
+ Initializing from a HuggingFace Dataset
+
+You can easily use SemHash with HuggingFace Datasets by converting them to a list:
+
+```python
+from datasets import load_dataset
+from semhash import SemHash
+
+# Load a HuggingFace dataset
+dataset = load_dataset("ag_news", split="train")
+
+# Convert to list and initialize SemHash
+semhash = SemHash.from_records(records=list(dataset), columns=["text"])
+
+# Deduplicate, filter outliers, and find representative samples
+deduplicated_texts = semhash.self_deduplicate().selected
+filtered_texts = semhash.self_filter_outliers().selected
+representative_texts = semhash.self_find_representative().selected
+```
+
+This also works with multi-column datasets:
+
+```python
+from datasets import load_dataset
+from semhash import SemHash
+
+# Load a multi-column dataset
+dataset = load_dataset("squad_v2", split="train")
+
+# Convert to list and initialize with multiple columns
+semhash = SemHash.from_records(records=list(dataset), columns=["question", "context"])
+
+# Deduplicate the records
+deduplicated_records = semhash.self_deduplicate().selected
+```
+
+
@@ -419,7 +527,7 @@ If you use SemHash in your research, please cite the following:
```bibtex
@software{minishlab2025semhash,
author = {{van Dongen}, Thomas and Stephan Tulkens},
- title = {SemHash: Fast Semantic Text Deduplication \& Filtering},
+ title = {SemHash: Fast Multimodal Semantic Deduplication \& Filtering},
year = {2025},
publisher = {Zenodo},
doi = {10.5281/zenodo.17265942},
diff --git a/pyproject.toml b/pyproject.toml
index 966df8a..e8ba058 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "semhash"
-description = "Fast Semantic Text Deduplication & Filtering"
+description = "Fast Multimodal Semantic Deduplication & Filtering"
authors = [{name = "Thomas van Dongen", email = "thomas123@live.nl"}, { name = "Stéphan Tulkens", email = "stephantul@gmail.com"}]
readme = { file = "README.md", content-type = "text/markdown" }
dynamic = ["version"]
@@ -43,6 +43,7 @@ dev = [
"ruff",
]
+
[project.urls]
"Homepage" = "https://github.com/MinishLab"
"Bug Reports" = "https://github.com/MinishLab/semhash/issues"
diff --git a/semhash/records.py b/semhash/records.py
index c2bbb27..2229e11 100644
--- a/semhash/records.py
+++ b/semhash/records.py
@@ -1,6 +1,129 @@
+from collections import defaultdict
from collections.abc import Sequence
+from typing import Any
+
+from frozendict import frozendict
from semhash.datamodels import DeduplicationResult, DuplicateRecord
+from semhash.utils import Record, coerce_value, to_frozendict
+
+
+def group_records_by_key(
+ records: Sequence[dict[str, Any]],
+ columns: Sequence[str],
+) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]]]:
+ """
+ Group records by exact match on columns, preserving first-occurrence order.
+
+ :param records: Records to group.
+ :param columns: Columns to use as grouping key.
+ :return: Tuple of (deduplicated_records, items) where:
+ - deduplicated_records: first record from each unique group
+ - items: list of groups, each group is a list of exact duplicates
+ """
+ buckets: dict[frozendict[str, Any], list[dict[str, Any]]] = {}
+ order: list[frozendict[str, Any]] = []
+
+ for r in records:
+ key = to_frozendict(r, columns)
+ bucket = buckets.get(key)
+ if bucket is None:
+ buckets[key] = [r]
+ order.append(key)
+ else:
+ bucket.append(r)
+
+ items = [buckets[k] for k in order]
+ deduplicated_records = [bucket[0] for bucket in items]
+ return deduplicated_records, items
+
+
+def remove_exact_duplicates(
+ records: Sequence[dict[str, Any]],
+ columns: Sequence[str],
+ reference_records: list[list[dict[str, Any]]] | None = None,
+) -> tuple[list[dict[str, Any]], list[tuple[dict[str, Any], list[dict[str, Any]]]]]:
+ """
+ Remove exact duplicates based on the hashable representation of each record.
+
+ If reference_records is None, the function will only check for duplicates within the records list.
+
+ :param records: A list of records to check for exact duplicates.
+ :param columns: Columns to unpack.
+ :param reference_records: A list of records to compare against. These are already unpacked
+ :return: A list of deduplicated records and a list of duplicates.
+ """
+ deduplicated: list[dict[str, Any]] = []
+ duplicates: list[tuple[dict[str, Any], list[dict[str, Any]]]] = []
+
+ column_set = set(columns)
+ # Build a seen set from reference_records if provided
+ seen: defaultdict[frozendict[str, Any], list[dict[str, Any]]] = defaultdict(list)
+ if reference_records is not None:
+ for record_set in reference_records:
+ key = to_frozendict(record_set[0], column_set)
+ seen[key] = list(record_set)
+ in_one_set = reference_records is None
+
+ for record in records:
+ frozen_record = to_frozendict(record, column_set)
+ if duplicated_records := seen.get(frozen_record):
+ duplicates.append((record, duplicated_records))
+ else:
+ deduplicated.append(record)
+ # Only add current documents to seen if no reference set is used
+ if in_one_set:
+ seen[frozen_record].append(record)
+
+ return deduplicated, duplicates
+
+
+def prepare_records(
+ records: Sequence[Record], columns: Sequence[str] | None
+) -> tuple[list[dict[str, Any]], Sequence[str], bool]:
+ """
+ Validate and prepare records for processing.
+
+ :param records: A list of records (strings or dictionaries).
+ :param columns: Columns to use if records are dictionaries.
+ :return: Tuple of (dict_records, columns, was_string).
+ :raises ValueError: If records are empty.
+ :raises ValueError: If columns are not provided for dictionary records.
+ :raises ValueError: If dict record contains None values.
+ :raises ValueError: If records are not homogeneous (mixed strings and dicts).
+ """
+ if len(records) == 0:
+ raise ValueError("records must not be empty")
+
+ if columns is None and isinstance(records[0], dict):
+ raise ValueError("Columns must be specified when passing dictionaries.")
+
+ if isinstance(records[0], str):
+ # Validate all records are strings
+ if not all(isinstance(r, str) for r in records):
+ raise ValueError("All records must be strings when the first record is a string.")
+ columns = ["text"]
+ dict_records: list[dict[str, Any]] = [{"text": record} for record in records]
+ was_string = True
+ else:
+ # Validate all records are dicts
+ if not all(isinstance(r, dict) for r in records):
+ raise ValueError("All records must be dicts when the first record is a dict.")
+ assert columns is not None
+ # Coerce values: stringify primitives, keep complex types raw (for images, etc.)
+ dict_records_typed: list[dict[str, Any]] = list(records) # type: ignore[arg-type]
+ dict_records = []
+ for r in dict_records_typed:
+ coerced: dict[str, Any] = {}
+ for c in columns:
+ val = r.get(c)
+ if val is None:
+ raise ValueError(f"Column '{c}' has None value in record {r}")
+ coerced[c] = coerce_value(val)
+ dict_records.append(coerced)
+ was_string = False
+
+ return dict_records, columns, was_string
def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str:
diff --git a/semhash/semhash.py b/semhash/semhash.py
index 2baaabe..e866bc4 100644
--- a/semhash/semhash.py
+++ b/semhash/semhash.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-from collections import defaultdict
from collections.abc import Sequence
from math import ceil
from typing import Any, Generic, Literal
@@ -11,15 +10,21 @@
from pyversity import Strategy, diversify
from vicinity import Backend
-from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, Record
+from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult
from semhash.index import Index
-from semhash.records import add_scores_to_records, map_deduplication_result_to_strings
+from semhash.records import (
+ add_scores_to_records,
+ group_records_by_key,
+ map_deduplication_result_to_strings,
+ prepare_records,
+ remove_exact_duplicates,
+)
from semhash.utils import (
Encoder,
+ Record,
+ coerce_value,
compute_candidate_limit,
featurize,
- prepare_records,
- remove_exact_duplicates,
to_frozendict,
)
@@ -52,7 +57,7 @@ def from_records(
"""
Initialize a SemHash instance from records.
- This removes exact duplicates, featurizes the records, and fits a vicinity index.
+ Removes exact duplicates, featurizes the records, and fits a vicinity index.
:param records: A list of records (strings or dictionaries).
:param columns: Columns to featurize if records are dictionaries.
@@ -65,37 +70,17 @@ def from_records(
dict_records, columns, was_string = prepare_records(records, columns)
# If no model is provided, load the default model
- if model is None:
+ if model is None: # pragma: no cover
model = StaticModel.from_pretrained("minishlab/potion-base-8M")
- # Remove exact duplicates
- deduplicated_records, duplicates = remove_exact_duplicates(dict_records, columns)
-
- col_set = set(columns)
- duplicate_map = defaultdict(list)
- for x, _ in duplicates:
- frozen_record = to_frozendict(x, col_set)
- duplicate_map[frozen_record].append(x)
-
- items: list[list[dict[str, str]]] = []
- for record in deduplicated_records:
- i = [record]
- frozen_record = to_frozendict(record, col_set)
- i.extend(duplicate_map[frozen_record])
- items.append(i)
+ # Group by exact match, preserving first-occurrence order
+ deduplicated_records, items = group_records_by_key(dict_records, columns)
# Create embeddings for deduplicated records only
embeddings = featurize(deduplicated_records, columns, model)
- # Build the Vicinity index
- index = Index.from_vectors_and_items(
- vectors=embeddings,
- items=items,
- backend_type=ann_backend,
- **kwargs,
- )
-
- return cls(index=index, columns=columns, model=model, was_string=was_string)
+ index = Index.from_vectors_and_items(vectors=embeddings, items=items, backend_type=ann_backend, **kwargs)
+ return cls(index=index, model=model, columns=columns, was_string=was_string)
@classmethod
def from_embeddings(
@@ -110,7 +95,7 @@ def from_embeddings(
"""
Initialize a SemHash instance from pre-computed embeddings.
- This removes exact duplicates and fits a vicinity index using the provided embeddings.
+ Removes exact duplicates, featurizes the records, and fits a vicinity index.
:param embeddings: Pre-computed embeddings as a numpy array of shape (n_records, embedding_dim).
:param records: A list of records (strings or dictionaries) corresponding to the embeddings.
@@ -160,10 +145,7 @@ def from_embeddings(
deduplicated_embeddings = embeddings[keep_embedding_indices]
index = Index.from_vectors_and_items(
- vectors=deduplicated_embeddings,
- items=items,
- backend_type=ann_backend,
- **kwargs,
+ vectors=deduplicated_embeddings, items=items, backend_type=ann_backend, **kwargs
)
return cls(index=index, model=model, columns=columns, was_string=was_string)
@@ -267,8 +249,8 @@ def self_deduplicate(
duplicate_records.append(DuplicateRecord(record=curr_record, duplicates=items_with_score, exact=True))
# If we don't see any similar_items, we know the record is not a duplicate.
- # in rare cases, the item itself might not be a duplicate of itself.
- if not similar_items:
+ # In rare cases, the item itself might not be returned by the index.
+ if not similar_items: # pragma: no cover
deduplicated_records.append(record)
continue
items, _ = zip(*similar_items)
@@ -299,30 +281,45 @@ def self_deduplicate(
return result
- def _validate_if_strings(self, records: Sequence[dict[str, str] | str]) -> Sequence[dict[str, str]]:
+ def _validate_if_strings(self, records: Sequence[dict[str, Any] | str]) -> list[dict[str, Any]]:
"""
Validate if the records are strings.
If the records are strings, they are converted to dictionaries with a single column.
+ If the records are dicts, primitives are stringified and complex types (images, etc.) are kept raw.
:param records: The records to validate.
:return: The records as a list of dictionaries.
:raises ValueError: If records are empty.
:raises ValueError: If the records are strings but were not originally strings.
- :raises ValueError: If the records are not all strings or dictionaries.
+ :raises ValueError: If the records are not all strings or all dictionaries.
+ :raises ValueError: If dict record contains None values.
"""
if len(records) == 0:
raise ValueError("records must not be empty")
+ # String path
if isinstance(records[0], str):
if not self._was_string:
raise ValueError("Records were not originally strings, but you passed strings.")
- dict_records = [{"text": record} for record in records if isinstance(record, str)]
- else:
- dict_records = [record for record in records if isinstance(record, dict)]
- if len(dict_records) != len(records):
- raise ValueError("Records must be either strings or dictionaries.")
- return dict_records
+ if not all(isinstance(r, str) for r in records):
+ raise ValueError("Records must be all strings.")
+ return [{"text": r} for r in records]
+
+ # Dict path
+ if not all(isinstance(r, dict) for r in records):
+ raise ValueError("Records must be all dictionaries.")
+
+ dict_records: Sequence[dict[str, Any]] = records # type: ignore[assignment]
+ result: list[dict[str, Any]] = []
+ for r in dict_records:
+ out = {}
+ for c in self.columns:
+ if (val := r.get(c)) is None:
+ raise ValueError(f"Column '{c}' has None value in record {r}")
+ out[c] = coerce_value(val)
+ result.append(out)
+ return result
def find_representative(
self,
diff --git a/semhash/utils.py b/semhash/utils.py
index f3abb00..1b0d7b3 100644
--- a/semhash/utils.py
+++ b/semhash/utils.py
@@ -1,4 +1,4 @@
-from collections import defaultdict
+import hashlib
from collections.abc import Sequence
from typing import Any, Protocol, TypeAlias, TypeVar
@@ -11,26 +11,78 @@
class Encoder(Protocol):
- """An encoder protocol for SemHash."""
+ """An encoder protocol for SemHash. Supports text, images, or any encodable data."""
def encode(
self,
- sentences: list[str] | str | Sequence[str],
+ inputs: Sequence[Any] | Any,
**kwargs: Any,
) -> np.ndarray:
"""
- Encode a list of sentences into embeddings.
+ Encode a list of inputs into embeddings.
- :param sentences: A list of sentences to encode.
+ :param inputs: A list of inputs to encode (strings, images, etc.).
:param **kwargs: Additional keyword arguments.
- :return: The embeddings of the sentences.
+ :return: The embeddings of the inputs.
"""
... # pragma: no cover
-def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]:
- """Convert a record to a frozendict."""
- return frozendict({k: record.get(k, "") for k in columns})
+def make_hashable(value: Any) -> Any:
+ """
+ Convert a value to a hashable representation for use as dict keys.
+
+ Strings and other hashable types are returned as-is.
+ Non-hashable types (like PIL images, numpy arrays) are hashed to a string.
+
+ :param value: The value to make hashable.
+ :return: A hashable representation of the value.
+ """
+ # Fast path: most values are strings or already hashable
+ if isinstance(value, (str, int, float, bool, type(None))):
+ return value
+ # Handle objects with tobytes() (PIL Image, numpy array, etc.)
+ if hasattr(value, "tobytes"):
+ return hashlib.md5(value.tobytes()).hexdigest()
+ # Fallback: try to hash, otherwise stringify
+ try:
+ hash(value)
+ return value
+ except TypeError:
+ return str(value)
+
+
+def coerce_value(value: Any) -> Any:
+ """
+ Coerce a value for encoding: stringify primitives, keep complex types raw.
+
+ This ensures primitives (int, float, bool) work with text encoders,
+ while complex types (PIL images, tensors, etc.) are passed through for multimodal encoders.
+
+ :param value: The value to coerce.
+ :return: The coerced value.
+ """
+ if isinstance(value, (str, bytes)):
+ return value
+ if isinstance(value, (int, float, bool)):
+ return str(value)
+ return value # Complex types (images, tensors, etc.)
+
+
+def to_frozendict(record: dict[str, Any], columns: Sequence[str] | set[str]) -> frozendict[str, Any]:
+ """
+ Convert a record to a frozendict with hashable values.
+
+ :param record: The record to convert.
+ :param columns: The columns to include.
+ :return: A frozendict with only the specified columns (values made hashable).
+ :raises ValueError: If a column is missing from the record.
+ """
+ try:
+ return frozendict({k: make_hashable(record[k]) for k in columns})
+ except KeyError as e:
+ missing = e.args[0]
+ raise ValueError(f"Missing column '{missing}' in record {record}") from e
def compute_candidate_limit(
@@ -62,7 +114,7 @@ def compute_candidate_limit(
def featurize(
- records: Sequence[dict[str, str]],
+ records: Sequence[dict[str, Any]],
columns: Sequence[str],
model: Encoder,
) -> np.ndarray:
@@ -73,81 +125,16 @@ def featurize(
:param columns: Columns to featurize.
:param model: An Encoder model.
:return: The embeddings of the records.
+ :raises ValueError: If a column is missing from one or more records.
"""
# Extract the embeddings for each column across all records
embeddings_per_col = []
for col in columns:
- col_texts = [r[col] for r in records]
+ try:
+ col_texts = [r[col] for r in records]
+ except KeyError as e:
+ raise ValueError(f"Missing column '{col}' in one or more records") from e
col_emb = model.encode(col_texts)
embeddings_per_col.append(np.asarray(col_emb))
return np.concatenate(embeddings_per_col, axis=1)
-
-
-def remove_exact_duplicates(
- records: Sequence[dict[str, str]],
- columns: Sequence[str],
- reference_records: list[list[dict[str, str]]] | None = None,
-) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]:
- """
- Remove exact duplicates based on the unpacked string representation of each record.
-
- If reference_records is None, the function will only check for duplicates within the records list.
-
- :param records: A list of records to check for exact duplicates.
- :param columns: Columns to unpack.
- :param reference_records: A list of records to compare against. These are already unpacked
- :return: A list of deduplicated records and a list of duplicates.
- """
- deduplicated = []
- duplicates = []
-
- column_set = set(columns)
- # Build a seen set from reference_records if provided
- seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list)
- if reference_records is not None:
- for record_set in reference_records:
- key = to_frozendict(record_set[0], column_set)
- seen[key] = list(record_set)
- in_one_set = reference_records is None
-
- for record in records:
- frozen_record = to_frozendict(record, column_set)
- if duplicated_records := seen.get(frozen_record):
- duplicates.append((record, duplicated_records))
- else:
- deduplicated.append(record)
- # Only add current documents to seen if no reference set is used
- if in_one_set:
- seen[frozen_record].append(record)
-
- return deduplicated, duplicates
-
-
-def prepare_records(
- records: Sequence[Record], columns: Sequence[str] | None
-) -> tuple[list[dict[str, str]], Sequence[str], bool]:
- """
- Validate and prepare records for processing.
-
- :param records: A list of records (strings or dictionaries).
- :param columns: Columns to use if records are dictionaries.
- :return: Tuple of (dict_records, columns, was_string).
- :raises ValueError: If records are empty.
- :raises ValueError: If columns are not provided for dictionary records.
- """
- if len(records) == 0:
- raise ValueError("records must not be empty")
-
- if columns is None and isinstance(records[0], dict):
- raise ValueError("Columns must be specified when passing dictionaries.")
-
- if isinstance(records[0], str):
- columns = ["text"]
- dict_records: list[dict[str, str]] = [{"text": str(record)} for record in records]
- was_string = True
- else:
- dict_records = list(records)
- was_string = False
-
- return dict_records, columns, was_string
diff --git a/semhash/version.py b/semhash/version.py
index 9bfefb0..cc35d8d 100644
--- a/semhash/version.py
+++ b/semhash/version.py
@@ -1,2 +1,2 @@
-__version_triple__ = (0, 3, 3)
-__version__ = ".".join(map(str, __version_triple__))
+__version_triple__ = (0, 3, 3) # pragma: no cover
+__version__ = ".".join(map(str, __version_triple__)) # pragma: no cover
diff --git a/tests/test_datamodels.py b/tests/test_datamodels.py
index 59c2563..307eeeb 100644
--- a/tests/test_datamodels.py
+++ b/tests/test_datamodels.py
@@ -1,8 +1,6 @@
import pytest
-import semhash
-import semhash.version
-from semhash.datamodels import DeduplicationResult, DuplicateRecord, SelectedWithDuplicates
+from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, SelectedWithDuplicates
def test_deduplication_scoring() -> None:
@@ -25,34 +23,27 @@ def test_deduplication_scoring_exact() -> None:
assert d.exact_duplicate_ratio == 0.2
-def test_deduplication_scoring_exact_empty() -> None:
- """Test the deduplication scoring."""
- d = DeduplicationResult([], [], 0.8, columns=["text"])
- assert d.exact_duplicate_ratio == 0.0
-
-
def test_deduplication_scoring_empty() -> None:
- """Test the deduplication scoring."""
+ """Test the deduplication scoring with empty results."""
d = DeduplicationResult([], [], 0.8, columns=["text"])
assert d.duplicate_ratio == 0.0
+ assert d.exact_duplicate_ratio == 0.0
def test_rethreshold() -> None:
- """Test rethresholding the duplicates."""
+ """Test rethresholding the duplicates, including empty case."""
d = DuplicateRecord("a", False, [("b", 0.9), ("c", 0.8)])
d._rethreshold(0.85)
assert d.duplicates == [("b", 0.9)]
-
-def test_rethreshold_empty() -> None:
- """Test rethresholding the duplicates."""
- d = DuplicateRecord("a", False, [])
- d._rethreshold(0.85)
- assert d.duplicates == []
+ # Empty case
+ d_empty = DuplicateRecord("a", False, [])
+ d_empty._rethreshold(0.85)
+ assert d_empty.duplicates == []
def test_get_least_similar_from_duplicates() -> None:
- """Test getting the least similar duplicates."""
+ """Test getting the least similar duplicates, including empty case."""
d = DeduplicationResult(
["a", "b", "c"],
[DuplicateRecord("a", False, [("b", 0.9), ("c", 0.7)]), DuplicateRecord("b", False, [("c", 0.8)])],
@@ -61,11 +52,9 @@ def test_get_least_similar_from_duplicates() -> None:
result = d.get_least_similar_from_duplicates(1)
assert result == [("a", "c", 0.7)]
-
-def test_get_least_similar_from_duplicates_empty() -> None:
- """Test getting the least similar duplicates."""
- d = DeduplicationResult([], [], 0.8, columns=["text"])
- assert d.get_least_similar_from_duplicates(1) == []
+ # Empty case
+ d_empty = DeduplicationResult([], [], 0.8, columns=["text"])
+ assert d_empty.get_least_similar_from_duplicates(1) == []
def test_rethreshold_deduplication_result() -> None:
@@ -243,3 +232,10 @@ def test_selected_with_duplicates_cache_invalidation_on_rethreshold() -> None:
assert result2[0].duplicates[0][0] == "duplicate_1"
# Results should be different objects
assert result1 is not result2
+
+
+def test_filter_result_empty() -> None:
+ """Test FilterResult ratios with empty lists."""
+ result = FilterResult(selected=[], filtered=[])
+ assert result.filter_ratio == 0.0
+ assert result.selected_ratio == 1.0
diff --git a/tests/test_semhash.py b/tests/test_semhash.py
index 1b4105c..55119fd 100644
--- a/tests/test_semhash.py
+++ b/tests/test_semhash.py
@@ -141,27 +141,31 @@ def test_deduplicate_with_only_exact_duplicates(model: Encoder) -> None:
def test_self_find_representative(model: Encoder, train_texts: list[str]) -> None:
"""Test the self_find_representative method."""
semhash = SemHash.from_records(records=train_texts, model=model)
- result = semhash.self_find_representative(
- candidate_limit=5,
- selection_size=3,
- diversity=0.5,
- )
+
+ # Test with explicit candidate_limit
+ result = semhash.self_find_representative(candidate_limit=5, selection_size=3, diversity=0.5)
assert len(result.selected) == 3, "Expected 3 representatives"
selected = {r["text"] for r in result.selected}
- assert selected == {
- "blueberry",
- "pineapple",
- "grape",
- }, "Expected representatives to be blueberry, pineapple, and grape"
+ assert selected == {"blueberry", "pineapple", "grape"}
+
+ # Test with auto candidate_limit (default)
+ result_auto = semhash.self_find_representative(selection_size=3, diversity=0.5)
+ assert len(result_auto.selected) == 3
def test_find_representative(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
"""Test the find_representative method."""
semhash = SemHash.from_records(records=train_texts, model=model)
+
+ # Test with explicit candidate_limit
result = semhash.find_representative(records=test_texts, candidate_limit=5, selection_size=3, diversity=0.5)
assert len(result.selected) == 3, "Expected 3 representatives"
selected = {r["text"] for r in result.selected}
- assert selected == {"grapefruit", "banana", "apple"}, "Expected representatives to be grapefruit, banana, and apple"
+ assert selected == {"grapefruit", "banana", "apple"}
+
+ # Test with auto candidate_limit (default)
+ result_auto = semhash.find_representative(records=test_texts, selection_size=3, diversity=0.5)
+ assert len(result_auto.selected) == 3
def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
@@ -173,10 +177,23 @@ def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: lis
filtered = {r["text"] for r in result.filtered}
assert filtered == {"motorcycle", "plane"}, "Expected outliers to be motorcycle and plane"
+ # Test FilterResult ratio properties
+ assert result.filter_ratio == len(result.filtered) / len(test_texts)
+ assert result.selected_ratio == len(result.selected) / len(test_texts)
+ assert result.filter_ratio + result.selected_ratio == 1.0
+
# Test with outlier_percentage=0.0 (should return no outliers)
result_zero = semhash.filter_outliers(records=test_texts, outlier_percentage=0.0)
assert result_zero.filtered == []
assert len(result_zero.selected) == len(test_texts)
+ assert result_zero.filter_ratio == 0.0
+ assert result_zero.selected_ratio == 1.0
+
+ # Invalid outlier_percentage raises ValueError
+ with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"):
+ semhash.filter_outliers(records=test_texts, outlier_percentage=-0.1)
+ with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"):
+ semhash.filter_outliers(records=test_texts, outlier_percentage=1.5)
def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None:
@@ -193,6 +210,12 @@ def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None:
assert result_zero.filtered == []
assert len(result_zero.selected) == len(train_texts)
+ # Invalid outlier_percentage raises ValueError
+ with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"):
+ semhash.self_filter_outliers(outlier_percentage=-0.1)
+ with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"):
+ semhash.self_filter_outliers(outlier_percentage=1.5)
+
def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test the _diversify method."""
@@ -226,30 +249,80 @@ def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None:
def test_from_embeddings(model: Encoder, train_texts: list[str]) -> None:
"""Test from_embeddings constructor with validation and comparison to from_records."""
- # Test validation: mismatched shapes
+ # Validation: empty records
+ with pytest.raises(ValueError, match="records must not be empty"):
+ SemHash.from_embeddings(embeddings=np.array([[]]), records=[], model=model)
+
+ # Validation: non-2D embeddings
+ with pytest.raises(ValueError, match="must be a 2D array"):
+ SemHash.from_embeddings(embeddings=np.array([1, 2, 3]), records=["a", "b", "c"], model=model)
+
+ # Validation: mismatched shapes
with pytest.raises(ValueError, match="Number of embeddings"):
wrong_embeddings = model.encode(["apple", "banana"])
SemHash.from_embeddings(embeddings=wrong_embeddings, records=train_texts, model=model)
# Test that from_embeddings behaves same as from_records
semhash_from_records = SemHash.from_records(records=train_texts, model=model)
-
embeddings = model.encode(train_texts)
semhash_from_embeddings = SemHash.from_embeddings(embeddings=embeddings, records=train_texts, model=model)
- # Both should give same deduplication results
result1 = semhash_from_records.self_deduplicate(threshold=0.95)
result2 = semhash_from_embeddings.self_deduplicate(threshold=0.95)
-
assert len(result1.selected) == len(result2.selected)
- assert len(result1.filtered) == len(result2.filtered)
# Test that from_embeddings keeps first-occurrence embeddings and drops duplicates
records = ["apple", "banana", "apple", "cherry"]
embeddings = np.array([[0.0], [1.0], [2.0], [3.0]], dtype=np.float32)
-
semhash = SemHash.from_embeddings(embeddings=embeddings, records=records, model=model)
-
assert semhash.index.vectors.shape == (3, 1)
- # Should keep embeddings at indices 0, 1, 3 (first occurrences of img1, img2, img3)
assert semhash.index.vectors.tolist() == [[0.0], [1.0], [3.0]]
+
+
+def test_from_records_edge_cases(model: Encoder) -> None:
+ """Test from_records edge cases: coercion, order preservation, None rejection."""
+ # Coerces non-string dict values to strings
+ records = [{"id": 1}, {"id": 2}, {"id": 1}] # Integers, with duplicate
+ semhash = SemHash.from_records(records, columns=["id"], model=model)
+ assert semhash.index.vectors.shape[0] == 2 # Deduplicated
+ assert 2 in [len(bucket) for bucket in semhash.index.items] # id=1 bucket has 2
+
+ # Preserves first-occurrence order (deterministic)
+ texts = ["zebra", "apple", "zebra", "banana", "apple", "cherry"]
+ semhash = SemHash.from_records(texts, model=model)
+ firsts = [bucket[0]["text"] for bucket in semhash.index.items]
+ assert firsts == ["zebra", "apple", "banana", "cherry"]
+
+ # Rejects None values in dict records
+ with pytest.raises(ValueError, match="has None value"):
+ SemHash.from_records([{"text": "apple"}, {"text": None}], columns=["text"], model=model)
+
+
+def test_deduplicate_edge_cases(model: Encoder) -> None:
+ """Test deduplicate() edge cases: coercion, None rejection, empty records, type mismatches."""
+ semhash = SemHash.from_records(["1", "2", "3"], model=model)
+
+ # Coerces non-string dict values
+ result = semhash.deduplicate([{"text": 1}, {"text": 4}], threshold=0.95)
+ assert len(result.filtered) + len(result.selected) == 2
+
+ # Rejects None values
+ with pytest.raises(ValueError, match="has None value"):
+ semhash.deduplicate([{"text": "cherry"}, {"text": None}], threshold=0.95)
+
+ # Rejects empty records
+ with pytest.raises(ValueError, match="records must not be empty"):
+ semhash.deduplicate([], threshold=0.95)
+
+ # Type mismatch: strings passed to dict-based index
+ semhash_dict = SemHash.from_records([{"col": "a"}, {"col": "b"}], columns=["col"], model=model)
+ with pytest.raises(ValueError, match="Records were not originally strings"):
+ semhash_dict.deduplicate(["x", "y"], threshold=0.95)
+
+ # Type mismatch: mixed strings
+ with pytest.raises(ValueError, match="Records must be all strings"):
+ semhash.deduplicate(["a", {"text": "b"}], threshold=0.95)
+
+ # Type mismatch: mixed dicts
+ with pytest.raises(ValueError, match="Records must be all dictionaries"):
+ semhash_dict.deduplicate([{"col": "a"}, "b"], threshold=0.95)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 371fb16..3224d73 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -2,23 +2,74 @@
import pytest
from frozendict import frozendict
-from semhash.utils import (
- Encoder,
- compute_candidate_limit,
- featurize,
- prepare_records,
- remove_exact_duplicates,
- to_frozendict,
-)
+from semhash.records import prepare_records, remove_exact_duplicates
+from semhash.utils import Encoder, coerce_value, compute_candidate_limit, featurize, make_hashable, to_frozendict
+
+
+def test_make_hashable() -> None:
+ """Test make_hashable with various types."""
+ # Fast path: primitives
+ assert make_hashable("hello") == "hello"
+ assert make_hashable(42) == 42
+ assert make_hashable(3.14) == 3.14
+ assert make_hashable(True) is True
+ assert make_hashable(None) is None
+
+ # Objects with tobytes() (simulate PIL Image or numpy array)
+ class MockImage:
+ def tobytes(self) -> bytes:
+ return b"fake_image_data"
+
+ img = MockImage()
+ result = make_hashable(img)
+ assert isinstance(result, str)
+ assert len(result) == 32 # MD5 hex digest
+
+ # Hashable objects (like tuples)
+ assert make_hashable((1, 2, 3)) == (1, 2, 3)
+
+ # Non-hashable fallback to string
+ unhashable = {"key": "value"}
+ result = make_hashable(unhashable)
+ assert result == "{'key': 'value'}"
+
+
+def test_coerce_value() -> None:
+ """Test coerce_value for encoding preparation."""
+ # Strings and bytes pass through
+ assert coerce_value("hello") == "hello"
+ assert coerce_value(b"bytes") == b"bytes"
+
+ # Primitives converted to strings
+ assert coerce_value(42) == "42"
+ assert coerce_value(3.14) == "3.14"
+ assert coerce_value(True) == "True"
+
+ # Complex types pass through unchanged
+ class MockImage:
+ pass
+
+ img = MockImage()
+ assert coerce_value(img) is img
def test_to_frozendict() -> None:
- """Test converting dict to frozendict."""
+ """Test converting dict to frozendict, including error cases."""
record = {"a": "1", "b": "2", "c": "3"}
+
+ # Basic case: select subset of columns
result = to_frozendict(record, {"a", "c"})
assert result == frozendict({"a": "1", "c": "3"})
assert "b" not in result
+ # Works with Sequence (not just set)
+ result = to_frozendict(record, ["a", "b"])
+ assert result == frozendict({"a": "1", "b": "2"})
+
+ # Missing column raises ValueError
+ with pytest.raises(ValueError, match="Missing column 'missing'"):
+ to_frozendict(record, {"a", "missing"})
+
def test_compute_candidate_limit() -> None:
"""Test candidate limit computation."""
@@ -33,30 +84,51 @@ def test_compute_candidate_limit() -> None:
def test_featurize(model: Encoder) -> None:
- """Test featurizing records."""
+ """Test featurizing records, including error cases."""
records = [{"text": "hello"}, {"text": "world"}]
embeddings = featurize(records, ["text"], model)
assert embeddings.shape == (2, 128) # Model has 128 dims
assert isinstance(embeddings, np.ndarray)
+ # Missing column raises ValueError
+ with pytest.raises(ValueError, match="Missing column 'missing'"):
+ featurize(records, ["missing"], model)
+
def test_remove_exact_duplicates() -> None:
- """Test exact duplicate removal."""
+ """Test exact duplicate removal, with and without reference records."""
+ # Basic case: remove duplicates within same list
records = [
{"text": "hello", "id": "1"},
{"text": "world", "id": "2"},
{"text": "hello", "id": "3"},
]
deduplicated, duplicates = remove_exact_duplicates(records, ["text"])
-
assert len(deduplicated) == 2
assert len(duplicates) == 1
assert duplicates[0][0] == {"text": "hello", "id": "3"}
+ # With reference_records: cross-dataset filtering
+ reference_records = [
+ [{"text": "apple"}],
+ [{"text": "banana"}, {"text": "banana"}],
+ ]
+ new_records = [
+ {"text": "cherry"}, # New
+ {"text": "apple"}, # Exists in reference
+ {"text": "date"}, # New
+ {"text": "banana"}, # Exists in reference
+ ]
+ deduplicated, duplicates = remove_exact_duplicates(new_records, ["text"], reference_records=reference_records)
+ assert len(deduplicated) == 2
+ assert {"text": "cherry"} in deduplicated
+ assert {"text": "date"} in deduplicated
+ assert len(duplicates) == 2
+
def test_prepare_records() -> None:
- """Test preparing records."""
- # String records
+ """Test preparing records, including validation and edge cases."""
+ # String records -> converts to dicts with "text" column
records = ["hello", "world"]
dict_records, columns, was_string = prepare_records(records, None)
assert was_string is True
@@ -71,6 +143,15 @@ def test_prepare_records() -> None:
assert dict_records == records
# Dict records without columns raises ValueError
- records = [{"text": "hello"}]
with pytest.raises(ValueError, match="Columns must be specified"):
- prepare_records(records, None)
+ prepare_records([{"text": "hello"}], None)
+
+ # Empty records raises ValueError
+ with pytest.raises(ValueError, match="records must not be empty"):
+ prepare_records([], None)
+
+ # Mixed types rejected
+ with pytest.raises(ValueError, match="All records must be"):
+ prepare_records(["a", {"text": "b"}], None)
+ with pytest.raises(ValueError, match="All records must be"):
+ prepare_records([{"text": "a"}, "b"], ["text"])
diff --git a/uv.lock b/uv.lock
index 0bd321f..9bd57ea 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,6 +1,11 @@
version = 1
revision = 3
requires-python = ">=3.10"
+resolution-markers = [
+ "python_full_version >= '3.12'",
+ "python_full_version == '3.11.*'",
+ "python_full_version < '3.11'",
+]
[[package]]
name = "asttokens"