2525from graphrag .index .context import PipelineRunContext
2626from graphrag .index .operations .embed_text import embed_text
2727from graphrag .index .typing import WorkflowFunctionOutput
28- from graphrag .storage .pipeline_storage import PipelineStorage
2928from graphrag .utils .storage import load_table_from_storage , write_table_to_storage
3029
3130log = logging .getLogger (__name__ )
@@ -37,114 +36,112 @@ async def run_workflow(
3736 callbacks : WorkflowCallbacks ,
3837) -> WorkflowFunctionOutput :
3938 """All the steps to transform community reports."""
40- final_documents = await load_table_from_storage ("documents" , context .storage )
41- final_relationships = await load_table_from_storage (
42- "relationships" , context .storage
43- )
44- final_text_units = await load_table_from_storage ("text_units" , context .storage )
45- final_entities = await load_table_from_storage ("entities" , context .storage )
46- final_community_reports = await load_table_from_storage (
39+ documents = await load_table_from_storage ("documents" , context .storage )
40+ relationships = await load_table_from_storage ("relationships" , context .storage )
41+ text_units = await load_table_from_storage ("text_units" , context .storage )
42+ entities = await load_table_from_storage ("entities" , context .storage )
43+ community_reports = await load_table_from_storage (
4744 "community_reports" , context .storage
4845 )
4946
5047 embedded_fields = get_embedded_fields (config )
5148 text_embed = get_embedding_settings (config )
5249
53- await generate_text_embeddings (
54- final_documents = final_documents ,
55- final_relationships = final_relationships ,
56- final_text_units = final_text_units ,
57- final_entities = final_entities ,
58- final_community_reports = final_community_reports ,
50+ result = await generate_text_embeddings (
51+ documents = documents ,
52+ relationships = relationships ,
53+ text_units = text_units ,
54+ entities = entities ,
55+ community_reports = community_reports ,
5956 callbacks = callbacks ,
6057 cache = context .cache ,
61- storage = context .storage ,
6258 text_embed_config = text_embed ,
6359 embedded_fields = embedded_fields ,
64- snapshot_embeddings_enabled = config .snapshots .embeddings ,
6560 )
6661
67- return WorkflowFunctionOutput (result = None , config = None )
62+ if config .snapshots .embeddings :
63+ for name , table in result .items ():
64+ await write_table_to_storage (
65+ table ,
66+ f"embeddings.{ name } " ,
67+ context .storage ,
68+ )
69+
70+ return WorkflowFunctionOutput (result = result , config = None )
6871
6972
7073async def generate_text_embeddings (
71- final_documents : pd .DataFrame | None ,
72- final_relationships : pd .DataFrame | None ,
73- final_text_units : pd .DataFrame | None ,
74- final_entities : pd .DataFrame | None ,
75- final_community_reports : pd .DataFrame | None ,
74+ documents : pd .DataFrame | None ,
75+ relationships : pd .DataFrame | None ,
76+ text_units : pd .DataFrame | None ,
77+ entities : pd .DataFrame | None ,
78+ community_reports : pd .DataFrame | None ,
7679 callbacks : WorkflowCallbacks ,
7780 cache : PipelineCache ,
78- storage : PipelineStorage ,
7981 text_embed_config : dict ,
8082 embedded_fields : set [str ],
81- snapshot_embeddings_enabled : bool = False ,
82- ) -> None :
83+ ) -> dict [str , pd .DataFrame ]:
8384 """All the steps to generate all embeddings."""
8485 embedding_param_map = {
8586 document_text_embedding : {
86- "data" : final_documents .loc [:, ["id" , "text" ]]
87- if final_documents is not None
88- else None ,
87+ "data" : documents .loc [:, ["id" , "text" ]] if documents is not None else None ,
8988 "embed_column" : "text" ,
9089 },
9190 relationship_description_embedding : {
92- "data" : final_relationships .loc [:, ["id" , "description" ]]
93- if final_relationships is not None
91+ "data" : relationships .loc [:, ["id" , "description" ]]
92+ if relationships is not None
9493 else None ,
9594 "embed_column" : "description" ,
9695 },
9796 text_unit_text_embedding : {
98- "data" : final_text_units .loc [:, ["id" , "text" ]]
99- if final_text_units is not None
97+ "data" : text_units .loc [:, ["id" , "text" ]]
98+ if text_units is not None
10099 else None ,
101100 "embed_column" : "text" ,
102101 },
103102 entity_title_embedding : {
104- "data" : final_entities .loc [:, ["id" , "title" ]]
105- if final_entities is not None
106- else None ,
103+ "data" : entities .loc [:, ["id" , "title" ]] if entities is not None else None ,
107104 "embed_column" : "title" ,
108105 },
109106 entity_description_embedding : {
110- "data" : final_entities .loc [:, ["id" , "title" , "description" ]].assign (
107+ "data" : entities .loc [:, ["id" , "title" , "description" ]].assign (
111108 title_description = lambda df : df ["title" ] + ":" + df ["description" ]
112109 )
113- if final_entities is not None
110+ if entities is not None
114111 else None ,
115112 "embed_column" : "title_description" ,
116113 },
117114 community_title_embedding : {
118- "data" : final_community_reports .loc [:, ["id" , "title" ]]
119- if final_community_reports is not None
115+ "data" : community_reports .loc [:, ["id" , "title" ]]
116+ if community_reports is not None
120117 else None ,
121118 "embed_column" : "title" ,
122119 },
123120 community_summary_embedding : {
124- "data" : final_community_reports .loc [:, ["id" , "summary" ]]
125- if final_community_reports is not None
121+ "data" : community_reports .loc [:, ["id" , "summary" ]]
122+ if community_reports is not None
126123 else None ,
127124 "embed_column" : "summary" ,
128125 },
129126 community_full_content_embedding : {
130- "data" : final_community_reports .loc [:, ["id" , "full_content" ]]
131- if final_community_reports is not None
127+ "data" : community_reports .loc [:, ["id" , "full_content" ]]
128+ if community_reports is not None
132129 else None ,
133130 "embed_column" : "full_content" ,
134131 },
135132 }
136133
137134 log .info ("Creating embeddings" )
135+ outputs = {}
138136 for field in embedded_fields :
139- await _run_and_snapshot_embeddings (
137+ outputs [ field ] = await _run_and_snapshot_embeddings (
140138 name = field ,
141139 callbacks = callbacks ,
142140 cache = cache ,
143- storage = storage ,
144141 text_embed_config = text_embed_config ,
145- snapshot_embeddings_enabled = snapshot_embeddings_enabled ,
146142 ** embedding_param_map [field ],
147143 )
144+ return outputs
148145
149146
150147async def _run_and_snapshot_embeddings (
@@ -153,21 +150,16 @@ async def _run_and_snapshot_embeddings(
153150 embed_column : str ,
154151 callbacks : WorkflowCallbacks ,
155152 cache : PipelineCache ,
156- storage : PipelineStorage ,
157153 text_embed_config : dict ,
158- snapshot_embeddings_enabled : bool ,
159- ) -> None :
154+ ) -> pd .DataFrame :
160155 """All the steps to generate single embedding."""
161- if text_embed_config :
162- data ["embedding" ] = await embed_text (
163- input = data ,
164- callbacks = callbacks ,
165- cache = cache ,
166- embed_column = embed_column ,
167- embedding_name = name ,
168- strategy = text_embed_config ["strategy" ],
169- )
156+ data ["embedding" ] = await embed_text (
157+ input = data ,
158+ callbacks = callbacks ,
159+ cache = cache ,
160+ embed_column = embed_column ,
161+ embedding_name = name ,
162+ strategy = text_embed_config ["strategy" ],
163+ )
170164
171- if snapshot_embeddings_enabled is True :
172- data = data .loc [:, ["id" , "embedding" ]]
173- await write_table_to_storage (data , f"embeddings.{ name } " , storage )
165+ return data .loc [:, ["id" , "embedding" ]]
0 commit comments