Skip to content

Commit d155d20

Browse files
fix: register storage actors in Driver
1 parent 48192fb commit d155d20

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

graphgen/engine.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from graphgen.bases import Config, Node
1414
from graphgen.utils import logger
15-
from graphgen.common import init_llm
15+
from graphgen.common import init_llm, init_storage
1616

1717
load_dotenv()
1818

@@ -25,6 +25,7 @@ def __init__(
2525
self.functions = functions
2626
self.datasets: Dict[str, ray.data.Dataset] = {}
2727
self.llm_actors = {}
28+
self.storage_actors = {}
2829

2930
ctx = DataContext.get_current()
3031
ctx.enable_rich_progress_bars = False
@@ -54,11 +55,57 @@ def __init__(
5455
logger.info("Ray Dashboard URL: %s", context.dashboard_url)
5556

5657
self._init_llms()
58+
self._init_storage()
5759

5860
def _init_llms(self):
5961
self.llm_actors["synthesizer"] = init_llm("synthesizer")
6062
self.llm_actors["trainee"] = init_llm("trainee")
6163

64+
def _init_storage(self):
65+
kv_namespaces, graph_namespaces = self._scan_storage_requirements()
66+
working_dir = self.global_params["working_dir"]
67+
68+
for node_id in kv_namespaces:
69+
proxy = init_storage(self.global_params["kv_backend"], working_dir, node_id)
70+
self.storage_actors[f"kv_{node_id}"] = proxy
71+
logger.info("Create KV Storage Actor: namespace=%s", node_id)
72+
73+
for ns in graph_namespaces:
74+
proxy = init_storage(self.global_params["graph_backend"], working_dir, ns)
75+
self.storage_actors[f"graph_{ns}"] = proxy
76+
logger.info("Create Graph Storage Actor: namespace=%s", ns)
77+
78+
def _scan_storage_requirements(self) -> tuple[set[str], set[str]]:
79+
kv_namespaces = set()
80+
graph_namespaces = set()
81+
82+
# TODO: Temporarily hard-coded; node storage will be centrally managed later.
83+
for node in self.config.nodes:
84+
op_name = node.op_name
85+
if self._function_needs_param(op_name, "kv_backend"):
86+
kv_namespaces.add(op_name)
87+
if self._function_needs_param(op_name, "graph_backend"):
88+
graph_namespaces.add("graph")
89+
return kv_namespaces, graph_namespaces
90+
91+
def _function_needs_param(self, op_name: str, param_name: str) -> bool:
92+
if op_name not in self.functions:
93+
return False
94+
95+
func = self.functions[op_name]
96+
97+
if inspect.isclass(func):
98+
try:
99+
sig = inspect.signature(func.__init__)
100+
return param_name in sig.parameters
101+
except (ValueError, TypeError):
102+
return False
103+
104+
try:
105+
sig = inspect.signature(func)
106+
return param_name in sig.parameters
107+
except (ValueError, TypeError):
108+
return False
62109

63110
@staticmethod
64111
def _topo_sort(nodes: List[Node]) -> List[Node]:

0 commit comments

Comments
 (0)