@@ -214,11 +214,13 @@ def __init__(
214214 self .use_nxcg_cache = True
215215 self .nxcg_graph = None
216216
217+ self .edge_type_key = edge_type_key
218+ self .read_parallelism = read_parallelism
219+ self .read_batch_size = read_batch_size
220+
217221 # Does not apply to undirected graphs
218222 self .symmetrize_edges = symmetrize_edges
219223
220- self .edge_type_key = edge_type_key
221-
222224 # TODO: Consider this
223225 # if not self.__graph_name:
224226 # if incoming_graph_data is not None:
@@ -227,8 +229,8 @@ def __init__(
227229
228230 self ._loaded_incoming_graph_data = False
229231 if self .graph_exists_in_db :
230- self ._set_factory_methods ()
231- self .__set_arangodb_backend_config (read_parallelism , read_batch_size )
232+ self ._set_factory_methods (read_parallelism , read_batch_size )
233+ self .__set_arangodb_backend_config ()
232234
233235 if isinstance (incoming_graph_data , nx .Graph ):
234236 self ._load_nx_graph (incoming_graph_data , write_batch_size , write_async )
@@ -267,7 +269,7 @@ def __init__(
267269 # Init helper methods #
268270 #######################
269271
270- def _set_factory_methods (self ) -> None :
272+ def _set_factory_methods (self , read_parallelism : int , read_batch_size : int ) -> None :
271273 """Set the factory methods for the graph, _node, and _adj dictionaries.
272274
273275 The ArangoDB CRUD operations are handled by the modified dictionaries.
@@ -282,39 +284,29 @@ def _set_factory_methods(self) -> None:
282284 """
283285
284286 base_args = (self .db , self .adb_graph )
287+
285288 node_args = (* base_args , self .default_node_type )
286- adj_args = (
287- * node_args ,
288- self .edge_type_key ,
289- self .edge_type_func ,
290- self .__class__ .__name__ ,
289+ node_args_with_read = (* node_args , read_parallelism , read_batch_size )
290+
291+ adj_args = (self .edge_type_key , self .edge_type_func , self .__class__ .__name__ )
292+ adj_inner_args = (* node_args , * adj_args )
293+ adj_outer_args = (
294+ * node_args_with_read ,
295+ * adj_args ,
296+ self .symmetrize_edges ,
291297 )
292298
293299 self .graph_attr_dict_factory = graph_dict_factory (* base_args )
294300
295- self .node_dict_factory = node_dict_factory (* node_args )
301+ self .node_dict_factory = node_dict_factory (* node_args_with_read )
296302 self .node_attr_dict_factory = node_attr_dict_factory (* base_args )
297303
298304 self .edge_attr_dict_factory = edge_attr_dict_factory (* base_args )
299- self .adjlist_inner_dict_factory = adjlist_inner_dict_factory (* adj_args )
300- self .adjlist_outer_dict_factory = adjlist_outer_dict_factory (
301- * adj_args , self .symmetrize_edges
302- )
303-
304- def __set_arangodb_backend_config (
305- self , read_parallelism : int , read_batch_size : int
306- ) -> None :
307- if not all ([self ._host , self ._username , self ._password , self ._db_name ]):
308- m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
309- raise OSError (m )
305+ self .adjlist_inner_dict_factory = adjlist_inner_dict_factory (* adj_inner_args )
306+ self .adjlist_outer_dict_factory = adjlist_outer_dict_factory (* adj_outer_args )
310307
308+ def __set_arangodb_backend_config (self ) -> None :
311309 config = nx .config .backends .arangodb
312- config .host = self ._host
313- config .username = self ._username
314- config .password = self ._password
315- config .db_name = self ._db_name
316- config .read_parallelism = read_parallelism
317- config .read_batch_size = read_batch_size
318310 config .use_gpu = True # Only used by default if nx-cugraph is available
319311
320312 def __set_edge_collections_attributes (self , attributes : set [str ] | None ) -> None :
@@ -328,7 +320,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None
328320 self ._edge_collections_attributes .add ("_id" )
329321
330322 def __set_db (self , db : Any = None ) -> None :
331- self ._host = os .getenv ("DATABASE_HOST" )
323+ self ._hosts = os .getenv ("DATABASE_HOST" , "" ). split ( ", " )
332324 self ._username = os .getenv ("DATABASE_USERNAME" )
333325 self ._password = os .getenv ("DATABASE_PASSWORD" )
334326 self ._db_name = os .getenv ("DATABASE_NAME" )
@@ -338,17 +330,20 @@ def __set_db(self, db: Any = None) -> None:
338330 m = "arango.database.StandardDatabase"
339331 raise TypeError (m )
340332
341- db .version ()
333+ db .version () # make sure the connection is valid
342334 self .__db = db
335+ self ._db_name = db .name
336+ self ._hosts = db ._conn ._hosts
337+ self ._username , self ._password = db ._conn ._auth
343338 return
344339
345- if not all ([self ._host , self ._username , self ._password , self ._db_name ]):
340+ if not all ([self ._hosts , self ._username , self ._password , self ._db_name ]):
346341 m = "Database environment variables not set. Can't connect to the database"
347342 logger .warning (m )
348343 self .__db = None
349344 return
350345
351- self .__db = ArangoClient (hosts = self ._host , request_timeout = None ).db (
346+ self .__db = ArangoClient (hosts = self ._hosts , request_timeout = None ).db (
352347 self ._db_name , self ._username , self ._password , verify = True
353348 )
354349
0 commit comments