Skip to content

Commit c494093

Browse files
fix(graphgen): add unique_id for log
1 parent 2623b15 commit c494093

File tree

5 files changed

+16
-19
lines changed

5 files changed

+16
-19
lines changed

baselines/Genie/genie.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# https://arxiv.org/pdf/2401.14367
22
import os
33
import json
4-
from dotenv import load_dotenv
54
import argparse
65
import asyncio
76

87
from dataclasses import dataclass
8+
from dotenv import load_dotenv
99
from models import OpenAIModel
1010
from typing import List
1111
from utils import create_event_loop, compute_content_hash
@@ -58,7 +58,7 @@ def generate(self, docs: List[List[dict]]) -> List[dict]:
5858
loop = create_event_loop()
5959
return loop.run_until_complete(self.async_generate(docs))
6060

61-
async def async_generate(self, docs: List[List[dict]]) -> List[dict]:
61+
async def async_generate(self, docs: List[List[dict]]) -> dict:
6262
final_results = {}
6363
semaphore = asyncio.Semaphore(self.max_concurrent)
6464

generate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import json
3+
import time
34
import argparse
45
from dotenv import load_dotenv
56

@@ -8,7 +9,8 @@
89
from utils import set_logger
910

1011
sys_path = os.path.abspath(os.path.dirname(__file__))
11-
set_logger(os.path.join(sys_path, "cache", "logs", "graphgen.log"), if_stream=False)
12+
unique_id = int(time.time())
13+
set_logger(os.path.join(sys_path, "cache", "logs", f"graphgen_{unique_id}.log"), if_stream=False)
1214

1315
load_dotenv()
1416

@@ -58,6 +60,7 @@
5860
traverse_strategy = TraverseStrategy()
5961

6062
graph_gen = GraphGen(
63+
unique_id=unique_id,
6164
teacher_llm_client=teacher_llm_client,
6265
student_llm_client=student_llm_client,
6366
if_web_search=args.web_search,

graphgen/graphgen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# https://github.com/HKUDS/LightRAG
22

33
import os
4-
import time
54
import asyncio
5+
import time
6+
from typing import List, cast, Union
7+
from dataclasses import dataclass
68
from tqdm.asyncio import tqdm as tqdm_async
79

810
from .operators import *
911
from models import Chunk, JsonKVStorage, OpenAIModel, NetworkXStorage, WikiSearch, Tokenizer, TraverseStrategy
10-
from typing import List, cast, Union
11-
12-
from dataclasses import dataclass
1312
from utils import create_event_loop, logger, compute_content_hash
1413
from models.storage.base_storage import StorageNameSpace
1514

@@ -19,6 +18,7 @@
1918

2019
@dataclass
2120
class GraphGen:
21+
unique_id: int = int(time.time())
2222
working_dir: str = os.path.join(sys_path, "cache")
2323
full_docs_storage: JsonKVStorage = JsonKVStorage(
2424
working_dir, namespace="full_docs"
@@ -33,7 +33,7 @@ class GraphGen:
3333
working_dir, namespace="graph"
3434
)
3535
qa_storage: JsonKVStorage = JsonKVStorage(
36-
os.path.join(working_dir, "data", "graphgen"), namespace=f"qa-{int(time.time())}"
36+
os.path.join(working_dir, "data", "graphgen"), namespace=f"qa-{unique_id}"
3737
)
3838

3939
# text chunking

models/storage/networkx_storage.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import os
22
import html
33
import networkx as nx
4-
import numpy as np
54

6-
from typing import Any, Union, cast
5+
from typing import Any, Union, cast, Optional
76
from dataclasses import dataclass
7+
88
from .base_storage import BaseGraphStorage
99
from utils import logger
1010

1111
@dataclass
1212
class NetworkXStorage(BaseGraphStorage):
1313
@staticmethod
14-
def load_nx_graph(file_name) -> nx.Graph:
14+
def load_nx_graph(file_name) -> Optional[nx.Graph]:
1515
if os.path.exists(file_name):
1616
return nx.read_graphml(file_name)
1717
return None
@@ -85,9 +85,6 @@ def __post_init__(self):
8585
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
8686
)
8787
self._graph = preloaded_graph or nx.Graph()
88-
# self._node_embed_algorithms = {
89-
# "node2vec": self._node2vec_embed,
90-
# }
9188

9289
async def index_done_callback(self):
9390
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
@@ -154,8 +151,3 @@ async def delete_node(self, node_id: str):
154151
logger.info(f"Node {node_id} deleted from the graph.")
155152
else:
156153
logger.warning(f"Node {node_id} not found in the graph for deletion.")
157-
158-
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
159-
if algorithm not in self._node_embed_algorithms:
160-
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
161-
return await self._node_embed_algorithms[algorithm]()

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ jieba
1414
torch
1515
plotly
1616
pandas
17+
gradio
18+
kaleido

0 commit comments

Comments
 (0)