Skip to content

Commit 3201f28

Browse files
authored
Add GraphRAG Cache package. (#2153)
* Add GraphRAG Cache package.
1 parent bffa400 commit 3201f28

File tree

36 files changed

+1763
-2055
lines changed

36 files changed

+1763
-2055
lines changed

docs/config/yaml.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,14 @@ This section controls the cache mechanism used by the pipeline. This is used to
141141

142142
#### Fields
143143

144-
- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file`
145-
- `base_dir` **str** - The base directory to write output artifacts to, relative to the root.
146-
- `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string.
147-
- `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name.
148-
- `storage_account_blob_url` **str** - (blob only) The storage account blob URL to use.
149-
- `cosmosdb_account_blob_url` **str** - (cosmosdb only) The CosmosDB account blob URL to use.
144+
- `type` **json|memory|none** - The storage type to use. Default=`json`
145+
- `storage` **StorageConfig**
146+
- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file`
147+
- `base_dir` **str** - The base directory to write output artifacts to, relative to the root.
148+
- `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string.
149+
- `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name.
150+
- `storage_account_blob_url` **str** - (blob only) The storage account blob URL to use.
151+
- `cosmosdb_account_blob_url` **str** - (cosmosdb only) The CosmosDB account blob URL to use.
150152

151153
### reporting
152154

packages/graphrag-cache/README.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# GraphRAG Cache
2+
3+
## Basic
4+
5+
```python
6+
import asyncio
7+
from graphrag_storage import StorageConfig, create_storage, StorageType
8+
from graphrag_cache import CacheConfig, create_cache, CacheType
9+
10+
async def run():
11+
cache = create_cache(
12+
CacheConfig(
13+
type=CacheType.Json
14+
storage=StorageConfig(
15+
type=StorageType.File,
16+
base_dir="cache"
17+
)
18+
),
19+
)
20+
21+
await cache.set("my_key", {"some": "object to cache"})
22+
print(await cache.get("my_key"))
23+
24+
if __name__ == "__main__":
25+
asyncio.run(run())
26+
```
27+
28+
## Custom Cache
29+
30+
```python
31+
import asyncio
32+
from typing import Any
33+
from graphrag_storage import Storage
34+
from graphrag_cache import Cache, CacheConfig, create_cache, register_cache
35+
36+
class MyCache(Cache):
37+
def __init__(self, some_setting: str, optional_setting: str = "default setting", **kwargs: Any):
38+
# Validate settings and initialize
39+
# View the JsonCache implementation to see how to create a cache that relies on a Storage provider.
40+
...
41+
42+
#Implement rest of interface
43+
...
44+
45+
register_cache("MyCache", MyCache)
46+
47+
async def run():
48+
cache = create_cache(
49+
CacheConfig(
50+
type="MyCache",
51+
some_setting="My Setting"
52+
)
53+
)
54+
55+
# Or use the factory directly to instantiate with a dict instead of using
56+
# CacheConfig + create_factory
57+
# from graphrag_cache.cache_factory import cache_factory
58+
# cache = cache_factory.create(strategy="MyCache", init_args={"some_setting": "My Setting"})
59+
60+
await cache.set("my_key", {"some": "object to cache"})
61+
print(await cache.get("my_key"))
62+
63+
if __name__ == "__main__":
64+
asyncio.run(run())
65+
```
66+
67+
### Details
68+
69+
By default, the `create_cache` comes with the following cache providers registered that correspond to the entries in the `CacheType` enum.
70+
71+
- `JsonCache`
72+
- `MemoryCache`
73+
- `NoopCache`
74+
75+
The preregistration happens dynamically, e.g., `JsonCache` is only imported and registered if you request a `JsonCache` with `create_cache(CacheType.Json, ...)`. There is no need to manually import and register builtin cache providers when using `create_cache`.
76+
77+
If you want a clean factory with no preregistered cache providers then directly import `cache_factory` and bypass using `create_cache`. The downside is that `cache_factory.create` uses a dict for init args instead of the strongly typed `CacheConfig` used with `create_cache`.
78+
79+
```python
80+
from graphrag_cache.cache_factory import cache_factory
81+
from graphrag_cache.json_cache import JsonCache
82+
83+
# cache_factory has no preregistered providers so you must register any
84+
# providers you plan on using.
85+
# May also register a custom implementation, see above for example.
86+
cache_factory.register("my_cache_impl", JsonCache)
87+
88+
cache = cache_factory.create(strategy="my_cache_impl", init_args={"some_setting": "..."})
89+
90+
...
91+
92+
```
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""The GraphRAG Cache package."""
5+
6+
from graphrag_cache.cache import Cache
7+
from graphrag_cache.cache_config import CacheConfig
8+
from graphrag_cache.cache_factory import create_cache, register_cache
9+
from graphrag_cache.cache_type import CacheType
10+
11+
__all__ = [
12+
"Cache",
13+
"CacheConfig",
14+
"CacheType",
15+
"create_cache",
16+
"register_cache",
17+
]

packages/graphrag/graphrag/cache/pipeline_cache.py renamed to packages/graphrag-cache/graphrag_cache/cache.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
# Copyright (c) 2024 Microsoft Corporation.
22
# Licensed under the MIT License
33

4-
"""A module containing 'PipelineCache' model."""
4+
"""Abstract base class for cache."""
55

66
from __future__ import annotations
77

8-
from abc import ABCMeta, abstractmethod
8+
from abc import ABC, abstractmethod
99
from typing import Any
1010

1111

12-
class PipelineCache(metaclass=ABCMeta):
12+
class Cache(ABC):
1313
"""Provide a cache interface for the pipeline."""
1414

15+
@abstractmethod
16+
def __init__(self, **kwargs: Any) -> None:
17+
"""Create a cache instance."""
18+
1519
@abstractmethod
1620
async def get(self, key: str) -> Any:
1721
"""Get the value for the given key.
@@ -59,7 +63,7 @@ async def clear(self) -> None:
5963
"""Clear the cache."""
6064

6165
@abstractmethod
62-
def child(self, name: str) -> PipelineCache:
66+
def child(self, name: str) -> Cache:
6367
"""Create a child cache with the given name.
6468
6569
Args:
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Cache configuration model."""
5+
6+
from graphrag_storage import StorageConfig
7+
from pydantic import BaseModel, ConfigDict, Field
8+
9+
from graphrag_cache.cache_type import CacheType
10+
11+
12+
class CacheConfig(BaseModel):
13+
"""The configuration section for cache."""
14+
15+
model_config = ConfigDict(extra="allow")
16+
"""Allow extra fields to support custom cache implementations."""
17+
18+
type: str = Field(
19+
description="The cache type to use. Builtin types include 'Json', 'Memory', and 'Noop'.",
20+
default=CacheType.Json,
21+
)
22+
23+
storage: StorageConfig | None = Field(
24+
description="The storage configuration to use for file-based caches such as 'Json'.",
25+
default=None,
26+
)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
5+
"""Cache factory implementation."""
6+
7+
from collections.abc import Callable
8+
9+
from graphrag_common.factory import Factory, ServiceScope
10+
from graphrag_storage import Storage
11+
12+
from graphrag_cache.cache import Cache
13+
from graphrag_cache.cache_config import CacheConfig
14+
from graphrag_cache.cache_type import CacheType
15+
16+
17+
class CacheFactory(Factory[Cache]):
18+
"""A factory class for cache implementations."""
19+
20+
21+
cache_factory = CacheFactory()
22+
23+
24+
def register_cache(
25+
cache_type: str,
26+
cache_initializer: Callable[..., Cache],
27+
scope: ServiceScope = "transient",
28+
) -> None:
29+
"""Register a custom storage implementation.
30+
31+
Args
32+
----
33+
- cache_type: str
34+
The cache id to register.
35+
- cache_initializer: Callable[..., Cache]
36+
The cache initializer to register.
37+
"""
38+
cache_factory.register(cache_type, cache_initializer, scope)
39+
40+
41+
def create_cache(config: CacheConfig, storage: Storage | None = None) -> Cache:
42+
"""Create a cache implementation based on the given configuration.
43+
44+
Args
45+
----
46+
- config: CacheConfig
47+
The cache configuration to use.
48+
- storage: Storage | None
49+
The storage implementation to use for file-based caches such as 'Json'.
50+
51+
Returns
52+
-------
53+
Cache
54+
The created cache implementation.
55+
"""
56+
config_model = config.model_dump()
57+
cache_strategy = config.type
58+
59+
if cache_strategy not in cache_factory:
60+
match cache_strategy:
61+
case CacheType.Json:
62+
from graphrag_cache.json_cache import JsonCache
63+
64+
register_cache(CacheType.Json, JsonCache)
65+
66+
case CacheType.Memory:
67+
from graphrag_cache.memory_cache import MemoryCache
68+
69+
register_cache(CacheType.Memory, MemoryCache)
70+
71+
case CacheType.Noop:
72+
from graphrag_cache.noop_cache import NoopCache
73+
74+
register_cache(CacheType.Noop, NoopCache)
75+
76+
case _:
77+
msg = f"CacheConfig.type '{cache_strategy}' is not registered in the CacheFactory. Registered types: {', '.join(cache_factory.keys())}."
78+
raise ValueError(msg)
79+
80+
if storage:
81+
config_model["storage"] = storage
82+
83+
return cache_factory.create(strategy=cache_strategy, init_args=config_model)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
5+
"""Builtin cache implementation types."""
6+
7+
from enum import StrEnum
8+
9+
10+
class CacheType(StrEnum):
11+
"""Enum for cache types."""
12+
13+
Json = "json"
14+
Memory = "memory"
15+
Noop = "none"

packages/graphrag/graphrag/cache/json_pipeline_cache.py renamed to packages/graphrag-cache/graphrag_cache/json_cache.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,35 @@
66
import json
77
from typing import Any
88

9-
from graphrag_storage import Storage
9+
from graphrag_storage import Storage, StorageConfig, create_storage
1010

11-
from graphrag.cache.pipeline_cache import PipelineCache
11+
from graphrag_cache.cache import Cache
1212

1313

14-
class JsonPipelineCache(PipelineCache):
14+
class JsonCache(Cache):
1515
"""File pipeline cache class definition."""
1616

1717
_storage: Storage
18-
_encoding: str
1918

20-
def __init__(self, storage: Storage, encoding="utf-8"):
19+
def __init__(
20+
self,
21+
storage: Storage | dict[str, Any] | None = None,
22+
**kwargs: Any,
23+
) -> None:
2124
"""Init method definition."""
22-
self._storage = storage
23-
self._encoding = encoding
25+
if storage is None:
26+
msg = "JsonCache requires either a Storage instance to be provided or a StorageConfig to create one."
27+
raise ValueError(msg)
28+
if isinstance(storage, Storage):
29+
self._storage = storage
30+
else:
31+
self._storage = create_storage(StorageConfig(**storage))
2432

25-
async def get(self, key: str) -> str | None:
33+
async def get(self, key: str) -> Any | None:
2634
"""Get method definition."""
2735
if await self.has(key):
2836
try:
29-
data = await self._storage.get(key, encoding=self._encoding)
37+
data = await self._storage.get(key)
3038
data = json.loads(data)
3139
except UnicodeDecodeError:
3240
await self._storage.delete(key)
@@ -44,9 +52,7 @@ async def set(self, key: str, value: Any, debug_data: dict | None = None) -> Non
4452
if value is None:
4553
return
4654
data = {"result": value, **(debug_data or {})}
47-
await self._storage.set(
48-
key, json.dumps(data, ensure_ascii=False), encoding=self._encoding
49-
)
55+
await self._storage.set(key, json.dumps(data, ensure_ascii=False))
5056

5157
async def has(self, key: str) -> bool:
5258
"""Has method definition."""
@@ -61,6 +67,6 @@ async def clear(self) -> None:
6167
"""Clear method definition."""
6268
await self._storage.clear()
6369

64-
def child(self, name: str) -> "JsonPipelineCache":
70+
def child(self, name: str) -> "Cache":
6571
"""Child method definition."""
66-
return JsonPipelineCache(self._storage.child(name), encoding=self._encoding)
72+
return JsonCache(storage=self._storage.child(name))

0 commit comments

Comments
 (0)