|
| 1 | +import json |
| 2 | +import logging |
| 3 | +from typing import List, Literal, Optional, Tuple, cast |
| 4 | +import uuid |
| 5 | + |
| 6 | +from qdrant_client import AsyncQdrantClient |
| 7 | +from qdrant_client.models import models |
| 8 | + |
| 9 | +from core.models.chunk import DocumentChunk |
| 10 | + |
| 11 | +from .base_vector_store import BaseVectorStore |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | +QDRANT_COLLECTION_NAME = "vector_embeddings" |
| 15 | + |
| 16 | + |
| 17 | +def _to_point_id(doc_id: str, chunk_number: int): |
| 18 | + return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{chunk_number}.{doc_id}.internal")) |
| 19 | + |
| 20 | + |
| 21 | +def _get_qdrant_distance(metric: Literal["cosine", "dotProduct"]) -> models.Distance: |
| 22 | + match metric: |
| 23 | + case "cosine": |
| 24 | + return models.Distance.COSINE |
| 25 | + case "dotProduct": |
| 26 | + return models.Distance.DOT |
| 27 | + |
| 28 | + |
| 29 | +class QdrantVectorStore(BaseVectorStore): |
| 30 | + def __init__(self, host: str, port: int, https: bool) -> None: |
| 31 | + from core.config import get_settings |
| 32 | + |
| 33 | + settings = get_settings() |
| 34 | + |
| 35 | + self.dimensions = settings.VECTOR_DIMENSIONS |
| 36 | + self.collection_name = QDRANT_COLLECTION_NAME |
| 37 | + self.distance = _get_qdrant_distance(settings.EMBEDDING_SIMILARITY_METRIC) |
| 38 | + self.client = AsyncQdrantClient( |
| 39 | + host=host, |
| 40 | + port=port, |
| 41 | + https=https, |
| 42 | + ) |
| 43 | + |
| 44 | + async def _create_collection(self): |
| 45 | + return await self.client.create_collection( |
| 46 | + collection_name=self.collection_name, |
| 47 | + vectors_config=models.VectorParams( |
| 48 | + size=self.dimensions, |
| 49 | + distance=self.distance, |
| 50 | + on_disk=True, |
| 51 | + ), |
| 52 | + quantization_config=models.ScalarQuantization( |
| 53 | + scalar=models.ScalarQuantizationConfig( |
| 54 | + type=models.ScalarType.INT8, |
| 55 | + always_ram=True, |
| 56 | + ), |
| 57 | + ), |
| 58 | + ) |
| 59 | + |
| 60 | + async def _check_collection_vector_size(self): |
| 61 | + collection = await self.client.get_collection(self.collection_name) |
| 62 | + params = collection.config.params |
| 63 | + assert params.vectors is not None |
| 64 | + vectors = cast(models.VectorParams, params.vectors) |
| 65 | + if vectors.size != self.dimensions: |
| 66 | + msg = f"Vector collection changed from {vectors.size} to {self.dimensions}. This requires recreating tables and will delete all existing vector data." |
| 67 | + logger.error(msg) |
| 68 | + raise ValueError(msg) |
| 69 | + return True |
| 70 | + |
| 71 | + async def initialize(self): |
| 72 | + logger.info("Initialize qdrant vector collection") |
| 73 | + try: |
| 74 | + if not await self.client.collection_exists(self.collection_name): |
| 75 | + logger.info("Detected no collection exists. Creating qdrant collection") |
| 76 | + await self._create_collection() |
| 77 | + else: |
| 78 | + await self._check_collection_vector_size() |
| 79 | + |
| 80 | + await self.client.create_payload_index( |
| 81 | + self.collection_name, |
| 82 | + "document_id", |
| 83 | + models.PayloadSchemaType.UUID, |
| 84 | + ) |
| 85 | + return True |
| 86 | + except Exception as e: |
| 87 | + logger.error(f"Error initializing Qdrant store: {str(e)}") |
| 88 | + return False |
| 89 | + |
| 90 | + async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]: |
| 91 | + try: |
| 92 | + batch = [ |
| 93 | + models.PointStruct( |
| 94 | + id=_to_point_id(chunk.document_id, chunk.chunk_number), |
| 95 | + vector=cast(List[float], chunk.embedding), |
| 96 | + payload={ |
| 97 | + "document_id": chunk.document_id, |
| 98 | + "chunk_number": chunk.chunk_number, |
| 99 | + "content": chunk.content, |
| 100 | + "metadata": json.dumps(chunk.metadata) if chunk.metadata is not None else "{}", |
| 101 | + }, |
| 102 | + ) |
| 103 | + for chunk in chunks |
| 104 | + ] |
| 105 | + await self.client.upsert(collection_name=self.collection_name, points=batch) |
| 106 | + return True, [cast(str, p.id) for p in batch] |
| 107 | + except Exception as e: |
| 108 | + logger.error(f"Error storing embeddings: {str(e)}") |
| 109 | + return False, [] |
| 110 | + |
| 111 | + async def query_similar( |
| 112 | + self, |
| 113 | + query_embedding: List[float], |
| 114 | + k: int, |
| 115 | + doc_ids: Optional[List[str]] = None, |
| 116 | + ) -> List[DocumentChunk]: |
| 117 | + try: |
| 118 | + query = None |
| 119 | + if doc_ids is not None: |
| 120 | + query = models.Filter( |
| 121 | + must=models.FieldCondition( |
| 122 | + key="document_id", |
| 123 | + match=models.MatchAny(any=doc_ids), |
| 124 | + ), |
| 125 | + ) |
| 126 | + |
| 127 | + resp = await self.client.query_points( |
| 128 | + self.collection_name, |
| 129 | + query=query_embedding, |
| 130 | + limit=k, |
| 131 | + query_filter=query, |
| 132 | + with_payload=True, |
| 133 | + ) |
| 134 | + return [ |
| 135 | + DocumentChunk( |
| 136 | + document_id=p.payload["document_id"], |
| 137 | + chunk_number=p.payload["chunk_number"], |
| 138 | + content=p.payload["content"], |
| 139 | + embedding=[], |
| 140 | + metadata=json.loads(p.payload["metadata"]), |
| 141 | + score=p.score, |
| 142 | + ) |
| 143 | + for p in resp.points |
| 144 | + if p.payload is not None |
| 145 | + ] |
| 146 | + except Exception as e: |
| 147 | + logger.error(f"Error querying similar chunks: {str(e)}") |
| 148 | + return [] |
| 149 | + |
| 150 | + async def get_chunks_by_id( |
| 151 | + self, |
| 152 | + chunk_identifiers: List[Tuple[str, int]], |
| 153 | + ) -> List[DocumentChunk]: |
| 154 | + try: |
| 155 | + if not chunk_identifiers: |
| 156 | + return [] |
| 157 | + |
| 158 | + ids = [_to_point_id(doc_id, chunk_number) for (doc_id, chunk_number) in chunk_identifiers] |
| 159 | + resp = await self.client.retrieve( |
| 160 | + self.collection_name, |
| 161 | + ids=ids, |
| 162 | + ) |
| 163 | + return [ |
| 164 | + DocumentChunk( |
| 165 | + document_id=p.payload["document_id"], |
| 166 | + chunk_number=p.payload["chunk_number"], |
| 167 | + content=p.payload["content"], |
| 168 | + embedding=[], |
| 169 | + metadata=json.loads(p.payload["metadata"]), |
| 170 | + score=0, |
| 171 | + ) |
| 172 | + for p in resp |
| 173 | + if p.payload is not None |
| 174 | + ] |
| 175 | + except Exception as e: |
| 176 | + logger.error(f"Error retrieving chunks by ID: {str(e)}") |
| 177 | + return [] |
| 178 | + |
| 179 | + async def delete_chunks_by_document_id(self, document_id: str) -> bool: |
| 180 | + try: |
| 181 | + await self.client.delete( |
| 182 | + self.collection_name, |
| 183 | + points_selector=models.Filter( |
| 184 | + must=models.FieldCondition( |
| 185 | + key="document_id", |
| 186 | + match=models.MatchValue(value=document_id), |
| 187 | + ), |
| 188 | + ), |
| 189 | + ) |
| 190 | + return True |
| 191 | + except Exception as e: |
| 192 | + logger.error(f"Error deleting chunks for document {document_id}: {str(e)}") |
| 193 | + return False |
0 commit comments