Skip to content

Commit 52c95c5

Browse files
committed
Added image benchmarks
1 parent 356d577 commit 52c95c5

File tree

8 files changed

+360
-15
lines changed

8 files changed

+360
-15
lines changed

Makefile

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,11 @@ fix:
1616

1717
test:
1818
uv run pytest --cov=semhash --cov-report=term-missing
19+
20+
benchmark-text:
21+
uv run python -m benchmarks.run_text_benchmarks
22+
23+
benchmark-image:
24+
uv run python -m benchmarks.run_image_benchmarks
25+
26+
benchmark: benchmark-text benchmark-image

README.md

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,28 @@ from semhash import SemHash
303303
class VisionEncoder:
304304
"""Custom encoder using timm models. Implements the Encoder protocol."""
305305

306-
def __init__(self, model_name: str = "mobilenetv3_small_100"):
306+
def __init__(self, model_name: str = "mobilenetv3_small_100.lamb_in1k"):
307307
self.model = timm.create_model(model_name, pretrained=True, num_classes=0).eval()
308-
self.transform = timm.data.create_transform(**timm.data.resolve_model_data_config(self.model))
308+
data_config = timm.data.resolve_model_data_config(self.model)
309+
self.transform = timm.data.create_transform(**data_config, is_training=False)
309310

310-
def encode(self, inputs):
311+
def encode(self, inputs, batch_size: int = 128):
311312
"""Encode a batch of PIL images into embeddings."""
313+
import numpy as np
314+
315+
# Convert grayscale to RGB if needed
316+
rgb_inputs = [img.convert("RGB") if img.mode != "RGB" else img for img in inputs]
317+
318+
# Process in batches to avoid memory issues
319+
all_embeddings = []
312320
with torch.no_grad():
313-
return self.model(torch.stack([self.transform(img) for img in inputs])).numpy()
321+
for i in range(0, len(rgb_inputs), batch_size):
322+
batch_inputs = rgb_inputs[i : i + batch_size]
323+
batch = torch.stack([self.transform(img) for img in batch_inputs])
324+
embeddings = self.model(batch).numpy()
325+
all_embeddings.append(embeddings)
326+
327+
return np.vstack(all_embeddings)
314328

315329
# Load image dataset
316330
dataset = load_dataset("uoft-cs/cifar10", split="test")
@@ -513,9 +527,22 @@ deduplicated_records = semhash.self_deduplicate().selected
513527

514528
## Benchmarks
515529

516-
SemHash is extremely fast and scales to large datasets with millions of records. We've benchmarked both single-dataset deduplication and train/test deduplication across a variety of datasets. For example, deduplicating 1.8M records takes only ~83 seconds on CPU.
530+
SemHash is extremely fast and scales to large datasets with millions of records. We've benchmarked both text and image deduplication across a variety of datasets. For example, deduplicating text 1.8M records takes only ~83 seconds on CPU.
531+
532+
For detailed benchmark results and analysis, see the [benchmarks directory](benchmarks/README.md).
533+
534+
### Running Benchmarks
517535

518-
For detailed benchmark results including performance metrics across 17 datasets, as well as code to reproduce the benchmarks, see the [benchmarks directory](benchmarks/README.md).
536+
```bash
537+
# Run text benchmarks
538+
make benchmark-text
539+
540+
# Run image benchmarks
541+
make benchmark-image
542+
543+
# Run all benchmarks
544+
make benchmark
545+
```
519546

520547
## License
521548

benchmarks/README.md

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# SemHash Benchmarks
22

3-
This directory contains the benchmarking code and results for SemHash. The benchmarks measure deduplication performance and speed across a variety of datasets.
3+
This directory contains the benchmarking code and results for SemHash. The benchmarks measure deduplication performance and speed across a variety of text and image datasets.
44

5-
## Setup
5+
## Text Benchmarks
66

7-
All benchmarks were run with the following configuration:
7+
### Setup
8+
9+
All text benchmarks were run with the following configuration:
810
- **CPU-only**: All benchmarks run on CPU (no GPU acceleration)
911
- **ANN backend**: Default backend (USearch)
1012
- **Encoder**: Default encoder ([potion-base-8M](https://huggingface.co/minishlab/potion-base-8M))
1113
- **Timing**: Includes encoding time, index building time, and deduplication time
14+
- **Dependencies**: Requires `datasets` package (`pip install datasets`)
1215

13-
## Results
16+
### Results
1417

1518
### Train Deduplication Benchmark
1619

@@ -60,7 +63,7 @@ This benchmark measures the performance of deduplicating a test dataset against
6063
| squad_v2 | 130319 | 11873 | 11863 | 0.08 | 7.13 |
6164
| wikitext | 1801350 | 4358 | 2139 | 50.92 | 40.32 |
6265

63-
## Key Findings
66+
### Key Findings
6467

6568
SemHash is extremely fast and scales to large datasets with millions of records. Some notable findings include:
6669

@@ -70,12 +73,77 @@ SemHash is extremely fast and scales to large datasets with millions of records.
7073
- `student`: 52% of test data overlaps with training data
7174
- `wikitext`: 51% of test data overlaps with training data
7275

73-
## Running the Benchmarks
76+
### Running Text Benchmarks
77+
78+
To run the text benchmarks yourself:
79+
80+
```bash
81+
# Install dependencies
82+
pip install datasets
83+
84+
# Run benchmarks
85+
python -m benchmarks.run_text_benchmarks
86+
# Or using make
87+
make benchmark-text
88+
```
89+
90+
## Image Benchmarks
91+
92+
### Setup
93+
94+
All image benchmarks were run with the following configuration:
95+
- **Device**: Apple Silicon GPU (MPS)
96+
- **ANN backend**: Default backend (USearch)
97+
- **Encoder**: MobileNetV3-Small ([mobilenetv3_small_100.lamb_in1k](https://huggingface.co/timm/mobilenetv3_small_100.lamb_in1k))
98+
- **Batch size**: 128 images per batch
99+
- **Timing**: Includes encoding time, index building time, and deduplication time
100+
101+
### Results
102+
103+
#### Train Deduplication Benchmark
104+
105+
This benchmark measures the performance of deduplicating within a single training dataset.
106+
107+
| Dataset | Original Train Size | Deduplicated Train Size | % Removed | Deduplication Time (s) |
108+
|----------------------|----------------------|--------------------------|------------|--------------------------|
109+
| cifar10 | 50000 | 48274 | 3.45 | 61.20 |
110+
| fashion_mnist | 60000 | 16714 | 72.14 | 86.61 |
111+
112+
#### Train/Test Deduplication Benchmark
113+
114+
This benchmark measures the performance of deduplicating a test dataset against a training dataset.
115+
116+
| Dataset | Train Size | Test Size | Deduplicated Test Size | % Removed | Deduplication Time (s) |
117+
|----------------------|--------------|--------------|--------------------------|------------|--------------------------|
118+
| cifar10 | 50000 | 10000 | 9397 | 6.03 | 67.43 |
119+
| fashion_mnist | 60000 | 10000 | 2052 | 79.48 | 72.14 |
120+
121+
### Key Findings
74122

75-
To run the benchmarks yourself:
123+
- **Fashion-MNIST high deduplication**: Fashion-MNIST shows very high duplication rates (72% train, 79% test) due to the simple nature of the dataset (10 clothing categories with similar items)
124+
- **CIFAR-10 moderate deduplication**: CIFAR-10 shows lower duplication (3.45% train, 6.03% test) as it contains more diverse natural images
125+
- **Speed**: Image deduplication is fast even for large datasets (60k images in ~87 seconds on MPS)
126+
127+
### Running Image Benchmarks
128+
129+
To run the image benchmarks yourself:
76130

77131
```bash
78-
python -m benchmarks.run_benchmarks
132+
# Install dependencies
133+
pip install timm torch datasets
134+
135+
# Run benchmarks
136+
python -m benchmarks.run_image_benchmarks
137+
# Or using make
138+
make benchmark-image
79139
```
80140

81-
The datasets can be customized by editing `benchmarks/data.py`.
141+
The image datasets can be customized by editing `benchmarks/data.py` (see `IMAGE_DATASET_DICT`).
142+
143+
## Running All Benchmarks
144+
145+
To run both text and image benchmarks:
146+
147+
```bash
148+
make benchmark
149+
```

benchmarks/data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class DatasetRecord:
1212
columns: list[str] | None = None
1313
split_one: str = "train"
1414
split_two: str = "test"
15+
modality: str = "text"
1516

1617

1718
DATASET_DICT: dict[str, DatasetRecord] = {
@@ -41,3 +42,8 @@ class DatasetRecord:
4142
name="Salesforce/wikitext", text_name="text", label_name="text", sub_directory="wikitext-103-raw-v1"
4243
),
4344
}
45+
46+
IMAGE_DATASET_DICT: dict[str, DatasetRecord] = {
47+
"cifar10": DatasetRecord(name="uoft-cs/cifar10", columns=["img"], modality="image"),
48+
"fashion_mnist": DatasetRecord(name="fashion_mnist", columns=["image"], modality="image"),
49+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[
2+
{
3+
"dataset": "cifar10",
4+
"original_train_size": 50000,
5+
"deduplicated_train_size": 48274,
6+
"percent_removed": 3.4519999999999995,
7+
"build_time_seconds": 56.00128899999254,
8+
"deduplication_time_seconds": 5.201297917010379,
9+
"time_seconds": 61.20258691700292
10+
},
11+
{
12+
"dataset": "fashion_mnist",
13+
"original_train_size": 60000,
14+
"deduplicated_train_size": 16714,
15+
"percent_removed": 72.14333333333333,
16+
"build_time_seconds": 61.14413262500602,
17+
"deduplication_time_seconds": 25.46288070900482,
18+
"time_seconds": 86.60701333401084
19+
}
20+
]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[
2+
{
3+
"dataset": "cifar10",
4+
"train_size": 50000,
5+
"test_size": 10000,
6+
"deduplicated_test_size": 9397,
7+
"percent_removed": 6.030000000000002,
8+
"build_time_seconds": 56.00128899999254,
9+
"deduplication_time_seconds": 11.428115875009098,
10+
"time_seconds": 67.42940487500164
11+
},
12+
{
13+
"dataset": "fashion_mnist",
14+
"train_size": 60000,
15+
"test_size": 10000,
16+
"deduplicated_test_size": 2052,
17+
"percent_removed": 79.47999999999999,
18+
"build_time_seconds": 61.14413262500602,
19+
"deduplication_time_seconds": 10.998616750002839,
20+
"time_seconds": 72.14274937500886
21+
}
22+
]

0 commit comments

Comments
 (0)