Skip to content

Commit cffb03a

Browse files
Gaudy BlancoGaudy Blanco
authored andcommitted
cosmosdb implementation
1 parent 9b14b9d commit cffb03a

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

graphrag/config/models/vector_store_schema_config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@
33

44
"""Parameterization settings for the default configuration."""
55

6+
import re
7+
68
from pydantic import BaseModel, Field, model_validator
79

810
DEFAULT_VECTOR_SIZE: int = 1536
911

12+
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
13+
14+
def is_valid_field_name(field: str) -> bool:
15+
"""Check if a field name is valid for CosmosDB."""
16+
return bool(VALID_IDENTIFIER_REGEX.match(field))
17+
1018
class VectorStoreSchemaConfig(BaseModel):
1119
"""The default configuration section for Vector Store Schema."""
1220

@@ -40,9 +48,17 @@ class VectorStoreSchemaConfig(BaseModel):
4048
default=DEFAULT_VECTOR_SIZE,
4149
)
4250

43-
#TODO GAUDY
4451
def _validate_schema(self) -> None:
4552
"""Validate the schema."""
53+
for field in [
54+
self.id_field,
55+
self.vector_field,
56+
self.text_field,
57+
self.attributes_field,
58+
]:
59+
if not is_valid_field_name(field):
60+
msg = f"Unsafe or invalid field name: {field}"
61+
raise ValueError(msg)
4662

4763
@model_validator(mode="after")
4864
def _validate_model(self):

graphrag/vector_stores/cosmosdb.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from azure.cosmos.partition_key import PartitionKey
1212
from azure.identity import DefaultAzureCredential
1313

14+
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
1415
from graphrag.data_model.types import TextEmbedder
1516
from graphrag.vector_stores.base import (
1617
BaseVectorStore,
@@ -26,8 +27,8 @@ class CosmosDBVectorStore(BaseVectorStore):
2627
_database_client: DatabaseProxy
2728
_container_client: ContainerProxy
2829

29-
def __init__(self, **kwargs: Any) -> None:
30-
super().__init__(**kwargs)
30+
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
31+
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
3132

3233
def connect(self, **kwargs: Any) -> Any:
3334
"""Connect to CosmosDB vector storage."""
@@ -48,13 +49,12 @@ def connect(self, **kwargs: Any) -> Any:
4849
msg = "Database name must be provided."
4950
raise ValueError(msg)
5051
self._database_name = database_name
51-
collection_name = self.index_name
52-
if collection_name is None:
53-
msg = "Collection name is empty or not provided."
52+
if self.index_name is None:
53+
msg = "Index name is empty or not provided."
5454
raise ValueError(msg)
55-
self._container_name = collection_name
55+
self._container_name = self.index_name
5656

57-
self.vector_size = kwargs.get("vector_size", 1024) #TODO GAUDY fix it
57+
self.vector_size = self.vector_size
5858
self._create_database()
5959
self._create_container()
6060

@@ -85,7 +85,7 @@ def _create_container(self) -> None:
8585
vector_embedding_policy = {
8686
"vectorEmbeddings": [
8787
{
88-
"path": "/vector",
88+
"path": f"/{self.vector_field}",
8989
"dataType": "float32",
9090
"distanceFunction": "cosine",
9191
"dimensions": self.vector_size,
@@ -98,13 +98,13 @@ def _create_container(self) -> None:
9898
"indexingMode": "consistent",
9999
"automatic": True,
100100
"includedPaths": [{"path": "/*"}],
101-
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
101+
"excludedPaths": [{"path": "/_etag/?"}, {"path": f"/{self.vector_field}/*"}],
102102
}
103103

104104
# Currently, the CosmosDB emulator does not support the diskANN policy.
105105
try:
106106
# First try with the standard diskANN policy
107-
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
107+
indexing_policy["vectorIndexes"] = [{"path": f"/{self.vector_field}", "type": "diskANN"}]
108108

109109
# Create the container and container client
110110
self._database_client.create_container_if_not_exists(
@@ -158,10 +158,10 @@ def load_documents(
158158
for doc in documents:
159159
if doc.vector is not None:
160160
doc_json = {
161-
"id": doc.id,
162-
"vector": doc.vector,
163-
"text": doc.text,
164-
"attributes": json.dumps(doc.attributes),
161+
self.id_field: doc.id,
162+
self.vector_field: doc.vector,
163+
self.text_field: doc.text,
164+
self.attributes_field: json.dumps(doc.attributes),
165165
}
166166
self._container_client.upsert_item(doc_json)
167167

@@ -174,7 +174,7 @@ def similarity_search_by_vector(
174174
raise ValueError(msg)
175175

176176
try:
177-
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
177+
query = f"SELECT TOP {k} c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field}, VectorDistance(c.{self.vector_field}, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.{self.vector_field}, @embedding)" # noqa: S608
178178
query_params = [{"name": "@embedding", "value": query_embedding}]
179179
items = list(
180180
self._container_client.query_items(
@@ -186,7 +186,7 @@ def similarity_search_by_vector(
186186
except (CosmosHttpResponseError, ValueError):
187187
# Currently, the CosmosDB emulator does not support the VectorDistance function.
188188
# For emulator or test environments - fetch all items and calculate distance locally
189-
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
189+
query = f"SELECT c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field} FROM c" # noqa: S608
190190
items = list(
191191
self._container_client.query_items(
192192
query=query,
@@ -205,7 +205,7 @@ def cosine_similarity(a, b):
205205

206206
# Calculate scores for all items
207207
for item in items:
208-
item_vector = item.get("vector", [])
208+
item_vector = item.get(self.vector_field, [])
209209
similarity = cosine_similarity(query_embedding, item_vector)
210210
item["SimilarityScore"] = similarity
211211

@@ -217,10 +217,10 @@ def cosine_similarity(a, b):
217217
return [
218218
VectorStoreSearchResult(
219219
document=VectorStoreDocument(
220-
id=item.get("id", ""),
221-
text=item.get("text", ""),
222-
vector=item.get("vector", []),
223-
attributes=(json.loads(item.get("attributes", "{}"))),
220+
id=item.get(self.id_field, ""),
221+
text=item.get(self.text_field, ""),
222+
vector=item.get(self.vector_field, []),
223+
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
224224
),
225225
score=item.get("SimilarityScore", 0.0),
226226
)
@@ -247,7 +247,7 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
247247
id_filter = ", ".join([f"'{id}'" for id in include_ids])
248248
else:
249249
id_filter = ", ".join([str(id) for id in include_ids])
250-
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
250+
self.query_filter = f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
251251
return self.query_filter
252252

253253
def search_by_id(self, id: str) -> VectorStoreDocument:
@@ -258,10 +258,10 @@ def search_by_id(self, id: str) -> VectorStoreDocument:
258258

259259
item = self._container_client.read_item(item=id, partition_key=id)
260260
return VectorStoreDocument(
261-
id=item.get("id", ""),
262-
vector=item.get("vector", []),
263-
text=item.get("text", ""),
264-
attributes=(json.loads(item.get("attributes", "{}"))),
261+
id=item.get(self.id_field, ""),
262+
vector=item.get(self.vector_field, []),
263+
text=item.get(self.text_field, ""),
264+
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
265265
)
266266

267267
def clear(self) -> None:

0 commit comments

Comments
 (0)