Skip to content

Commit 8e06e85

Browse files
committed
update storage factory design
1 parent e79ab29 commit 8e06e85

File tree

7 files changed

+73
-145
lines changed

7 files changed

+73
-145
lines changed

graphrag/storage/blob_pipeline_storage.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,20 @@ class BlobPipelineStorage(PipelineStorage):
2929
_encoding: str
3030
_storage_account_blob_url: str | None
3131

32-
def __init__(
33-
self,
34-
connection_string: str | None,
35-
container_name: str,
36-
encoding: str = "utf-8",
37-
path_prefix: str | None = None,
38-
storage_account_blob_url: str | None = None,
39-
):
32+
def __init__(self, **kwargs: Any) -> None:
4033
"""Create a new BlobStorage instance."""
34+
connection_string = kwargs.get("connection_string")
35+
storage_account_blob_url = kwargs.get("storage_account_blob_url")
36+
path_prefix = kwargs.get("base_dir")
37+
container_name = kwargs["container_name"]
38+
if container_name is None:
39+
msg = "No container name provided for blob storage."
40+
raise ValueError(msg)
41+
if connection_string is None and storage_account_blob_url is None:
42+
msg = "No storage account blob url provided for blob storage."
43+
raise ValueError(msg)
44+
45+
logger.info("Creating blob storage at %s", container_name)
4146
if connection_string:
4247
self._blob_service_client = BlobServiceClient.from_connection_string(
4348
connection_string
@@ -51,7 +56,7 @@ def __init__(
5156
account_url=storage_account_blob_url,
5257
credential=DefaultAzureCredential(),
5358
)
54-
self._encoding = encoding
59+
self._encoding = kwargs.get("encoding", "utf-8")
5560
self._container_name = container_name
5661
self._connection_string = connection_string
5762
self._path_prefix = path_prefix or ""
@@ -308,27 +313,6 @@ async def get_creation_date(self, key: str) -> str:
308313
return ""
309314

310315

311-
def create_blob_storage(**kwargs: Any) -> PipelineStorage:
312-
"""Create a blob based storage."""
313-
connection_string = kwargs.get("connection_string")
314-
storage_account_blob_url = kwargs.get("storage_account_blob_url")
315-
base_dir = kwargs.get("base_dir")
316-
container_name = kwargs["container_name"]
317-
logger.info("Creating blob storage at %s", container_name)
318-
if container_name is None:
319-
msg = "No container name provided for blob storage."
320-
raise ValueError(msg)
321-
if connection_string is None and storage_account_blob_url is None:
322-
msg = "No storage account blob url provided for blob storage."
323-
raise ValueError(msg)
324-
return BlobPipelineStorage(
325-
connection_string=connection_string,
326-
container_name=container_name,
327-
path_prefix=base_dir,
328-
storage_account_blob_url=storage_account_blob_url,
329-
)
330-
331-
332316
def validate_blob_container_name(container_name: str):
333317
"""
334318
Check if the provided blob container name is valid based on Azure rules.

graphrag/storage/cosmosdb_pipeline_storage.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,20 @@ class CosmosDBPipelineStorage(PipelineStorage):
3939
_encoding: str
4040
_no_id_prefixes: list[str]
4141

42-
def __init__(
43-
self,
44-
database_name: str,
45-
container_name: str,
46-
cosmosdb_account_url: str | None = None,
47-
connection_string: str | None = None,
48-
encoding: str = "utf-8",
49-
):
50-
"""Initialize the CosmosDB Storage."""
42+
def __init__(self, **kwargs: Any) -> None:
43+
"""Create a CosmosDB storage instance."""
44+
logger.info("Creating cosmosdb storage")
45+
cosmosdb_account_url = kwargs.get("cosmosdb_account_url")
46+
connection_string = kwargs.get("connection_string")
47+
database_name = kwargs["base_dir"]
48+
container_name = kwargs["container_name"]
49+
if not database_name:
50+
msg = "No base_dir provided for database name"
51+
raise ValueError(msg)
52+
if connection_string is None and cosmosdb_account_url is None:
53+
msg = "connection_string or cosmosdb_account_url is required."
54+
raise ValueError(msg)
55+
5156
if connection_string:
5257
self._cosmos_client = CosmosClient.from_connection_string(connection_string)
5358
else:
@@ -60,7 +65,7 @@ def __init__(
6065
url=cosmosdb_account_url,
6166
credential=DefaultAzureCredential(),
6267
)
63-
self._encoding = encoding
68+
self._encoding = kwargs.get("encoding", "utf-8")
6469
self._database_name = database_name
6570
self._connection_string = connection_string
6671
self._cosmosdb_account_url = cosmosdb_account_url
@@ -348,29 +353,6 @@ async def get_creation_date(self, key: str) -> str:
348353
return ""
349354

350355

351-
# TODO remove this helper function and have the factory instantiate the class directly
352-
# once the new config system is in place and will enforce the correct types/existence of certain fields
353-
def create_cosmosdb_storage(**kwargs: Any) -> PipelineStorage:
354-
"""Create a CosmosDB storage instance."""
355-
logger.info("Creating cosmosdb storage")
356-
cosmosdb_account_url = kwargs.get("cosmosdb_account_url")
357-
connection_string = kwargs.get("connection_string")
358-
base_dir = kwargs["base_dir"]
359-
container_name = kwargs["container_name"]
360-
if not base_dir:
361-
msg = "No base_dir provided for database name"
362-
raise ValueError(msg)
363-
if connection_string is None and cosmosdb_account_url is None:
364-
msg = "connection_string or cosmosdb_account_url is required."
365-
raise ValueError(msg)
366-
return CosmosDBPipelineStorage(
367-
cosmosdb_account_url=cosmosdb_account_url,
368-
connection_string=connection_string,
369-
database_name=base_dir,
370-
container_name=container_name,
371-
)
372-
373-
374356
def _create_progress_status(
375357
num_loaded: int, num_filtered: int, num_total: int
376358
) -> Progress:

graphrag/storage/factory.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from typing import TYPE_CHECKING, ClassVar
99

1010
from graphrag.config.enums import StorageType
11-
from graphrag.storage.blob_pipeline_storage import create_blob_storage
12-
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
13-
from graphrag.storage.file_pipeline_storage import create_file_storage
11+
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
12+
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
13+
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
1414
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
1515

1616
if TYPE_CHECKING:
@@ -28,7 +28,7 @@ class StorageFactory:
2828
for individual enforcement of required/optional arguments.
2929
"""
3030

31-
_storage_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {}
31+
_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {}
3232

3333
@classmethod
3434
def register(
@@ -44,10 +44,7 @@ def register(
4444
------
4545
TypeError: If creator is a class type instead of a factory function.
4646
"""
47-
if isinstance(creator, type):
48-
msg = "Registering classes directly is no longer supported. Please provide a factory function instead."
49-
raise TypeError(msg)
50-
cls._storage_registry[storage_type] = creator
47+
cls._registry[storage_type] = creator
5148

5249
@classmethod
5350
def create_storage(
@@ -67,31 +64,31 @@ def create_storage(
6764
------
6865
ValueError: If the storage type is not registered.
6966
"""
70-
storage_type_str = (
67+
type_str = (
7168
storage_type.value
7269
if isinstance(storage_type, StorageType)
7370
else storage_type
7471
)
7572

76-
if storage_type_str not in cls._storage_registry:
73+
if type_str not in cls._storage_registry:
7774
msg = f"Unknown storage type: {storage_type}"
7875
raise ValueError(msg)
7976

80-
return cls._storage_registry[storage_type_str](**kwargs)
77+
return cls._storage_registry[type_str](**kwargs)
8178

8279
@classmethod
8380
def get_storage_types(cls) -> list[str]:
8481
"""Get the registered storage implementations."""
85-
return list(cls._storage_registry.keys())
82+
return list(cls._registry.keys())
8683

8784
@classmethod
88-
def is_supported_storage_type(cls, storage_type: str) -> bool:
85+
def is_supported_type(cls, storage_type: str) -> bool:
8986
"""Check if the given storage type is supported."""
90-
return storage_type in cls._storage_registry
87+
return storage_type in cls._registry
9188

9289

93-
# --- Register default implementations ---
94-
StorageFactory.register(StorageType.blob.value, create_blob_storage)
95-
StorageFactory.register(StorageType.cosmosdb.value, create_cosmosdb_storage)
96-
StorageFactory.register(StorageType.file.value, create_file_storage)
97-
StorageFactory.register(StorageType.memory.value, lambda **_: MemoryPipelineStorage())
90+
# --- register built-in storage implementations ---
91+
StorageFactory.register(StorageType.blob.value, BlobPipelineStorage)
92+
StorageFactory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage)
93+
StorageFactory.register(StorageType.file.value, FilePipelineStorage)
94+
StorageFactory.register(StorageType.memory.value, MemoryPipelineStorage)

graphrag/storage/file_pipeline_storage.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ class FilePipelineStorage(PipelineStorage):
3030
_root_dir: str
3131
_encoding: str
3232

33-
def __init__(self, root_dir: str = "", encoding: str = "utf-8"):
34-
"""Init method definition."""
35-
self._root_dir = root_dir
36-
self._encoding = encoding
33+
def __init__(self, **kwargs: Any) -> None:
34+
"""Create a file based storage."""
35+
self._root_dir = kwargs["base_dir"]
36+
self._encoding = kwargs.get("encoding", "utf-8")
37+
logger.info("Creating file storage at %s", self._root_dir)
3738
Path(self._root_dir).mkdir(parents=True, exist_ok=True)
3839

3940
def find(
@@ -167,10 +168,3 @@ async def get_creation_date(self, key: str) -> str:
167168
def join_path(file_path: str, file_name: str) -> Path:
168169
"""Join a path and a file. Independent of the OS."""
169170
return Path(file_path) / Path(file_name).parent / Path(file_name).name
170-
171-
172-
def create_file_storage(**kwargs: Any) -> PipelineStorage:
173-
"""Create a file based storage."""
174-
base_dir = kwargs["base_dir"]
175-
logger.info("Creating file storage at %s", base_dir)
176-
return FilePipelineStorage(root_dir=base_dir)

graphrag/vector_stores/factory.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from enum import Enum
99
from typing import TYPE_CHECKING, ClassVar
1010

11+
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
12+
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
13+
from graphrag.vector_stores.lancedb import LanceDBVectorStore
14+
1115
if TYPE_CHECKING:
1216
from collections.abc import Callable
1317

@@ -31,7 +35,7 @@ class VectorStoreFactory:
3135
for individual enforcement of required/optional arguments.
3236
"""
3337

34-
_vector_store_registry: ClassVar[dict[str, Callable[..., BaseVectorStore]]] = {}
38+
_registry: ClassVar[dict[str, Callable[..., BaseVectorStore]]] = {}
3539

3640
@classmethod
3741
def register(
@@ -47,10 +51,7 @@ def register(
4751
------
4852
TypeError: If creator is a class type instead of a factory function.
4953
"""
50-
if isinstance(creator, type):
51-
msg = "Registering classes directly is no longer supported. Please provide a factory function instead."
52-
raise TypeError(msg)
53-
cls._vector_store_registry[vector_store_type] = creator
54+
cls._registry[vector_store_type] = creator
5455

5556
@classmethod
5657
def create_vector_store(
@@ -70,56 +71,32 @@ def create_vector_store(
7071
------
7172
ValueError: If the vector store type is not registered.
7273
"""
73-
vector_store_type_str = (
74+
type_str = (
7475
vector_store_type.value
7576
if isinstance(vector_store_type, VectorStoreType)
7677
else vector_store_type
7778
)
7879

79-
if vector_store_type_str not in cls._vector_store_registry:
80+
if type_str not in cls._registry:
8081
msg = f"Unknown vector store type: {vector_store_type}"
8182
raise ValueError(msg)
8283

83-
return cls._vector_store_registry[vector_store_type_str](**kwargs)
84+
return cls._registry[type_str](**kwargs)
8485

8586
@classmethod
8687
def get_vector_store_types(cls) -> list[str]:
8788
"""Get the registered vector store implementations."""
88-
return list(cls._vector_store_registry.keys())
89+
return list(cls._registry.keys())
8990

9091
@classmethod
91-
def is_supported_vector_store_type(cls, vector_store_type: str) -> bool:
92+
def is_supported_type(cls, vector_store_type: str) -> bool:
9293
"""Check if the given vector store type is supported."""
93-
return vector_store_type in cls._vector_store_registry
94-
95-
96-
# --- Factory functions for built-in vector stores ---
97-
def create_lancedb_vector_store(**kwargs) -> BaseVectorStore:
98-
"""Create a LanceDB vector store."""
99-
from graphrag.vector_stores.lancedb import LanceDBVectorStore
100-
101-
return LanceDBVectorStore(**kwargs)
102-
103-
104-
def create_azure_ai_search_vector_store(**kwargs) -> BaseVectorStore:
105-
"""Create an Azure AI Search vector store."""
106-
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
94+
return vector_store_type in cls._registry
10795

108-
return AzureAISearchVectorStore(**kwargs)
10996

110-
111-
def create_cosmosdb_vector_store(**kwargs) -> BaseVectorStore:
112-
"""Create a CosmosDB vector store."""
113-
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
114-
115-
return CosmosDBVectorStore(**kwargs)
116-
117-
118-
# --- register default implementations ---
119-
VectorStoreFactory.register(VectorStoreType.LanceDB.value, create_lancedb_vector_store)
120-
VectorStoreFactory.register(
121-
VectorStoreType.AzureAISearch.value, create_azure_ai_search_vector_store
122-
)
97+
# --- register built-in vector store implementations ---
98+
VectorStoreFactory.register(VectorStoreType.LanceDB.value, LanceDBVectorStore)
12399
VectorStoreFactory.register(
124-
VectorStoreType.CosmosDB.value, create_cosmosdb_vector_store
100+
VectorStoreType.AzureAISearch.value, AzureAISearchVectorStore
125101
)
102+
VectorStoreFactory.register(VectorStoreType.CosmosDB.value, CosmosDBVectorStore)

tests/integration/storage/test_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_register_and_create_custom_storage():
8484

8585
# Check if it's in the list of registered storage types
8686
assert "custom" in StorageFactory.get_storage_types()
87-
assert StorageFactory.is_supported_storage_type("custom")
87+
assert StorageFactory.is_supported_type("custom")
8888

8989

9090
def test_get_storage_types():

tests/integration/vector_stores/test_factory.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_register_and_create_custom_vector_store():
7575

7676
# Check if it's in the list of registered vector store types
7777
assert "custom" in VectorStoreFactory.get_vector_store_types()
78-
assert VectorStoreFactory.is_supported_vector_store_type("custom")
78+
assert VectorStoreFactory.is_supported_type("custom")
7979

8080

8181
def test_get_vector_store_types():
@@ -91,20 +91,14 @@ def test_create_unknown_vector_store():
9191
VectorStoreFactory.create_vector_store("unknown", {})
9292

9393

94-
def test_is_supported_vector_store_type():
94+
def test_is_supported_type():
9595
# Test built-in types
96-
assert VectorStoreFactory.is_supported_vector_store_type(
97-
VectorStoreType.LanceDB.value
98-
)
99-
assert VectorStoreFactory.is_supported_vector_store_type(
100-
VectorStoreType.AzureAISearch.value
101-
)
102-
assert VectorStoreFactory.is_supported_vector_store_type(
103-
VectorStoreType.CosmosDB.value
104-
)
96+
assert VectorStoreFactory.is_supported_type(VectorStoreType.LanceDB.value)
97+
assert VectorStoreFactory.is_supported_type(VectorStoreType.AzureAISearch.value)
98+
assert VectorStoreFactory.is_supported_type(VectorStoreType.CosmosDB.value)
10599

106100
# Test unknown type
107-
assert not VectorStoreFactory.is_supported_vector_store_type("unknown")
101+
assert not VectorStoreFactory.is_supported_type("unknown")
108102

109103

110104
def test_enum_and_string_compatibility():

0 commit comments

Comments
 (0)