Skip to content

Commit 0790ba4

Browse files
feat: use storage actor
1 parent b400d2e commit 0790ba4

File tree

9 files changed

+266
-30
lines changed

9 files changed

+266
-30
lines changed

graphgen/common/init_storage.py

Lines changed: 245 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,257 @@
1-
from graphgen.models import JsonKVStorage, NetworkXStorage
1+
from typing import Any, Dict, Union
2+
3+
import ray
4+
5+
from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage
6+
7+
8+
class KVStorageActor:
9+
def __init__(self, backend: str, working_dir: str, namespace: str):
10+
if backend == "json_kv":
11+
from graphgen.models import JsonKVStorage
12+
13+
self.kv = JsonKVStorage(working_dir, namespace)
14+
elif backend == "rocksdb":
15+
from graphgen.models import RocksDBKVStorage
16+
17+
self.kv = RocksDBKVStorage(working_dir, namespace)
18+
else:
19+
raise ValueError(f"Unknown KV backend: {backend}")
20+
21+
def data(self) -> Dict[str, Dict]:
22+
return self.kv.data
23+
24+
def all_keys(self) -> list[str]:
25+
return self.kv.all_keys()
26+
27+
def index_done_callback(self):
28+
return self.kv.index_done_callback()
29+
30+
def get_by_id(self, id: str) -> Dict:
31+
return self.kv.get_by_id(id)
32+
33+
def get_by_ids(self, ids: list[str], fields=None) -> list:
34+
return self.kv.get_by_ids(ids, fields)
35+
36+
def get_all(self) -> Dict[str, Dict]:
37+
return self.kv.get_all()
38+
39+
def filter_keys(self, data: list[str]) -> set[str]:
40+
return self.kv.filter_keys(data)
41+
42+
def upsert(self, data: dict) -> dict:
43+
return self.kv.upsert(data)
44+
45+
def drop(self):
46+
return self.kv.drop()
47+
48+
def reload(self):
49+
return self.kv.reload()
50+
51+
52+
class GraphStorageActor:
53+
def __init__(self, backend: str, working_dir: str, namespace: str):
54+
if backend == "networkx":
55+
from graphgen.models import NetworkXStorage
56+
57+
self.graph = NetworkXStorage(working_dir, namespace)
58+
else:
59+
raise ValueError(f"Unknown Graph backend: {backend}")
60+
61+
def index_done_callback(self):
62+
return self.graph.index_done_callback()
63+
64+
def has_node(self, node_id: str) -> bool:
65+
return self.graph.has_node(node_id)
66+
67+
def has_edge(self, source_node_id: str, target_node_id: str):
68+
return self.graph.has_edge(source_node_id, target_node_id)
69+
70+
def node_degree(self, node_id: str) -> int:
71+
return self.graph.node_degree(node_id)
72+
73+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
74+
return self.graph.edge_degree(src_id, tgt_id)
75+
76+
def get_node(self, node_id: str) -> Any:
77+
return self.graph.get_node(node_id)
78+
79+
def update_node(self, node_id: str, node_data: dict[str, str]):
80+
return self.graph.update_node(node_id, node_data)
81+
82+
def get_all_nodes(self) -> Any:
83+
return self.graph.get_all_nodes()
84+
85+
def get_edge(self, source_node_id: str, target_node_id: str):
86+
return self.graph.get_edge(source_node_id, target_node_id)
87+
88+
def update_edge(
89+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
90+
):
91+
return self.graph.update_edge(source_node_id, target_node_id, edge_data)
92+
93+
def get_all_edges(self) -> Any:
94+
return self.graph.get_all_edges()
95+
96+
def get_node_edges(self, source_node_id: str) -> Any:
97+
return self.graph.get_node_edges(source_node_id)
98+
99+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
100+
return self.graph.upsert_node(node_id, node_data)
101+
102+
def upsert_edge(
103+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
104+
):
105+
return self.graph.upsert_edge(source_node_id, target_node_id, edge_data)
106+
107+
def delete_node(self, node_id: str):
108+
return self.graph.delete_node(node_id)
109+
110+
def reload(self):
111+
return self.graph.reload()
112+
113+
114+
def get_actor_handle(name: str):
115+
try:
116+
return ray.get_actor(name)
117+
except ValueError as exc:
118+
raise RuntimeError(
119+
f"Actor {name} not found. Make sure it is created before accessing."
120+
) from exc
121+
122+
123+
class RemoteKVStorageProxy(BaseKVStorage):
124+
def __init__(self, namespace: str):
125+
super().__init__()
126+
self.namespace = namespace
127+
self.actor_name = f"Actor_KV_{namespace}"
128+
self.actor = get_actor_handle(self.actor_name)
129+
130+
def data(self) -> Dict[str, Any]:
131+
return ray.get(self.actor.data.remote())
132+
133+
def all_keys(self) -> list[str]:
134+
return ray.get(self.actor.all_keys.remote())
135+
136+
def index_done_callback(self):
137+
return ray.get(self.actor.index_done_callback.remote())
138+
139+
def get_by_id(self, id: str) -> Union[Any, None]:
140+
return ray.get(self.actor.get_by_id.remote(id))
141+
142+
def get_by_ids(self, ids: list[str], fields=None) -> list[Any]:
143+
return ray.get(self.actor.get_by_ids.remote(ids, fields))
144+
145+
def get_all(self) -> Dict[str, Any]:
146+
return ray.get(self.actor.get_all.remote())
147+
148+
def filter_keys(self, data: list[str]) -> set[str]:
149+
return ray.get(self.actor.filter_keys.remote(data))
150+
151+
def upsert(self, data: Dict[str, Any]):
152+
return ray.get(self.actor.upsert.remote(data))
153+
154+
def drop(self):
155+
return ray.get(self.actor.drop.remote())
156+
157+
def reload(self):
158+
return ray.get(self.actor.reload.remote())
159+
160+
161+
class RemoteGraphStorageProxy(BaseGraphStorage):
162+
def __init__(self, namespace: str):
163+
super().__init__()
164+
self.namespace = namespace
165+
self.actor_name = f"Actor_Graph_{namespace}"
166+
self.actor = get_actor_handle(self.actor_name)
167+
168+
def index_done_callback(self):
169+
return ray.get(self.actor.index_done_callback.remote())
170+
171+
def has_node(self, node_id: str) -> bool:
172+
return ray.get(self.actor.has_node.remote(node_id))
173+
174+
def has_edge(self, source_node_id: str, target_node_id: str):
175+
return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))
176+
177+
def node_degree(self, node_id: str) -> int:
178+
return ray.get(self.actor.node_degree.remote(node_id))
179+
180+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
181+
return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))
182+
183+
def get_node(self, node_id: str) -> Any:
184+
return ray.get(self.actor.get_node.remote(node_id))
185+
186+
def update_node(self, node_id: str, node_data: dict[str, str]):
187+
return ray.get(self.actor.update_node.remote(node_id, node_data))
188+
189+
def get_all_nodes(self) -> Any:
190+
return ray.get(self.actor.get_all_nodes.remote())
191+
192+
def get_edge(self, source_node_id: str, target_node_id: str):
193+
return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))
194+
195+
def update_edge(
196+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
197+
):
198+
return ray.get(
199+
self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
200+
)
201+
202+
def get_all_edges(self) -> Any:
203+
return ray.get(self.actor.get_all_edges.remote())
204+
205+
def get_node_edges(self, source_node_id: str) -> Any:
206+
return ray.get(self.actor.get_node_edges.remote(source_node_id))
207+
208+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
209+
return ray.get(self.actor.upsert_node.remote(node_id, node_data))
210+
211+
def upsert_edge(
212+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
213+
):
214+
return ray.get(
215+
self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
216+
)
217+
218+
def delete_node(self, node_id: str):
219+
return ray.get(self.actor.delete_node.remote(node_id))
220+
221+
def reload(self):
222+
return ray.get(self.actor.reload.remote())
2223

3224

4225
class StorageFactory:
5226
"""
6227
Factory class to create storage instances based on backend.
7-
Supported backends:
8-
kv_storage(key-value storage):
9-
- json_kv: JsonKVStorage
10-
graph_storage:
11-
- networkx: NetworkXStorage (graph storage)
12228
"""
13229

14230
@staticmethod
15231
def create_storage(backend: str, working_dir: str, namespace: str):
16-
if backend == "json_kv":
17-
return JsonKVStorage(working_dir, namespace=namespace)
18-
19-
if backend == "networkx":
20-
return NetworkXStorage(working_dir, namespace=namespace)
21-
22-
raise NotImplementedError(
23-
f"Storage backend '{backend}' is not implemented yet."
24-
)
232+
if backend in ["json_kv", "rocksdb"]:
233+
actor_name = f"Actor_KV_{namespace}"
234+
try:
235+
ray.get_actor(actor_name)
236+
except ValueError:
237+
ray.remote(KVStorageActor).options(
238+
name=actor_name,
239+
lifetime="detached",
240+
get_if_exists=True,
241+
).remote(backend, working_dir, namespace)
242+
return RemoteKVStorageProxy(namespace)
243+
if backend in ["networkx"]:
244+
actor_name = f"Actor_Graph_{namespace}"
245+
try:
246+
ray.get_actor(actor_name)
247+
except ValueError:
248+
ray.remote(GraphStorageActor).options(
249+
name=actor_name,
250+
lifetime="detached",
251+
get_if_exists=True,
252+
).remote(backend, working_dir, namespace)
253+
return RemoteGraphStorageProxy(namespace)
254+
raise ValueError(f"Unknown storage backend: {backend}")
25255

26256

27257
def init_storage(backend: str, working_dir: str, namespace: str):

graphgen/engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import ray.data
99

1010
from graphgen.bases import Config, Node
11+
from graphgen.utils import logger
1112

1213

1314
class Engine:
@@ -26,7 +27,7 @@ def __init__(
2627
log_to_driver=True,
2728
**ray_init_kwargs,
2829
)
29-
print(f"Ray Dashboard URL: {context.dashboard_url}")
30+
logger.info("Ray Dashboard URL: %s", context.dashboard_url)
3031

3132
@staticmethod
3233
def _topo_sort(nodes: List[Node]) -> List[Node]:

graphgen/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@
3232
from .searcher.web.bing_search import BingSearch
3333
from .searcher.web.google_search import GoogleSearch
3434
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
35-
from .storage import JsonKVStorage, NetworkXStorage, RocksDBCache
35+
from .storage import JsonKVStorage, NetworkXStorage, RocksDBCache, RocksDBKVStorage
3636
from .tokenizer import Tokenizer
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from graphgen.models.storage.graph.networkx_storage import NetworkXStorage
22
from graphgen.models.storage.kv.json_storage import JsonKVStorage
3+
from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage
34

45
from .rocksdb_cache import RocksDBCache

graphgen/models/storage/kv/rocksdb_storage.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from rocksdict import Rdict
88

99
from graphgen.bases.base_storage import BaseKVStorage
10-
from graphgen.utils import logger
1110

1211

1312
@dataclass
@@ -18,7 +17,9 @@ class RocksDBKVStorage(BaseKVStorage):
1817
def __post_init__(self):
1918
self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db")
2019
self._db = Rdict(self._db_path)
21-
logger.info("Load KV (RocksDB) %s at %s", self.namespace, self._db_path)
20+
print(
21+
f"RocksDBKVStorage initialized for namespace '{self.namespace}' at '{self._db_path}'"
22+
)
2223

2324
@property
2425
def data(self):
@@ -29,7 +30,7 @@ def all_keys(self) -> List[str]:
2930

3031
def index_done_callback(self):
3132
self._db.flush()
32-
logger.info("RocksDB flushed for %s", self.namespace)
33+
print(f"RocksDB flushed for {self.namespace}")
3334

3435
def get_by_id(self, id: str) -> Any:
3536
return self._db.get(id, None)
@@ -63,7 +64,6 @@ def upsert(self, data: Dict[str, Any]):
6364
if left_data:
6465
for k, v in left_data.items():
6566
self._db[k] = v
66-
6767
# if left_data is very large, it is recommended to use self._db.write_batch() for optimization
6868

6969
return left_data
@@ -72,8 +72,14 @@ def drop(self):
7272
self._db.close()
7373
Rdict.destroy(self._db_path)
7474
self._db = Rdict(self._db_path)
75-
logger.info("Dropped RocksDB %s", self.namespace)
75+
print(f"Dropped RocksDB {self.namespace}")
7676

7777
def close(self):
7878
if self._db:
7979
self._db.close()
80+
81+
def reload(self):
82+
if self._db:
83+
self._db.close()
84+
self._db = Rdict(self._db_path)
85+
print(f"Reloaded RocksDB {self.namespace}")

graphgen/operators/chunk/chunk_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, working_dir: str = "cache", **chunk_kwargs):
4747
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
4848
self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
4949
self.chunk_storage = init_storage(
50-
backend="json_kv",
50+
backend="rocksdb",
5151
working_dir=working_dir,
5252
namespace="chunk",
5353
)

graphgen/operators/partition/partition_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, working_dir: str = "cache", **partition_kwargs):
2626
namespace="graph",
2727
)
2828
self.chunk_storage: BaseKVStorage = init_storage(
29-
backend="json_kv",
29+
backend="rocksdb",
3030
working_dir=working_dir,
3131
namespace="chunk",
3232
)

graphgen/operators/quiz/quiz_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
)
2424
# { _quiz_id: { "description": str, "quizzes": List[Tuple[str, str]] } }
2525
self.quiz_storage: BaseKVStorage = init_storage(
26-
backend="json_kv", working_dir=working_dir, namespace="quiz"
26+
backend="rocksdb", working_dir=working_dir, namespace="quiz"
2727
)
2828
self.generator = QuizGenerator(self.llm_client)
2929
self.concurrency_limit = concurrency_limit

0 commit comments

Comments
 (0)