1+ import os
12import inspect
23import logging
34from collections import defaultdict , deque
45from functools import wraps
56from typing import Any , Callable , Dict , List , Set
7+ from dotenv import load_dotenv
68
79import ray
810import ray .data
911from ray .data import DataContext
1012
1113from graphgen .bases import Config , Node
1214from graphgen .utils import logger
15+ from graphgen .common import init_llm , init_storage
1316
17+ load_dotenv ()
1418
1519class Engine :
1620 def __init__ (
@@ -20,6 +24,8 @@ def __init__(
2024 self .global_params = self .config .global_params
2125 self .functions = functions
2226 self .datasets : Dict [str , ray .data .Dataset ] = {}
27+ self .llm_actors = {}
28+ self .storage_actors = {}
2329
2430 ctx = DataContext .get_current ()
2531 ctx .enable_rich_progress_bars = False
@@ -29,6 +35,16 @@ def __init__(
2935 ctx .enable_tensor_extension_casting = False
3036 ctx ._metrics_export_port = 0 # Disable metrics exporter to avoid RpcError
3137
38+ all_env_vars = os .environ .copy ()
39+ if "runtime_env" not in ray_init_kwargs :
40+ ray_init_kwargs ["runtime_env" ] = {}
41+
42+ existing_env_vars = ray_init_kwargs ["runtime_env" ].get ("env_vars" , {})
43+ ray_init_kwargs ["runtime_env" ]["env_vars" ] = {
44+ ** all_env_vars ,
45+ ** existing_env_vars
46+ }
47+
3248 if not ray .is_initialized ():
3349 context = ray .init (
3450 ignore_reinit_error = True ,
@@ -38,6 +54,59 @@ def __init__(
3854 )
3955 logger .info ("Ray Dashboard URL: %s" , context .dashboard_url )
4056
57+ self ._init_llms ()
58+ self ._init_storage ()
59+
60+ def _init_llms (self ):
61+ self .llm_actors ["synthesizer" ] = init_llm ("synthesizer" )
62+ self .llm_actors ["trainee" ] = init_llm ("trainee" )
63+
64+ def _init_storage (self ):
65+ kv_namespaces , graph_namespaces = self ._scan_storage_requirements ()
66+ working_dir = self .global_params ["working_dir" ]
67+
68+ for node_id in kv_namespaces :
69+ proxy = init_storage (self .global_params ["kv_backend" ], working_dir , node_id )
70+ self .storage_actors [f"kv_{ node_id } " ] = proxy
71+ logger .info ("Create KV Storage Actor: namespace=%s" , node_id )
72+
73+ for ns in graph_namespaces :
74+ proxy = init_storage (self .global_params ["graph_backend" ], working_dir , ns )
75+ self .storage_actors [f"graph_{ ns } " ] = proxy
76+ logger .info ("Create Graph Storage Actor: namespace=%s" , ns )
77+
78+ def _scan_storage_requirements (self ) -> tuple [set [str ], set [str ]]:
79+ kv_namespaces = set ()
80+ graph_namespaces = set ()
81+
82+ # TODO: Temporarily hard-coded; node storage will be centrally managed later.
83+ for node in self .config .nodes :
84+ op_name = node .op_name
85+ if self ._function_needs_param (op_name , "kv_backend" ):
86+ kv_namespaces .add (op_name )
87+ if self ._function_needs_param (op_name , "graph_backend" ):
88+ graph_namespaces .add ("graph" )
89+ return kv_namespaces , graph_namespaces
90+
91+ def _function_needs_param (self , op_name : str , param_name : str ) -> bool :
92+ if op_name not in self .functions :
93+ return False
94+
95+ func = self .functions [op_name ]
96+
97+ if inspect .isclass (func ):
98+ try :
99+ sig = inspect .signature (func .__init__ )
100+ return param_name in sig .parameters
101+ except (ValueError , TypeError ):
102+ return False
103+
104+ try :
105+ sig = inspect .signature (func )
106+ return param_name in sig .parameters
107+ except (ValueError , TypeError ):
108+ return False
109+
41110 @staticmethod
42111 def _topo_sort (nodes : List [Node ]) -> List [Node ]:
43112 id_to_node : Dict [str , Node ] = {}
0 commit comments