Skip to content

Commit ae39aca

Browse files
committed
Follow up on part 1 for the Colpali blog.
Goes into bit vectors, average vectors and token pooling to make late interaction vectors more scalable.
1 parent 72835b0 commit ae39aca

File tree

3 files changed

+1135
-0
lines changed

3 files changed

+1135
-0
lines changed
Lines changed: 394 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,394 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "e0056726-3359-4a9b-9913-016617525a6d",
6+
"metadata": {},
7+
"source": [
8+
"# Scalable late interaction vectors in Elasticsearch: Bit Vectors #\n",
9+
"\n",
10+
"In this notebook, we will be looking at how to convert late interaction vectors to bit vectors to \n",
11+
"1. Save siginificant disk space \n",
12+
"2. Lower query latency\n",
13+
" \n",
14+
"We will also look at how we can use hamming distance to speed our queries up even further. \n",
15+
"This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook. \n",
16+
" \n",
17+
"Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. "
18+
]
19+
},
20+
{
21+
"cell_type": "markdown",
22+
"id": "49dbcc61-5dab-4cf6-bbc5-7fa898707ce6",
23+
"metadata": {},
24+
"source": [
25+
"This is the key part of this notebook. We use the `to_bit_vectors()` function to convert our vectors into bit vectors. \n",
26+
"The function is simple in essence. Values `> 0` are converted to `1`, values `< 0` are converted to `0`. We then convert our array of `0`s and `1`s to a hex string, that represents our bit vector. \n",
27+
"So don't be surprised that the values that we will be indexing look like strings and not arrays as before. This is intended! \n",
28+
"\n",
29+
"Learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. "
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": 1,
35+
"id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba",
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"import numpy as np\n",
40+
"\n",
41+
"def to_bit_vectors(embedding: list) -> list:\n",
42+
" embeddings = []\n",
43+
" for idx, patch_embedding in enumerate(embedding):\n",
44+
" patch_embedding = np.array(patch_embedding)\n",
45+
" binary_vector = (\n",
46+
" np.packbits(np.where(patch_embedding > 0, 1, 0))\n",
47+
" .astype(np.int8)\n",
48+
" .tobytes()\n",
49+
" .hex()\n",
50+
" )\n",
51+
" embeddings.append(binary_vector)\n",
52+
" return embeddings"
53+
]
54+
},
55+
{
56+
"cell_type": "markdown",
57+
"id": "52b7449b-8fbf-46b7-90c9-330070f6996a",
58+
"metadata": {},
59+
"source": [
60+
"Here we are defining our mapping for our Elasticsearch index. Note how we set the `element_type` parameter to `bit` to inform Elasticsearch that we will be indexing bit vectors in this field. "
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": 2,
66+
"id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8",
67+
"metadata": {},
68+
"outputs": [
69+
{
70+
"name": "stdout",
71+
"output_type": "stream",
72+
"text": [
73+
"[INFO] Index 'searchlabs-colpali-hamming' already exists.\n"
74+
]
75+
}
76+
],
77+
"source": [
78+
"import os\n",
79+
"from dotenv import load_dotenv\n",
80+
"from elasticsearch import Elasticsearch\n",
81+
"\n",
82+
"load_dotenv(\"elastic.env\")\n",
83+
"\n",
84+
"ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n",
85+
"ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n",
86+
"INDEX_NAME = \"searchlabs-colpali-hamming\"\n",
87+
"\n",
88+
"es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n",
89+
"\n",
90+
"mappings = {\n",
91+
" \"mappings\": {\n",
92+
" \"properties\": {\n",
93+
" \"col_pali_vectors\": {\n",
94+
" \"type\": \"rank_vectors\",\n",
95+
" \"element_type\": \"bit\"\n",
96+
" }\n",
97+
" }\n",
98+
" }\n",
99+
"}\n",
100+
"\n",
101+
"if not es.indices.exists(index=INDEX_NAME):\n",
102+
" print(f\"[INFO] Creating index: {INDEX_NAME}\")\n",
103+
" es.indices.create(index=INDEX_NAME, body=mappings)\n",
104+
"else:\n",
105+
" print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n",
106+
"\n",
107+
"def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n",
108+
" for attempt in range(1, retries + 1):\n",
109+
" try:\n",
110+
" return es_client.index(index=index, id=doc_id, document=document)\n",
111+
" except Exception as e:\n",
112+
" if attempt < retries:\n",
113+
" wait_time = initial_backoff * (2 ** (attempt - 1))\n",
114+
" print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n",
115+
" time.sleep(wait_time)\n",
116+
" else:\n",
117+
" print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n",
118+
" raise"
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": 3,
124+
"id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2",
125+
"metadata": {},
126+
"outputs": [
127+
{
128+
"data": {
129+
"application/vnd.jupyter.widget-view+json": {
130+
"model_id": "022b4af8891b4a06962e023c7f92d8f4",
131+
"version_major": 2,
132+
"version_minor": 0
133+
},
134+
"text/plain": [
135+
"Indexing documents: 0%| | 0/500 [00:00<?, ?it/s]"
136+
]
137+
},
138+
"metadata": {},
139+
"output_type": "display_data"
140+
},
141+
{
142+
"name": "stdout",
143+
"output_type": "stream",
144+
"text": [
145+
"Completed indexing 500 documents\n"
146+
]
147+
}
148+
],
149+
"source": [
150+
"from concurrent.futures import ThreadPoolExecutor\n",
151+
"from tqdm.notebook import tqdm\n",
152+
"import pickle\n",
153+
"\n",
154+
"def process_file(file_name, vectors):\n",
155+
" if es.exists(index=INDEX_NAME, id=file_name):\n",
156+
" return\n",
157+
" \n",
158+
" bit_vectors = to_bit_vectors(vectors)\n",
159+
" \n",
160+
" index_document(\n",
161+
" es_client=es, \n",
162+
" index=INDEX_NAME, \n",
163+
" doc_id=file_name, \n",
164+
" document={\"col_pali_vectors\": bit_vectors}\n",
165+
" )\n",
166+
"\n",
167+
"with open('col_pali_vectors.pkl', 'rb') as f:\n",
168+
" file_to_multi_vectors = pickle.load(f)\n",
169+
"\n",
170+
"with ThreadPoolExecutor(max_workers=10) as executor:\n",
171+
" list(tqdm(\n",
172+
" executor.map(lambda item: process_file(*item), file_to_multi_vectors.items()),\n",
173+
" total=len(file_to_multi_vectors),\n",
174+
" desc=\"Indexing documents\"\n",
175+
" ))\n",
176+
"\n",
177+
"print(f\"Completed indexing {len(file_to_multi_vectors)} documents\")"
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": 4,
183+
"id": "1dfc3713-d649-46db-aa81-171d6d92668e",
184+
"metadata": {},
185+
"outputs": [
186+
{
187+
"data": {
188+
"application/vnd.jupyter.widget-view+json": {
189+
"model_id": "064e33061bac40e4802138e30599225b",
190+
"version_major": 2,
191+
"version_minor": 0
192+
},
193+
"text/plain": [
194+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
195+
]
196+
},
197+
"metadata": {},
198+
"output_type": "display_data"
199+
}
200+
],
201+
"source": [
202+
"import torch\n",
203+
"from PIL import Image\n",
204+
"from colpali_engine.models import ColPali, ColPaliProcessor\n",
205+
"\n",
206+
"model_name = \"vidore/colpali-v1.3\"\n",
207+
"model = ColPali.from_pretrained(\n",
208+
" \"vidore/colpali-v1.3\",\n",
209+
" torch_dtype=torch.float32,\n",
210+
" device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
211+
").eval()\n",
212+
"\n",
213+
"col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n",
214+
"\n",
215+
"def create_col_pali_query_vectors(query: str) -> list:\n",
216+
" queries = col_pali_processor.process_queries([query]).to(model.device)\n",
217+
" with torch.no_grad():\n",
218+
" return model(**queries).tolist()[0]"
219+
]
220+
},
221+
{
222+
"cell_type": "raw",
223+
"id": "5e86697d-d9dd-4224-85c8-023c71c88548",
224+
"metadata": {},
225+
"source": [
226+
"Here we run the search against our index comparing our query vector converted to bit vectors to the bit vectors in our index. \n",
227+
"Trading of a bit of accuracy, this is allows us to use hamming distance (`maxSimInvHamming(...)`), which is able to leverage optimzations such as bit-masks, SIMD, etc. Again - learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. \n",
228+
"\n",
229+
"See the cell below about a different technique to query our bit vectors. "
230+
]
231+
},
232+
{
233+
"cell_type": "code",
234+
"execution_count": 5,
235+
"id": "8e322b23-b4bc-409d-9e00-2dab93f6a295",
236+
"metadata": {},
237+
"outputs": [
238+
{
239+
"name": "stdout",
240+
"output_type": "stream",
241+
"text": [
242+
"{\"_source\": false, \"query\": {\"script_score\": {\"query\": {\"match_all\": {}}, \"script\": {\"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\", \"params\": {\"query_vector\": [\"7747bcd9732859c3645aa81036f5c960\", \"729b3c418ba8594a67daa042eca1c961\", \"609e3d8a2ac379c2204aa0cfa8345bdc\", \"30bf378a2ac279da245aa8dfa83c3bdc\", \"64af77ea2acdf9c28c0aa5df863677f4\", \"686f3fce2ac871c26e6aaddf023455ec\", \"383f31a8e8c0f8ca2c4ab54f047c7dec\", \"203b33caaac279da0acaa54f8a3c6bcc\", \"319a63eba8d279ca30dbbccf8f757b8e\", \"203b73ca28d2798a325bb44f8c3c5bce\", \"203bb7caa8d2718a1a4bb14f8a3c5bdc\", \"203bb7caa8d2798a1a6aa14f8a3c5fdc\", \"303b33caa8d2798a0a4aa14f8a3c5bdc\", \"303b33caaad379ca0e4aa14f8a3c5bdc\", \"709b33caaac379ca0c4aa14f8a3c5fdc\", \"708e37eaaac779ca2c4aa1df863c1fdc\", \"648e77ea6acd79caac4ae1df86363ffc\", \"648e77ea6acdf9caac4ae5df06363ffc\", \"608f37ea2ac579ca2c4ea1df063c3ffc\", \"709f37c8aac379ca2c4ea1df863c1fdc\", \"70af31c82ac671ce2c6ab14fc43c1bfc\"]}}}}, \"size\": 5}\n"
243+
]
244+
},
245+
{
246+
"data": {
247+
"text/html": [
248+
"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
249+
],
250+
"text/plain": [
251+
"<IPython.core.display.HTML object>"
252+
]
253+
},
254+
"metadata": {},
255+
"output_type": "display_data"
256+
}
257+
],
258+
"source": [
259+
"from IPython.display import display, HTML\n",
260+
"import os\n",
261+
"import json\n",
262+
"\n",
263+
"DOCUMENT_DIR = \"searchlabs-colpali\"\n",
264+
"\n",
265+
"query = \"What do companies use for recruiting?\"\n",
266+
"query_vector = to_bit_vectors(create_col_pali_query_vectors(query))\n",
267+
"es_query = {\n",
268+
" \"_source\": False,\n",
269+
" \"query\": {\n",
270+
" \"script_score\": {\n",
271+
" \"query\": {\n",
272+
" \"match_all\": {}\n",
273+
" },\n",
274+
" \"script\": {\n",
275+
" \"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\",\n",
276+
" \"params\": {\n",
277+
" \"query_vector\": query_vector\n",
278+
" }\n",
279+
" }\n",
280+
" }\n",
281+
" },\n",
282+
" \"size\": 5\n",
283+
"}\n",
284+
"print(json.dumps(es_query))\n",
285+
"\n",
286+
"results = es.search(index=INDEX_NAME, body=es_query)\n",
287+
"image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
288+
"\n",
289+
"html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
290+
"for image_id in image_ids:\n",
291+
" image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
292+
" html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
293+
"html += \"</div>\"\n",
294+
"\n",
295+
"display(HTML(html))"
296+
]
297+
},
298+
{
299+
"cell_type": "markdown",
300+
"id": "e27b68ac-bec8-4415-919e-8b916bc35816",
301+
"metadata": {},
302+
"source": [
303+
"Above we have seen how to query our data using the `maxSimInvHamming(...)` function. \n",
304+
"We can also just pass the full fidelity col pali vector and use the `maxSimDotProduct(...)` function for [asymmetric similarity](https://www.elastic.co/guide/en/elasticsearch/reference/8.18/rank-vectors.html#rank-vectors-scoring) between the vectors. "
305+
]
306+
},
307+
{
308+
"cell_type": "code",
309+
"execution_count": 6,
310+
"id": "32fd9ee4-d7c6-4954-a766-7b06735290ff",
311+
"metadata": {},
312+
"outputs": [
313+
{
314+
"data": {
315+
"text/html": [
316+
"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
317+
],
318+
"text/plain": [
319+
"<IPython.core.display.HTML object>"
320+
]
321+
},
322+
"metadata": {},
323+
"output_type": "display_data"
324+
}
325+
],
326+
"source": [
327+
"query = \"What do companies use for recruiting?\"\n",
328+
"query_vector = create_col_pali_query_vectors(query)\n",
329+
"es_query = {\n",
330+
" \"_source\": False,\n",
331+
" \"query\": {\n",
332+
" \"script_score\": {\n",
333+
" \"query\": {\n",
334+
" \"match_all\": {}\n",
335+
" },\n",
336+
" \"script\": {\n",
337+
" \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n",
338+
" \"params\": {\n",
339+
" \"query_vector\": query_vector\n",
340+
" }\n",
341+
" }\n",
342+
" }\n",
343+
" },\n",
344+
" \"size\": 5\n",
345+
"}\n",
346+
"\n",
347+
"results = es.search(index=INDEX_NAME, body=es_query)\n",
348+
"image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
349+
"\n",
350+
"html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
351+
"for image_id in image_ids:\n",
352+
" image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
353+
" html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
354+
"html += \"</div>\"\n",
355+
"\n",
356+
"display(HTML(html))"
357+
]
358+
},
359+
{
360+
"cell_type": "code",
361+
"execution_count": null,
362+
"id": "ee8df1e3-af66-4e35-9c26-7257c281536f",
363+
"metadata": {},
364+
"outputs": [],
365+
"source": [
366+
"# We kill the kernel forcefully to free up the memory from the ColPali model. \n",
367+
"print(\"Shutting down the kernel to free memory...\")\n",
368+
"import os\n",
369+
"os._exit(0)"
370+
]
371+
}
372+
],
373+
"metadata": {
374+
"kernelspec": {
375+
"display_name": "dependecy-test-colpali-blog",
376+
"language": "python",
377+
"name": "dependecy-test-colpali-blog"
378+
},
379+
"language_info": {
380+
"codemirror_mode": {
381+
"name": "ipython",
382+
"version": 3
383+
},
384+
"file_extension": ".py",
385+
"mimetype": "text/x-python",
386+
"name": "python",
387+
"nbconvert_exporter": "python",
388+
"pygments_lexer": "ipython3",
389+
"version": "3.12.6"
390+
}
391+
},
392+
"nbformat": 4,
393+
"nbformat_minor": 5
394+
}

0 commit comments

Comments
 (0)