11import asyncio
22from tqdm .asyncio import tqdm as tqdm_async
33
4- from models import OpenAIModel , NetworkXStorage , TraverseStrategy , Tokenizer
4+ from models import OpenAIModel , NetworkXStorage , TraverseStrategy , Tokenizer , JsonKVStorage
55from templates import ANSWER_REPHRASING_PROMPT , QUESTION_GENERATION_PROMPT
66from utils import detect_main_language , compute_content_hash , logger
77from graphgen .operators .split_graph import get_batches_with_strategy
@@ -49,6 +49,46 @@ async def handle_node(node: dict) -> dict:
4949 await graph_storage .index_done_callback ()
5050 return new_edges , new_nodes
5151
52+ async def _construct_rephrasing_prompt (_process_nodes : list ,
53+ _process_edges : list ,
54+ _difficulty : str ,
55+ text_chunks_storage : JsonKVStorage ,
56+ add_context : bool = False
57+ ) -> str :
58+ entities = [
59+ f"{ _process_node ['node_id' ]} : { _process_node ['description' ]} " for _process_node in _process_nodes
60+ ]
61+ relations = [
62+ f"{ _process_edge [0 ]} -- { _process_edge [1 ]} : { _process_edge [2 ]['description' ]} "
63+ for _process_edge in _process_edges
64+ ]
65+
66+ entities_str = "\n " .join ([f"{ index + 1 } . { entity } " for index , entity in enumerate (entities )])
67+ relations_str = "\n " .join ([f"{ index + 1 } . { relation } " for index , relation in enumerate (relations )])
68+ language = "Chinese" if detect_main_language (entities_str + relations_str ) == "zh" else "English"
69+
70+ if add_context :
71+ original_ids = ([node ['source_id' ].split ('<SEP>' )[0 ] for node in _process_nodes ] +
72+ [edge [2 ]['source_id' ].split ('<SEP>' )[0 ] for edge in _process_edges ])
73+
74+ original_ids = list (set (original_ids ))
75+ original_text = await text_chunks_storage .get_by_ids (original_ids )
76+ original_text = "\n " .join ([f"{ index + 1 } . { text ['content' ]} " for index , text in enumerate (original_text )])
77+
78+ prompt = ANSWER_REPHRASING_PROMPT [_difficulty ][language ]['CONTEXT_TEMPLATE' ].format (
79+ language = language ,
80+ original_text = original_text ,
81+ entities = entities_str ,
82+ relationships = relations_str
83+ )
84+ return prompt
85+
86+ prompt = ANSWER_REPHRASING_PROMPT [_difficulty ][language ]['TEMPLATE' ].format (
87+ language = language ,
88+ entities = entities_str ,
89+ relationships = relations_str
90+ )
91+ return prompt
5292
5393def get_loss_tercile (losses : list ) -> (float , float ):
5494 losses = sorted (losses )
@@ -61,11 +101,39 @@ def get_average_loss(batch: tuple) -> float:
61101 return sum (edge [2 ]['loss' ] for edge in batch [1 ]) + sum (node ['loss' ] for node in batch [0 ]) / \
62102 (len (batch [0 ]) + len (batch [1 ]))
63103
104+ def _post_process_synthetic_data (data ):
105+ block = data .split ("\n \n " )
106+ qas = []
107+ for line in block :
108+ if "Question:" in line and "Answer:" in line :
109+ question = line .split ("Question:" )[1 ].split ("Answer:" )[0 ].strip ()
110+ answer = line .split ("Answer:" )[1 ].strip ()
111+ qas .append ({
112+ "question" : question ,
113+ "answer" : answer
114+ })
115+ elif "问题:" in line and "答案:" in line :
116+ question = line .split ("问题:" )[1 ].split ("答案:" )[0 ].strip ()
117+ answer = line .split ("答案:" )[1 ].strip ()
118+ qas .append ({
119+ "question" : question ,
120+ "answer" : answer
121+ })
122+ elif "问题:" in line and "回答:" in line :
123+ question = line .split ("问题:" )[1 ].split ("回答:" )[0 ].strip ()
124+ answer = line .split ("回答:" )[1 ].strip ()
125+ qas .append ({
126+ "question" : question ,
127+ "answer" : answer
128+ })
129+ return qas
130+
64131async def traverse_graph_by_edge (
65132 llm_client : OpenAIModel ,
66133 tokenizer : Tokenizer ,
67134 graph_storage : NetworkXStorage ,
68135 traverse_strategy : TraverseStrategy ,
136+ text_chunks_storage : JsonKVStorage ,
69137 max_concurrent : int = 1000
70138) -> dict :
71139 """
@@ -75,6 +143,7 @@ async def traverse_graph_by_edge(
75143 :param tokenizer
76144 :param graph_storage
77145 :param traverse_strategy
146+ :param text_chunks_storage
78147 :param max_concurrent
79148 :return: question and answer
80149 """
@@ -84,26 +153,15 @@ async def traverse_graph_by_edge(
84153 async def _process_nodes_and_edges (
85154 _process_nodes : list ,
86155 _process_edges : list ,
87- _difficulty : str
156+ _difficulty : str ,
88157 ) -> str :
89- entities = [
90- f"{ _process_node ['node_id' ]} : { _process_node ['description' ]} " for _process_node in _process_nodes
91- ]
92- relations = [
93- f"{ _process_edge [0 ]} -- { _process_edge [1 ]} : { _process_edge [2 ]['description' ]} "
94- for _process_edge in _process_edges
95- ]
96-
97- entities_str = "\n " .join ([f"{ index + 1 } . { entity } " for index , entity in enumerate (entities )])
98- relations_str = "\n " .join ([f"{ index + 1 } . { relation } " for index , relation in enumerate (relations )])
99-
100- language = "Chinese" if detect_main_language (entities_str + relations_str ) == "zh" else "English"
101- prompt = ANSWER_REPHRASING_PROMPT [_difficulty ][language ]['TEMPLATE' ].format (
102- language = language ,
103- entities = entities_str ,
104- relationships = relations_str
158+ prompt = await _construct_rephrasing_prompt (
159+ _process_nodes ,
160+ _process_edges ,
161+ _difficulty ,
162+ text_chunks_storage ,
163+ add_context = False
105164 )
106-
107165 context = await llm_client .generate_answer (prompt )
108166
109167 # post-process the context
@@ -115,7 +173,8 @@ async def _process_nodes_and_edges(
115173 return context
116174
117175 async def _process_single_batch (
118- _process_batch : tuple
176+ _process_batch : tuple ,
177+ question_type : str = "single"
119178 ) -> dict :
120179 async with semaphore :
121180 context = await _process_nodes_and_edges (
@@ -125,32 +184,55 @@ async def _process_single_batch(
125184 )
126185
127186 language = "Chinese" if detect_main_language (context ) == "zh" else "English"
128- question = await llm_client .generate_answer (
129- QUESTION_GENERATION_PROMPT [language ]['TEMPLATE' ].format (
130- answer = context
131- )
132- )
133-
134- if question .startswith ("Question:" ):
135- question = question [len ("Question:" ):].strip ()
136- elif question .startswith ("问题:" ):
137- question = question [len ("问题:" ):].strip ()
138-
139187 pre_length = sum (node ['length' ] for node in _process_batch [0 ]) \
140188 + sum (edge [2 ]['length' ] for edge in _process_batch [1 ])
141189
142190 logger .info ("%d nodes and %d edges processed" , len (_process_batch [0 ]), len (_process_batch [1 ]))
143191 logger .info ("Pre-length: %s" , pre_length )
144- logger .info ("Question: %s Answer: %s" , question , context )
145192
146- return {
147- compute_content_hash (context ): {
148- "question" : question ,
149- "answer" : context ,
193+ if question_type == "single" :
194+ question = await llm_client .generate_answer (
195+ QUESTION_GENERATION_PROMPT [language ]['SINGLE_TEMPLATE' ].format (
196+ answer = context
197+ )
198+ )
199+ if question .startswith ("Question:" ):
200+ question = question [len ("Question:" ):].strip ()
201+ elif question .startswith ("问题:" ):
202+ question = question [len ("问题:" ):].strip ()
203+
204+ return {
205+ compute_content_hash (context ): {
206+ "question" : question ,
207+ "answer" : context ,
208+ "loss" : get_average_loss (_process_batch ),
209+ "difficulty" : _process_batch [2 ],
210+ }
211+ }
212+
213+ content = await llm_client .generate_answer (
214+ QUESTION_GENERATION_PROMPT [language ]['MULTI_TEMPLATE' ].format (
215+ doc = context
216+ )
217+ )
218+ qas = _post_process_synthetic_data (content )
219+
220+ if len (qas ) == 0 :
221+ print (content )
222+ logger .error ("Error occurred while processing batch, question or answer is None" )
223+ return {}
224+
225+ final_results = {}
226+ for qa in qas :
227+ logger .info ("Question: %s" , qa ['question' ])
228+ logger .info ("Answer: %s" , qa ['answer' ])
229+ final_results [compute_content_hash (qa ['question' ])] = {
230+ "question" : qa ['question' ],
231+ "answer" : qa ['answer' ],
150232 "loss" : get_average_loss (_process_batch ),
151233 "difficulty" : _process_batch [2 ],
152234 }
153- }
235+ return final_results
154236
155237 results = {}
156238 edges = list (await graph_storage .get_all_edges ())
0 commit comments