11from colpali_engine .models import ColQwen2 , ColQwen2Processor
22from datasets import load_dataset
3- from pgvector .psycopg import register_vector
3+ from pgvector .psycopg import register_vector , Bit
44import psycopg
55import torch
66
1010register_vector (conn )
1111
1212conn .execute ('DROP TABLE IF EXISTS documents' )
13- conn .execute ('CREATE TABLE documents (id bigserial PRIMARY KEY, embeddings vector (128)[])' )
13+ conn .execute ('CREATE TABLE documents (id bigserial PRIMARY KEY, embeddings bit (128)[])' )
1414conn .execute ("""
15- CREATE OR REPLACE FUNCTION max_sim(document vector [], query vector []) RETURNS double precision AS $$
15+ CREATE OR REPLACE FUNCTION max_sim(document bit [], query bit []) RETURNS double precision AS $$
1616 WITH queries AS (
1717 SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
1818 ),
1919 documents AS (
2020 SELECT unnest(document) AS document
2121 ),
2222 similarities AS (
23- SELECT query_number, 1 - (document <= > query) AS similarity FROM queries CROSS JOIN documents
23+ SELECT query_number, 1 - (( document <~ > query) / bit_length(query) ) AS similarity FROM queries CROSS JOIN documents
2424 ),
2525 max_similarities AS (
2626 SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
@@ -40,13 +40,17 @@ def generate_embeddings(processed):
4040 return model (** processed .to (model .device )).to (device = 'cpu' , dtype = torch .float32 )
4141
4242
43+ def binary_quantize (embedding ):
44+ return Bit (embedding > 0 )
45+
46+
4347input = load_dataset ('vidore/docvqa_test_subsampled' , split = 'test[:3]' )['image' ]
4448for content in input :
45- embeddings = [e .numpy () for e in generate_embeddings (processor .process_images ([content ]))[0 ]]
49+ embeddings = [binary_quantize ( e .numpy () ) for e in generate_embeddings (processor .process_images ([content ]))[0 ]]
4650 conn .execute ('INSERT INTO documents (embeddings) VALUES (%s)' , (embeddings ,))
4751
4852query = 'dividend'
49- query_embeddings = [e .numpy () for e in generate_embeddings (processor .process_queries ([query ]))[0 ]]
53+ query_embeddings = [binary_quantize ( e .numpy () ) for e in generate_embeddings (processor .process_queries ([query ]))[0 ]]
5054result = conn .execute ('SELECT id, max_sim(embeddings, %s) AS max_sim FROM documents ORDER BY max_sim DESC LIMIT 5' , (query_embeddings ,)).fetchall ()
5155for row in result :
5256 print (row )
0 commit comments