Skip to content

Commit d2aa701

Browse files
feat: calculate average loss
1 parent 0e0bdd0 commit d2aa701

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

judge.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@
1010

1111
load_dotenv()
1212

13+
def calculate_average_loss(graph: NetworkXStorage):
14+
"""
15+
Calculate the average loss of the graph.
16+
17+
:param graph: NetworkXStorage
18+
:return: float
19+
"""
20+
edges = asyncio.run(graph.get_all_edges())
21+
total_loss = 0
22+
for edge in edges:
23+
total_loss += edge[2]['loss']
24+
return total_loss / len(edges)
25+
26+
1327

1428
if __name__ == '__main__':
1529
parser = argparse.ArgumentParser()
@@ -38,3 +52,6 @@
3852
graph_file = asyncio.run(graph_storage.get_graph())
3953

4054
new_graph.write_nx_graph(graph_file, args.output)
55+
56+
average_loss = calculate_average_loss(graph_storage)
57+
print(f"Average loss of the graph: {average_loss}")

scripts/judge.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
python3 evaluate.py --output cache/output/new_graph.graphml \
1+
python3 judge.py --output cache/output/new_graph.graphml \

0 commit comments

Comments
 (0)