Skip to content

Commit 8e39654

Browse files
feat(graphgen): add losses to output
1 parent a65e7b3 commit 8e39654

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

graphgen/operators/traverse_graph.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import asyncio
2+
from tqdm.asyncio import tqdm as tqdm_async
23

34
from models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer
45
from templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT
5-
from utils import detect_main_language, compute_content_hash, logger, create_event_loop
6-
from tqdm.asyncio import tqdm as tqdm_async
7-
from .split_graph import get_batches_with_strategy
6+
from utils import detect_main_language, compute_content_hash, logger
7+
from graphgen.operators.split_graph import get_batches_with_strategy
88

99

1010
async def _pre_tokenize(graph_storage: NetworkXStorage,
@@ -17,25 +17,31 @@ async def handle_edge(edge: tuple) -> tuple:
1717
async with sem:
1818
if 'length' not in edge[2]:
1919
edge[2]['length'] = len(
20-
await asyncio.get_event_loop().run_in_executor(None, tokenizer.encode_string, edge[2]['description']))
20+
await asyncio.get_event_loop().run_in_executor(None,
21+
tokenizer.encode_string,
22+
edge[2]['description']))
2123
return edge
2224

2325
async def handle_node(node: dict) -> dict:
2426
async with sem:
2527
if 'length' not in node[1]:
2628
node[1]['length'] = len(
27-
await asyncio.get_event_loop().run_in_executor(None, tokenizer.encode_string, node[1]['description']))
29+
await asyncio.get_event_loop().run_in_executor(None,
30+
tokenizer.encode_string,
31+
node[1]['description']))
2832
return node
2933

3034
new_edges = []
3135
new_nodes = []
3236

33-
for result in tqdm_async(asyncio.as_completed([handle_edge(edge) for edge in edges]), total=len(edges), desc="Pre-tokenizing edges"):
37+
for result in tqdm_async(asyncio.as_completed([handle_edge(edge) for edge in edges]),
38+
total=len(edges), desc="Pre-tokenizing edges"):
3439
new_edge = await result
3540
await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
3641
new_edges.append(new_edge)
3742

38-
for result in tqdm_async(asyncio.as_completed([handle_node(node) for node in nodes]), total=len(nodes), desc="Pre-tokenizing nodes"):
43+
for result in tqdm_async(asyncio.as_completed([handle_node(node) for node in nodes]),
44+
total=len(nodes), desc="Pre-tokenizing nodes"):
3945
new_node = await result
4046
await graph_storage.update_node(new_node[0], new_node[1])
4147
new_nodes.append(new_node)
@@ -71,7 +77,8 @@ async def _process_nodes_and_edges(
7177
f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
7278
]
7379
relations = [
74-
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}" for _process_edge in _process_edges
80+
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
81+
for _process_edge in _process_edges
7582
]
7683

7784
entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
@@ -115,16 +122,20 @@ async def _process_single_batch(
115122
elif question.startswith("问题:"):
116123
question = question[len("问题:"):].strip()
117124

118-
pre_length = sum([node['length'] for node in _process_batch[0]]) + sum([edge[2]['length'] for edge in _process_batch[1]])
125+
pre_length = sum(node['length'] for node in _process_batch[0]) \
126+
+ sum(edge[2]['length'] for edge in _process_batch[1])
127+
128+
losses = [(edge[0], edge[1], edge[2]['loss']) for edge in _process_batch[1]]
119129

120-
logger.info(f"{len(_process_batch[0])} nodes and {len(_process_batch[1])} edges processed")
121-
logger.info(f"Pre-length: {pre_length}")
122-
logger.info(f"Question: {question} Answer: {context}")
130+
logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
131+
logger.info("Pre-length: %s", pre_length)
132+
logger.info("Question: %s Answer: %s", question, context)
123133

124134
return {
125135
compute_content_hash(context): {
126136
"question": question,
127-
"answer": context
137+
"answer": context,
138+
"losses": losses
128139
}
129140
}
130141

@@ -146,7 +157,7 @@ async def _process_single_batch(
146157
), total=len(processing_batches), desc="Processing batches"):
147158
try:
148159
results.update(await result)
149-
except Exception as e:
160+
except Exception as e: # pylint: disable=broad-except
150161
logger.error("Error occurred while processing batches: %s", e)
151162

152163
return results

0 commit comments

Comments
 (0)