Skip to content

Commit 30bdb35

Browse files
authored
Selective embeddings loading (#2035)
* Invert embedding table loading logic * Semver
1 parent 77fb7d9 commit 30bdb35

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "generate_text_embeddings only loads tables if embedding field is specified."
4+
}

graphrag/index/workflows/generate_text_embeddings.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from graphrag.index.typing.workflow import WorkflowFunctionOutput
2727
from graphrag.utils.storage import (
2828
load_table_from_storage,
29-
storage_has_table,
3029
write_table_to_storage,
3130
)
3231

@@ -39,27 +38,35 @@ async def run_workflow(
3938
) -> WorkflowFunctionOutput:
4039
"""All the steps to transform community reports."""
4140
logger.info("Workflow started: generate_text_embeddings")
41+
embedded_fields = config.embed_text.names
42+
logger.info("Embedding the following fields: %s", embedded_fields)
4243
documents = None
4344
relationships = None
4445
text_units = None
4546
entities = None
4647
community_reports = None
47-
if await storage_has_table("documents", context.output_storage):
48+
if document_text_embedding in embedded_fields:
4849
documents = await load_table_from_storage("documents", context.output_storage)
49-
if await storage_has_table("relationships", context.output_storage):
50+
if relationship_description_embedding in embedded_fields:
5051
relationships = await load_table_from_storage(
5152
"relationships", context.output_storage
5253
)
53-
if await storage_has_table("text_units", context.output_storage):
54+
if text_unit_text_embedding in embedded_fields:
5455
text_units = await load_table_from_storage("text_units", context.output_storage)
55-
if await storage_has_table("entities", context.output_storage):
56+
if (
57+
entity_title_embedding in embedded_fields
58+
or entity_description_embedding in embedded_fields
59+
):
5660
entities = await load_table_from_storage("entities", context.output_storage)
57-
if await storage_has_table("community_reports", context.output_storage):
61+
if (
62+
community_title_embedding in embedded_fields
63+
or community_summary_embedding in embedded_fields
64+
or community_full_content_embedding in embedded_fields
65+
):
5866
community_reports = await load_table_from_storage(
5967
"community_reports", context.output_storage
6068
)
6169

62-
embedded_fields = config.embed_text.names
6370
text_embed = get_embedding_settings(config)
6471

6572
output = await generate_text_embeddings(

0 commit comments

Comments
 (0)