Skip to content

Commit 5dd9fc5

Browse files
authored
Move embeddings snapshots (#1737)
* Move embedding snapshots to the workflow runner * Semver * Rename input tables
1 parent e0d233f commit 5dd9fc5

File tree

3 files changed

+70
-69
lines changed

3 files changed

+70
-69
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Move embeddings snspshots to the workflow runner."
4+
}

graphrag/index/update/incremental_index.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,19 +145,24 @@ async def update_dataframe_outputs(
145145
progress_logger.info("Updating Text Embeddings")
146146
embedded_fields = get_embedded_fields(config)
147147
text_embed = get_embedding_settings(config)
148-
await generate_text_embeddings(
149-
final_documents=final_documents_df,
150-
final_relationships=merged_relationships_df,
151-
final_text_units=merged_text_units,
152-
final_entities=merged_entities_df,
153-
final_community_reports=merged_community_reports,
148+
result = await generate_text_embeddings(
149+
documents=final_documents_df,
150+
relationships=merged_relationships_df,
151+
text_units=merged_text_units,
152+
entities=merged_entities_df,
153+
community_reports=merged_community_reports,
154154
callbacks=callbacks,
155155
cache=cache,
156-
storage=output_storage,
157156
text_embed_config=text_embed,
158157
embedded_fields=embedded_fields,
159-
snapshot_embeddings_enabled=config.snapshots.embeddings,
160158
)
159+
if config.snapshots.embeddings:
160+
for name, table in result.items():
161+
await write_table_to_storage(
162+
table,
163+
f"embeddings.{name}",
164+
output_storage,
165+
)
161166

162167

163168
async def _update_community_reports(

graphrag/index/workflows/generate_text_embeddings.py

Lines changed: 53 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from graphrag.index.context import PipelineRunContext
2626
from graphrag.index.operations.embed_text import embed_text
2727
from graphrag.index.typing import WorkflowFunctionOutput
28-
from graphrag.storage.pipeline_storage import PipelineStorage
2928
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
3029

3130
log = logging.getLogger(__name__)
@@ -37,114 +36,112 @@ async def run_workflow(
3736
callbacks: WorkflowCallbacks,
3837
) -> WorkflowFunctionOutput:
3938
"""All the steps to transform community reports."""
40-
final_documents = await load_table_from_storage("documents", context.storage)
41-
final_relationships = await load_table_from_storage(
42-
"relationships", context.storage
43-
)
44-
final_text_units = await load_table_from_storage("text_units", context.storage)
45-
final_entities = await load_table_from_storage("entities", context.storage)
46-
final_community_reports = await load_table_from_storage(
39+
documents = await load_table_from_storage("documents", context.storage)
40+
relationships = await load_table_from_storage("relationships", context.storage)
41+
text_units = await load_table_from_storage("text_units", context.storage)
42+
entities = await load_table_from_storage("entities", context.storage)
43+
community_reports = await load_table_from_storage(
4744
"community_reports", context.storage
4845
)
4946

5047
embedded_fields = get_embedded_fields(config)
5148
text_embed = get_embedding_settings(config)
5249

53-
await generate_text_embeddings(
54-
final_documents=final_documents,
55-
final_relationships=final_relationships,
56-
final_text_units=final_text_units,
57-
final_entities=final_entities,
58-
final_community_reports=final_community_reports,
50+
result = await generate_text_embeddings(
51+
documents=documents,
52+
relationships=relationships,
53+
text_units=text_units,
54+
entities=entities,
55+
community_reports=community_reports,
5956
callbacks=callbacks,
6057
cache=context.cache,
61-
storage=context.storage,
6258
text_embed_config=text_embed,
6359
embedded_fields=embedded_fields,
64-
snapshot_embeddings_enabled=config.snapshots.embeddings,
6560
)
6661

67-
return WorkflowFunctionOutput(result=None, config=None)
62+
if config.snapshots.embeddings:
63+
for name, table in result.items():
64+
await write_table_to_storage(
65+
table,
66+
f"embeddings.{name}",
67+
context.storage,
68+
)
69+
70+
return WorkflowFunctionOutput(result=result, config=None)
6871

6972

7073
async def generate_text_embeddings(
71-
final_documents: pd.DataFrame | None,
72-
final_relationships: pd.DataFrame | None,
73-
final_text_units: pd.DataFrame | None,
74-
final_entities: pd.DataFrame | None,
75-
final_community_reports: pd.DataFrame | None,
74+
documents: pd.DataFrame | None,
75+
relationships: pd.DataFrame | None,
76+
text_units: pd.DataFrame | None,
77+
entities: pd.DataFrame | None,
78+
community_reports: pd.DataFrame | None,
7679
callbacks: WorkflowCallbacks,
7780
cache: PipelineCache,
78-
storage: PipelineStorage,
7981
text_embed_config: dict,
8082
embedded_fields: set[str],
81-
snapshot_embeddings_enabled: bool = False,
82-
) -> None:
83+
) -> dict[str, pd.DataFrame]:
8384
"""All the steps to generate all embeddings."""
8485
embedding_param_map = {
8586
document_text_embedding: {
86-
"data": final_documents.loc[:, ["id", "text"]]
87-
if final_documents is not None
88-
else None,
87+
"data": documents.loc[:, ["id", "text"]] if documents is not None else None,
8988
"embed_column": "text",
9089
},
9190
relationship_description_embedding: {
92-
"data": final_relationships.loc[:, ["id", "description"]]
93-
if final_relationships is not None
91+
"data": relationships.loc[:, ["id", "description"]]
92+
if relationships is not None
9493
else None,
9594
"embed_column": "description",
9695
},
9796
text_unit_text_embedding: {
98-
"data": final_text_units.loc[:, ["id", "text"]]
99-
if final_text_units is not None
97+
"data": text_units.loc[:, ["id", "text"]]
98+
if text_units is not None
10099
else None,
101100
"embed_column": "text",
102101
},
103102
entity_title_embedding: {
104-
"data": final_entities.loc[:, ["id", "title"]]
105-
if final_entities is not None
106-
else None,
103+
"data": entities.loc[:, ["id", "title"]] if entities is not None else None,
107104
"embed_column": "title",
108105
},
109106
entity_description_embedding: {
110-
"data": final_entities.loc[:, ["id", "title", "description"]].assign(
107+
"data": entities.loc[:, ["id", "title", "description"]].assign(
111108
title_description=lambda df: df["title"] + ":" + df["description"]
112109
)
113-
if final_entities is not None
110+
if entities is not None
114111
else None,
115112
"embed_column": "title_description",
116113
},
117114
community_title_embedding: {
118-
"data": final_community_reports.loc[:, ["id", "title"]]
119-
if final_community_reports is not None
115+
"data": community_reports.loc[:, ["id", "title"]]
116+
if community_reports is not None
120117
else None,
121118
"embed_column": "title",
122119
},
123120
community_summary_embedding: {
124-
"data": final_community_reports.loc[:, ["id", "summary"]]
125-
if final_community_reports is not None
121+
"data": community_reports.loc[:, ["id", "summary"]]
122+
if community_reports is not None
126123
else None,
127124
"embed_column": "summary",
128125
},
129126
community_full_content_embedding: {
130-
"data": final_community_reports.loc[:, ["id", "full_content"]]
131-
if final_community_reports is not None
127+
"data": community_reports.loc[:, ["id", "full_content"]]
128+
if community_reports is not None
132129
else None,
133130
"embed_column": "full_content",
134131
},
135132
}
136133

137134
log.info("Creating embeddings")
135+
outputs = {}
138136
for field in embedded_fields:
139-
await _run_and_snapshot_embeddings(
137+
outputs[field] = await _run_and_snapshot_embeddings(
140138
name=field,
141139
callbacks=callbacks,
142140
cache=cache,
143-
storage=storage,
144141
text_embed_config=text_embed_config,
145-
snapshot_embeddings_enabled=snapshot_embeddings_enabled,
146142
**embedding_param_map[field],
147143
)
144+
return outputs
148145

149146

150147
async def _run_and_snapshot_embeddings(
@@ -153,21 +150,16 @@ async def _run_and_snapshot_embeddings(
153150
embed_column: str,
154151
callbacks: WorkflowCallbacks,
155152
cache: PipelineCache,
156-
storage: PipelineStorage,
157153
text_embed_config: dict,
158-
snapshot_embeddings_enabled: bool,
159-
) -> None:
154+
) -> pd.DataFrame:
160155
"""All the steps to generate single embedding."""
161-
if text_embed_config:
162-
data["embedding"] = await embed_text(
163-
input=data,
164-
callbacks=callbacks,
165-
cache=cache,
166-
embed_column=embed_column,
167-
embedding_name=name,
168-
strategy=text_embed_config["strategy"],
169-
)
156+
data["embedding"] = await embed_text(
157+
input=data,
158+
callbacks=callbacks,
159+
cache=cache,
160+
embed_column=embed_column,
161+
embedding_name=name,
162+
strategy=text_embed_config["strategy"],
163+
)
170164

171-
if snapshot_embeddings_enabled is True:
172-
data = data.loc[:, ["id", "embedding"]]
173-
await write_table_to_storage(data, f"embeddings.{name}", storage)
165+
return data.loc[:, ["id", "embedding"]]

0 commit comments

Comments
 (0)