Skip to content

Commit cbb8f87

Browse files
authored
Fix storage class instantiation (#1582)
1 parent a35cb12 commit cbb8f87

File tree

5 files changed

+92
-8
lines changed

5 files changed

+92
-8
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "fix instantiation of storage classes."
4+
}

graphrag/cache/factory.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import TYPE_CHECKING, ClassVar
99

1010
from graphrag.config.enums import CacheType
11-
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
11+
from graphrag.storage.blob_pipeline_storage import create_blob_storage
1212
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
1313
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
1414

@@ -24,6 +24,9 @@ class CacheFactory:
2424
"""A factory class for cache implementations.
2525
2626
Includes a method for users to register a custom cache implementation.
27+
28+
Configuration arguments are passed to each cache implementation as kwargs (where possible)
29+
for individual enforcement of required/optional arguments.
2730
"""
2831

2932
cache_types: ClassVar[dict[str, type]] = {}
@@ -50,7 +53,7 @@ def create_cache(
5053
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
5154
)
5255
case CacheType.blob:
53-
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
56+
return JsonPipelineCache(create_blob_storage(**kwargs))
5457
case CacheType.cosmosdb:
5558
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
5659
case _:

graphrag/storage/blob_pipeline_storage.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,12 @@ def _abfs_url(self, key: str) -> str:
290290
return f"abfs://{path}"
291291

292292

293-
def create_blob_storage(
294-
connection_string: str | None,
295-
storage_account_blob_url: str | None,
296-
container_name: str,
297-
base_dir: str | None,
298-
) -> PipelineStorage:
293+
def create_blob_storage(**kwargs: Any) -> PipelineStorage:
299294
"""Create a blob based storage."""
295+
connection_string = kwargs.get("connection_string")
296+
storage_account_blob_url = kwargs.get("storage_account_blob_url")
297+
base_dir = kwargs.get("base_dir")
298+
container_name = kwargs["container_name"]
300299
log.info("Creating blob storage at %s", container_name)
301300
if container_name is None:
302301
msg = "No container name provided for blob storage."

graphrag/storage/factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class StorageFactory:
2121
"""A factory class for storage implementations.
2222
2323
Includes a method for users to register a custom storage implementation.
24+
25+
Configuration arguments are passed to each storage implementation as kwargs
26+
for individual enforcement of required/optional arguments.
2427
"""
2528

2629
storage_types: ClassVar[dict[str, type]] = {}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
"""StorageFactory Tests.
4+
5+
These tests will test the StorageFactory class and the creation of each storage type that is natively supported.
6+
"""
7+
8+
import sys
9+
10+
import pytest
11+
12+
from graphrag.config.enums import StorageType
13+
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
14+
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
15+
from graphrag.storage.factory import StorageFactory
16+
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
17+
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
18+
19+
# cspell:disable-next-line well-known-key
20+
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
21+
# cspell:disable-next-line well-known-key
22+
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
23+
24+
25+
def test_create_blob_storage():
26+
kwargs = {
27+
"type": "blob",
28+
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
29+
"base_dir": "testbasedir",
30+
"container_name": "testcontainer",
31+
}
32+
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
33+
assert isinstance(storage, BlobPipelineStorage)
34+
35+
36+
@pytest.mark.skipif(
37+
not sys.platform.startswith("win"),
38+
reason="cosmosdb emulator is only available on windows runners at this time",
39+
)
40+
def test_create_cosmosdb_storage():
41+
kwargs = {
42+
"type": "cosmosdb",
43+
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
44+
"base_dir": "testdatabase",
45+
"container_name": "testcontainer",
46+
}
47+
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
48+
assert isinstance(storage, CosmosDBPipelineStorage)
49+
50+
51+
def test_create_file_storage():
52+
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
53+
storage = StorageFactory.create_storage(StorageType.file, kwargs)
54+
assert isinstance(storage, FilePipelineStorage)
55+
56+
57+
def test_create_memory_storage():
58+
kwargs = {"type": "memory"}
59+
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
60+
assert isinstance(storage, MemoryPipelineStorage)
61+
62+
63+
def test_register_and_create_custom_storage():
64+
class CustomStorage:
65+
def __init__(self, **kwargs):
66+
pass
67+
68+
StorageFactory.register("custom", CustomStorage)
69+
storage = StorageFactory.create_storage("custom", {})
70+
assert isinstance(storage, CustomStorage)
71+
72+
73+
def test_create_unknown_storage():
74+
with pytest.raises(ValueError, match="Unknown storage type: unknown"):
75+
StorageFactory.create_storage("unknown", {})

0 commit comments

Comments
 (0)