Skip to content

Commit 15f2144

Browse files
Gaudy BlancoGaudy Blanco
authored andcommitted
uv run poe format
1 parent cffb03a commit 15f2144

File tree

8 files changed

+78
-42
lines changed

8 files changed

+78
-42
lines changed

graphrag/config/models/vector_store_schema_config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,16 @@
1111

1212
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
1313

14+
1415
def is_valid_field_name(field: str) -> bool:
1516
"""Check if a field name is valid for CosmosDB."""
1617
return bool(VALID_IDENTIFIER_REGEX.match(field))
1718

19+
1820
class VectorStoreSchemaConfig(BaseModel):
1921
"""The default configuration section for Vector Store Schema."""
2022

21-
index_name: str = Field(
22-
description="The index name to use.",
23-
default=""
24-
)
23+
index_name: str = Field(description="The index name to use.", default="")
2524

2625
id_field: str = Field(
2726
description="The ID field to use.",

graphrag/index/operations/embed_text/embed_text.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,16 @@ def _create_vector_store(
188188
) -> BaseVectorStore:
189189
vector_store_type: str = str(vector_store_config.get("type"))
190190

191-
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get("embeddings_schema", {})
191+
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
192+
"embeddings_schema", {}
193+
)
192194
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
193195

194-
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
196+
if (
197+
embeddings_schema is not None
198+
and embedding_name is not None
199+
and embedding_name in embeddings_schema
200+
):
195201
raw_config = embeddings_schema[embedding_name]
196202
if isinstance(raw_config, dict):
197203
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
@@ -202,7 +208,9 @@ def _create_vector_store(
202208
single_embedding_config.index_name = index_name
203209

204210
vector_store = VectorStoreFactory().create_vector_store(
205-
vector_store_schema_config=single_embedding_config, vector_store_type=vector_store_type, kwargs=vector_store_config
211+
vector_store_schema_config=single_embedding_config,
212+
vector_store_type=vector_store_type,
213+
kwargs=vector_store_config,
206214
)
207215

208216
vector_store.connect(**vector_store_config)

graphrag/utils/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,16 @@ def get_embedding_store(
108108
store.get("container_name", "default"), embedding_name
109109
)
110110

111-
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get("embeddings_schema", {})
111+
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
112+
"embeddings_schema", {}
113+
)
112114
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
113115

114-
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
116+
if (
117+
embeddings_schema is not None
118+
and embedding_name is not None
119+
and embedding_name in embeddings_schema
120+
):
115121
raw_config = embeddings_schema[embedding_name]
116122
if isinstance(raw_config, dict):
117123
single_embedding_config = VectorStoreSchemaConfig(**raw_config)

graphrag/vector_stores/azure_ai_search.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ class AzureAISearchVectorStore(BaseVectorStore):
3838

3939
index_client: SearchIndexClient
4040

41-
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
42-
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
41+
def __init__(
42+
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
43+
) -> None:
44+
super().__init__(
45+
vector_store_schema_config=vector_store_schema_config, **kwargs
46+
)
4347

4448
def connect(self, **kwargs: Any) -> Any:
4549
"""Connect to AI search vector storage."""
@@ -77,8 +81,11 @@ def load_documents(
7781
) -> None:
7882
"""Load documents into an Azure AI Search index."""
7983
if overwrite:
80-
if self.index_name != "" and self.index_name in self.index_client.list_index_names():
81-
self.index_client.delete_index(self.index_name)
84+
if (
85+
self.index_name != ""
86+
and self.index_name in self.index_client.list_index_names()
87+
):
88+
self.index_client.delete_index(self.index_name)
8289

8390
# Configure vector search profile
8491
vector_search = VectorSearch(
@@ -114,8 +121,12 @@ def load_documents(
114121
vector_search_dimensions=self.vector_size,
115122
vector_search_profile_name=self.vector_search_profile_name,
116123
),
117-
SearchableField(name=self.text_field, type=SearchFieldDataType.String),
118-
SimpleField(name=self.attributes_field, type=SearchFieldDataType.String,
124+
SearchableField(
125+
name=self.text_field, type=SearchFieldDataType.String
126+
),
127+
SimpleField(
128+
name=self.attributes_field,
129+
type=SearchFieldDataType.String,
119130
),
120131
],
121132
vector_search=vector_search,

graphrag/vector_stores/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,14 @@ def __init__(
5151
self.document_collection = document_collection
5252
self.query_filter = query_filter
5353
self.kwargs = kwargs
54-
54+
5555
self.index_name = vector_store_schema_config.index_name
5656
self.id_field = vector_store_schema_config.id_field
5757
self.text_field = vector_store_schema_config.text_field
5858
self.vector_field = vector_store_schema_config.vector_field
5959
self.attributes_field = vector_store_schema_config.attributes_field
6060
self.vector_size = vector_store_schema_config.vector_size
6161

62-
6362
@abstractmethod
6463
def connect(self, **kwargs: Any) -> None:
6564
"""Connect to vector storage."""

graphrag/vector_stores/cosmosdb.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,12 @@ class CosmosDBVectorStore(BaseVectorStore):
2727
_database_client: DatabaseProxy
2828
_container_client: ContainerProxy
2929

30-
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
31-
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
30+
def __init__(
31+
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
32+
) -> None:
33+
super().__init__(
34+
vector_store_schema_config=vector_store_schema_config, **kwargs
35+
)
3236

3337
def connect(self, **kwargs: Any) -> Any:
3438
"""Connect to CosmosDB vector storage."""
@@ -98,13 +102,18 @@ def _create_container(self) -> None:
98102
"indexingMode": "consistent",
99103
"automatic": True,
100104
"includedPaths": [{"path": "/*"}],
101-
"excludedPaths": [{"path": "/_etag/?"}, {"path": f"/{self.vector_field}/*"}],
105+
"excludedPaths": [
106+
{"path": "/_etag/?"},
107+
{"path": f"/{self.vector_field}/*"},
108+
],
102109
}
103110

104111
# Currently, the CosmosDB emulator does not support the diskANN policy.
105112
try:
106113
# First try with the standard diskANN policy
107-
indexing_policy["vectorIndexes"] = [{"path": f"/{self.vector_field}", "type": "diskANN"}]
114+
indexing_policy["vectorIndexes"] = [
115+
{"path": f"/{self.vector_field}", "type": "diskANN"}
116+
]
108117

109118
# Create the container and container client
110119
self._database_client.create_container_if_not_exists(
@@ -247,7 +256,9 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
247256
id_filter = ", ".join([f"'{id}'" for id in include_ids])
248257
else:
249258
id_filter = ", ".join([str(id) for id in include_ids])
250-
self.query_filter = f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
259+
self.query_filter = (
260+
f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
261+
)
251262
return self.query_filter
252263

253264
def search_by_id(self, id: str) -> VectorStoreDocument:

graphrag/vector_stores/factory.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def register(
5050

5151
@classmethod
5252
def create_vector_store(
53-
cls, vector_store_type: str, vector_store_schema_config: VectorStoreSchemaConfig, kwargs: dict
53+
cls,
54+
vector_store_type: str,
55+
vector_store_schema_config: VectorStoreSchemaConfig,
56+
kwargs: dict,
5457
) -> BaseVectorStore:
5558
"""Create a vector store object from the provided type.
5659
@@ -71,10 +74,9 @@ def create_vector_store(
7174
raise ValueError(msg)
7275

7376
return cls._registry[vector_store_type](
74-
vector_store_schema_config=vector_store_schema_config,
75-
**kwargs
77+
vector_store_schema_config=vector_store_schema_config, **kwargs
7678
)
77-
79+
7880
@classmethod
7981
def get_vector_store_types(cls) -> list[str]:
8082
"""Get the registered vector store implementations."""

graphrag/vector_stores/lancedb.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,19 @@
2121
class LanceDBVectorStore(BaseVectorStore):
2222
"""LanceDB vector storage implementation."""
2323

24-
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
25-
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
24+
def __init__(
25+
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
26+
) -> None:
27+
super().__init__(
28+
vector_store_schema_config=vector_store_schema_config, **kwargs
29+
)
2630

2731
def connect(self, **kwargs: Any) -> Any:
2832
"""Connect to the vector storage."""
2933
self.db_connection = lancedb.connect(kwargs["db_uri"])
3034

31-
if (
32-
self.index_name
33-
and self.index_name in self.db_connection.table_names()
34-
):
35-
self.document_collection = self.db_connection.open_table(
36-
self.index_name
37-
)
38-
35+
if self.index_name and self.index_name in self.db_connection.table_names():
36+
self.document_collection = self.db_connection.open_table(self.index_name)
3937

4038
def load_documents(
4139
self, documents: list[VectorStoreDocument], overwrite: bool = True
@@ -61,14 +59,16 @@ def load_documents(
6159
# Step 3: Flatten the vectors and build FixedSizeListArray manually
6260
flat_vector = np.concatenate(vectors).astype(np.float32)
6361
flat_array = pa.array(flat_vector, type=pa.float32())
64-
vector_column = pa.FixedSizeListArray.from_arrays(flat_array, self.vector_size)
62+
vector_column = pa.FixedSizeListArray.from_arrays(
63+
flat_array, self.vector_size
64+
)
6565

6666
# Step 4: Create PyArrow table (let schema be inferred)
6767
data = pa.table({
6868
self.id_field: pa.array(ids, type=pa.string()),
6969
self.text_field: pa.array(texts, type=pa.string()),
7070
self.vector_field: vector_column,
71-
self.attributes_field: pa.array(attributes, type=pa.string())
71+
self.attributes_field: pa.array(attributes, type=pa.string()),
7272
})
7373

7474
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
@@ -83,12 +83,12 @@ def load_documents(
8383
self.document_collection = self.db_connection.create_table(
8484
self.index_name, mode="overwrite"
8585
)
86-
self.document_collection.create_index(vector_column_name=self.vector_field, index_type="IVF_FLAT")
86+
self.document_collection.create_index(
87+
vector_column_name=self.vector_field, index_type="IVF_FLAT"
88+
)
8789
else:
8890
# add data to existing table
89-
self.document_collection = self.db_connection.open_table(
90-
self.index_name
91-
)
91+
self.document_collection = self.db_connection.open_table(self.index_name)
9292
if data:
9393
self.document_collection.add(data)
9494

0 commit comments

Comments
 (0)