Skip to content

Commit 368b363

Browse files
committed
Added ColBERT example for binary embeddings [skip ci]
1 parent dbc44f4 commit 368b363

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

examples/colbert/exact_binary.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from colbert.infra import ColBERTConfig
2+
from colbert.modeling.checkpoint import Checkpoint
3+
from pgvector.psycopg import register_vector, Bit
4+
import psycopg
5+
6+
conn = psycopg.connect(dbname='pgvector_example', autocommit=True)
7+
8+
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
9+
register_vector(conn)
10+
11+
conn.execute('DROP TABLE IF EXISTS documents')
12+
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embeddings bit(128)[])')
13+
conn.execute("""
14+
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
15+
WITH queries AS (
16+
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
17+
),
18+
documents AS (
19+
SELECT unnest(document) AS document
20+
),
21+
similarities AS (
22+
SELECT query_number, 1 - ((document <~> query) / bit_length(query)) AS similarity FROM queries CROSS JOIN documents
23+
),
24+
max_similarities AS (
25+
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
26+
)
27+
SELECT SUM(max_similarity) FROM max_similarities
28+
$$ LANGUAGE SQL
29+
""")
30+
31+
32+
def binary_quantize(embeddings):
33+
return [Bit(e.numpy()) for e in (embeddings > 0)]
34+
35+
36+
config = ColBERTConfig(doc_maxlen=220, query_maxlen=32)
37+
checkpoint = Checkpoint('colbert-ir/colbertv2.0', colbert_config=config, verbose=0)
38+
39+
input = [
40+
'The dog is barking',
41+
'The cat is purring',
42+
'The bear is growling'
43+
]
44+
doc_embeddings = checkpoint.docFromText(input, keep_dims=False)
45+
for content, embeddings in zip(input, doc_embeddings):
46+
embeddings = binary_quantize(embeddings)
47+
conn.execute('INSERT INTO documents (content, embeddings) VALUES (%s, %s)', (content, embeddings))
48+
49+
query = 'puppy'
50+
query_embeddings = binary_quantize(checkpoint.queryFromText([query])[0])
51+
result = conn.execute('SELECT content, max_sim(embeddings, %s) AS max_sim FROM documents ORDER BY max_sim DESC LIMIT 5', (query_embeddings,)).fetchall()
52+
for row in result:
53+
print(row)

0 commit comments

Comments
 (0)