1212
1313from graphgen .bases import Config , Node
1414from graphgen .utils import logger
15- from graphgen .common import init_llm
15+ from graphgen .common import init_llm , init_storage
1616
1717load_dotenv ()
1818
@@ -25,6 +25,7 @@ def __init__(
2525 self .functions = functions
2626 self .datasets : Dict [str , ray .data .Dataset ] = {}
2727 self .llm_actors = {}
28+ self .storage_actors = {}
2829
2930 ctx = DataContext .get_current ()
3031 ctx .enable_rich_progress_bars = False
@@ -54,11 +55,57 @@ def __init__(
5455 logger .info ("Ray Dashboard URL: %s" , context .dashboard_url )
5556
5657 self ._init_llms ()
58+ self ._init_storage ()
5759
5860 def _init_llms (self ):
5961 self .llm_actors ["synthesizer" ] = init_llm ("synthesizer" )
6062 self .llm_actors ["trainee" ] = init_llm ("trainee" )
6163
64+ def _init_storage (self ):
65+ kv_namespaces , graph_namespaces = self ._scan_storage_requirements ()
66+ working_dir = self .global_params ["working_dir" ]
67+
68+ for node_id in kv_namespaces :
69+ proxy = init_storage (self .global_params ["kv_backend" ], working_dir , node_id )
70+ self .storage_actors [f"kv_{ node_id } " ] = proxy
71+ logger .info ("Create KV Storage Actor: namespace=%s" , node_id )
72+
73+ for ns in graph_namespaces :
74+ proxy = init_storage (self .global_params ["graph_backend" ], working_dir , ns )
75+ self .storage_actors [f"graph_{ ns } " ] = proxy
76+ logger .info ("Create Graph Storage Actor: namespace=%s" , ns )
77+
78+ def _scan_storage_requirements (self ) -> tuple [set [str ], set [str ]]:
79+ kv_namespaces = set ()
80+ graph_namespaces = set ()
81+
82+ # TODO: Temporarily hard-coded; node storage will be centrally managed later.
83+ for node in self .config .nodes :
84+ op_name = node .op_name
85+ if self ._function_needs_param (op_name , "kv_backend" ):
86+ kv_namespaces .add (op_name )
87+ if self ._function_needs_param (op_name , "graph_backend" ):
88+ graph_namespaces .add ("graph" )
89+ return kv_namespaces , graph_namespaces
90+
91+ def _function_needs_param (self , op_name : str , param_name : str ) -> bool :
92+ if op_name not in self .functions :
93+ return False
94+
95+ func = self .functions [op_name ]
96+
97+ if inspect .isclass (func ):
98+ try :
99+ sig = inspect .signature (func .__init__ )
100+ return param_name in sig .parameters
101+ except (ValueError , TypeError ):
102+ return False
103+
104+ try :
105+ sig = inspect .signature (func )
106+ return param_name in sig .parameters
107+ except (ValueError , TypeError ):
108+ return False
62109
63110 @staticmethod
64111 def _topo_sort (nodes : List [Node ]) -> List [Node ]:
0 commit comments