|
1 | 1 | # Copyright (c) 2024 Microsoft Corporation. |
2 | 2 | # Licensed under the MIT License |
3 | 3 |
|
4 | | -"""A module containing create_cache method definition.""" |
| 4 | +"""A module containing cache factory for creating cache implementations.""" |
5 | 5 |
|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | from typing import TYPE_CHECKING, ClassVar |
9 | 9 |
|
10 | 10 | from graphrag.config.enums import CacheType |
11 | | -from graphrag.storage.blob_pipeline_storage import create_blob_storage |
12 | | -from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage |
| 11 | +from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage |
| 12 | +from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage |
13 | 13 | from graphrag.storage.file_pipeline_storage import FilePipelineStorage |
14 | 14 |
|
15 | 15 | if TYPE_CHECKING: |
| 16 | + from collections.abc import Callable |
| 17 | + |
16 | 18 | from graphrag.cache.pipeline_cache import PipelineCache |
17 | 19 |
|
18 | 20 | from graphrag.cache.json_pipeline_cache import JsonPipelineCache |
19 | 21 | from graphrag.cache.memory_pipeline_cache import InMemoryCache |
20 | 22 | from graphrag.cache.noop_pipeline_cache import NoopPipelineCache |
21 | 23 |
|
22 | 24 |
|
| 25 | +def create_noop_cache(**_kwargs) -> PipelineCache: |
| 26 | + """Create a no-op cache implementation.""" |
| 27 | + return NoopPipelineCache() |
| 28 | + |
| 29 | + |
| 30 | +def create_memory_cache(**_kwargs) -> PipelineCache: |
| 31 | + """Create an in-memory cache implementation.""" |
| 32 | + return InMemoryCache() |
| 33 | + |
| 34 | + |
| 35 | +def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache: |
| 36 | + """Create a file-based cache implementation.""" |
| 37 | + # Create storage with base_dir in kwargs since FilePipelineStorage expects it there |
| 38 | + storage_kwargs = {"base_dir": root_dir, **kwargs} |
| 39 | + storage = FilePipelineStorage(**storage_kwargs).child(base_dir) |
| 40 | + return JsonPipelineCache(storage) |
| 41 | + |
| 42 | + |
| 43 | +def create_blob_cache(**kwargs) -> PipelineCache: |
| 44 | + """Create a blob storage-based cache implementation.""" |
| 45 | + storage = BlobPipelineStorage(**kwargs) |
| 46 | + return JsonPipelineCache(storage) |
| 47 | + |
| 48 | + |
| 49 | +def create_cosmosdb_cache(**kwargs) -> PipelineCache: |
| 50 | + """Create a CosmosDB-based cache implementation.""" |
| 51 | + storage = CosmosDBPipelineStorage(**kwargs) |
| 52 | + return JsonPipelineCache(storage) |
| 53 | + |
| 54 | + |
23 | 55 | class CacheFactory: |
24 | 56 | """A factory class for cache implementations. |
25 | 57 |
|
26 | 58 | Includes a method for users to register a custom cache implementation. |
27 | 59 |
|
28 | | - Configuration arguments are passed to each cache implementation as kwargs (where possible) |
| 60 | + Configuration arguments are passed to each cache implementation as kwargs |
29 | 61 | for individual enforcement of required/optional arguments. |
30 | 62 | """ |
31 | 63 |
|
32 | | - cache_types: ClassVar[dict[str, type]] = {} |
| 64 | + _registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {} |
33 | 65 |
|
34 | 66 | @classmethod |
35 | | - def register(cls, cache_type: str, cache: type): |
36 | | - """Register a custom cache implementation.""" |
37 | | - cls.cache_types[cache_type] = cache |
| 67 | + def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None: |
| 68 | + """Register a custom cache implementation. |
| 69 | +
|
| 70 | + Args: |
| 71 | + cache_type: The type identifier for the cache. |
| 72 | + creator: A callable that creates an instance of the cache. |
| 73 | +
|
| 74 | + Raises |
| 75 | + ------ |
| 76 | + TypeError: If creator is a class type instead of a factory function. |
| 77 | + """ |
| 78 | + if isinstance(creator, type): |
| 79 | + msg = "Registering classes directly is no longer supported. Please provide a factory function instead." |
| 80 | + raise TypeError(msg) |
| 81 | + cls._registry[cache_type] = creator |
38 | 82 |
|
39 | 83 | @classmethod |
40 | 84 | def create_cache( |
41 | 85 | cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict |
42 | 86 | ) -> PipelineCache: |
43 | | - """Create or get a cache from the provided type.""" |
44 | | - if not cache_type: |
45 | | - return NoopPipelineCache() |
46 | | - match cache_type: |
47 | | - case CacheType.none: |
48 | | - return NoopPipelineCache() |
49 | | - case CacheType.memory: |
50 | | - return InMemoryCache() |
51 | | - case CacheType.file: |
52 | | - return JsonPipelineCache( |
53 | | - FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"]) |
54 | | - ) |
55 | | - case CacheType.blob: |
56 | | - return JsonPipelineCache(create_blob_storage(**kwargs)) |
57 | | - case CacheType.cosmosdb: |
58 | | - return JsonPipelineCache(create_cosmosdb_storage(**kwargs)) |
59 | | - case _: |
60 | | - if cache_type in cls.cache_types: |
61 | | - return cls.cache_types[cache_type](**kwargs) |
62 | | - msg = f"Unknown cache type: {cache_type}" |
63 | | - raise ValueError(msg) |
| 87 | + """Create a cache object from the provided type. |
| 88 | +
|
| 89 | + Args: |
| 90 | + cache_type: The type of cache to create. |
| 91 | + root_dir: The root directory for file-based caches. |
| 92 | + kwargs: Additional keyword arguments for the cache constructor. |
| 93 | +
|
| 94 | + Returns |
| 95 | + ------- |
| 96 | + A PipelineCache instance. |
| 97 | +
|
| 98 | + Raises |
| 99 | + ------ |
| 100 | + ValueError: If the cache type is not registered. |
| 101 | + """ |
| 102 | + if not cache_type or cache_type == CacheType.none: |
| 103 | + return create_noop_cache() |
| 104 | + |
| 105 | + type_str = cache_type.value if isinstance(cache_type, CacheType) else cache_type |
| 106 | + |
| 107 | + if type_str not in cls._registry: |
| 108 | + msg = f"Unknown cache type: {cache_type}" |
| 109 | + raise ValueError(msg) |
| 110 | + |
| 111 | + # Add root_dir to kwargs for file cache |
| 112 | + if type_str == CacheType.file.value: |
| 113 | + kwargs = {**kwargs, "root_dir": root_dir} |
| 114 | + |
| 115 | + return cls._registry[type_str](**kwargs) |
| 116 | + |
| 117 | + @classmethod |
| 118 | + def get_cache_types(cls) -> list[str]: |
| 119 | + """Get the registered cache implementations.""" |
| 120 | + return list(cls._registry.keys()) |
| 121 | + |
| 122 | + @classmethod |
| 123 | + def is_supported_type(cls, cache_type: str) -> bool: |
| 124 | + """Check if the given cache type is supported.""" |
| 125 | + return cache_type in cls._registry |
| 126 | + |
| 127 | + |
| 128 | +# --- register built-in cache implementations --- |
| 129 | +CacheFactory.register(CacheType.none.value, create_noop_cache) |
| 130 | +CacheFactory.register(CacheType.memory.value, create_memory_cache) |
| 131 | +CacheFactory.register(CacheType.file.value, create_file_cache) |
| 132 | +CacheFactory.register(CacheType.blob.value, create_blob_cache) |
| 133 | +CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache) |
0 commit comments