2727 node_attr_dict_factory ,
2828 node_dict_factory ,
2929)
30+ from .dict .adj import AdjListOuterDict
31+ from .enum import TraversalDirection
3032from .function import get_node_id
3133from .reportviews import CustomEdgeView , CustomNodeView
3234
@@ -96,6 +98,7 @@ def __init__(
9698 # m = "Must set **graph_name** if passing **incoming_graph_data**"
9799 # raise ValueError(m)
98100
101+ loaded_incoming_graph_data = False
99102 if self ._graph_exists_in_db :
100103 if incoming_graph_data is not None :
101104 m = "Cannot pass both **incoming_graph_data** and **graph_name** yet if the already graph exists" # noqa: E501
@@ -170,29 +173,44 @@ def edge_type_func(u: str, v: str) -> str:
170173 use_async = write_async ,
171174 )
172175
176+ loaded_incoming_graph_data = True
177+
173178 else :
174179 self .adb_graph = self .db .create_graph (
175180 self .__name ,
176181 edge_definitions = edge_definitions ,
177182 )
178183
179- # Let the parent class handle the incoming graph data
180- # if it is not a networkx.Graph object
181- kwargs ["incoming_graph_data" ] = incoming_graph_data
182-
183184 self ._set_factory_methods ()
184185 self ._set_arangodb_backend_config ()
185186 logger .info (f"Graph '{ name } ' created." )
186187 self ._graph_exists_in_db = True
187188
188- else :
189- kwargs ["incoming_graph_data" ] = incoming_graph_data
190-
191- if name is not None :
192- kwargs ["name" ] = name
189+ if self .__name is not None :
190+ kwargs ["name" ] = self .__name
193191
194192 super ().__init__ (* args , ** kwargs )
195193
194+ if self .is_directed () and self .graph_exists_in_db :
195+ assert isinstance (self ._succ , AdjListOuterDict )
196+ assert isinstance (self ._pred , AdjListOuterDict )
197+ self ._succ .mirror = self ._pred
198+ self ._pred .mirror = self ._succ
199+ self ._succ .traversal_direction = TraversalDirection .OUTBOUND
200+ self ._pred .traversal_direction = TraversalDirection .INBOUND
201+
202+ if incoming_graph_data is not None and not loaded_incoming_graph_data :
203+ nx .convert .to_networkx_graph (incoming_graph_data , create_using = self )
204+
205+ if self .graph_exists_in_db :
206+ self .copy = self .copy_override
207+ self .subgraph = self .subgraph_override
208+ self .clear = self .clear_override
209+ self .clear_edges = self .clear_edges_override
210+ self .add_node = self .add_node_override
211+ self .number_of_edges = self .number_of_edges_override
212+ self .nbunch_iter = self .nbunch_iter_override
213+
196214 #######################
197215 # Init helper methods #
198216 #######################
@@ -345,6 +363,9 @@ def _set_graph_name(self, graph_name: str | None = None) -> None:
345363 # ArangoDB Methods #
346364 ####################
347365
366+ def clear_nxcg_cache (self ):
367+ self .nxcg_graph = None
368+
348369 def aql (self , query : str , bind_vars : dict [str , Any ] = {}, ** kwargs : Any ) -> Cursor :
349370 return nxadb .classes .function .aql (self .db , query , bind_vars , ** kwargs )
350371
@@ -355,7 +376,7 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs
355376 # NOTE: OUT OF SERVICE
356377 # def chat(self, prompt: str) -> str:
357378 # if self.__qa_chain is None:
358- # if not self.__graph_exists_in_db :
379+ # if not self.graph_exists_in_db :
359380 # return "Could not initialize QA chain: Graph does not exist"
360381
361382 # # try:
@@ -381,30 +402,6 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs
381402 # nx.Graph Overides #
382403 #####################
383404
384- def copy (self , * args , ** kwargs ):
385- logger .warning ("Note that copying a graph loses the connection to the database" )
386- G = super ().copy (* args , ** kwargs )
387- G .node_dict_factory = nx .Graph .node_dict_factory
388- G .node_attr_dict_factory = nx .Graph .node_attr_dict_factory
389- G .edge_attr_dict_factory = nx .Graph .edge_attr_dict_factory
390- G .adjlist_inner_dict_factory = nx .Graph .adjlist_inner_dict_factory
391- G .adjlist_outer_dict_factory = nx .Graph .adjlist_outer_dict_factory
392- return G
393-
394- def subgraph (self , nbunch ):
395- raise NotImplementedError ("Subgraphing is not yet implemented" )
396-
397- def clear (self ):
398- logger .info ("Note that clearing only erases the local cache" )
399- super ().clear ()
400-
401- def clear_edges (self ):
402- logger .info ("Note that clearing edges ony erases the edges in the local cache" )
403- super ().clear_edges ()
404-
405- def clear_nxcg_cache (self ):
406- self .nxcg_graph = None
407-
408405 @cached_property
409406 def nodes (self ):
410407 if self .__use_experimental_views and self .graph_exists_in_db :
@@ -437,7 +434,30 @@ def edges(self):
437434
438435 return super ().edges
439436
440- def add_node (self , node_for_adding , ** attr ):
437+ def copy_override (self , * args , ** kwargs ):
438+ logger .warning ("Note that copying a graph loses the connection to the database" )
439+ G = super ().copy (* args , ** kwargs )
440+ G .node_dict_factory = nx .Graph .node_dict_factory
441+ G .node_attr_dict_factory = nx .Graph .node_attr_dict_factory
442+ G .edge_attr_dict_factory = nx .Graph .edge_attr_dict_factory
443+ G .adjlist_inner_dict_factory = nx .Graph .adjlist_inner_dict_factory
444+ G .adjlist_outer_dict_factory = nx .Graph .adjlist_outer_dict_factory
445+ return G
446+
447+ def subgraph_override (self , nbunch ):
448+ raise NotImplementedError ("Subgraphing is not yet implemented" )
449+
450+ def clear_override (self ):
451+ logger .info ("Note that clearing only erases the local cache" )
452+ super ().clear ()
453+
454+ def clear_edges_override (self ):
455+ logger .info ("Note that clearing edges ony erases the edges in the local cache" )
456+ for nbr_dict in self ._adj .data .values ():
457+ nbr_dict .clear ()
458+ nx ._clear_cache (self )
459+
460+ def add_node_override (self , node_for_adding , ** attr ):
441461 if node_for_adding not in self ._node :
442462 if node_for_adding is None :
443463 raise ValueError ("None cannot be a node" )
@@ -467,10 +487,7 @@ def add_node(self, node_for_adding, **attr):
467487
468488 nx ._clear_cache (self )
469489
470- def number_of_edges (self , u = None , v = None ):
471- if not self .graph_exists_in_db :
472- return super ().number_of_edges (u , v )
473-
490+ def number_of_edges_override (self , u = None , v = None ):
474491 if u is not None :
475492 return super ().number_of_edges (u , v )
476493
@@ -494,10 +511,7 @@ def number_of_edges(self, u=None, v=None):
494511 # It is more efficient to count the number of edges in the edge collections
495512 # compared to relying on the DegreeView.
496513
497- def nbunch_iter (self , nbunch = None ):
498- if not self ._graph_exists_in_db :
499- return super ().nbunch_iter (nbunch )
500-
514+ def nbunch_iter_override (self , nbunch = None ):
501515 if nbunch is None :
502516 bunch = iter (self ._adj )
503517 elif nbunch in self :
@@ -508,12 +522,13 @@ def nbunch_iter(self, nbunch=None):
508522 # Old: Nothing
509523
510524 # New:
511- if isinstance (nbunch , int ):
525+ if isinstance (nbunch , ( str , int ) ):
512526 nbunch = get_node_id (str (nbunch ), self .default_node_type )
513527
514528 # Reason:
515529 # ArangoDB only uses strings as node IDs. Therefore, we need to convert
516- # the integer node ID to a string before using it in an iterator.
530+ # the non-prefixed node ID to an ArangoDB ID before
531+ # using it in an iterator.
517532
518533 bunch = iter ([nbunch ])
519534 else :
@@ -528,13 +543,15 @@ def bunch_iter(nlist, adj):
528543 # Old: Nothing
529544
530545 # New:
531- if isinstance (n , int ):
546+ if isinstance (n , ( str , int ) ):
532547 n = get_node_id (str (n ), self .default_node_type )
533548
534549 # Reason:
535550 # ArangoDB only uses strings as node IDs. Therefore,
536- # we need to convert the integer node ID to a
537- # string before using it in an iterator.
551+ # we need to convert non-prefixed node IDs to an
552+ # ArangoDB ID before using it in an iterator.
553+
554+ ######################
538555
539556 if n in adj :
540557 yield n
0 commit comments