@@ -214,11 +214,13 @@ def __init__(
214214 self .use_nxcg_cache = True
215215 self .nxcg_graph = None
216216
217+ self .edge_type_key = edge_type_key
218+ self .read_parallelism = read_parallelism
219+ self .read_batch_size = read_batch_size
220+
217221 # Does not apply to undirected graphs
218222 self .symmetrize_edges = symmetrize_edges
219223
220- self .edge_type_key = edge_type_key
221-
222224 # TODO: Consider this
223225 # if not self.__graph_name:
224226 # if incoming_graph_data is not None:
@@ -227,8 +229,8 @@ def __init__(
227229
228230 self ._loaded_incoming_graph_data = False
229231 if self .graph_exists_in_db :
230- self ._set_factory_methods ()
231- self .__set_arangodb_backend_config (read_parallelism , read_batch_size )
232+ self ._set_factory_methods (read_parallelism , read_batch_size )
233+ self .__set_arangodb_backend_config ()
232234
233235 if overwrite_graph :
234236 logger .info ("Overwriting graph..." )
@@ -284,7 +286,7 @@ def __init__(
284286 # Init helper methods #
285287 #######################
286288
287- def _set_factory_methods (self ) -> None :
289+ def _set_factory_methods (self , read_parallelism : int , read_batch_size : int ) -> None :
288290 """Set the factory methods for the graph, _node, and _adj dictionaries.
289291
290292 The ArangoDB CRUD operations are handled by the modified dictionaries.
@@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None:
299301 """
300302
301303 base_args = (self .db , self .adb_graph )
304+
302305 node_args = (* base_args , self .default_node_type )
303- adj_args = (
304- * node_args ,
305- self .edge_type_key ,
306- self .edge_type_func ,
307- self .__class__ .__name__ ,
306+ node_args_with_read = (* node_args , read_parallelism , read_batch_size )
307+
308+ adj_args = (self .edge_type_key , self .edge_type_func , self .__class__ .__name__ )
309+ adj_inner_args = (* node_args , * adj_args )
310+ adj_outer_args = (
311+ * node_args_with_read ,
312+ * adj_args ,
313+ self .symmetrize_edges ,
308314 )
309315
310316 self .graph_attr_dict_factory = graph_dict_factory (* base_args )
311317
312- self .node_dict_factory = node_dict_factory (* node_args )
318+ self .node_dict_factory = node_dict_factory (* node_args_with_read )
313319 self .node_attr_dict_factory = node_attr_dict_factory (* base_args )
314320
315321 self .edge_attr_dict_factory = edge_attr_dict_factory (* base_args )
316- self .adjlist_inner_dict_factory = adjlist_inner_dict_factory (* adj_args )
317- self .adjlist_outer_dict_factory = adjlist_outer_dict_factory (
318- * adj_args , self .symmetrize_edges
319- )
320-
321- def __set_arangodb_backend_config (
322- self , read_parallelism : int , read_batch_size : int
323- ) -> None :
324- if not all ([self ._host , self ._username , self ._password , self ._db_name ]):
325- m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
326- raise OSError (m )
322+ self .adjlist_inner_dict_factory = adjlist_inner_dict_factory (* adj_inner_args )
323+ self .adjlist_outer_dict_factory = adjlist_outer_dict_factory (* adj_outer_args )
327324
325+ def __set_arangodb_backend_config (self ) -> None :
328326 config = nx .config .backends .arangodb
329- config .host = self ._host
330- config .username = self ._username
331- config .password = self ._password
332- config .db_name = self ._db_name
333- config .read_parallelism = read_parallelism
334- config .read_batch_size = read_batch_size
335327 config .use_gpu = True # Only used by default if nx-cugraph is available
336328
337329 def __set_edge_collections_attributes (self , attributes : set [str ] | None ) -> None :
@@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None
345337 self ._edge_collections_attributes .add ("_id" )
346338
347339 def __set_db (self , db : Any = None ) -> None :
348- self ._host = os .getenv ("DATABASE_HOST" )
340+ self ._hosts = os .getenv ("DATABASE_HOST" , "" ). split ( ", " )
349341 self ._username = os .getenv ("DATABASE_USERNAME" )
350342 self ._password = os .getenv ("DATABASE_PASSWORD" )
351343 self ._db_name = os .getenv ("DATABASE_NAME" )
@@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None:
355347 m = "arango.database.StandardDatabase"
356348 raise TypeError (m )
357349
358- db .version ()
350+ db .version () # make sure the connection is valid
359351 self .__db = db
352+ self ._db_name = db .name
353+ self ._hosts = db ._conn ._hosts
354+ self ._username , self ._password = db ._conn ._auth
360355 return
361356
362- if not all ([self ._host , self ._username , self ._password , self ._db_name ]):
357+ if not all ([self ._hosts , self ._username , self ._password , self ._db_name ]):
363358 m = "Database environment variables not set. Can't connect to the database"
364359 logger .warning (m )
365360 self .__db = None
366361 return
367362
368- self .__db = ArangoClient (hosts = self ._host , request_timeout = None ).db (
363+ self .__db = ArangoClient (hosts = self ._hosts , request_timeout = None ).db (
369364 self ._db_name , self ._username , self ._password , verify = True
370365 )
371366
0 commit comments