11import asyncio
22import os
33import time
4- from dataclasses import dataclass , field
4+ from dataclasses import dataclass
55from typing import Dict , cast
66
77import gradio as gr
1414 NetworkXStorage ,
1515 OpenAIClient ,
1616 Tokenizer ,
17- TraverseStrategy ,
1817)
1918from graphgen .operators import (
2019 chunk_documents ,
4039
4140@dataclass
4241class GraphGen :
43- unique_id : int = int (time .time ())
4442 working_dir : str = os .path .join (sys_path , "cache" )
45- config : Dict = field (default_factory = dict )
43+ output_path : str = os .path .join (
44+ working_dir , "data" , "graphgen" , str (int (time .time ()))
45+ )
4646
4747 # llm
4848 tokenizer_instance : Tokenizer = None
4949 synthesizer_llm_client : OpenAIClient = None
5050 trainee_llm_client : OpenAIClient = None
5151
52- # search
53- search_config : dict = field (
54- default_factory = lambda : {"enabled" : False , "search_types" : ["wikipedia" ]}
55- )
56-
57- # traversal
58- traverse_strategy : TraverseStrategy = None
59-
6052 # webui
6153 progress_bar : gr .Progress = None
6254
6355 def __post_init__ (self ):
6456 self .tokenizer_instance : Tokenizer = Tokenizer (
65- model_name = self . config [ "tokenizer" ]
57+ model_name = os . getenv ( "TOKENIZER_MODEL" )
6658 )
59+
6760 self .synthesizer_llm_client : OpenAIClient = OpenAIClient (
6861 model_name = os .getenv ("SYNTHESIZER_MODEL" ),
6962 api_key = os .getenv ("SYNTHESIZER_API_KEY" ),
@@ -76,12 +69,6 @@ def __post_init__(self):
7669 base_url = os .getenv ("TRAINEE_BASE_URL" ),
7770 tokenizer = self .tokenizer_instance ,
7871 )
79- self .search_config = self .config ["search" ]
80-
81- if "traverse_strategy" in self .config :
82- self .traverse_strategy = TraverseStrategy (
83- ** self .config ["traverse_strategy" ]
84- )
8572
8673 self .full_docs_storage : JsonKVStorage = JsonKVStorage (
8774 self .working_dir , namespace = "full_docs"
@@ -99,24 +86,17 @@ def __post_init__(self):
9986 self .working_dir , namespace = "rephrase"
10087 )
10188 self .qa_storage : JsonListStorage = JsonListStorage (
102- os .path .join (
103- self .working_dir ,
104- "data" ,
105- "graphgen" ,
106- f"{ self .unique_id } _{ self .config ['output_data_type' ]} " ,
107- ),
89+ self .working_dir ,
10890 namespace = "qa" ,
10991 )
11092
11193 @async_to_sync_method
112- async def insert (self ):
94+ async def insert (self , read_config : Dict , split_config : Dict ):
11395 """
11496 insert chunks into the graph
11597 """
116- input_file = self .config ["read" ]["input_file" ]
117-
11898 # Step 1: Read files
119- data = read_files (input_file )
99+ data = read_files (read_config [ " input_file" ] )
120100 if len (data ) == 0 :
121101 logger .warning ("No data to process" )
122102 return
@@ -141,8 +121,8 @@ async def insert(self):
141121
142122 inserting_chunks = await chunk_documents (
143123 new_docs ,
144- self . config [ "split" ] ["chunk_size" ],
145- self . config [ "split" ] ["chunk_overlap" ],
124+ split_config ["chunk_size" ],
125+ split_config ["chunk_overlap" ],
146126 self .tokenizer_instance ,
147127 self .progress_bar ,
148128 )
@@ -178,6 +158,7 @@ async def insert(self):
178158 return
179159
180160 await self ._insert_done ()
161+ return _add_entities_and_relations
181162
182163 async def _insert_done (self ):
183164 tasks = []
@@ -193,14 +174,12 @@ async def _insert_done(self):
193174 await asyncio .gather (* tasks )
194175
195176 @async_to_sync_method
196- async def search (self ):
177+ async def search (self , search_config : Dict ):
197178 logger .info (
198- "Search is %s" , "enabled" if self . search_config ["enabled" ] else "disabled"
179+ "Search is %s" , "enabled" if search_config ["enabled" ] else "disabled"
199180 )
200- if self .search_config ["enabled" ]:
201- logger .info (
202- "[Search] %s ..." , ", " .join (self .search_config ["search_types" ])
203- )
181+ if search_config ["enabled" ]:
182+ logger .info ("[Search] %s ..." , ", " .join (search_config ["search_types" ]))
204183 all_nodes = await self .graph_storage .get_all_nodes ()
205184 all_nodes_names = [node [0 ] for node in all_nodes ]
206185 new_search_entities = await self .full_docs_storage .filter_keys (
@@ -210,7 +189,7 @@ async def search(self):
210189 "[Search] Found %d entities to search" , len (new_search_entities )
211190 )
212191 _add_search_data = await search_all (
213- search_types = self . search_config ["search_types" ],
192+ search_types = search_config ["search_types" ],
214193 search_entities = new_search_entities ,
215194 )
216195 if _add_search_data :
@@ -230,27 +209,37 @@ async def search(self):
230209 await self .insert ()
231210
232211 @async_to_sync_method
233- async def quiz (self ):
234- max_samples = self .config ["quiz_and_judge_strategy" ]["quiz_samples" ]
212+ async def quiz_and_judge (self , quiz_and_judge_config : Dict ):
213+ if quiz_and_judge_config is None or not quiz_and_judge_config .get (
214+ "enabled" , False
215+ ):
216+ logger .warning ("Quiz and Judge is not used in this pipeline." )
217+ return
218+ max_samples = quiz_and_judge_config ["quiz_samples" ]
235219 await quiz (
236220 self .synthesizer_llm_client ,
237221 self .graph_storage ,
238222 self .rephrase_storage ,
239223 max_samples ,
240224 )
241- await self .rephrase_storage .index_done_callback ()
242225
243- @async_to_sync_method
244- async def judge (self ):
245- re_judge = self .config ["quiz_and_judge_strategy" ]["re_judge" ]
226+ # TODO: assert trainee_llm_client is valid before judge
227+ re_judge = quiz_and_judge_config ["re_judge" ]
246228 _update_relations = await judge_statement (
247229 self .trainee_llm_client ,
248230 self .graph_storage ,
249231 self .rephrase_storage ,
250232 re_judge ,
251233 )
234+ await self .rephrase_storage .index_done_callback ()
252235 await _update_relations .index_done_callback ()
253236
237+ @async_to_sync_method
238+ async def generate (self , partition_config : Dict , generate_config : Dict ):
239+ # Step 1: partition the graph
240+ # TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
241+ pass
242+
254243 @async_to_sync_method
255244 async def traverse (self ):
256245 output_data_type = self .config ["output_data_type" ]
0 commit comments