Skip to content

Commit 3b4235b

Browse files
Copilotjgbradley1
andcommitted
Fix Python CI test failures and improve code quality
Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>
1 parent de9987f commit 3b4235b

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

graphrag/storage/factory.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from __future__ import annotations
77

8-
from collections.abc import Callable
9-
from typing import TYPE_CHECKING, Any, ClassVar
8+
from contextlib import suppress
9+
from typing import TYPE_CHECKING, ClassVar
1010

1111
from graphrag.config.enums import OutputType
1212
from graphrag.storage.blob_pipeline_storage import create_blob_storage
@@ -15,6 +15,8 @@
1515
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
1616

1717
if TYPE_CHECKING:
18+
from collections.abc import Callable
19+
1820
from graphrag.storage.pipeline_storage import PipelineStorage
1921

2022

@@ -31,47 +33,52 @@ class StorageFactory:
3133
storage_types: ClassVar[dict[str, type]] = {} # For backward compatibility
3234

3335
@classmethod
34-
def register(cls, storage_type: str, creator: Callable[..., PipelineStorage]) -> None:
36+
def register(
37+
cls, storage_type: str, creator: Callable[..., PipelineStorage]
38+
) -> None:
3539
"""Register a custom storage implementation.
36-
40+
3741
Args:
3842
storage_type: The type identifier for the storage.
3943
creator: A callable that creates an instance of the storage.
4044
"""
4145
cls._storage_registry[storage_type] = creator
42-
46+
4347
# For backward compatibility with code that may access storage_types directly
44-
if callable(creator) and hasattr(creator, "__annotations__") and "return" in creator.__annotations__:
45-
try:
48+
if (
49+
callable(creator)
50+
and hasattr(creator, "__annotations__")
51+
and "return" in creator.__annotations__
52+
):
53+
with suppress(TypeError, KeyError):
4654
cls.storage_types[storage_type] = creator.__annotations__["return"]
47-
except (TypeError, KeyError):
48-
# Just ignore if we can't maintain backward compatibility in this case
49-
pass
5055

5156
@classmethod
5257
def create_storage(
5358
cls, storage_type: OutputType | str, kwargs: dict
5459
) -> PipelineStorage:
5560
"""Create a storage object from the provided type.
56-
61+
5762
Args:
5863
storage_type: The type of storage to create.
5964
kwargs: Additional keyword arguments for the storage constructor.
60-
65+
6166
Returns
6267
-------
6368
A PipelineStorage instance.
64-
69+
6570
Raises
6671
------
6772
ValueError: If the storage type is not registered.
6873
"""
69-
storage_type_str = storage_type.value if isinstance(storage_type, OutputType) else storage_type
70-
74+
storage_type_str = (
75+
storage_type.value if isinstance(storage_type, OutputType) else storage_type
76+
)
77+
7178
if storage_type_str not in cls._storage_registry:
7279
msg = f"Unknown storage type: {storage_type}"
7380
raise ValueError(msg)
74-
81+
7582
return cls._storage_registry[storage_type_str](**kwargs)
7683

7784
@classmethod
@@ -89,4 +96,4 @@ def is_supported_storage_type(cls, storage_type: str) -> bool:
8996
StorageFactory.register(OutputType.blob.value, create_blob_storage)
9097
StorageFactory.register(OutputType.cosmosdb.value, create_cosmosdb_storage)
9198
StorageFactory.register(OutputType.file.value, create_file_storage)
92-
StorageFactory.register(OutputType.memory.value, lambda **kwargs: MemoryPipelineStorage())
99+
StorageFactory.register(OutputType.memory.value, lambda **_: MemoryPipelineStorage())

tests/integration/storage/test_factory.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from graphrag.storage.factory import StorageFactory
1616
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
1717
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
18+
from graphrag.storage.pipeline_storage import PipelineStorage
1819

1920
# cspell:disable-next-line well-known-key
2021
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
@@ -62,15 +63,25 @@ def test_create_memory_storage():
6263

6364

6465
def test_register_and_create_custom_storage():
65-
class CustomStorage:
66-
def __init__(self, **kwargs):
67-
self.initialized = True
68-
69-
StorageFactory.register("custom", lambda **kwargs: CustomStorage(**kwargs))
66+
"""Test registering and creating a custom storage type."""
67+
from unittest.mock import MagicMock
68+
69+
# Create a mock that satisfies the PipelineStorage interface
70+
custom_storage_class = MagicMock(spec=PipelineStorage)
71+
# Make the mock return a mock instance when instantiated
72+
instance = MagicMock()
73+
# We can set attributes on the mock instance, even if they don't exist on PipelineStorage
74+
instance.initialized = True
75+
custom_storage_class.return_value = instance
76+
77+
StorageFactory.register("custom", lambda **kwargs: custom_storage_class(**kwargs))
7078
storage = StorageFactory.create_storage("custom", {})
71-
assert isinstance(storage, CustomStorage)
72-
assert storage.initialized
73-
79+
80+
assert custom_storage_class.called
81+
assert storage is instance
82+
# Access the attribute we set on our mock
83+
assert storage.initialized is True # type: ignore # Attribute only exists on our mock
84+
7485
# Check if it's in the list of registered storage types
7586
assert "custom" in StorageFactory.get_storage_types()
7687
assert StorageFactory.is_supported_storage_type("custom")

0 commit comments

Comments
 (0)