Skip to content

Commit 7d8a417

Browse files
committed
Updated ColPali example to use binary quantization [skip ci]
1 parent 267d796 commit 7d8a417

File tree

2 files changed

+10
-59
lines changed

2 files changed

+10
-59
lines changed

examples/colbert/exact_binary.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

examples/colpali/exact.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from colpali_engine.models import ColQwen2, ColQwen2Processor
22
from datasets import load_dataset
3-
from pgvector.psycopg import register_vector
3+
from pgvector.psycopg import register_vector, Bit
44
import psycopg
55
import torch
66

@@ -10,17 +10,17 @@
1010
register_vector(conn)
1111

1212
conn.execute('DROP TABLE IF EXISTS documents')
13-
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, embeddings vector(128)[])')
13+
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, embeddings bit(128)[])')
1414
conn.execute("""
15-
CREATE OR REPLACE FUNCTION max_sim(document vector[], query vector[]) RETURNS double precision AS $$
15+
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
1616
WITH queries AS (
1717
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
1818
),
1919
documents AS (
2020
SELECT unnest(document) AS document
2121
),
2222
similarities AS (
23-
SELECT query_number, 1 - (document <=> query) AS similarity FROM queries CROSS JOIN documents
23+
SELECT query_number, 1 - ((document <~> query) / bit_length(query)) AS similarity FROM queries CROSS JOIN documents
2424
),
2525
max_similarities AS (
2626
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
@@ -40,13 +40,17 @@ def generate_embeddings(processed):
4040
return model(**processed.to(model.device)).to(device='cpu', dtype=torch.float32)
4141

4242

43+
def binary_quantize(embedding):
44+
return Bit(embedding > 0)
45+
46+
4347
input = load_dataset('vidore/docvqa_test_subsampled', split='test[:3]')['image']
4448
for content in input:
45-
embeddings = [e.numpy() for e in generate_embeddings(processor.process_images([content]))[0]]
49+
embeddings = [binary_quantize(e.numpy()) for e in generate_embeddings(processor.process_images([content]))[0]]
4650
conn.execute('INSERT INTO documents (embeddings) VALUES (%s)', (embeddings,))
4751

4852
query = 'dividend'
49-
query_embeddings = [e.numpy() for e in generate_embeddings(processor.process_queries([query]))[0]]
53+
query_embeddings = [binary_quantize(e.numpy()) for e in generate_embeddings(processor.process_queries([query]))[0]]
5054
result = conn.execute('SELECT id, max_sim(embeddings, %s) AS max_sim FROM documents ORDER BY max_sim DESC LIMIT 5', (query_embeddings,)).fetchall()
5155
for row in result:
5256
print(row)

0 commit comments

Comments
 (0)