@@ -178,6 +178,15 @@ def _create_list_query(
178
178
]
179
179
return query , values
180
180
181
+ async def _check_table_exists (self ) -> bool :
182
+ check_table_existence = """
183
+ SELECT EXISTS (
184
+ SELECT FROM information_schema.tables
185
+ WHERE table_name = $1
186
+ ); """
187
+ async with self ._client .acquire () as conn :
188
+ return await conn .fetchval (check_table_existence , self ._table_name )
189
+
181
190
async def create_table (self ) -> None :
182
191
"""
183
192
Create a pgVector table with an HNSW index for given similarity.
@@ -188,43 +197,36 @@ async def create_table(self) -> None:
188
197
vector_size = self ._vector_size ,
189
198
hnsw_index_parameters = self ._hnsw_params ,
190
199
):
191
- check_table_existence = """
192
- SELECT EXISTS (
193
- SELECT FROM information_schema.tables
194
- WHERE table_name = $1
195
- ); """
196
200
distance = DISTANCE_OPS [self ._distance_method ].function_name
197
201
create_vector_extension = "CREATE EXTENSION IF NOT EXISTS vector;"
198
202
# _table_name and has been validated in the class constructor, and it is a valid table name.
199
203
# _vector_size has been validated in the class constructor, and it is a valid vector size.
200
204
201
205
create_table_query = f"""
202
206
CREATE TABLE { self ._table_name }
203
- (id UUID, key TEXT, vector VECTOR({ self ._vector_size } ), metadata JSONB);
207
+ (id UUID, text TEXT, image_bytes BYTEA , vector VECTOR({ self ._vector_size } ), metadata JSONB);
204
208
"""
205
209
# _hnsw_params has been validated in the class constructor, and it is valid dict[str,int].
206
210
create_index_query = f"""
207
211
CREATE INDEX { self ._table_name + "_hnsw_idx" } ON { self ._table_name }
208
212
USING hnsw (vector { distance } )
209
213
WITH (m = { self ._hnsw_params ["m" ]} , ef_construction = { self ._hnsw_params ["ef_construction" ]} );
210
214
"""
211
-
215
+ if await self ._check_table_exists ():
216
+ print (f"Table { self ._table_name } already exist!" )
217
+ return
212
218
async with self ._client .acquire () as conn :
213
219
await conn .execute (create_vector_extension )
214
- exists = await conn .fetchval (check_table_existence , self ._table_name )
215
220
216
- if not exists :
217
- try :
218
- async with conn .transaction ():
219
- await conn .execute (create_table_query )
220
- await conn .execute (create_index_query )
221
+ try :
222
+ async with conn .transaction ():
223
+ await conn .execute (create_table_query )
224
+ await conn .execute (create_index_query )
221
225
222
- print ("Table and index created!" )
223
- except Exception as e :
224
- print (f"Failed to create table and index: { e } " )
225
- raise
226
- else :
227
- print ("Table already exists!" )
226
+ print ("Table and index created!" )
227
+ except Exception as e :
228
+ print (f"Failed to create table and index: { e } " )
229
+ raise
228
230
229
231
async def store (self , entries : list [VectorStoreEntry ]) -> None :
230
232
"""
@@ -237,8 +239,8 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
237
239
return
238
240
# _table_name has been validated in the class constructor, and it is a valid table name.
239
241
insert_query = f"""
240
- INSERT INTO { self ._table_name } (id, key , vector, metadata)
241
- VALUES ($1, $2, $3, $4)
242
+ INSERT INTO { self ._table_name } (id, text, image_bytes , vector, metadata)
243
+ VALUES ($1, $2, $3, $4, $5 )
242
244
""" # noqa S608
243
245
with trace (
244
246
table_name = self ._table_name ,
@@ -248,30 +250,28 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
248
250
embedding_type = self ._embedding_type ,
249
251
):
250
252
embeddings = await self ._create_embeddings (entries )
251
-
252
- try :
253
- async with self ._client .acquire () as conn :
254
- for entry in entries :
255
- if entry .id not in embeddings :
256
- continue
257
-
258
- await conn .execute (
259
- insert_query ,
260
- str (entry .id ),
261
- entry .text ,
262
- str (embeddings [entry .id ]),
263
- json .dumps (entry .metadata , default = pydantic_encoder ),
264
- )
265
- except asyncpg .exceptions .UndefinedTableError :
253
+ exists = await self ._check_table_exists ()
254
+ if not exists :
266
255
print (f"Table { self ._table_name } does not exist. Creating the table." )
267
256
try :
268
257
await self .create_table ()
269
258
except Exception as e :
270
259
print (f"Failed to handle missing table: { e } " )
271
260
return
272
261
273
- print ("Table created successfully. Inserting entries..." )
274
- await self .store (entries )
262
+ async with self ._client .acquire () as conn :
263
+ for entry in entries :
264
+ if entry .id not in embeddings :
265
+ continue
266
+
267
+ await conn .execute (
268
+ insert_query ,
269
+ str (entry .id ),
270
+ entry .text ,
271
+ entry .image_bytes ,
272
+ str (embeddings [entry .id ]),
273
+ json .dumps (entry .metadata , default = pydantic_encoder ),
274
+ )
275
275
276
276
async def remove (self , ids : list [UUID ]) -> None :
277
277
"""
@@ -296,34 +296,6 @@ async def remove(self, ids: list[UUID]) -> None:
296
296
print (f"Table { self ._table_name } does not exist." )
297
297
return
298
298
299
- async def _fetch_records (self , query : str , values : list [Any ]) -> list [VectorStoreEntry ]:
300
- """
301
- Fetch records from the pgVector collection.
302
-
303
- Args:
304
- query: sql query
305
- values: list of values to be used in the query.
306
-
307
- Returns:
308
- list of VectorStoreEntry objects.
309
- """
310
- try :
311
- async with self ._client .acquire () as conn :
312
- results = await conn .fetch (query , * values )
313
-
314
- return [
315
- VectorStoreEntry (
316
- id = record ["id" ],
317
- text = record ["key" ],
318
- metadata = json .loads (record ["metadata" ]),
319
- )
320
- for record in results
321
- ]
322
-
323
- except asyncpg .exceptions .UndefinedTableError :
324
- print (f"Table { self ._table_name } does not exist." )
325
- return []
326
-
327
299
async def retrieve (
328
300
self ,
329
301
text : str ,
@@ -362,7 +334,8 @@ async def retrieve(
362
334
VectorStoreResult (
363
335
entry = VectorStoreEntry (
364
336
id = record ["id" ],
365
- text = record ["key" ],
337
+ text = record ["text" ],
338
+ image_bytes = record ["image_bytes" ],
366
339
metadata = json .loads (record ["metadata" ]),
367
340
),
368
341
vector = json .loads (record ["vector" ]),
@@ -393,5 +366,20 @@ async def list(
393
366
"""
394
367
with trace (table = self ._table_name , query = where , limit = limit , offset = offset ) as outputs :
395
368
list_query , values = self ._create_list_query (where , limit , offset )
396
- outputs .listed_entries = await self ._fetch_records (list_query , values )
369
+ try :
370
+ async with self ._client .acquire () as conn :
371
+ results = await conn .fetch (list_query , * values )
372
+ outputs .listed_entries = [
373
+ VectorStoreEntry (
374
+ id = record ["id" ],
375
+ text = record ["text" ],
376
+ image_bytes = record ["image_bytes" ],
377
+ metadata = json .loads (record ["metadata" ]),
378
+ )
379
+ for record in results
380
+ ]
381
+
382
+ except asyncpg .exceptions .UndefinedTableError :
383
+ print (f"Table { self ._table_name } does not exist." )
384
+ outputs .listed_entries = []
397
385
return outputs .listed_entries
0 commit comments