|
7 | 7 | from typing import Any |
8 | 8 |
|
9 | 9 | from azure.cosmos import ContainerProxy, CosmosClient, DatabaseProxy |
| 10 | +from azure.cosmos.exceptions import CosmosHttpResponseError |
10 | 11 | from azure.cosmos.partition_key import PartitionKey |
11 | 12 | from azure.identity import DefaultAzureCredential |
12 | 13 |
|
|
19 | 20 | ) |
20 | 21 |
|
21 | 22 |
|
22 | | -class CosmosDBVectoreStore(BaseVectorStore): |
| 23 | +class CosmosDBVectorStore(BaseVectorStore): |
23 | 24 | """Azure CosmosDB vector storage implementation.""" |
24 | 25 |
|
25 | 26 | _cosmos_client: CosmosClient |
@@ -99,16 +100,32 @@ def _create_container(self) -> None: |
99 | 100 | "automatic": True, |
100 | 101 | "includedPaths": [{"path": "/*"}], |
101 | 102 | "excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}], |
102 | | - "vectorIndexes": [{"path": "/vector", "type": "diskANN"}], |
103 | 103 | } |
104 | 104 |
|
105 | | - # Create the container and container client |
106 | | - self._database_client.create_container_if_not_exists( |
107 | | - id=self._container_name, |
108 | | - partition_key=partition_key, |
109 | | - indexing_policy=indexing_policy, |
110 | | - vector_embedding_policy=vector_embedding_policy, |
111 | | - ) |
| 105 | + # Currently, the CosmosDB emulator does not support the diskANN policy. |
| 106 | + try: |
| 107 | + # First try with the standard diskANN policy |
| 108 | + indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}] |
| 109 | + |
| 110 | + # Create the container and container client |
| 111 | + self._database_client.create_container_if_not_exists( |
| 112 | + id=self._container_name, |
| 113 | + partition_key=partition_key, |
| 114 | + indexing_policy=indexing_policy, |
| 115 | + vector_embedding_policy=vector_embedding_policy, |
| 116 | + ) |
| 117 | + except CosmosHttpResponseError: |
| 118 | + # If diskANN fails (likely in emulator), retry without vector indexes |
| 119 | + indexing_policy.pop("vectorIndexes", None) |
| 120 | + |
| 121 | + # Create the container with compatible indexing policy |
| 122 | + self._database_client.create_container_if_not_exists( |
| 123 | + id=self._container_name, |
| 124 | + partition_key=partition_key, |
| 125 | + indexing_policy=indexing_policy, |
| 126 | + vector_embedding_policy=vector_embedding_policy, |
| 127 | + ) |
| 128 | + |
112 | 129 | self._container_client = self._database_client.get_container_client( |
113 | 130 | self._container_name |
114 | 131 | ) |
@@ -157,13 +174,46 @@ def similarity_search_by_vector( |
157 | 174 | msg = "Container client is not initialized." |
158 | 175 | raise ValueError(msg) |
159 | 176 |
|
160 | | - query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608 |
161 | | - query_params = [{"name": "@embedding", "value": query_embedding}] |
162 | | - items = self._container_client.query_items( |
163 | | - query=query, |
164 | | - parameters=query_params, |
165 | | - enable_cross_partition_query=True, |
166 | | - ) |
| 177 | + try: |
| 178 | + query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608 |
| 179 | + query_params = [{"name": "@embedding", "value": query_embedding}] |
| 180 | + items = list( |
| 181 | + self._container_client.query_items( |
| 182 | + query=query, |
| 183 | + parameters=query_params, |
| 184 | + enable_cross_partition_query=True, |
| 185 | + ) |
| 186 | + ) |
| 187 | + except (CosmosHttpResponseError, ValueError): |
| 188 | + # Currently, the CosmosDB emulator does not support the VectorDistance function. |
| 189 | + # For emulator or test environments - fetch all items and calculate distance locally |
| 190 | + query = "SELECT c.id, c.text, c.vector, c.attributes FROM c" |
| 191 | + items = list( |
| 192 | + self._container_client.query_items( |
| 193 | + query=query, |
| 194 | + enable_cross_partition_query=True, |
| 195 | + ) |
| 196 | + ) |
| 197 | + |
| 198 | + # Calculate cosine similarity locally (1 - cosine distance) |
| 199 | + from numpy import dot |
| 200 | + from numpy.linalg import norm |
| 201 | + |
| 202 | + def cosine_similarity(a, b): |
| 203 | + if norm(a) * norm(b) == 0: |
| 204 | + return 0.0 |
| 205 | + return dot(a, b) / (norm(a) * norm(b)) |
| 206 | + |
| 207 | + # Calculate scores for all items |
| 208 | + for item in items: |
| 209 | + item_vector = item.get("vector", []) |
| 210 | + similarity = cosine_similarity(query_embedding, item_vector) |
| 211 | + item["SimilarityScore"] = similarity |
| 212 | + |
| 213 | + # Sort by similarity score (higher is better) and take top k |
| 214 | + items = sorted( |
| 215 | + items, key=lambda x: x.get("SimilarityScore", 0.0), reverse=True |
| 216 | + )[:k] |
167 | 217 |
|
168 | 218 | return [ |
169 | 219 | VectorStoreSearchResult( |
@@ -214,3 +264,8 @@ def search_by_id(self, id: str) -> VectorStoreDocument: |
214 | 264 | text=item.get("text", ""), |
215 | 265 | attributes=(json.loads(item.get("attributes", "{}"))), |
216 | 266 | ) |
| 267 | + |
| 268 | + def clear(self) -> None: |
| 269 | + """Clear the vector store.""" |
| 270 | + self._delete_container() |
| 271 | + self._delete_database() |
0 commit comments