1111from azure .cosmos .partition_key import PartitionKey
1212from azure .identity import DefaultAzureCredential
1313
14+ from graphrag .config .models .vector_store_schema_config import VectorStoreSchemaConfig
1415from graphrag .data_model .types import TextEmbedder
1516from graphrag .vector_stores .base import (
1617 BaseVectorStore ,
@@ -26,8 +27,8 @@ class CosmosDBVectorStore(BaseVectorStore):
2627 _database_client : DatabaseProxy
2728 _container_client : ContainerProxy
2829
29- def __init__ (self , ** kwargs : Any ) -> None :
30- super ().__init__ (** kwargs )
30+ def __init__ (self , vector_store_schema_config : VectorStoreSchemaConfig , ** kwargs : Any ) -> None :
31+ super ().__init__ (vector_store_schema_config = vector_store_schema_config , ** kwargs )
3132
3233 def connect (self , ** kwargs : Any ) -> Any :
3334 """Connect to CosmosDB vector storage."""
@@ -48,13 +49,12 @@ def connect(self, **kwargs: Any) -> Any:
4849 msg = "Database name must be provided."
4950 raise ValueError (msg )
5051 self ._database_name = database_name
51- collection_name = self .index_name
52- if collection_name is None :
53- msg = "Collection name is empty or not provided."
52+ if self .index_name is None :
53+ msg = "Index name is empty or not provided."
5454 raise ValueError (msg )
55- self ._container_name = collection_name
55+ self ._container_name = self . index_name
5656
57- self .vector_size = kwargs . get ( " vector_size" , 1024 ) #TODO GAUDY fix it
57+ self .vector_size = self . vector_size
5858 self ._create_database ()
5959 self ._create_container ()
6060
@@ -85,7 +85,7 @@ def _create_container(self) -> None:
8585 vector_embedding_policy = {
8686 "vectorEmbeddings" : [
8787 {
88- "path" : "/vector " ,
88+ "path" : f"/ { self . vector_field } " ,
8989 "dataType" : "float32" ,
9090 "distanceFunction" : "cosine" ,
9191 "dimensions" : self .vector_size ,
@@ -98,13 +98,13 @@ def _create_container(self) -> None:
9898 "indexingMode" : "consistent" ,
9999 "automatic" : True ,
100100 "includedPaths" : [{"path" : "/*" }],
101- "excludedPaths" : [{"path" : "/_etag/?" }, {"path" : "/vector /*" }],
101+ "excludedPaths" : [{"path" : "/_etag/?" }, {"path" : f"/ { self . vector_field } /*" }],
102102 }
103103
104104 # Currently, the CosmosDB emulator does not support the diskANN policy.
105105 try :
106106 # First try with the standard diskANN policy
107- indexing_policy ["vectorIndexes" ] = [{"path" : "/vector " , "type" : "diskANN" }]
107+ indexing_policy ["vectorIndexes" ] = [{"path" : f"/ { self . vector_field } " , "type" : "diskANN" }]
108108
109109 # Create the container and container client
110110 self ._database_client .create_container_if_not_exists (
@@ -158,10 +158,10 @@ def load_documents(
158158 for doc in documents :
159159 if doc .vector is not None :
160160 doc_json = {
161- "id" : doc .id ,
162- "vector" : doc .vector ,
163- "text" : doc .text ,
164- "attributes" : json .dumps (doc .attributes ),
161+ self . id_field : doc .id ,
162+ self . vector_field : doc .vector ,
163+ self . text_field : doc .text ,
164+ self . attributes_field : json .dumps (doc .attributes ),
165165 }
166166 self ._container_client .upsert_item (doc_json )
167167
@@ -174,7 +174,7 @@ def similarity_search_by_vector(
174174 raise ValueError (msg )
175175
176176 try :
177- query = f"SELECT TOP { k } c.id , c.text , c.vector , c.attributes , VectorDistance(c.vector , @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector , @embedding)" # noqa: S608
177+ query = f"SELECT TOP { k } c.{ self . id_field } , c.{ self . text_field } , c.{ self . vector_field } , c.{ self . attributes_field } , VectorDistance(c.{ self . vector_field } , @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.{ self . vector_field } , @embedding)" # noqa: S608
178178 query_params = [{"name" : "@embedding" , "value" : query_embedding }]
179179 items = list (
180180 self ._container_client .query_items (
@@ -186,7 +186,7 @@ def similarity_search_by_vector(
186186 except (CosmosHttpResponseError , ValueError ):
187187 # Currently, the CosmosDB emulator does not support the VectorDistance function.
188188 # For emulator or test environments - fetch all items and calculate distance locally
189- query = "SELECT c.id , c.text , c.vector , c.attributes FROM c"
189+ query = f "SELECT c.{ self . id_field } , c.{ self . text_field } , c.{ self . vector_field } , c.{ self . attributes_field } FROM c" # noqa: S608
190190 items = list (
191191 self ._container_client .query_items (
192192 query = query ,
@@ -205,7 +205,7 @@ def cosine_similarity(a, b):
205205
206206 # Calculate scores for all items
207207 for item in items :
208- item_vector = item .get ("vector" , [])
208+ item_vector = item .get (self . vector_field , [])
209209 similarity = cosine_similarity (query_embedding , item_vector )
210210 item ["SimilarityScore" ] = similarity
211211
@@ -217,10 +217,10 @@ def cosine_similarity(a, b):
217217 return [
218218 VectorStoreSearchResult (
219219 document = VectorStoreDocument (
220- id = item .get ("id" , "" ),
221- text = item .get ("text" , "" ),
222- vector = item .get ("vector" , []),
223- attributes = (json .loads (item .get ("attributes" , "{}" ))),
220+ id = item .get (self . id_field , "" ),
221+ text = item .get (self . text_field , "" ),
222+ vector = item .get (self . vector_field , []),
223+ attributes = (json .loads (item .get (self . attributes_field , "{}" ))),
224224 ),
225225 score = item .get ("SimilarityScore" , 0.0 ),
226226 )
@@ -247,7 +247,7 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
247247 id_filter = ", " .join ([f"'{ id } '" for id in include_ids ])
248248 else :
249249 id_filter = ", " .join ([str (id ) for id in include_ids ])
250- self .query_filter = f"SELECT * FROM c WHERE c.id IN ({ id_filter } )" # noqa: S608
250+ self .query_filter = f"SELECT * FROM c WHERE c.{ self . id_field } IN ({ id_filter } )" # noqa: S608
251251 return self .query_filter
252252
253253 def search_by_id (self , id : str ) -> VectorStoreDocument :
@@ -258,10 +258,10 @@ def search_by_id(self, id: str) -> VectorStoreDocument:
258258
259259 item = self ._container_client .read_item (item = id , partition_key = id )
260260 return VectorStoreDocument (
261- id = item .get ("id" , "" ),
262- vector = item .get ("vector" , []),
263- text = item .get ("text" , "" ),
264- attributes = (json .loads (item .get ("attributes" , "{}" ))),
261+ id = item .get (self . id_field , "" ),
262+ vector = item .get (self . vector_field , []),
263+ text = item .get (self . text_field , "" ),
264+ attributes = (json .loads (item .get (self . attributes_field , "{}" ))),
265265 )
266266
267267 def clear (self ) -> None :
0 commit comments