@@ -42,8 +42,9 @@ async def create_base_entity_graph(
4242 summarization_strategy : dict [str , Any ] | None = None ,
4343 summarization_num_threads : int = 4 ,
4444 embedding_strategy : dict [str , Any ] | None = None ,
45- graphml_snapshot_enabled : bool = False ,
46- raw_entity_snapshot_enabled : bool = False ,
45+ snapshot_graphml_enabled : bool = False ,
46+ snapshot_raw_entities_enabled : bool = False ,
47+ snapshot_transient_enabled : bool = False ,
4748) -> pd .DataFrame :
4849 """All the steps to create the base entity graph."""
4950 # this returns a graph for each text unit, to be merged later
@@ -92,15 +93,15 @@ async def create_base_entity_graph(
9293 strategy = embedding_strategy ,
9394 )
9495
95- if raw_entity_snapshot_enabled :
96+ if snapshot_raw_entities_enabled :
9697 await snapshot (
9798 entities ,
9899 name = "raw_extracted_entities" ,
99100 storage = storage ,
100101 formats = ["json" ],
101102 )
102103
103- if graphml_snapshot_enabled :
104+ if snapshot_graphml_enabled :
104105 await snapshot_graphml (
105106 merged_graph ,
106107 name = "merged_graph" ,
@@ -131,4 +132,14 @@ async def create_base_entity_graph(
131132 if embedding_strategy :
132133 final_columns .append ("embeddings" )
133134
134- return cast (pd .DataFrame , clustered [final_columns ])
135+ output = cast (pd .DataFrame , clustered [final_columns ])
136+
137+ if snapshot_transient_enabled :
138+ await snapshot (
139+ output ,
140+ name = "create_base_entity_graph" ,
141+ storage = storage ,
142+ formats = ["parquet" ],
143+ )
144+
145+ return output
0 commit comments