diff --git a/CITATION.cff b/CITATION.cff index cd06753..42199b3 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,6 +1,6 @@ cff-version: 1.2.0 message: "If you use SemHash in your research, please cite it as below." -title: "SemHash: Fast Semantic Text Deduplication & Filtering" +title: "SemHash: Fast Multimodal Semantic Deduplication & Filtering" authors: - family-names: "van Dongen" given-names: "Thomas" @@ -14,7 +14,7 @@ date-released: "2025-01-05" preferred-citation: type: software - title: "SemHash: Fast Semantic Text Deduplication & Filtering" + title: "SemHash: Fast Multimodal Semantic Deduplication & Filtering" authors: - family-names: "van Dongen" given-names: "Thomas" diff --git a/Makefile b/Makefile index 2e79ab7..b5af932 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ install: venv uv run pre-commit install install-no-pre-commit: - uv pip install ".[dev]" + uv pip install ".[dev,all]" fix: uv run pre-commit run --all-files diff --git a/README.md b/README.md index fd2ed6a..cffb47d 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

SemHash logo
- Fast Semantic Text Deduplication & Filtering + Fast Multimodal Semantic Deduplication & Filtering

@@ -38,9 +38,9 @@ -SemHash is a lightweight and flexible tool for deduplicating datasets, filtering outliers, and finding representative samples using semantic similarity. It combines fast embedding generation from [Model2Vec](https://github.com/MinishLab/model2vec) with efficient ANN-based similarity search through [Vicinity](https://github.com/MinishLab/vicinity). +SemHash is a lightweight library for semantic deduplication, outlier filtering, and representative sample selection. It's fully multimodal: text works out-of-the-box with fast Model2Vec embeddings, and you can bring your own encoders for images, audio, or custom models. -SemHash supports both single-dataset deduplication & filtering (e.g., cleaning up a train set by removing duplicates and outliers) and multi-dataset deduplication & filtering (e.g., ensuring no overlap between a test set and a train set). It works with simple datasets, such as text lists, and more complex ones, like multi-column QA datasets. Additionally, it includes functions to inspect deduplication results, making it easier to understand and refine your data cleaning process. +SemHash supports both single-dataset operations (clean a training set) and cross-dataset operations (deduplicate test against train). It works with simple lists and complex multi-column datasets, and includes inspection tools to help you understand and refine results. All operations use Vicinity for efficient similarity search. ## Quickstart @@ -49,6 +49,8 @@ Install the package with: pip install semhash ``` +### Text Deduplication, Filtering & Representative Sampling + Deduplicate a single dataset, filter outliers, and find representative samples with the following code (note: the examples assume you have `datasets` installed, which you can install with `pip install datasets`): ```python @@ -71,7 +73,35 @@ filtered_texts = semhash.self_filter_outliers().selected representative_texts = semhash.self_find_representative().selected ``` -Or, deduplicate across two datasets, filter outliers, and find representative samples with the following code (e.g., eliminating train/test leakage): +### Image Deduplication, Filtering & Representative Sampling + +Deduplicate an image dataset, filter outliers, and find representative samples using a vision model (requires `pip install sentence-transformers`): + +```python +from datasets import load_dataset +from sentence_transformers import SentenceTransformer +from semhash import SemHash + +# Load an image dataset and vision model +model = SentenceTransformer('clip-ViT-B-32') +dataset = load_dataset("uoft-cs/cifar10", split="test") + +# Initialize a SemHash instance with the 'img' column +semhash = SemHash.from_records(list(dataset), columns=["img"], model=model) + +# Deduplicate the images +deduplicated_images = semhash.self_deduplicate().selected + +# Filter outliers +filtered_images = semhash.self_filter_outliers().selected + +# Find representative images +representative_images = semhash.self_find_representative().selected +``` + +### Cross-Dataset Deduplication, Filtering & Representative Sampling + +Deduplicate across two datasets, filter outliers, and find representative samples (e.g., eliminating train/test leakage): ```python from datasets import load_dataset @@ -93,13 +123,12 @@ filtered_test_texts = semhash.filter_outliers(records=test_texts, outlier_percen # Find representative texts in the test data against the training data, # optionally with a specific selection size -representative_test_texts = semhash.find_representative( - records=test_texts, selection_size=10).selected - - +representative_test_texts = semhash.find_representative(records=test_texts, selection_size=10).selected ``` -Or, deduplicate multi-column dataset, filter outliers, and find representative samples with the following code (e.g., deduplicating a QA dataset): +### Multi-Column Deduplication + +Deduplicate multi-column datasets (e.g., deduplicating a QA dataset): ```python from datasets import load_dataset @@ -116,15 +145,9 @@ semhash = SemHash.from_records(records=records, columns=["question", "context"]) # Deduplicate the records deduplicated_records = semhash.self_deduplicate().selected - -# Filter outliers from the records -filtered_texts = semhash.self_filter_outliers().selected - -# Find representative texts in the records -representative_texts = semhash.self_find_representative().selected ``` -The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L58). This object stores the deduplicated corpus, a set of duplicate object (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result. +The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L58). This object stores the deduplicated corpus, a set of duplicate objects (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result. The `filter_outliers`, `self_filter_outliers`, `find_representative`, and `self_find_representative` functions return a [FilterResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#179). This object stores the found outliers/representative samples. @@ -212,14 +235,11 @@ The following code snippet shows how to deduplicate across two datasets, filter from datasets import load_dataset from semhash import SemHash -# Initialize a SemHash instance -semhash = SemHash() - # Load two datasets to deduplicate train_texts = load_dataset("ag_news", split="train")["text"] test_texts = load_dataset("ag_news", split="test")["text"] -# Initialize a SemHash instance +# Initialize a SemHash instance with the training data semhash = SemHash.from_records(records=train_texts) # Deduplicate the test data against the training data @@ -265,6 +285,56 @@ representative_records = semhash.self_find_representative().selected +
+ Deduplicate, filter outliers, and find representative samples on image datasets +
+ +You can bring your own encoder for any modality by implementing the Encoder protocol. Here's an example using a vision model from timm for image deduplication: + +```python +from datasets import load_dataset +import timm +import torch +from semhash import SemHash + +# Requires: pip install timm torch datasets + +# Create a custom image encoder +class VisionEncoder: + """Custom encoder using timm models. Implements the Encoder protocol.""" + + def __init__(self, model_name: str = "mobilenetv3_small_100"): + self.model = timm.create_model(model_name, pretrained=True, num_classes=0).eval() + self.transform = timm.data.create_transform(**timm.data.resolve_model_data_config(self.model)) + + def encode(self, inputs): + """Encode a batch of PIL images into embeddings.""" + with torch.no_grad(): + return self.model(torch.stack([self.transform(img) for img in inputs])).numpy() + +# Load image dataset +dataset = load_dataset("uoft-cs/cifar10", split="test") +train_data = [{"img": img, "id": i} for i, img in enumerate(dataset["img"][:100])] +test_data = [{"img": img, "id": i} for i, img in enumerate(dataset["img"][100:150])] + +# Initialize SemHash with the custom vision encoder +semhash = SemHash.from_records(train_data, columns=["img"], model=VisionEncoder()) + +# Single-dataset operations +deduplicated = semhash.self_deduplicate().selected +outliers = semhash.self_filter_outliers().selected +representatives = semhash.self_find_representative().selected + +# Cross-dataset operations +test_deduplicated = semhash.deduplicate(test_data).selected +test_outliers = semhash.filter_outliers(test_data).selected +test_representatives = semhash.find_representative(test_data, selection_size=10).selected +``` + +The Encoder protocol requires only an `encode(inputs, **kwargs)` method that returns a numpy array. This makes it easy to integrate any embedding model for any modality. + +
+
Using custom encoders
@@ -400,6 +470,44 @@ representative_texts = semhash.self_find_representative().selected ```
+
+ Initializing from a HuggingFace Dataset +
+You can easily use SemHash with HuggingFace Datasets by converting them to a list: + +```python +from datasets import load_dataset +from semhash import SemHash + +# Load a HuggingFace dataset +dataset = load_dataset("ag_news", split="train") + +# Convert to list and initialize SemHash +semhash = SemHash.from_records(records=list(dataset), columns=["text"]) + +# Deduplicate, filter outliers, and find representative samples +deduplicated_texts = semhash.self_deduplicate().selected +filtered_texts = semhash.self_filter_outliers().selected +representative_texts = semhash.self_find_representative().selected +``` + +This also works with multi-column datasets: + +```python +from datasets import load_dataset +from semhash import SemHash + +# Load a multi-column dataset +dataset = load_dataset("squad_v2", split="train") + +# Convert to list and initialize with multiple columns +semhash = SemHash.from_records(records=list(dataset), columns=["question", "context"]) + +# Deduplicate the records +deduplicated_records = semhash.self_deduplicate().selected +``` +
+ @@ -419,7 +527,7 @@ If you use SemHash in your research, please cite the following: ```bibtex @software{minishlab2025semhash, author = {{van Dongen}, Thomas and Stephan Tulkens}, - title = {SemHash: Fast Semantic Text Deduplication \& Filtering}, + title = {SemHash: Fast Multimodal Semantic Deduplication \& Filtering}, year = {2025}, publisher = {Zenodo}, doi = {10.5281/zenodo.17265942}, diff --git a/pyproject.toml b/pyproject.toml index 966df8a..e8ba058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "semhash" -description = "Fast Semantic Text Deduplication & Filtering" +description = "Fast Multimodal Semantic Deduplication & Filtering" authors = [{name = "Thomas van Dongen", email = "thomas123@live.nl"}, { name = "Stéphan Tulkens", email = "stephantul@gmail.com"}] readme = { file = "README.md", content-type = "text/markdown" } dynamic = ["version"] @@ -43,6 +43,7 @@ dev = [ "ruff", ] + [project.urls] "Homepage" = "https://github.com/MinishLab" "Bug Reports" = "https://github.com/MinishLab/semhash/issues" diff --git a/semhash/records.py b/semhash/records.py index c2bbb27..2229e11 100644 --- a/semhash/records.py +++ b/semhash/records.py @@ -1,6 +1,129 @@ +from collections import defaultdict from collections.abc import Sequence +from typing import Any + +from frozendict import frozendict from semhash.datamodels import DeduplicationResult, DuplicateRecord +from semhash.utils import Record, coerce_value, to_frozendict + + +def group_records_by_key( + records: Sequence[dict[str, Any]], + columns: Sequence[str], +) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]]]: + """ + Group records by exact match on columns, preserving first-occurrence order. + + :param records: Records to group. + :param columns: Columns to use as grouping key. + :return: Tuple of (deduplicated_records, items) where: + - deduplicated_records: first record from each unique group + - items: list of groups, each group is a list of exact duplicates + """ + buckets: dict[frozendict[str, Any], list[dict[str, Any]]] = {} + order: list[frozendict[str, Any]] = [] + + for r in records: + key = to_frozendict(r, columns) + bucket = buckets.get(key) + if bucket is None: + buckets[key] = [r] + order.append(key) + else: + bucket.append(r) + + items = [buckets[k] for k in order] + deduplicated_records = [bucket[0] for bucket in items] + return deduplicated_records, items + + +def remove_exact_duplicates( + records: Sequence[dict[str, Any]], + columns: Sequence[str], + reference_records: list[list[dict[str, Any]]] | None = None, +) -> tuple[list[dict[str, Any]], list[tuple[dict[str, Any], list[dict[str, Any]]]]]: + """ + Remove exact duplicates based on the hashable representation of each record. + + If reference_records is None, the function will only check for duplicates within the records list. + + :param records: A list of records to check for exact duplicates. + :param columns: Columns to unpack. + :param reference_records: A list of records to compare against. These are already unpacked + :return: A list of deduplicated records and a list of duplicates. + """ + deduplicated: list[dict[str, Any]] = [] + duplicates: list[tuple[dict[str, Any], list[dict[str, Any]]]] = [] + + column_set = set(columns) + # Build a seen set from reference_records if provided + seen: defaultdict[frozendict[str, Any], list[dict[str, Any]]] = defaultdict(list) + if reference_records is not None: + for record_set in reference_records: + key = to_frozendict(record_set[0], column_set) + seen[key] = list(record_set) + in_one_set = reference_records is None + + for record in records: + frozen_record = to_frozendict(record, column_set) + if duplicated_records := seen.get(frozen_record): + duplicates.append((record, duplicated_records)) + else: + deduplicated.append(record) + # Only add current documents to seen if no reference set is used + if in_one_set: + seen[frozen_record].append(record) + + return deduplicated, duplicates + + +def prepare_records( + records: Sequence[Record], columns: Sequence[str] | None +) -> tuple[list[dict[str, Any]], Sequence[str], bool]: + """ + Validate and prepare records for processing. + + :param records: A list of records (strings or dictionaries). + :param columns: Columns to use if records are dictionaries. + :return: Tuple of (dict_records, columns, was_string). + :raises ValueError: If records are empty. + :raises ValueError: If columns are not provided for dictionary records. + :raises ValueError: If dict record contains None values. + :raises ValueError: If records are not homogeneous (mixed strings and dicts). + """ + if len(records) == 0: + raise ValueError("records must not be empty") + + if columns is None and isinstance(records[0], dict): + raise ValueError("Columns must be specified when passing dictionaries.") + + if isinstance(records[0], str): + # Validate all records are strings + if not all(isinstance(r, str) for r in records): + raise ValueError("All records must be strings when the first record is a string.") + columns = ["text"] + dict_records: list[dict[str, Any]] = [{"text": record} for record in records] + was_string = True + else: + # Validate all records are dicts + if not all(isinstance(r, dict) for r in records): + raise ValueError("All records must be dicts when the first record is a dict.") + assert columns is not None + # Coerce values: stringify primitives, keep complex types raw (for images, etc.) + dict_records_typed: list[dict[str, Any]] = list(records) # type: ignore[arg-type] + dict_records = [] + for r in dict_records_typed: + coerced: dict[str, Any] = {} + for c in columns: + val = r.get(c) + if val is None: + raise ValueError(f"Column '{c}' has None value in record {r}") + coerced[c] = coerce_value(val) + dict_records.append(coerced) + was_string = False + + return dict_records, columns, was_string def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str: diff --git a/semhash/semhash.py b/semhash/semhash.py index 2baaabe..e866bc4 100644 --- a/semhash/semhash.py +++ b/semhash/semhash.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections import defaultdict from collections.abc import Sequence from math import ceil from typing import Any, Generic, Literal @@ -11,15 +10,21 @@ from pyversity import Strategy, diversify from vicinity import Backend -from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, Record +from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult from semhash.index import Index -from semhash.records import add_scores_to_records, map_deduplication_result_to_strings +from semhash.records import ( + add_scores_to_records, + group_records_by_key, + map_deduplication_result_to_strings, + prepare_records, + remove_exact_duplicates, +) from semhash.utils import ( Encoder, + Record, + coerce_value, compute_candidate_limit, featurize, - prepare_records, - remove_exact_duplicates, to_frozendict, ) @@ -52,7 +57,7 @@ def from_records( """ Initialize a SemHash instance from records. - This removes exact duplicates, featurizes the records, and fits a vicinity index. + Removes exact duplicates, featurizes the records, and fits a vicinity index. :param records: A list of records (strings or dictionaries). :param columns: Columns to featurize if records are dictionaries. @@ -65,37 +70,17 @@ def from_records( dict_records, columns, was_string = prepare_records(records, columns) # If no model is provided, load the default model - if model is None: + if model is None: # pragma: no cover model = StaticModel.from_pretrained("minishlab/potion-base-8M") - # Remove exact duplicates - deduplicated_records, duplicates = remove_exact_duplicates(dict_records, columns) - - col_set = set(columns) - duplicate_map = defaultdict(list) - for x, _ in duplicates: - frozen_record = to_frozendict(x, col_set) - duplicate_map[frozen_record].append(x) - - items: list[list[dict[str, str]]] = [] - for record in deduplicated_records: - i = [record] - frozen_record = to_frozendict(record, col_set) - i.extend(duplicate_map[frozen_record]) - items.append(i) + # Group by exact match, preserving first-occurrence order + deduplicated_records, items = group_records_by_key(dict_records, columns) # Create embeddings for deduplicated records only embeddings = featurize(deduplicated_records, columns, model) - # Build the Vicinity index - index = Index.from_vectors_and_items( - vectors=embeddings, - items=items, - backend_type=ann_backend, - **kwargs, - ) - - return cls(index=index, columns=columns, model=model, was_string=was_string) + index = Index.from_vectors_and_items(vectors=embeddings, items=items, backend_type=ann_backend, **kwargs) + return cls(index=index, model=model, columns=columns, was_string=was_string) @classmethod def from_embeddings( @@ -110,7 +95,7 @@ def from_embeddings( """ Initialize a SemHash instance from pre-computed embeddings. - This removes exact duplicates and fits a vicinity index using the provided embeddings. + Removes exact duplicates, featurizes the records, and fits a vicinity index. :param embeddings: Pre-computed embeddings as a numpy array of shape (n_records, embedding_dim). :param records: A list of records (strings or dictionaries) corresponding to the embeddings. @@ -160,10 +145,7 @@ def from_embeddings( deduplicated_embeddings = embeddings[keep_embedding_indices] index = Index.from_vectors_and_items( - vectors=deduplicated_embeddings, - items=items, - backend_type=ann_backend, - **kwargs, + vectors=deduplicated_embeddings, items=items, backend_type=ann_backend, **kwargs ) return cls(index=index, model=model, columns=columns, was_string=was_string) @@ -267,8 +249,8 @@ def self_deduplicate( duplicate_records.append(DuplicateRecord(record=curr_record, duplicates=items_with_score, exact=True)) # If we don't see any similar_items, we know the record is not a duplicate. - # in rare cases, the item itself might not be a duplicate of itself. - if not similar_items: + # In rare cases, the item itself might not be returned by the index. + if not similar_items: # pragma: no cover deduplicated_records.append(record) continue items, _ = zip(*similar_items) @@ -299,30 +281,45 @@ def self_deduplicate( return result - def _validate_if_strings(self, records: Sequence[dict[str, str] | str]) -> Sequence[dict[str, str]]: + def _validate_if_strings(self, records: Sequence[dict[str, Any] | str]) -> list[dict[str, Any]]: """ Validate if the records are strings. If the records are strings, they are converted to dictionaries with a single column. + If the records are dicts, primitives are stringified and complex types (images, etc.) are kept raw. :param records: The records to validate. :return: The records as a list of dictionaries. :raises ValueError: If records are empty. :raises ValueError: If the records are strings but were not originally strings. - :raises ValueError: If the records are not all strings or dictionaries. + :raises ValueError: If the records are not all strings or all dictionaries. + :raises ValueError: If dict record contains None values. """ if len(records) == 0: raise ValueError("records must not be empty") + # String path if isinstance(records[0], str): if not self._was_string: raise ValueError("Records were not originally strings, but you passed strings.") - dict_records = [{"text": record} for record in records if isinstance(record, str)] - else: - dict_records = [record for record in records if isinstance(record, dict)] - if len(dict_records) != len(records): - raise ValueError("Records must be either strings or dictionaries.") - return dict_records + if not all(isinstance(r, str) for r in records): + raise ValueError("Records must be all strings.") + return [{"text": r} for r in records] + + # Dict path + if not all(isinstance(r, dict) for r in records): + raise ValueError("Records must be all dictionaries.") + + dict_records: Sequence[dict[str, Any]] = records # type: ignore[assignment] + result: list[dict[str, Any]] = [] + for r in dict_records: + out = {} + for c in self.columns: + if (val := r.get(c)) is None: + raise ValueError(f"Column '{c}' has None value in record {r}") + out[c] = coerce_value(val) + result.append(out) + return result def find_representative( self, diff --git a/semhash/utils.py b/semhash/utils.py index f3abb00..1b0d7b3 100644 --- a/semhash/utils.py +++ b/semhash/utils.py @@ -1,4 +1,4 @@ -from collections import defaultdict +import hashlib from collections.abc import Sequence from typing import Any, Protocol, TypeAlias, TypeVar @@ -11,26 +11,78 @@ class Encoder(Protocol): - """An encoder protocol for SemHash.""" + """An encoder protocol for SemHash. Supports text, images, or any encodable data.""" def encode( self, - sentences: list[str] | str | Sequence[str], + inputs: Sequence[Any] | Any, **kwargs: Any, ) -> np.ndarray: """ - Encode a list of sentences into embeddings. + Encode a list of inputs into embeddings. - :param sentences: A list of sentences to encode. + :param inputs: A list of inputs to encode (strings, images, etc.). :param **kwargs: Additional keyword arguments. - :return: The embeddings of the sentences. + :return: The embeddings of the inputs. """ ... # pragma: no cover -def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]: - """Convert a record to a frozendict.""" - return frozendict({k: record.get(k, "") for k in columns}) +def make_hashable(value: Any) -> Any: + """ + Convert a value to a hashable representation for use as dict keys. + + Strings and other hashable types are returned as-is. + Non-hashable types (like PIL images, numpy arrays) are hashed to a string. + + :param value: The value to make hashable. + :return: A hashable representation of the value. + """ + # Fast path: most values are strings or already hashable + if isinstance(value, (str, int, float, bool, type(None))): + return value + # Handle objects with tobytes() (PIL Image, numpy array, etc.) + if hasattr(value, "tobytes"): + return hashlib.md5(value.tobytes()).hexdigest() + # Fallback: try to hash, otherwise stringify + try: + hash(value) + return value + except TypeError: + return str(value) + + +def coerce_value(value: Any) -> Any: + """ + Coerce a value for encoding: stringify primitives, keep complex types raw. + + This ensures primitives (int, float, bool) work with text encoders, + while complex types (PIL images, tensors, etc.) are passed through for multimodal encoders. + + :param value: The value to coerce. + :return: The coerced value. + """ + if isinstance(value, (str, bytes)): + return value + if isinstance(value, (int, float, bool)): + return str(value) + return value # Complex types (images, tensors, etc.) + + +def to_frozendict(record: dict[str, Any], columns: Sequence[str] | set[str]) -> frozendict[str, Any]: + """ + Convert a record to a frozendict with hashable values. + + :param record: The record to convert. + :param columns: The columns to include. + :return: A frozendict with only the specified columns (values made hashable). + :raises ValueError: If a column is missing from the record. + """ + try: + return frozendict({k: make_hashable(record[k]) for k in columns}) + except KeyError as e: + missing = e.args[0] + raise ValueError(f"Missing column '{missing}' in record {record}") from e def compute_candidate_limit( @@ -62,7 +114,7 @@ def compute_candidate_limit( def featurize( - records: Sequence[dict[str, str]], + records: Sequence[dict[str, Any]], columns: Sequence[str], model: Encoder, ) -> np.ndarray: @@ -73,81 +125,16 @@ def featurize( :param columns: Columns to featurize. :param model: An Encoder model. :return: The embeddings of the records. + :raises ValueError: If a column is missing from one or more records. """ # Extract the embeddings for each column across all records embeddings_per_col = [] for col in columns: - col_texts = [r[col] for r in records] + try: + col_texts = [r[col] for r in records] + except KeyError as e: + raise ValueError(f"Missing column '{col}' in one or more records") from e col_emb = model.encode(col_texts) embeddings_per_col.append(np.asarray(col_emb)) return np.concatenate(embeddings_per_col, axis=1) - - -def remove_exact_duplicates( - records: Sequence[dict[str, str]], - columns: Sequence[str], - reference_records: list[list[dict[str, str]]] | None = None, -) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]: - """ - Remove exact duplicates based on the unpacked string representation of each record. - - If reference_records is None, the function will only check for duplicates within the records list. - - :param records: A list of records to check for exact duplicates. - :param columns: Columns to unpack. - :param reference_records: A list of records to compare against. These are already unpacked - :return: A list of deduplicated records and a list of duplicates. - """ - deduplicated = [] - duplicates = [] - - column_set = set(columns) - # Build a seen set from reference_records if provided - seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list) - if reference_records is not None: - for record_set in reference_records: - key = to_frozendict(record_set[0], column_set) - seen[key] = list(record_set) - in_one_set = reference_records is None - - for record in records: - frozen_record = to_frozendict(record, column_set) - if duplicated_records := seen.get(frozen_record): - duplicates.append((record, duplicated_records)) - else: - deduplicated.append(record) - # Only add current documents to seen if no reference set is used - if in_one_set: - seen[frozen_record].append(record) - - return deduplicated, duplicates - - -def prepare_records( - records: Sequence[Record], columns: Sequence[str] | None -) -> tuple[list[dict[str, str]], Sequence[str], bool]: - """ - Validate and prepare records for processing. - - :param records: A list of records (strings or dictionaries). - :param columns: Columns to use if records are dictionaries. - :return: Tuple of (dict_records, columns, was_string). - :raises ValueError: If records are empty. - :raises ValueError: If columns are not provided for dictionary records. - """ - if len(records) == 0: - raise ValueError("records must not be empty") - - if columns is None and isinstance(records[0], dict): - raise ValueError("Columns must be specified when passing dictionaries.") - - if isinstance(records[0], str): - columns = ["text"] - dict_records: list[dict[str, str]] = [{"text": str(record)} for record in records] - was_string = True - else: - dict_records = list(records) - was_string = False - - return dict_records, columns, was_string diff --git a/semhash/version.py b/semhash/version.py index 9bfefb0..cc35d8d 100644 --- a/semhash/version.py +++ b/semhash/version.py @@ -1,2 +1,2 @@ -__version_triple__ = (0, 3, 3) -__version__ = ".".join(map(str, __version_triple__)) +__version_triple__ = (0, 3, 3) # pragma: no cover +__version__ = ".".join(map(str, __version_triple__)) # pragma: no cover diff --git a/tests/test_datamodels.py b/tests/test_datamodels.py index 59c2563..307eeeb 100644 --- a/tests/test_datamodels.py +++ b/tests/test_datamodels.py @@ -1,8 +1,6 @@ import pytest -import semhash -import semhash.version -from semhash.datamodels import DeduplicationResult, DuplicateRecord, SelectedWithDuplicates +from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, SelectedWithDuplicates def test_deduplication_scoring() -> None: @@ -25,34 +23,27 @@ def test_deduplication_scoring_exact() -> None: assert d.exact_duplicate_ratio == 0.2 -def test_deduplication_scoring_exact_empty() -> None: - """Test the deduplication scoring.""" - d = DeduplicationResult([], [], 0.8, columns=["text"]) - assert d.exact_duplicate_ratio == 0.0 - - def test_deduplication_scoring_empty() -> None: - """Test the deduplication scoring.""" + """Test the deduplication scoring with empty results.""" d = DeduplicationResult([], [], 0.8, columns=["text"]) assert d.duplicate_ratio == 0.0 + assert d.exact_duplicate_ratio == 0.0 def test_rethreshold() -> None: - """Test rethresholding the duplicates.""" + """Test rethresholding the duplicates, including empty case.""" d = DuplicateRecord("a", False, [("b", 0.9), ("c", 0.8)]) d._rethreshold(0.85) assert d.duplicates == [("b", 0.9)] - -def test_rethreshold_empty() -> None: - """Test rethresholding the duplicates.""" - d = DuplicateRecord("a", False, []) - d._rethreshold(0.85) - assert d.duplicates == [] + # Empty case + d_empty = DuplicateRecord("a", False, []) + d_empty._rethreshold(0.85) + assert d_empty.duplicates == [] def test_get_least_similar_from_duplicates() -> None: - """Test getting the least similar duplicates.""" + """Test getting the least similar duplicates, including empty case.""" d = DeduplicationResult( ["a", "b", "c"], [DuplicateRecord("a", False, [("b", 0.9), ("c", 0.7)]), DuplicateRecord("b", False, [("c", 0.8)])], @@ -61,11 +52,9 @@ def test_get_least_similar_from_duplicates() -> None: result = d.get_least_similar_from_duplicates(1) assert result == [("a", "c", 0.7)] - -def test_get_least_similar_from_duplicates_empty() -> None: - """Test getting the least similar duplicates.""" - d = DeduplicationResult([], [], 0.8, columns=["text"]) - assert d.get_least_similar_from_duplicates(1) == [] + # Empty case + d_empty = DeduplicationResult([], [], 0.8, columns=["text"]) + assert d_empty.get_least_similar_from_duplicates(1) == [] def test_rethreshold_deduplication_result() -> None: @@ -243,3 +232,10 @@ def test_selected_with_duplicates_cache_invalidation_on_rethreshold() -> None: assert result2[0].duplicates[0][0] == "duplicate_1" # Results should be different objects assert result1 is not result2 + + +def test_filter_result_empty() -> None: + """Test FilterResult ratios with empty lists.""" + result = FilterResult(selected=[], filtered=[]) + assert result.filter_ratio == 0.0 + assert result.selected_ratio == 1.0 diff --git a/tests/test_semhash.py b/tests/test_semhash.py index 1b4105c..55119fd 100644 --- a/tests/test_semhash.py +++ b/tests/test_semhash.py @@ -141,27 +141,31 @@ def test_deduplicate_with_only_exact_duplicates(model: Encoder) -> None: def test_self_find_representative(model: Encoder, train_texts: list[str]) -> None: """Test the self_find_representative method.""" semhash = SemHash.from_records(records=train_texts, model=model) - result = semhash.self_find_representative( - candidate_limit=5, - selection_size=3, - diversity=0.5, - ) + + # Test with explicit candidate_limit + result = semhash.self_find_representative(candidate_limit=5, selection_size=3, diversity=0.5) assert len(result.selected) == 3, "Expected 3 representatives" selected = {r["text"] for r in result.selected} - assert selected == { - "blueberry", - "pineapple", - "grape", - }, "Expected representatives to be blueberry, pineapple, and grape" + assert selected == {"blueberry", "pineapple", "grape"} + + # Test with auto candidate_limit (default) + result_auto = semhash.self_find_representative(selection_size=3, diversity=0.5) + assert len(result_auto.selected) == 3 def test_find_representative(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: """Test the find_representative method.""" semhash = SemHash.from_records(records=train_texts, model=model) + + # Test with explicit candidate_limit result = semhash.find_representative(records=test_texts, candidate_limit=5, selection_size=3, diversity=0.5) assert len(result.selected) == 3, "Expected 3 representatives" selected = {r["text"] for r in result.selected} - assert selected == {"grapefruit", "banana", "apple"}, "Expected representatives to be grapefruit, banana, and apple" + assert selected == {"grapefruit", "banana", "apple"} + + # Test with auto candidate_limit (default) + result_auto = semhash.find_representative(records=test_texts, selection_size=3, diversity=0.5) + assert len(result_auto.selected) == 3 def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: @@ -173,10 +177,23 @@ def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: lis filtered = {r["text"] for r in result.filtered} assert filtered == {"motorcycle", "plane"}, "Expected outliers to be motorcycle and plane" + # Test FilterResult ratio properties + assert result.filter_ratio == len(result.filtered) / len(test_texts) + assert result.selected_ratio == len(result.selected) / len(test_texts) + assert result.filter_ratio + result.selected_ratio == 1.0 + # Test with outlier_percentage=0.0 (should return no outliers) result_zero = semhash.filter_outliers(records=test_texts, outlier_percentage=0.0) assert result_zero.filtered == [] assert len(result_zero.selected) == len(test_texts) + assert result_zero.filter_ratio == 0.0 + assert result_zero.selected_ratio == 1.0 + + # Invalid outlier_percentage raises ValueError + with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"): + semhash.filter_outliers(records=test_texts, outlier_percentage=-0.1) + with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"): + semhash.filter_outliers(records=test_texts, outlier_percentage=1.5) def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None: @@ -193,6 +210,12 @@ def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None: assert result_zero.filtered == [] assert len(result_zero.selected) == len(train_texts) + # Invalid outlier_percentage raises ValueError + with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"): + semhash.self_filter_outliers(outlier_percentage=-0.1) + with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"): + semhash.self_filter_outliers(outlier_percentage=1.5) + def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None: """Test the _diversify method.""" @@ -226,30 +249,80 @@ def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None: def test_from_embeddings(model: Encoder, train_texts: list[str]) -> None: """Test from_embeddings constructor with validation and comparison to from_records.""" - # Test validation: mismatched shapes + # Validation: empty records + with pytest.raises(ValueError, match="records must not be empty"): + SemHash.from_embeddings(embeddings=np.array([[]]), records=[], model=model) + + # Validation: non-2D embeddings + with pytest.raises(ValueError, match="must be a 2D array"): + SemHash.from_embeddings(embeddings=np.array([1, 2, 3]), records=["a", "b", "c"], model=model) + + # Validation: mismatched shapes with pytest.raises(ValueError, match="Number of embeddings"): wrong_embeddings = model.encode(["apple", "banana"]) SemHash.from_embeddings(embeddings=wrong_embeddings, records=train_texts, model=model) # Test that from_embeddings behaves same as from_records semhash_from_records = SemHash.from_records(records=train_texts, model=model) - embeddings = model.encode(train_texts) semhash_from_embeddings = SemHash.from_embeddings(embeddings=embeddings, records=train_texts, model=model) - # Both should give same deduplication results result1 = semhash_from_records.self_deduplicate(threshold=0.95) result2 = semhash_from_embeddings.self_deduplicate(threshold=0.95) - assert len(result1.selected) == len(result2.selected) - assert len(result1.filtered) == len(result2.filtered) # Test that from_embeddings keeps first-occurrence embeddings and drops duplicates records = ["apple", "banana", "apple", "cherry"] embeddings = np.array([[0.0], [1.0], [2.0], [3.0]], dtype=np.float32) - semhash = SemHash.from_embeddings(embeddings=embeddings, records=records, model=model) - assert semhash.index.vectors.shape == (3, 1) - # Should keep embeddings at indices 0, 1, 3 (first occurrences of img1, img2, img3) assert semhash.index.vectors.tolist() == [[0.0], [1.0], [3.0]] + + +def test_from_records_edge_cases(model: Encoder) -> None: + """Test from_records edge cases: coercion, order preservation, None rejection.""" + # Coerces non-string dict values to strings + records = [{"id": 1}, {"id": 2}, {"id": 1}] # Integers, with duplicate + semhash = SemHash.from_records(records, columns=["id"], model=model) + assert semhash.index.vectors.shape[0] == 2 # Deduplicated + assert 2 in [len(bucket) for bucket in semhash.index.items] # id=1 bucket has 2 + + # Preserves first-occurrence order (deterministic) + texts = ["zebra", "apple", "zebra", "banana", "apple", "cherry"] + semhash = SemHash.from_records(texts, model=model) + firsts = [bucket[0]["text"] for bucket in semhash.index.items] + assert firsts == ["zebra", "apple", "banana", "cherry"] + + # Rejects None values in dict records + with pytest.raises(ValueError, match="has None value"): + SemHash.from_records([{"text": "apple"}, {"text": None}], columns=["text"], model=model) + + +def test_deduplicate_edge_cases(model: Encoder) -> None: + """Test deduplicate() edge cases: coercion, None rejection, empty records, type mismatches.""" + semhash = SemHash.from_records(["1", "2", "3"], model=model) + + # Coerces non-string dict values + result = semhash.deduplicate([{"text": 1}, {"text": 4}], threshold=0.95) + assert len(result.filtered) + len(result.selected) == 2 + + # Rejects None values + with pytest.raises(ValueError, match="has None value"): + semhash.deduplicate([{"text": "cherry"}, {"text": None}], threshold=0.95) + + # Rejects empty records + with pytest.raises(ValueError, match="records must not be empty"): + semhash.deduplicate([], threshold=0.95) + + # Type mismatch: strings passed to dict-based index + semhash_dict = SemHash.from_records([{"col": "a"}, {"col": "b"}], columns=["col"], model=model) + with pytest.raises(ValueError, match="Records were not originally strings"): + semhash_dict.deduplicate(["x", "y"], threshold=0.95) + + # Type mismatch: mixed strings + with pytest.raises(ValueError, match="Records must be all strings"): + semhash.deduplicate(["a", {"text": "b"}], threshold=0.95) + + # Type mismatch: mixed dicts + with pytest.raises(ValueError, match="Records must be all dictionaries"): + semhash_dict.deduplicate([{"col": "a"}, "b"], threshold=0.95) diff --git a/tests/test_utils.py b/tests/test_utils.py index 371fb16..3224d73 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,23 +2,74 @@ import pytest from frozendict import frozendict -from semhash.utils import ( - Encoder, - compute_candidate_limit, - featurize, - prepare_records, - remove_exact_duplicates, - to_frozendict, -) +from semhash.records import prepare_records, remove_exact_duplicates +from semhash.utils import Encoder, coerce_value, compute_candidate_limit, featurize, make_hashable, to_frozendict + + +def test_make_hashable() -> None: + """Test make_hashable with various types.""" + # Fast path: primitives + assert make_hashable("hello") == "hello" + assert make_hashable(42) == 42 + assert make_hashable(3.14) == 3.14 + assert make_hashable(True) is True + assert make_hashable(None) is None + + # Objects with tobytes() (simulate PIL Image or numpy array) + class MockImage: + def tobytes(self) -> bytes: + return b"fake_image_data" + + img = MockImage() + result = make_hashable(img) + assert isinstance(result, str) + assert len(result) == 32 # MD5 hex digest + + # Hashable objects (like tuples) + assert make_hashable((1, 2, 3)) == (1, 2, 3) + + # Non-hashable fallback to string + unhashable = {"key": "value"} + result = make_hashable(unhashable) + assert result == "{'key': 'value'}" + + +def test_coerce_value() -> None: + """Test coerce_value for encoding preparation.""" + # Strings and bytes pass through + assert coerce_value("hello") == "hello" + assert coerce_value(b"bytes") == b"bytes" + + # Primitives converted to strings + assert coerce_value(42) == "42" + assert coerce_value(3.14) == "3.14" + assert coerce_value(True) == "True" + + # Complex types pass through unchanged + class MockImage: + pass + + img = MockImage() + assert coerce_value(img) is img def test_to_frozendict() -> None: - """Test converting dict to frozendict.""" + """Test converting dict to frozendict, including error cases.""" record = {"a": "1", "b": "2", "c": "3"} + + # Basic case: select subset of columns result = to_frozendict(record, {"a", "c"}) assert result == frozendict({"a": "1", "c": "3"}) assert "b" not in result + # Works with Sequence (not just set) + result = to_frozendict(record, ["a", "b"]) + assert result == frozendict({"a": "1", "b": "2"}) + + # Missing column raises ValueError + with pytest.raises(ValueError, match="Missing column 'missing'"): + to_frozendict(record, {"a", "missing"}) + def test_compute_candidate_limit() -> None: """Test candidate limit computation.""" @@ -33,30 +84,51 @@ def test_compute_candidate_limit() -> None: def test_featurize(model: Encoder) -> None: - """Test featurizing records.""" + """Test featurizing records, including error cases.""" records = [{"text": "hello"}, {"text": "world"}] embeddings = featurize(records, ["text"], model) assert embeddings.shape == (2, 128) # Model has 128 dims assert isinstance(embeddings, np.ndarray) + # Missing column raises ValueError + with pytest.raises(ValueError, match="Missing column 'missing'"): + featurize(records, ["missing"], model) + def test_remove_exact_duplicates() -> None: - """Test exact duplicate removal.""" + """Test exact duplicate removal, with and without reference records.""" + # Basic case: remove duplicates within same list records = [ {"text": "hello", "id": "1"}, {"text": "world", "id": "2"}, {"text": "hello", "id": "3"}, ] deduplicated, duplicates = remove_exact_duplicates(records, ["text"]) - assert len(deduplicated) == 2 assert len(duplicates) == 1 assert duplicates[0][0] == {"text": "hello", "id": "3"} + # With reference_records: cross-dataset filtering + reference_records = [ + [{"text": "apple"}], + [{"text": "banana"}, {"text": "banana"}], + ] + new_records = [ + {"text": "cherry"}, # New + {"text": "apple"}, # Exists in reference + {"text": "date"}, # New + {"text": "banana"}, # Exists in reference + ] + deduplicated, duplicates = remove_exact_duplicates(new_records, ["text"], reference_records=reference_records) + assert len(deduplicated) == 2 + assert {"text": "cherry"} in deduplicated + assert {"text": "date"} in deduplicated + assert len(duplicates) == 2 + def test_prepare_records() -> None: - """Test preparing records.""" - # String records + """Test preparing records, including validation and edge cases.""" + # String records -> converts to dicts with "text" column records = ["hello", "world"] dict_records, columns, was_string = prepare_records(records, None) assert was_string is True @@ -71,6 +143,15 @@ def test_prepare_records() -> None: assert dict_records == records # Dict records without columns raises ValueError - records = [{"text": "hello"}] with pytest.raises(ValueError, match="Columns must be specified"): - prepare_records(records, None) + prepare_records([{"text": "hello"}], None) + + # Empty records raises ValueError + with pytest.raises(ValueError, match="records must not be empty"): + prepare_records([], None) + + # Mixed types rejected + with pytest.raises(ValueError, match="All records must be"): + prepare_records(["a", {"text": "b"}], None) + with pytest.raises(ValueError, match="All records must be"): + prepare_records([{"text": "a"}, "b"], ["text"]) diff --git a/uv.lock b/uv.lock index 0bd321f..9bd57ea 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,11 @@ version = 1 revision = 3 requires-python = ">=3.10" +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version < '3.11'", +] [[package]] name = "asttokens"