Skip to content

Commit c69483b

Browse files
Fix: fix detached actors (#132)
* fix: fix detached llm actors * feat: pass env_vars in engine * fix: make storage's lifetime job type * fix: fix lint problems * fix: register storage actors in Driver * fix: fix lint problem
1 parent 84e9f50 commit c69483b

File tree

4 files changed

+101
-47
lines changed

4 files changed

+101
-47
lines changed

graphgen/common/init_llm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import ray
55

66
from graphgen.bases import BaseLLMWrapper
7-
from graphgen.common.init_storage import get_actor_handle
87
from graphgen.models import Tokenizer
98

109

@@ -74,9 +73,9 @@ class LLMServiceProxy(BaseLLMWrapper):
7473
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
7574
"""
7675

77-
def __init__(self, actor_name: str):
76+
def __init__(self, actor_handle: ray.actor.ActorHandle):
7877
super().__init__()
79-
self.actor_handle = get_actor_handle(actor_name)
78+
self.actor_handle = actor_handle
8079
self._create_local_tokenizer()
8180

8281
async def generate_answer(
@@ -128,25 +127,25 @@ def create_llm(
128127

129128
actor_name = f"Actor_LLM_{model_type}"
130129
try:
131-
ray.get_actor(actor_name)
130+
actor_handle = ray.get_actor(actor_name)
131+
print(f"Using existing Ray actor: {actor_name}")
132132
except ValueError:
133133
print(f"Creating Ray actor for LLM {model_type} with backend {backend}.")
134134
num_gpus = float(config.pop("num_gpus", 0))
135-
actor = (
135+
actor_handle = (
136136
ray.remote(LLMServiceActor)
137137
.options(
138138
name=actor_name,
139139
num_gpus=num_gpus,
140-
lifetime="detached",
141140
get_if_exists=True,
142141
)
143142
.remote(backend, config)
144143
)
145144

146145
# wait for actor to be ready
147-
ray.get(actor.ready.remote())
146+
ray.get(actor_handle.ready.remote())
148147

149-
return LLMServiceProxy(actor_name)
148+
return LLMServiceProxy(actor_handle)
150149

151150

152151
def _load_env_group(prefix: str) -> Dict[str, Any]:

graphgen/common/init_storage.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def drop(self):
4848
def reload(self):
4949
return self.kv.reload()
5050

51+
def ready(self) -> bool:
52+
return True
53+
5154

5255
class GraphStorageActor:
5356
def __init__(self, backend: str, working_dir: str, namespace: str):
@@ -114,22 +117,14 @@ def delete_node(self, node_id: str):
114117
def reload(self):
115118
return self.graph.reload()
116119

117-
118-
def get_actor_handle(name: str):
119-
try:
120-
return ray.get_actor(name)
121-
except ValueError as exc:
122-
raise RuntimeError(
123-
f"Actor {name} not found. Make sure it is created before accessing."
124-
) from exc
120+
def ready(self) -> bool:
121+
return True
125122

126123

127124
class RemoteKVStorageProxy(BaseKVStorage):
128-
def __init__(self, namespace: str):
125+
def __init__(self, actor_handle: ray.actor.ActorHandle):
129126
super().__init__()
130-
self.namespace = namespace
131-
self.actor_name = f"Actor_KV_{namespace}"
132-
self.actor = get_actor_handle(self.actor_name)
127+
self.actor = actor_handle
133128

134129
def data(self) -> Dict[str, Any]:
135130
return ray.get(self.actor.data.remote())
@@ -163,11 +158,9 @@ def reload(self):
163158

164159

165160
class RemoteGraphStorageProxy(BaseGraphStorage):
166-
def __init__(self, namespace: str):
161+
def __init__(self, actor_handle: ray.actor.ActorHandle):
167162
super().__init__()
168-
self.namespace = namespace
169-
self.actor_name = f"Actor_Graph_{namespace}"
170-
self.actor = get_actor_handle(self.actor_name)
163+
self.actor = actor_handle
171164

172165
def index_done_callback(self):
173166
return ray.get(self.actor.index_done_callback.remote())
@@ -235,27 +228,23 @@ class StorageFactory:
235228
def create_storage(backend: str, working_dir: str, namespace: str):
236229
if backend in ["json_kv", "rocksdb"]:
237230
actor_name = f"Actor_KV_{namespace}"
238-
try:
239-
ray.get_actor(actor_name)
240-
except ValueError:
241-
ray.remote(KVStorageActor).options(
242-
name=actor_name,
243-
lifetime="detached",
244-
get_if_exists=True,
245-
).remote(backend, working_dir, namespace)
246-
return RemoteKVStorageProxy(namespace)
247-
if backend in ["networkx", "kuzu"]:
231+
actor_class = KVStorageActor
232+
proxy_class = RemoteKVStorageProxy
233+
elif backend in ["networkx", "kuzu"]:
248234
actor_name = f"Actor_Graph_{namespace}"
249-
try:
250-
ray.get_actor(actor_name)
251-
except ValueError:
252-
ray.remote(GraphStorageActor).options(
253-
name=actor_name,
254-
lifetime="detached",
255-
get_if_exists=True,
256-
).remote(backend, working_dir, namespace)
257-
return RemoteGraphStorageProxy(namespace)
258-
raise ValueError(f"Unknown storage backend: {backend}")
235+
actor_class = GraphStorageActor
236+
proxy_class = RemoteGraphStorageProxy
237+
else:
238+
raise ValueError(f"Unknown storage backend: {backend}")
239+
try:
240+
actor_handle = ray.get_actor(actor_name)
241+
except ValueError:
242+
actor_handle = ray.remote(actor_class).options(
243+
name=actor_name,
244+
get_if_exists=True,
245+
).remote(backend, working_dir, namespace)
246+
ray.get(actor_handle.ready.remote())
247+
return proxy_class(actor_handle)
259248

260249

261250
def init_storage(backend: str, working_dir: str, namespace: str):

graphgen/engine.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
import os
12
import inspect
23
import logging
34
from collections import defaultdict, deque
45
from functools import wraps
56
from typing import Any, Callable, Dict, List, Set
7+
from dotenv import load_dotenv
68

79
import ray
810
import ray.data
911
from ray.data import DataContext
1012

1113
from graphgen.bases import Config, Node
1214
from graphgen.utils import logger
15+
from graphgen.common import init_llm, init_storage
1316

17+
load_dotenv()
1418

1519
class Engine:
1620
def __init__(
@@ -20,6 +24,8 @@ def __init__(
2024
self.global_params = self.config.global_params
2125
self.functions = functions
2226
self.datasets: Dict[str, ray.data.Dataset] = {}
27+
self.llm_actors = {}
28+
self.storage_actors = {}
2329

2430
ctx = DataContext.get_current()
2531
ctx.enable_rich_progress_bars = False
@@ -29,6 +35,16 @@ def __init__(
2935
ctx.enable_tensor_extension_casting = False
3036
ctx._metrics_export_port = 0 # Disable metrics exporter to avoid RpcError
3137

38+
all_env_vars = os.environ.copy()
39+
if "runtime_env" not in ray_init_kwargs:
40+
ray_init_kwargs["runtime_env"] = {}
41+
42+
existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
43+
ray_init_kwargs["runtime_env"]["env_vars"] = {
44+
**all_env_vars,
45+
**existing_env_vars
46+
}
47+
3248
if not ray.is_initialized():
3349
context = ray.init(
3450
ignore_reinit_error=True,
@@ -38,6 +54,59 @@ def __init__(
3854
)
3955
logger.info("Ray Dashboard URL: %s", context.dashboard_url)
4056

57+
self._init_llms()
58+
self._init_storage()
59+
60+
def _init_llms(self):
61+
self.llm_actors["synthesizer"] = init_llm("synthesizer")
62+
self.llm_actors["trainee"] = init_llm("trainee")
63+
64+
def _init_storage(self):
65+
kv_namespaces, graph_namespaces = self._scan_storage_requirements()
66+
working_dir = self.global_params["working_dir"]
67+
68+
for node_id in kv_namespaces:
69+
proxy = init_storage(self.global_params["kv_backend"], working_dir, node_id)
70+
self.storage_actors[f"kv_{node_id}"] = proxy
71+
logger.info("Create KV Storage Actor: namespace=%s", node_id)
72+
73+
for ns in graph_namespaces:
74+
proxy = init_storage(self.global_params["graph_backend"], working_dir, ns)
75+
self.storage_actors[f"graph_{ns}"] = proxy
76+
logger.info("Create Graph Storage Actor: namespace=%s", ns)
77+
78+
def _scan_storage_requirements(self) -> tuple[set[str], set[str]]:
79+
kv_namespaces = set()
80+
graph_namespaces = set()
81+
82+
# TODO: Temporarily hard-coded; node storage will be centrally managed later.
83+
for node in self.config.nodes:
84+
op_name = node.op_name
85+
if self._function_needs_param(op_name, "kv_backend"):
86+
kv_namespaces.add(op_name)
87+
if self._function_needs_param(op_name, "graph_backend"):
88+
graph_namespaces.add("graph")
89+
return kv_namespaces, graph_namespaces
90+
91+
def _function_needs_param(self, op_name: str, param_name: str) -> bool:
92+
if op_name not in self.functions:
93+
return False
94+
95+
func = self.functions[op_name]
96+
97+
if inspect.isclass(func):
98+
try:
99+
sig = inspect.signature(func.__init__)
100+
return param_name in sig.parameters
101+
except (ValueError, TypeError):
102+
return False
103+
104+
try:
105+
sig = inspect.signature(func)
106+
return param_name in sig.parameters
107+
except (ValueError, TypeError):
108+
return False
109+
41110
@staticmethod
42111
def _topo_sort(nodes: List[Node]) -> List[Node]:
43112
id_to_node: Dict[str, Node] = {}

graphgen/run.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import ray
88
import yaml
9-
from dotenv import load_dotenv
109
from ray.data.block import Block
1110
from ray.data.datasource.filename_provider import FilenameProvider
1211

@@ -16,8 +15,6 @@
1615

1716
sys_path = os.path.abspath(os.path.dirname(__file__))
1817

19-
load_dotenv()
20-
2118

2219
def set_working_dir(folder):
2320
os.makedirs(folder, exist_ok=True)

0 commit comments

Comments
 (0)