Skip to content

Commit 5e84934

Browse files
Copilotjgbradley1
andcommitted
Refactor CacheFactory to use registration functionality like StorageFactory
Co-authored-by: jgbradley1 <[email protected]>
1 parent 8e06e85 commit 5e84934

File tree

10 files changed

+318
-44
lines changed

10 files changed

+318
-44
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
changes:
2+
- type: patch
3+
description: Refactor CacheFactory to use registration functionality like StorageFactory and VectorStoreFactory. Added factory functions, utility methods, and comprehensive test suite. Fixed issues with FilePipelineStorage child method and MemoryPipelineStorage constructor compatibility.

docs/examples_notebooks/index_migration_to_v1.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,11 @@
202202
"metadata": {},
203203
"outputs": [],
204204
"source": [
205+
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
206+
"\n",
205207
"from graphrag.cache.factory import CacheFactory\n",
206208
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
207209
"from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings\n",
208-
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
209210
"\n",
210211
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
211212
"# We'll construct the context and run this function flow directly to avoid everything else\n",

examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
"\n",
3030
"import pandas as pd\n",
3131
"import tiktoken\n",
32+
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
33+
"from graphrag.query.llm.oai.embedding import OpenAIEmbedding\n",
34+
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
3235
"\n",
3336
"from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n",
3437
"from graphrag.query.indexer_adapters import (\n",
@@ -38,9 +41,6 @@
3841
" read_indexer_reports,\n",
3942
" read_indexer_text_units,\n",
4043
")\n",
41-
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
42-
"from graphrag.query.llm.oai.embedding import OpenAIEmbedding\n",
43-
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
4444
"from graphrag.query.structured_search.local_search.mixed_context import (\n",
4545
" LocalSearchMixedContext,\n",
4646
")\n",

graphrag/cache/factory.py

Lines changed: 99 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,133 @@
11
# Copyright (c) 2024 Microsoft Corporation.
22
# Licensed under the MIT License
33

4-
"""A module containing create_cache method definition."""
4+
"""A module containing cache factory for creating cache implementations."""
55

66
from __future__ import annotations
77

88
from typing import TYPE_CHECKING, ClassVar
99

1010
from graphrag.config.enums import CacheType
11-
from graphrag.storage.blob_pipeline_storage import create_blob_storage
12-
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
11+
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
12+
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
1313
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
1414

1515
if TYPE_CHECKING:
16+
from collections.abc import Callable
17+
1618
from graphrag.cache.pipeline_cache import PipelineCache
1719

1820
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
1921
from graphrag.cache.memory_pipeline_cache import InMemoryCache
2022
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
2123

2224

25+
def create_noop_cache(**_kwargs) -> PipelineCache:
26+
"""Create a no-op cache implementation."""
27+
return NoopPipelineCache()
28+
29+
30+
def create_memory_cache(**_kwargs) -> PipelineCache:
31+
"""Create an in-memory cache implementation."""
32+
return InMemoryCache()
33+
34+
35+
def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache:
36+
"""Create a file-based cache implementation."""
37+
# Create storage with base_dir in kwargs since FilePipelineStorage expects it there
38+
storage_kwargs = {"base_dir": root_dir, **kwargs}
39+
storage = FilePipelineStorage(**storage_kwargs).child(base_dir)
40+
return JsonPipelineCache(storage)
41+
42+
43+
def create_blob_cache(**kwargs) -> PipelineCache:
44+
"""Create a blob storage-based cache implementation."""
45+
storage = BlobPipelineStorage(**kwargs)
46+
return JsonPipelineCache(storage)
47+
48+
49+
def create_cosmosdb_cache(**kwargs) -> PipelineCache:
50+
"""Create a CosmosDB-based cache implementation."""
51+
storage = CosmosDBPipelineStorage(**kwargs)
52+
return JsonPipelineCache(storage)
53+
54+
2355
class CacheFactory:
2456
"""A factory class for cache implementations.
2557
2658
Includes a method for users to register a custom cache implementation.
2759
28-
Configuration arguments are passed to each cache implementation as kwargs (where possible)
60+
Configuration arguments are passed to each cache implementation as kwargs
2961
for individual enforcement of required/optional arguments.
3062
"""
3163

32-
cache_types: ClassVar[dict[str, type]] = {}
64+
_registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {}
3365

3466
@classmethod
35-
def register(cls, cache_type: str, cache: type):
36-
"""Register a custom cache implementation."""
37-
cls.cache_types[cache_type] = cache
67+
def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None:
68+
"""Register a custom cache implementation.
69+
70+
Args:
71+
cache_type: The type identifier for the cache.
72+
creator: A callable that creates an instance of the cache.
73+
74+
Raises
75+
------
76+
TypeError: If creator is a class type instead of a factory function.
77+
"""
78+
if isinstance(creator, type):
79+
msg = "Registering classes directly is no longer supported. Please provide a factory function instead."
80+
raise TypeError(msg)
81+
cls._registry[cache_type] = creator
3882

3983
@classmethod
4084
def create_cache(
4185
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
4286
) -> PipelineCache:
43-
"""Create or get a cache from the provided type."""
44-
if not cache_type:
45-
return NoopPipelineCache()
46-
match cache_type:
47-
case CacheType.none:
48-
return NoopPipelineCache()
49-
case CacheType.memory:
50-
return InMemoryCache()
51-
case CacheType.file:
52-
return JsonPipelineCache(
53-
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
54-
)
55-
case CacheType.blob:
56-
return JsonPipelineCache(create_blob_storage(**kwargs))
57-
case CacheType.cosmosdb:
58-
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
59-
case _:
60-
if cache_type in cls.cache_types:
61-
return cls.cache_types[cache_type](**kwargs)
62-
msg = f"Unknown cache type: {cache_type}"
63-
raise ValueError(msg)
87+
"""Create a cache object from the provided type.
88+
89+
Args:
90+
cache_type: The type of cache to create.
91+
root_dir: The root directory for file-based caches.
92+
kwargs: Additional keyword arguments for the cache constructor.
93+
94+
Returns
95+
-------
96+
A PipelineCache instance.
97+
98+
Raises
99+
------
100+
ValueError: If the cache type is not registered.
101+
"""
102+
if not cache_type or cache_type == CacheType.none:
103+
return create_noop_cache()
104+
105+
type_str = cache_type.value if isinstance(cache_type, CacheType) else cache_type
106+
107+
if type_str not in cls._registry:
108+
msg = f"Unknown cache type: {cache_type}"
109+
raise ValueError(msg)
110+
111+
# Add root_dir to kwargs for file cache
112+
if type_str == CacheType.file.value:
113+
kwargs = {**kwargs, "root_dir": root_dir}
114+
115+
return cls._registry[type_str](**kwargs)
116+
117+
@classmethod
118+
def get_cache_types(cls) -> list[str]:
119+
"""Get the registered cache implementations."""
120+
return list(cls._registry.keys())
121+
122+
@classmethod
123+
def is_supported_type(cls, cache_type: str) -> bool:
124+
"""Check if the given cache type is supported."""
125+
return cache_type in cls._registry
126+
127+
128+
# --- register built-in cache implementations ---
129+
CacheFactory.register(CacheType.none.value, create_noop_cache)
130+
CacheFactory.register(CacheType.memory.value, create_memory_cache)
131+
CacheFactory.register(CacheType.file.value, create_file_cache)
132+
CacheFactory.register(CacheType.blob.value, create_blob_cache)
133+
CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)

graphrag/storage/factory.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def register(
4444
------
4545
TypeError: If creator is a class type instead of a factory function.
4646
"""
47+
if isinstance(creator, type):
48+
msg = "Registering classes directly is no longer supported. Please provide a factory function instead."
49+
raise TypeError(msg)
4750
cls._registry[storage_type] = creator
4851

4952
@classmethod
@@ -70,11 +73,11 @@ def create_storage(
7073
else storage_type
7174
)
7275

73-
if type_str not in cls._storage_registry:
76+
if type_str not in cls._registry:
7477
msg = f"Unknown storage type: {storage_type}"
7578
raise ValueError(msg)
7679

77-
return cls._storage_registry[type_str](**kwargs)
80+
return cls._registry[type_str](**kwargs)
7881

7982
@classmethod
8083
def get_storage_types(cls) -> list[str]:
@@ -88,7 +91,24 @@ def is_supported_type(cls, storage_type: str) -> bool:
8891

8992

9093
# --- register built-in storage implementations ---
91-
StorageFactory.register(StorageType.blob.value, BlobPipelineStorage)
92-
StorageFactory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage)
93-
StorageFactory.register(StorageType.file.value, FilePipelineStorage)
94-
StorageFactory.register(StorageType.memory.value, MemoryPipelineStorage)
94+
StorageFactory.register(
95+
StorageType.blob.value,
96+
lambda **kwargs: BlobPipelineStorage(**{
97+
k: v for k, v in kwargs.items() if k != "type"
98+
}),
99+
)
100+
StorageFactory.register(
101+
StorageType.cosmosdb.value,
102+
lambda **kwargs: CosmosDBPipelineStorage(**{
103+
k: v for k, v in kwargs.items() if k != "type"
104+
}),
105+
)
106+
StorageFactory.register(
107+
StorageType.file.value,
108+
lambda **kwargs: FilePipelineStorage(**{
109+
k: v for k, v in kwargs.items() if k != "type"
110+
}),
111+
)
112+
StorageFactory.register(
113+
StorageType.memory.value, lambda **_kwargs: MemoryPipelineStorage()
114+
)

graphrag/storage/file_pipeline_storage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def child(self, name: str | None) -> "PipelineStorage":
149149
"""Create a child storage instance."""
150150
if name is None:
151151
return self
152-
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
152+
child_path = str(Path(self._root_dir) / Path(name))
153+
return FilePipelineStorage(base_dir=child_path, encoding=self._encoding)
153154

154155
def keys(self) -> list[str]:
155156
"""Return the keys in the storage."""

graphrag/storage/memory_pipeline_storage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ class MemoryPipelineStorage(FilePipelineStorage):
1818

1919
def __init__(self):
2020
"""Init method definition."""
21-
super().__init__()
21+
# MemoryPipelineStorage doesn't need actual file storage, use temp dir
22+
import tempfile
23+
24+
temp_dir = tempfile.mkdtemp()
25+
super().__init__(base_dir=temp_dir)
2226
self._storage = {}
2327

2428
async def get(

graphrag/vector_stores/factory.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def register(
5151
------
5252
TypeError: If creator is a class type instead of a factory function.
5353
"""
54+
if isinstance(creator, type):
55+
msg = "Registering classes directly is no longer supported. Please provide a factory function instead."
56+
raise TypeError(msg)
5457
cls._registry[vector_store_type] = creator
5558

5659
@classmethod
@@ -95,8 +98,13 @@ def is_supported_type(cls, vector_store_type: str) -> bool:
9598

9699

97100
# --- register built-in vector store implementations ---
98-
VectorStoreFactory.register(VectorStoreType.LanceDB.value, LanceDBVectorStore)
99101
VectorStoreFactory.register(
100-
VectorStoreType.AzureAISearch.value, AzureAISearchVectorStore
102+
VectorStoreType.LanceDB.value, lambda **kwargs: LanceDBVectorStore(**kwargs)
103+
)
104+
VectorStoreFactory.register(
105+
VectorStoreType.AzureAISearch.value,
106+
lambda **kwargs: AzureAISearchVectorStore(**kwargs),
107+
)
108+
VectorStoreFactory.register(
109+
VectorStoreType.CosmosDB.value, lambda **kwargs: CosmosDBVectorStore(**kwargs)
101110
)
102-
VectorStoreFactory.register(VectorStoreType.CosmosDB.value, CosmosDBVectorStore)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License

0 commit comments

Comments
 (0)