11import asyncio
2+ from tqdm .asyncio import tqdm as tqdm_async
23
34from models import OpenAIModel , NetworkXStorage , TraverseStrategy , Tokenizer
45from templates import ANSWER_REPHRASING_PROMPT , QUESTION_GENERATION_PROMPT
5- from utils import detect_main_language , compute_content_hash , logger , create_event_loop
6- from tqdm .asyncio import tqdm as tqdm_async
7- from .split_graph import get_batches_with_strategy
6+ from utils import detect_main_language , compute_content_hash , logger
7+ from graphgen .operators .split_graph import get_batches_with_strategy
88
99
1010async def _pre_tokenize (graph_storage : NetworkXStorage ,
@@ -17,25 +17,31 @@ async def handle_edge(edge: tuple) -> tuple:
1717 async with sem :
1818 if 'length' not in edge [2 ]:
1919 edge [2 ]['length' ] = len (
20- await asyncio .get_event_loop ().run_in_executor (None , tokenizer .encode_string , edge [2 ]['description' ]))
20+ await asyncio .get_event_loop ().run_in_executor (None ,
21+ tokenizer .encode_string ,
22+ edge [2 ]['description' ]))
2123 return edge
2224
2325 async def handle_node (node : dict ) -> dict :
2426 async with sem :
2527 if 'length' not in node [1 ]:
2628 node [1 ]['length' ] = len (
27- await asyncio .get_event_loop ().run_in_executor (None , tokenizer .encode_string , node [1 ]['description' ]))
29+ await asyncio .get_event_loop ().run_in_executor (None ,
30+ tokenizer .encode_string ,
31+ node [1 ]['description' ]))
2832 return node
2933
3034 new_edges = []
3135 new_nodes = []
3236
33- for result in tqdm_async (asyncio .as_completed ([handle_edge (edge ) for edge in edges ]), total = len (edges ), desc = "Pre-tokenizing edges" ):
37+ for result in tqdm_async (asyncio .as_completed ([handle_edge (edge ) for edge in edges ]),
38+ total = len (edges ), desc = "Pre-tokenizing edges" ):
3439 new_edge = await result
3540 await graph_storage .update_edge (new_edge [0 ], new_edge [1 ], new_edge [2 ])
3641 new_edges .append (new_edge )
3742
38- for result in tqdm_async (asyncio .as_completed ([handle_node (node ) for node in nodes ]), total = len (nodes ), desc = "Pre-tokenizing nodes" ):
43+ for result in tqdm_async (asyncio .as_completed ([handle_node (node ) for node in nodes ]),
44+ total = len (nodes ), desc = "Pre-tokenizing nodes" ):
3945 new_node = await result
4046 await graph_storage .update_node (new_node [0 ], new_node [1 ])
4147 new_nodes .append (new_node )
@@ -71,7 +77,8 @@ async def _process_nodes_and_edges(
7177 f"{ _process_node ['node_id' ]} : { _process_node ['description' ]} " for _process_node in _process_nodes
7278 ]
7379 relations = [
74- f"{ _process_edge [0 ]} -- { _process_edge [1 ]} : { _process_edge [2 ]['description' ]} " for _process_edge in _process_edges
80+ f"{ _process_edge [0 ]} -- { _process_edge [1 ]} : { _process_edge [2 ]['description' ]} "
81+ for _process_edge in _process_edges
7582 ]
7683
7784 entities_str = "\n " .join ([f"{ index + 1 } . { entity } " for index , entity in enumerate (entities )])
@@ -115,16 +122,20 @@ async def _process_single_batch(
115122 elif question .startswith ("问题:" ):
116123 question = question [len ("问题:" ):].strip ()
117124
118- pre_length = sum ([node ['length' ] for node in _process_batch [0 ]]) + sum ([edge [2 ]['length' ] for edge in _process_batch [1 ]])
125+ pre_length = sum (node ['length' ] for node in _process_batch [0 ]) \
126+ + sum (edge [2 ]['length' ] for edge in _process_batch [1 ])
127+
128+ losses = [(edge [0 ], edge [1 ], edge [2 ]['loss' ]) for edge in _process_batch [1 ]]
119129
120- logger .info (f" { len (_process_batch [0 ])} nodes and { len (_process_batch [1 ])} edges processed" )
121- logger .info (f "Pre-length: { pre_length } " )
122- logger .info (f "Question: { question } Answer: { context } " )
130+ logger .info ("%d nodes and %d edges processed" , len (_process_batch [0 ]), len (_process_batch [1 ]))
131+ logger .info ("Pre-length: %s" , pre_length )
132+ logger .info ("Question: %s Answer: %s" , question , context )
123133
124134 return {
125135 compute_content_hash (context ): {
126136 "question" : question ,
127- "answer" : context
137+ "answer" : context ,
138+ "losses" : losses
128139 }
129140 }
130141
@@ -146,7 +157,7 @@ async def _process_single_batch(
146157 ), total = len (processing_batches ), desc = "Processing batches" ):
147158 try :
148159 results .update (await result )
149- except Exception as e :
160+ except Exception as e : # pylint: disable=broad-except
150161 logger .error ("Error occurred while processing batches: %s" , e )
151162
152163 return results
0 commit comments