Skip to content

Commit 1bb9fa8

Browse files
authored
Unified factory (#2105)
* Simplify Factory interface * Migrate CacheFactory to standard base class * Migrate LoggerFactory to standard base class * Migrate StorageFactory to standard base class * Migrate VectorStoreFactory to standard base class * Update vector store example notebook * Delete notebook outputs * Move default providers into factories * Move retry/limit tests into integ * Split language model factories * Set smoke test tpm/rpm * Fix factory integ tests * Add method to smoke test, switch text to 'fast' * Fix text smoke config for fast workflow * Add new workflows to text smoke test * Convert input readers to a proper factory * Remove covariates from fast smoke test * Update docs for input factory * Bump smoke runtime * Even longer runtime * min-csv timeout * Remove unnecessary lambdas
1 parent 0436405 commit 1bb9fa8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+520
-823
lines changed

docs/config/models.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,20 @@ Many users have used platforms such as [ollama](https://ollama.com/) and [LiteLL
9595

9696
### Model Protocol
9797

98-
As of GraphRAG 2.0.0, we support model injection through the use of a standard chat and embedding Protocol and an accompanying ModelFactory that you can use to register your model implementation. This is not supported with the CLI, so you'll need to use GraphRAG as a library.
98+
As of GraphRAG 2.0.0, we support model injection through the use of a standard chat and embedding Protocol and an accompanying factories that you can use to register your model implementation. This is not supported with the CLI, so you'll need to use GraphRAG as a library.
9999

100100
- Our Protocol is [defined here](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/protocol/base.py)
101101
- We have a simple mock implementation in our tests that you can [reference here](https://github.com/microsoft/graphrag/blob/main/tests/mock_provider.py)
102102

103-
Once you have a model implementation, you need to register it with our ModelFactory:
103+
Once you have a model implementation, you need to register it with our ChatModelFactory or EmbeddingModelFactory:
104104

105105
```python
106106
class MyCustomModel:
107107
...
108108
# implementation
109109
110110
# elsewhere...
111-
ModelFactory.register_chat("my-custom-chat-model", lambda **kwargs: MyCustomModel(**kwargs))
111+
ChatModelFactory.register("my-custom-chat-model", lambda **kwargs: MyCustomModel(**kwargs))
112112
```
113113

114114
Then in your config you can reference the type name you used:

docs/examples_notebooks/custom_vector_store.ipynb

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,19 @@
155155
" self.connected = True\n",
156156
" print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n",
157157
"\n",
158-
" def load_documents(\n",
159-
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
160-
" ) -> None:\n",
158+
" def create_index(self) -> None:\n",
159+
" \"\"\"Create index in the vector store (no-op for in-memory store).\"\"\"\n",
160+
" self.documents.clear()\n",
161+
" self.vectors.clear()\n",
162+
"\n",
163+
" print(f\"✅ Index '{self.index_name}' is ready in in-memory vector store\")\n",
164+
"\n",
165+
" def load_documents(self, documents: list[VectorStoreDocument]) -> None:\n",
161166
" \"\"\"Load documents into the vector store.\"\"\"\n",
162167
" if not self.connected:\n",
163168
" msg = \"Vector store not connected. Call connect() first.\"\n",
164169
" raise RuntimeError(msg)\n",
165170
"\n",
166-
" if overwrite:\n",
167-
" self.documents.clear()\n",
168-
" self.vectors.clear()\n",
169-
"\n",
170171
" loaded_count = 0\n",
171172
" for doc in documents:\n",
172173
" if doc.vector is not None:\n",
@@ -230,13 +231,6 @@
230231
" # Use vector search with the embedding\n",
231232
" return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n",
232233
"\n",
233-
" def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:\n",
234-
" \"\"\"Build a query filter to filter documents by id.\n",
235-
"\n",
236-
" For this simple implementation, we return the list of IDs as the filter.\n",
237-
" \"\"\"\n",
238-
" return [str(id_) for id_ in include_ids]\n",
239-
"\n",
240234
" def search_by_id(self, id: str) -> VectorStoreDocument:\n",
241235
" \"\"\"Search for a document by id.\"\"\"\n",
242236
" doc_id = str(id)\n",
@@ -281,15 +275,15 @@
281275
"CUSTOM_VECTOR_STORE_TYPE = \"simple_memory\"\n",
282276
"\n",
283277
"# Register the vector store class\n",
284-
"VectorStoreFactory.register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
278+
"VectorStoreFactory().register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
285279
"\n",
286280
"print(f\"✅ Registered custom vector store with type: '{CUSTOM_VECTOR_STORE_TYPE}'\")\n",
287281
"\n",
288282
"# Verify registration\n",
289-
"available_types = VectorStoreFactory.get_vector_store_types()\n",
283+
"available_types = VectorStoreFactory().keys()\n",
290284
"print(f\"\\n📋 Available vector store types: {available_types}\")\n",
291285
"print(\n",
292-
" f\"🔍 Is our custom type supported? {VectorStoreFactory.is_supported_type(CUSTOM_VECTOR_STORE_TYPE)}\"\n",
286+
" f\"🔍 Is our custom type supported? {CUSTOM_VECTOR_STORE_TYPE in VectorStoreFactory()}\"\n",
293287
")"
294288
]
295289
},
@@ -347,8 +341,8 @@
347341
"schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n",
348342
"\n",
349343
"# Create vector store instance using factory\n",
350-
"vector_store = VectorStoreFactory.create_vector_store(\n",
351-
" CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n",
344+
"vector_store = VectorStoreFactory().create(\n",
345+
" CUSTOM_VECTOR_STORE_TYPE, {\"vector_store_schema_config\": schema}\n",
352346
")\n",
353347
"\n",
354348
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
@@ -363,6 +357,7 @@
363357
"source": [
364358
"# Connect and load documents\n",
365359
"vector_store.connect()\n",
360+
"vector_store.create_index()\n",
366361
"vector_store.load_documents(sample_documents)\n",
367362
"\n",
368363
"print(f\"📊 Updated stats: {vector_store.get_stats()}\")"
@@ -472,13 +467,12 @@
472467
" # 1. GraphRAG creates vector store using factory\n",
473468
" schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n",
474469
"\n",
475-
" store = VectorStoreFactory.create_vector_store(\n",
470+
" store = VectorStoreFactory().create(\n",
476471
" CUSTOM_VECTOR_STORE_TYPE,\n",
477-
" vector_store_schema_config=schema,\n",
478-
" similarity_threshold=0.3,\n",
472+
" {\"vector_store_schema_config\": schema, \"similarity_threshold\": 0.3},\n",
479473
" )\n",
480474
" store.connect()\n",
481-
"\n",
475+
" store.create_index()\n",
482476
" print(\"✅ Step 1: Vector store created and connected\")\n",
483477
"\n",
484478
" # 2. During indexing, GraphRAG loads extracted entities\n",
@@ -534,12 +528,12 @@
534528
"\n",
535529
" # Test 1: Basic functionality\n",
536530
" print(\"Test 1: Basic functionality\")\n",
537-
" store = VectorStoreFactory.create_vector_store(\n",
531+
" store = VectorStoreFactory().create(\n",
538532
" CUSTOM_VECTOR_STORE_TYPE,\n",
539-
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n",
533+
" {\"vector_store_schema_config\": VectorStoreSchemaConfig(index_name=\"test\")},\n",
540534
" )\n",
541535
" store.connect()\n",
542-
"\n",
536+
" store.create_index()\n",
543537
" # Load test documents\n",
544538
" test_docs = sample_documents[:2]\n",
545539
" store.load_documents(test_docs)\n",
@@ -575,17 +569,11 @@
575569
"\n",
576570
" print(\"✅ Search by ID test passed\")\n",
577571
"\n",
578-
" # Test 4: Filter functionality\n",
579-
" print(\"\\nTest 4: Filter functionality\")\n",
580-
" filter_result = store.filter_by_id([\"doc_1\", \"doc_2\"])\n",
581-
" assert filter_result == [\"doc_1\", \"doc_2\"], \"Should return filtered IDs\"\n",
582-
" print(\"✅ Filter functionality test passed\")\n",
583-
"\n",
584-
" # Test 5: Error handling\n",
572+
" # Test 4: Error handling\n",
585573
" print(\"\\nTest 5: Error handling\")\n",
586-
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
574+
" disconnected_store = VectorStoreFactory().create(\n",
587575
" CUSTOM_VECTOR_STORE_TYPE,\n",
588-
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n",
576+
" {\"vector_store_schema_config\": VectorStoreSchemaConfig(index_name=\"test2\")},\n",
589577
" )\n",
590578
"\n",
591579
" try:\n",

docs/index/architecture.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Several subsystems within GraphRAG use a factory pattern to register and retriev
3939
The following subsystems use a factory pattern that allows you to register your own implementations:
4040

4141
- [language model](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/factory.py) - implement your own `chat` and `embed` methods to use a model provider of choice beyond the built-in OpenAI/Azure support
42+
- [input reader](https://github.com/microsoft/graphrag/blob/main/graphrag/index/input/factory.py) - implement your own input document reader to support file types other than text, CSV, and JSON
4243
- [cache](https://github.com/microsoft/graphrag/blob/main/graphrag/cache/factory.py) - create your own cache storage location in addition to the file, blob, and CosmosDB ones we provide
4344
- [logger](https://github.com/microsoft/graphrag/blob/main/graphrag/logger/factory.py) - create your own log writing location in addition to the built-in file and blob storage
4445
- [storage](https://github.com/microsoft/graphrag/blob/main/graphrag/storage/factory.py) - create your own storage provider (database, etc.) beyond the file, blob, and CosmosDB ones built in

docs/index/inputs.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ Also see the [outputs](outputs.md) documentation for the final documents table s
2020

2121
As of version 2.6.0, GraphRAG's [indexing API method](https://github.com/microsoft/graphrag/blob/main/graphrag/api/index.py) allows you to pass in your own pandas DataFrame and bypass all of the input loading/parsing described in the next section. This is convenient if you have content in a format or storage location we don't support out-of-the-box. __You must ensure that your input DataFrame conforms to the schema described above.__ All of the chunking behavior described later will proceed exactly the same.
2222

23+
## Custom File Handling
24+
25+
As of version 3.0.0, we have migrated to using an injectable InputReader provider class. This means you can implement any input file handling you want in a class that extends InputReader and register it with the InputReaderFactory. See the [architecture page](https://microsoft.github.io/graphrag/index/architecture/) for more info on our standard provider pattern.
26+
2327
## Formats
2428

2529
We support three file formats out-of-the-box. This covers the overwhelming majority of use cases we have encountered. If you have a different format, we recommend writing a script to convert to one of these, which are widely used and supported by many tools and libraries.

graphrag/cache/factory.py

Lines changed: 9 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,18 @@
55

66
from __future__ import annotations
77

8-
from typing import TYPE_CHECKING, ClassVar
9-
108
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
119
from graphrag.cache.memory_pipeline_cache import InMemoryCache
1210
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
11+
from graphrag.cache.pipeline_cache import PipelineCache
1312
from graphrag.config.enums import CacheType
13+
from graphrag.factory.factory import Factory
1414
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
1515
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
1616
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
1717

18-
if TYPE_CHECKING:
19-
from collections.abc import Callable
20-
21-
from graphrag.cache.pipeline_cache import PipelineCache
2218

23-
24-
class CacheFactory:
19+
class CacheFactory(Factory[PipelineCache]):
2520
"""A factory class for cache implementations.
2621
2722
Includes a method for users to register a custom cache implementation.
@@ -30,51 +25,6 @@ class CacheFactory:
3025
for individual enforcement of required/optional arguments.
3126
"""
3227

33-
_registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {}
34-
35-
@classmethod
36-
def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None:
37-
"""Register a custom cache implementation.
38-
39-
Args:
40-
cache_type: The type identifier for the cache.
41-
creator: A class or callable that creates an instance of PipelineCache.
42-
"""
43-
cls._registry[cache_type] = creator
44-
45-
@classmethod
46-
def create_cache(cls, cache_type: str, kwargs: dict) -> PipelineCache:
47-
"""Create a cache object from the provided type.
48-
49-
Args:
50-
cache_type: The type of cache to create.
51-
root_dir: The root directory for file-based caches.
52-
kwargs: Additional keyword arguments for the cache constructor.
53-
54-
Returns
55-
-------
56-
A PipelineCache instance.
57-
58-
Raises
59-
------
60-
ValueError: If the cache type is not registered.
61-
"""
62-
if cache_type not in cls._registry:
63-
msg = f"Unknown cache type: {cache_type}"
64-
raise ValueError(msg)
65-
66-
return cls._registry[cache_type](**kwargs)
67-
68-
@classmethod
69-
def get_cache_types(cls) -> list[str]:
70-
"""Get the registered cache implementations."""
71-
return list(cls._registry.keys())
72-
73-
@classmethod
74-
def is_supported_type(cls, cache_type: str) -> bool:
75-
"""Check if the given cache type is supported."""
76-
return cache_type in cls._registry
77-
7828

7929
# --- register built-in cache implementations ---
8030
def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache:
@@ -108,8 +58,9 @@ def create_memory_cache(**kwargs) -> PipelineCache:
10858

10959

11060
# --- register built-in cache implementations ---
111-
CacheFactory.register(CacheType.none.value, create_noop_cache)
112-
CacheFactory.register(CacheType.memory.value, create_memory_cache)
113-
CacheFactory.register(CacheType.file.value, create_file_cache)
114-
CacheFactory.register(CacheType.blob.value, create_blob_cache)
115-
CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)
61+
cache_factory = CacheFactory()
62+
cache_factory.register(CacheType.none.value, create_noop_cache)
63+
cache_factory.register(CacheType.memory.value, create_memory_cache)
64+
cache_factory.register(CacheType.file.value, create_file_cache)
65+
cache_factory.register(CacheType.blob.value, create_blob_cache)
66+
cache_factory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)

graphrag/config/defaults.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
"""Common default configuration values."""
55

6-
from collections.abc import Callable
76
from dataclasses import dataclass, field
87
from pathlib import Path
98
from typing import ClassVar
@@ -24,25 +23,6 @@
2423
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
2524
EN_STOP_WORDS,
2625
)
27-
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
28-
RateLimiter,
29-
)
30-
from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import (
31-
StaticRateLimiter,
32-
)
33-
from graphrag.language_model.providers.litellm.services.retry.exponential_retry import (
34-
ExponentialRetry,
35-
)
36-
from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import (
37-
IncrementalWaitRetry,
38-
)
39-
from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import (
40-
NativeRetry,
41-
)
42-
from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import (
43-
RandomWaitRetry,
44-
)
45-
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
4626

4727
DEFAULT_OUTPUT_BASE_DIR = "output"
4828
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
@@ -60,17 +40,6 @@
6040

6141
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
6242

63-
DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
64-
"native": NativeRetry,
65-
"exponential_backoff": ExponentialRetry,
66-
"random_wait": RandomWaitRetry,
67-
"incremental_wait": IncrementalWaitRetry,
68-
}
69-
70-
DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = {
71-
"static": StaticRateLimiter,
72-
}
73-
7443

7544
@dataclass
7645
class BasicSearchDefaults:

graphrag/config/models/graph_rag_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ def _validate_retry_services(self) -> None:
104104

105105
_ = retry_factory.create(
106106
strategy=model.retry_strategy,
107-
max_retries=model.max_retries,
108-
max_retry_wait=model.max_retry_wait,
107+
init_args={
108+
"max_retries": model.max_retries,
109+
"max_retry_wait": model.max_retry_wait,
110+
},
109111
)
110112

111113
def _validate_rate_limiter_services(self) -> None:
@@ -130,7 +132,8 @@ def _validate_rate_limiter_services(self) -> None:
130132
)
131133
if rpm is not None or tpm is not None:
132134
_ = rate_limiter_factory.create(
133-
strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm
135+
strategy=model.rate_limit_strategy,
136+
init_args={"rpm": rpm, "tpm": tpm},
134137
)
135138

136139
input: InputConfig = Field(

graphrag/config/models/language_model_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
AzureApiVersionMissingError,
1616
ConflictingSettingsError,
1717
)
18-
from graphrag.language_model.factory import ModelFactory
18+
from graphrag.language_model.factory import ChatModelFactory, EmbeddingModelFactory
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -91,8 +91,11 @@ def _validate_type(self) -> None:
9191
If the model name is not recognized.
9292
"""
9393
# Type should be contained by the registered models
94-
if not ModelFactory.is_supported_model(self.type):
95-
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
94+
if (
95+
self.type not in ChatModelFactory()
96+
and self.type not in EmbeddingModelFactory()
97+
):
98+
msg = f"Model type {self.type} is not recognized, must be one of {ChatModelFactory().keys() + EmbeddingModelFactory().keys()}."
9699
raise KeyError(msg)
97100

98101
model_provider: str | None = Field(

0 commit comments

Comments
 (0)