Skip to content

Commit 8233421

Browse files
Cleanup factory methods (#1482)
* cleanup factory methods to have similar design pattern across codebase * add semversioner file * cleanup logging factory * update developer guide * add comment * typo fix * cleanup reporter terminology * renmae reporter to logger * fix comments * update comment * instantiate factory classes correctly and update index api callback parameter --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent 0440580 commit 8233421

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1249
-1152
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "cleanup and refactor factory classes."
4+
}

DEVELOPING.md

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,56 @@
1010
# Getting Started
1111

1212
## Install Dependencies
13-
14-
```sh
15-
# Install Python dependencies.
13+
```shell
14+
# install python dependencies
1615
poetry install
1716
```
1817

19-
## Executing the Indexing Engine
20-
21-
```sh
18+
## Execute the indexing engine
19+
```shell
2220
poetry run poe index <...args>
2321
```
2422

25-
## Executing Queries
23+
## Execute prompt tuning
24+
```shell
25+
poetry run poe prompt_tune <...args>
26+
```
2627

27-
```sh
28+
## Execute Queries
29+
```shell
2830
poetry run poe query <...args>
2931
```
3032

33+
## Repository Structure
34+
An overview of the repository's top-level folder structure is provided below, detailing the overall design and purpose.
35+
We leverage a factory design pattern where possible, enabling a variety of implementations for each core component of graphrag.
36+
37+
```shell
38+
graphrag
39+
├── api # library API definitions
40+
├── cache # cache module supporting several options
41+
│   └─ factory.py # └─ main entrypoint to create a cache
42+
├── callbacks # a collection of commonly used callback functions
43+
├── cli # library CLI
44+
│   └─ main.py # └─ primary CLI entrypoint
45+
├── config # configuration management
46+
├── index # indexing engine
47+
| └─ run/run.py # main entrypoint to build an index
48+
├── llm # generic llm interfaces
49+
├── logger # logger module supporting several options
50+
│   └─ factory.py # └─ main entrypoint to create a logger
51+
├── model # data model definitions associated with the knowledge graph
52+
├── prompt_tune # prompt tuning module
53+
├── prompts # a collection of all the system prompts used by graphrag
54+
├── query # query engine
55+
├── storage # storage module supporting several options
56+
│   └─ factory.py # └─ main entrypoint to create/load a storage endpoint
57+
├── utils # helper functions used throughout the library
58+
└── vector_stores # vector store module containing a few options
59+
└─ factory.py # └─ main entrypoint to create a vector store
60+
```
61+
Where appropriate, the factories expose a registration method for users to provide their own custom implementations if desired.
62+
3163
## Versioning
3264

3365
We use [semversioner](https://github.com/raulgomis/semversioner) to automate and enforce semantic versioning in the release process. Our CI/CD pipeline checks that all PR's include a json file generated by semversioner. When submitting a PR, please run:

docs/config/env_vars.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ This section controls the storage mechanism used by the pipeline used for export
156156

157157
| Parameter | Description | Type | Required or Optional | Default |
158158
| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | -------------------- | ------- |
159-
| `GRAPHRAG_STORAGE_TYPE` | The type of reporter to use. Options are `file`, `memory`, or `blob` | `str` | optional | `file` |
159+
| `GRAPHRAG_STORAGE_TYPE` | The type of storage to use. Options are `file`, `memory`, or `blob` | `str` | optional | `file` |
160160
| `GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | None |
161161
| `GRAPHRAG_STORAGE_CONNECTION_STRING` | The Azure Storage connection string to use when in `blob` mode. | `str` | optional | None |
162162
| `GRAPHRAG_STORAGE_CONTAINER_NAME` | The Azure Storage container name to use when in `blob` mode. | `str` | optional | None |

graphrag/api/index.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from graphrag.index.create_pipeline_config import create_pipeline_config
1818
from graphrag.index.run import run_pipeline_with_config
1919
from graphrag.index.typing import PipelineRunResult
20-
from graphrag.logging.base import ProgressReporter
20+
from graphrag.logger.base import ProgressLogger
2121

2222

2323
async def build_index(
@@ -26,7 +26,7 @@ async def build_index(
2626
is_resume_run: bool = False,
2727
memory_profile: bool = False,
2828
callbacks: list[WorkflowCallbacks] | None = None,
29-
progress_reporter: ProgressReporter | None = None,
29+
progress_logger: ProgressLogger | None = None,
3030
) -> list[PipelineRunResult]:
3131
"""Run the pipeline with the given configuration.
3232
@@ -42,8 +42,8 @@ async def build_index(
4242
Whether to enable memory profiling.
4343
callbacks : list[WorkflowCallbacks] | None default=None
4444
A list of callbacks to register.
45-
progress_reporter : ProgressReporter | None default=None
46-
The progress reporter.
45+
progress_logger : ProgressLogger | None default=None
46+
The progress logger.
4747
4848
Returns
4949
-------
@@ -60,26 +60,26 @@ async def build_index(
6060
pipeline_cache = (
6161
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
6262
)
63+
# create a pipeline reporter and add to any additional callbacks
6364
# TODO: remove the type ignore once the new config engine has been refactored
64-
callbacks = (
65-
[create_pipeline_reporter(config.reporting, None)] if config.reporting else None # type: ignore
66-
) # type: ignore
65+
callbacks = callbacks or []
66+
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
6767
outputs: list[PipelineRunResult] = []
6868
async for output in run_pipeline_with_config(
6969
pipeline_config,
7070
run_id=run_id,
7171
memory_profile=memory_profile,
7272
cache=pipeline_cache,
7373
callbacks=callbacks,
74-
progress_reporter=progress_reporter,
74+
logger=progress_logger,
7575
is_resume_run=is_resume_run,
7676
is_update_run=is_update_run,
7777
):
7878
outputs.append(output)
79-
if progress_reporter:
79+
if progress_logger:
8080
if output.errors and len(output.errors) > 0:
81-
progress_reporter.error(output.workflow)
81+
progress_logger.error(output.workflow)
8282
else:
83-
progress_reporter.success(output.workflow)
84-
progress_reporter.info(str(output.result))
83+
progress_logger.success(output.workflow)
84+
progress_logger.info(str(output.result))
8585
return outputs

graphrag/api/prompt_tune.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from graphrag.config.models.graph_rag_config import GraphRagConfig
1818
from graphrag.index.llm.load_llm import load_llm
19-
from graphrag.logging.print_progress import PrintProgressReporter
19+
from graphrag.logger.print_progress import PrintProgressLogger
2020
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT
2121
from graphrag.prompt_tune.generator.community_report_rating import (
2222
generate_community_report_rating,
@@ -80,15 +80,15 @@ async def generate_indexing_prompts(
8080
-------
8181
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
8282
"""
83-
reporter = PrintProgressReporter("")
83+
logger = PrintProgressLogger("")
8484

8585
# Retrieve documents
8686
doc_list = await load_docs_in_chunks(
8787
root=root,
8888
config=config,
8989
limit=limit,
9090
select_method=selection_method,
91-
reporter=reporter,
91+
logger=logger,
9292
chunk_size=chunk_size,
9393
n_subset_max=n_subset_max,
9494
k=k,
@@ -103,25 +103,25 @@ async def generate_indexing_prompts(
103103
)
104104

105105
if not domain:
106-
reporter.info("Generating domain...")
106+
logger.info("Generating domain...")
107107
domain = await generate_domain(llm, doc_list)
108-
reporter.info(f"Generated domain: {domain}")
108+
logger.info(f"Generated domain: {domain}") # noqa
109109

110110
if not language:
111-
reporter.info("Detecting language...")
111+
logger.info("Detecting language...")
112112
language = await detect_language(llm, doc_list)
113113

114-
reporter.info("Generating persona...")
114+
logger.info("Generating persona...")
115115
persona = await generate_persona(llm, domain)
116116

117-
reporter.info("Generating community report ranking description...")
117+
logger.info("Generating community report ranking description...")
118118
community_report_ranking = await generate_community_report_rating(
119119
llm, domain=domain, persona=persona, docs=doc_list
120120
)
121121

122122
entity_types = None
123123
if discover_entity_types:
124-
reporter.info("Generating entity types...")
124+
logger.info("Generating entity types...")
125125
entity_types = await generate_entity_types(
126126
llm,
127127
domain=domain,
@@ -130,7 +130,7 @@ async def generate_indexing_prompts(
130130
json_mode=config.llm.model_supports_json or False,
131131
)
132132

133-
reporter.info("Generating entity relationship examples...")
133+
logger.info("Generating entity relationship examples...")
134134
examples = await generate_entity_relationship_examples(
135135
llm,
136136
persona=persona,
@@ -140,7 +140,7 @@ async def generate_indexing_prompts(
140140
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
141141
)
142142

143-
reporter.info("Generating entity extraction prompt...")
143+
logger.info("Generating entity extraction prompt...")
144144
entity_extraction_prompt = create_entity_extraction_prompt(
145145
entity_types=entity_types,
146146
docs=doc_list,
@@ -152,18 +152,18 @@ async def generate_indexing_prompts(
152152
min_examples_required=min_examples_required,
153153
)
154154

155-
reporter.info("Generating entity summarization prompt...")
155+
logger.info("Generating entity summarization prompt...")
156156
entity_summarization_prompt = create_entity_summarization_prompt(
157157
persona=persona,
158158
language=language,
159159
)
160160

161-
reporter.info("Generating community reporter role...")
161+
logger.info("Generating community reporter role...")
162162
community_reporter_role = await generate_community_reporter_role(
163163
llm, domain=domain, persona=persona, docs=doc_list
164164
)
165165

166-
reporter.info("Generating community summarization prompt...")
166+
logger.info("Generating community summarization prompt...")
167167
community_summarization_prompt = create_community_summarization_prompt(
168168
persona=persona,
169169
role=community_reporter_role,

graphrag/api/query.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from collections.abc import AsyncGenerator
2121
from pathlib import Path
22-
from typing import Any
22+
from typing import TYPE_CHECKING, Any
2323

2424
import pandas as pd
2525
from pydantic import validate_call
@@ -29,7 +29,7 @@
2929
community_full_content_embedding,
3030
entity_description_embedding,
3131
)
32-
from graphrag.logging.print_progress import PrintProgressReporter
32+
from graphrag.logger.print_progress import PrintProgressLogger
3333
from graphrag.query.factory import (
3434
get_drift_search_engine,
3535
get_global_search_engine,
@@ -44,13 +44,15 @@
4444
read_indexer_reports,
4545
read_indexer_text_units,
4646
)
47-
from graphrag.query.structured_search.base import SearchResult # noqa: TC001
4847
from graphrag.utils.cli import redact
4948
from graphrag.utils.embeddings import create_collection_name
5049
from graphrag.vector_stores.base import BaseVectorStore
5150
from graphrag.vector_stores.factory import VectorStoreFactory
5251

53-
reporter = PrintProgressReporter("")
52+
if TYPE_CHECKING:
53+
from graphrag.query.structured_search.base import SearchResult
54+
55+
logger = PrintProgressLogger("")
5456

5557

5658
@validate_call(config={"arbitrary_types_allowed": True})
@@ -241,7 +243,7 @@ async def local_search(
241243
TODO: Document any exceptions to expect.
242244
"""
243245
vector_store_args = config.embeddings.vector_store
244-
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
246+
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
245247

246248
description_embedding_store = _get_embedding_store(
247249
config_args=vector_store_args, # type: ignore
@@ -307,7 +309,7 @@ async def local_search_streaming(
307309
TODO: Document any exceptions to expect.
308310
"""
309311
vector_store_args = config.embeddings.vector_store
310-
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
312+
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
311313

312314
description_embedding_store = _get_embedding_store(
313315
config_args=vector_store_args, # type: ignore
@@ -380,7 +382,7 @@ async def drift_search(
380382
TODO: Document any exceptions to expect.
381383
"""
382384
vector_store_args = config.embeddings.vector_store
383-
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
385+
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
384386

385387
description_embedding_store = _get_embedding_store(
386388
config_args=vector_store_args, # type: ignore
@@ -430,7 +432,7 @@ def _get_embedding_store(
430432
collection_name = create_collection_name(
431433
config_args.get("container_name", "default"), embedding_name
432434
)
433-
embedding_store = VectorStoreFactory.get_vector_store(
435+
embedding_store = VectorStoreFactory().create_vector_store(
434436
vector_store_type=vector_store_type,
435437
kwargs={**config_args, "collection_name": collection_name},
436438
)

graphrag/cache/factory.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,53 @@
55

66
from __future__ import annotations
77

8-
from typing import TYPE_CHECKING, cast
8+
from typing import TYPE_CHECKING, ClassVar
99

1010
from graphrag.config.enums import CacheType
1111
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
1212
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
1313

1414
if TYPE_CHECKING:
1515
from graphrag.cache.pipeline_cache import PipelineCache
16-
from graphrag.index.config.cache import (
17-
PipelineBlobCacheConfig,
18-
PipelineCacheConfig,
19-
PipelineFileCacheConfig,
20-
)
2116

2217
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
2318
from graphrag.cache.memory_pipeline_cache import InMemoryCache
2419
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
2520

2621

27-
def create_cache(
28-
config: PipelineCacheConfig | None, root_dir: str | None
29-
) -> PipelineCache:
30-
"""Create a cache from the given config."""
31-
if config is None:
32-
return NoopPipelineCache()
22+
class CacheFactory:
23+
"""A factory class for cache implementations.
3324
34-
match config.type:
35-
case CacheType.none:
25+
Includes a method for users to register a custom cache implementation.
26+
"""
27+
28+
cache_types: ClassVar[dict[str, type]] = {}
29+
30+
@classmethod
31+
def register(cls, cache_type: str, cache: type):
32+
"""Register a custom cache implementation."""
33+
cls.cache_types[cache_type] = cache
34+
35+
@classmethod
36+
def create_cache(
37+
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
38+
) -> PipelineCache:
39+
"""Create or get a cache from the provided type."""
40+
if not cache_type:
3641
return NoopPipelineCache()
37-
case CacheType.memory:
38-
return InMemoryCache()
39-
case CacheType.file:
40-
config = cast("PipelineFileCacheConfig", config)
41-
storage = FilePipelineStorage(root_dir).child(config.base_dir)
42-
return JsonPipelineCache(storage)
43-
case CacheType.blob:
44-
config = cast("PipelineBlobCacheConfig", config)
45-
storage = BlobPipelineStorage(
46-
config.connection_string,
47-
config.container_name,
48-
storage_account_blob_url=config.storage_account_blob_url,
49-
).child(config.base_dir)
50-
return JsonPipelineCache(storage)
51-
case _:
52-
msg = f"Unknown cache type: {config.type}"
53-
raise ValueError(msg)
42+
match cache_type:
43+
case CacheType.none:
44+
return NoopPipelineCache()
45+
case CacheType.memory:
46+
return InMemoryCache()
47+
case CacheType.file:
48+
return JsonPipelineCache(
49+
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
50+
)
51+
case CacheType.blob:
52+
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
53+
case _:
54+
if cache_type in cls.cache_types:
55+
return cls.cache_types[cache_type](**kwargs)
56+
msg = f"Unknown cache type: {cache_type}"
57+
raise ValueError(msg)

0 commit comments

Comments
 (0)