Skip to content

Commit a6d32d7

Browse files
fix: fix detached llm actors
1 parent 84e9f50 commit a6d32d7

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

graphgen/common/init_llm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import ray
55

66
from graphgen.bases import BaseLLMWrapper
7-
from graphgen.common.init_storage import get_actor_handle
87
from graphgen.models import Tokenizer
98

109

@@ -74,9 +73,9 @@ class LLMServiceProxy(BaseLLMWrapper):
7473
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
7574
"""
7675

77-
def __init__(self, actor_name: str):
76+
def __init__(self, actor_handle: ray.actor.ActorHandle):
7877
super().__init__()
79-
self.actor_handle = get_actor_handle(actor_name)
78+
self.actor_handle = actor_handle
8079
self._create_local_tokenizer()
8180

8281
async def generate_answer(
@@ -128,25 +127,25 @@ def create_llm(
128127

129128
actor_name = f"Actor_LLM_{model_type}"
130129
try:
131-
ray.get_actor(actor_name)
130+
actor_handle = ray.get_actor(actor_name)
131+
print(f"Using existing Ray actor: {actor_name}")
132132
except ValueError:
133133
print(f"Creating Ray actor for LLM {model_type} with backend {backend}.")
134134
num_gpus = float(config.pop("num_gpus", 0))
135-
actor = (
135+
actor_handle = (
136136
ray.remote(LLMServiceActor)
137137
.options(
138138
name=actor_name,
139139
num_gpus=num_gpus,
140-
lifetime="detached",
141140
get_if_exists=True,
142141
)
143142
.remote(backend, config)
144143
)
145144

146145
# wait for actor to be ready
147-
ray.get(actor.ready.remote())
146+
ray.get(actor_handle.ready.remote())
148147

149-
return LLMServiceProxy(actor_name)
148+
return LLMServiceProxy(actor_handle)
150149

151150

152151
def _load_env_group(prefix: str) -> Dict[str, Any]:

graphgen/engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from graphgen.bases import Config, Node
1212
from graphgen.utils import logger
13+
from graphgen.common import init_llm
1314

1415

1516
class Engine:
@@ -20,6 +21,7 @@ def __init__(
2021
self.global_params = self.config.global_params
2122
self.functions = functions
2223
self.datasets: Dict[str, ray.data.Dataset] = {}
24+
self.llm_actors = {}
2325

2426
ctx = DataContext.get_current()
2527
ctx.enable_rich_progress_bars = False
@@ -37,6 +39,13 @@ def __init__(
3739
**ray_init_kwargs,
3840
)
3941
logger.info("Ray Dashboard URL: %s", context.dashboard_url)
42+
43+
self._init_llms()
44+
45+
def _init_llms(self):
46+
self.llm_actors["synthesizer"] = init_llm("synthesizer")
47+
self.llm_actors["trainee"] = init_llm("trainee")
48+
4049

4150
@staticmethod
4251
def _topo_sort(nodes: List[Node]) -> List[Node]:

0 commit comments

Comments
 (0)