|
25 | 25 | "source": [ |
26 | 26 | "!pip install -r requirements.txt\n", |
27 | 27 | "from IPython.display import clear_output\n", |
28 | | - "clear_output() # for less space usage. " |
| 28 | + "\n", |
| 29 | + "clear_output() # for less space usage." |
29 | 30 | ] |
30 | 31 | }, |
31 | 32 | { |
|
64 | 65 | ], |
65 | 66 | "source": [ |
66 | 67 | "from datasets import load_dataset\n", |
67 | | - "from tqdm.notebook import tqdm \n", |
| 68 | + "from tqdm.notebook import tqdm\n", |
68 | 69 | "import os\n", |
69 | 70 | "\n", |
70 | 71 | "DATASET_NAME = \"vidore/infovqa_test_subsampled\"\n", |
71 | 72 | "DOCUMENT_DIR = \"searchlabs-colpali\"\n", |
72 | 73 | "\n", |
73 | 74 | "os.makedirs(DOCUMENT_DIR, exist_ok=True)\n", |
74 | 75 | "dataset = load_dataset(DATASET_NAME, split=\"test\")\n", |
75 | | - " \n", |
| 76 | + "\n", |
76 | 77 | "for i, row in enumerate(tqdm(dataset, desc=\"Saving images to disk\")):\n", |
77 | 78 | " image = row.get(\"image\")\n", |
78 | 79 | " image_name = f\"image_{i}.jpg\"\n", |
|
123 | 124 | "model = ColPali.from_pretrained(\n", |
124 | 125 | " \"vidore/colpali-v1.3\",\n", |
125 | 126 | " torch_dtype=torch.float32,\n", |
126 | | - " device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n", |
| 127 | + " device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n", |
127 | 128 | ").eval()\n", |
128 | 129 | "\n", |
129 | 130 | "col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n", |
130 | 131 | "\n", |
| 132 | + "\n", |
131 | 133 | "def create_col_pali_image_vectors(image_path: str) -> list:\n", |
132 | | - " batch_images = col_pali_processor.process_images([Image.open(image_path)]).to(model.device)\n", |
133 | | - " \n", |
| 134 | + " batch_images = col_pali_processor.process_images([Image.open(image_path)]).to(\n", |
| 135 | + " model.device\n", |
| 136 | + " )\n", |
| 137 | + "\n", |
134 | 138 | " with torch.no_grad():\n", |
135 | 139 | " return model(**batch_images).tolist()[0]\n", |
136 | 140 | "\n", |
| 141 | + "\n", |
137 | 142 | "def create_col_pali_query_vectors(query: str) -> list:\n", |
138 | 143 | " queries = col_pali_processor.process_queries([query]).to(model.device)\n", |
139 | 144 | " with torch.no_grad():\n", |
|
194 | 199 | " vectors_f32 = create_col_pali_image_vectors(image_path)\n", |
195 | 200 | " file_to_multi_vectors[file_name] = vectors_f32\n", |
196 | 201 | "\n", |
197 | | - "with open('col_pali_vectors.pkl', 'wb') as f:\n", |
| 202 | + "with open(\"col_pali_vectors.pkl\", \"wb\") as f:\n", |
198 | 203 | " pickle.dump(file_to_multi_vectors, f)\n", |
199 | | - " \n", |
| 204 | + "\n", |
200 | 205 | "print(f\"Saved {len(file_to_multi_vectors)} vector entries to disk\")" |
201 | 206 | ] |
202 | 207 | }, |
|
239 | 244 | "\n", |
240 | 245 | "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n", |
241 | 246 | "\n", |
242 | | - "mappings = {\n", |
243 | | - " \"mappings\": {\n", |
244 | | - " \"properties\": {\n", |
245 | | - " \"col_pali_vectors\": {\n", |
246 | | - " \"type\": \"rank_vectors\"\n", |
247 | | - " }\n", |
248 | | - " }\n", |
249 | | - " }\n", |
250 | | - "}\n", |
| 247 | + "mappings = {\"mappings\": {\"properties\": {\"col_pali_vectors\": {\"type\": \"rank_vectors\"}}}}\n", |
251 | 248 | "\n", |
252 | 249 | "if not es.indices.exists(index=INDEX_NAME):\n", |
253 | 250 | " print(f\"[INFO] Creating index: {INDEX_NAME}\")\n", |
254 | 251 | " es.indices.create(index=INDEX_NAME, body=mappings)\n", |
255 | 252 | "else:\n", |
256 | 253 | " print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n", |
257 | 254 | "\n", |
| 255 | + "\n", |
258 | 256 | "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n", |
259 | 257 | " for attempt in range(1, retries + 1):\n", |
260 | 258 | " try:\n", |
|
304 | 302 | } |
305 | 303 | ], |
306 | 304 | "source": [ |
307 | | - "with open('col_pali_vectors.pkl', 'rb') as f:\n", |
| 305 | + "with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n", |
308 | 306 | " file_to_multi_vectors = pickle.load(f)\n", |
309 | 307 | "\n", |
310 | 308 | "for file_name, vectors in tqdm(file_to_multi_vectors.items(), desc=\"Index documents\"):\n", |
311 | 309 | " if es.exists(index=INDEX_NAME, id=file_name):\n", |
312 | 310 | " continue\n", |
313 | | - " \n", |
| 311 | + "\n", |
314 | 312 | " index_document(\n", |
315 | | - " es_client=es, \n", |
316 | | - " index=INDEX_NAME, \n", |
317 | | - " doc_id=file_name, \n", |
318 | | - " document={\"col_pali_vectors\": vectors}\n", |
| 313 | + " es_client=es,\n", |
| 314 | + " index=INDEX_NAME,\n", |
| 315 | + " doc_id=file_name,\n", |
| 316 | + " document={\"col_pali_vectors\": vectors},\n", |
319 | 317 | " )" |
320 | 318 | ] |
321 | 319 | }, |
|
360 | 358 | " \"_source\": False,\n", |
361 | 359 | " \"query\": {\n", |
362 | 360 | " \"script_score\": {\n", |
363 | | - " \"query\": {\n", |
364 | | - " \"match_all\": {}\n", |
365 | | - " },\n", |
| 361 | + " \"query\": {\"match_all\": {}},\n", |
366 | 362 | " \"script\": {\n", |
367 | 363 | " \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n", |
368 | | - " \"params\": {\n", |
369 | | - " \"query_vector\": create_col_pali_query_vectors(query)\n", |
370 | | - " }\n", |
371 | | - " }\n", |
| 364 | + " \"params\": {\"query_vector\": create_col_pali_query_vectors(query)},\n", |
| 365 | + " },\n", |
372 | 366 | " }\n", |
373 | 367 | " },\n", |
374 | | - " \"size\": 5\n", |
| 368 | + " \"size\": 5,\n", |
375 | 369 | "}\n", |
376 | 370 | "\n", |
377 | 371 | "results = es.search(index=INDEX_NAME, body=es_query)\n", |
|
393 | 387 | "metadata": {}, |
394 | 388 | "outputs": [], |
395 | 389 | "source": [ |
396 | | - "# We kill the kernel forcefully to free up the memory from the ColPali model. \n", |
| 390 | + "# We kill the kernel forcefully to free up the memory from the ColPali model.\n", |
397 | 391 | "print(\"Shutting down the kernel to free memory...\")\n", |
398 | 392 | "import os\n", |
| 393 | + "\n", |
399 | 394 | "os._exit(0)" |
400 | 395 | ] |
401 | 396 | } |
|
0 commit comments