33from typing import Any , Callable , ClassVar
44
55import networkx as nx
6- from adbnx_adapter import ADBNX_Adapter
6+ from adbnx_adapter import ADBNX_Adapter , ADBNX_Controller
7+ from adbnx_adapter .typings import NxData , NxId
78from arango import ArangoClient
89from arango .cursor import Cursor
910from arango .database import StandardDatabase
@@ -186,19 +187,18 @@ def __init__(
186187 write_async : bool = True ,
187188 symmetrize_edges : bool = False ,
188189 use_arango_views : bool = False ,
190+ overwrite_graph : bool = False ,
189191 * args : Any ,
190192 ** kwargs : Any ,
191193 ):
192194 self .__db = None
193- self .__name = None
194195 self .__use_arango_views = use_arango_views
195196 self .__graph_exists_in_db = False
196197
197198 self .__set_db (db )
198- if self .__db is not None :
199- self .__set_graph_name (name )
200-
201- self .__set_edge_collections_attributes (edge_collections_attributes )
199+ if all ([self .__db , name ]):
200+ self .__set_graph (name , default_node_type , edge_type_func )
201+ self .__set_edge_collections_attributes (edge_collections_attributes )
202202
203203 # NOTE: Need to revisit these...
204204 # self.maintain_node_dict_cache = False
@@ -219,96 +219,25 @@ def __init__(
219219 # raise ValueError(m)
220220
221221 self ._loaded_incoming_graph_data = False
222-
223- if self .__graph_exists_in_db :
224- if incoming_graph_data is not None :
225- m = "Cannot pass both **incoming_graph_data** and **name** yet if the already graph exists" # noqa: E501
226- raise NotImplementedError (m )
227-
228- if edge_type_func is not None :
229- m = "Cannot pass **edge_type_func** if the graph already exists"
230- raise NotImplementedError (m )
231-
232- self .adb_graph = self .db .graph (self .__name )
233- vertex_collections = self .adb_graph .vertex_collections ()
234- edge_definitions = self .adb_graph .edge_definitions ()
235-
236- if default_node_type is None :
237- default_node_type = list (vertex_collections )[0 ]
238- logger .info (f"Default node type set to '{ default_node_type } '" )
239- elif default_node_type not in vertex_collections :
240- m = f"Default node type '{ default_node_type } ' not found in graph '{ name } '" # noqa: E501
241- raise InvalidDefaultNodeType (m )
242-
243- node_types_to_edge_type_map : dict [tuple [str , str ], str ] = {}
244- for e_d in edge_definitions :
245- for f in e_d ["from_vertex_collections" ]:
246- for t in e_d ["to_vertex_collections" ]:
247- if (f , t ) in node_types_to_edge_type_map :
248- # TODO: Should we log a warning at least?
249- continue
250-
251- node_types_to_edge_type_map [(f , t )] = e_d ["edge_collection" ]
252-
253- def edge_type_func (u : str , v : str ) -> str :
254- try :
255- return node_types_to_edge_type_map [(u , v )]
256- except KeyError :
257- m = f"Edge type ambiguity between '{ u } ' and '{ v } '"
258- raise EdgeTypeAmbiguity (m )
259-
260- self .edge_type_func = edge_type_func
261- self .default_node_type = default_node_type
262-
222+ if self .graph_exists_in_db :
263223 self ._set_factory_methods ()
264224 self .__set_arangodb_backend_config (read_parallelism , read_batch_size )
265225
266- elif self .__name :
226+ if overwrite_graph :
227+ logger .info ("Truncating graph collections..." )
267228
268- prefix = f"{ name } _" if name else ""
269- if default_node_type is None :
270- default_node_type = f"{ prefix } node"
271- if edge_type_func is None :
272- edge_type_func = lambda u , v : f"{ u } _to_{ v } " # noqa: E731
229+ for col in self .adb_graph .vertex_collections ():
230+ self .db .collection (col ).truncate ()
273231
274- self .edge_type_func = edge_type_func
275- self .default_node_type = default_node_type
276-
277- # TODO: Parameterize the edge definitions
278- # How can we work with a heterogenous **incoming_graph_data**?
279- default_edge_type = edge_type_func (default_node_type , default_node_type )
280- edge_definitions = [
281- {
282- "edge_collection" : default_edge_type ,
283- "from_vertex_collections" : [default_node_type ],
284- "to_vertex_collections" : [default_node_type ],
285- }
286- ]
232+ for col in self .adb_graph .edge_definitions ():
233+ self .db .collection (col ["edge_collection" ]).truncate ()
287234
288235 if isinstance (incoming_graph_data , nx .Graph ):
289- self .adb_graph = ADBNX_Adapter (self .db ).networkx_to_arangodb (
290- self .__name ,
291- incoming_graph_data ,
292- edge_definitions = edge_definitions ,
293- batch_size = write_batch_size ,
294- use_async = write_async ,
295- )
296-
236+ self ._load_nx_graph (incoming_graph_data , write_batch_size , write_async )
297237 self ._loaded_incoming_graph_data = True
298238
299- else :
300- self .adb_graph = self .db .create_graph (
301- self .__name ,
302- edge_definitions = edge_definitions ,
303- )
304-
305- self ._set_factory_methods ()
306- self .__set_arangodb_backend_config (read_parallelism , read_batch_size )
307- logger .info (f"Graph '{ name } ' created." )
308- self .__graph_exists_in_db = True
309-
310- if self .__name is not None :
311- kwargs ["name" ] = self .__name
239+ if name is not None :
240+ kwargs ["name" ] = name
312241
313242 super ().__init__ (* args , ** kwargs )
314243
@@ -333,6 +262,7 @@ def edge_type_func(u: str, v: str) -> str:
333262 and not self ._loaded_incoming_graph_data
334263 ):
335264 nx .convert .to_networkx_graph (incoming_graph_data , create_using = self )
265+ self ._loaded_incoming_graph_data = True
336266
337267 #######################
338268 # Init helper methods #
@@ -423,23 +353,118 @@ def __set_db(self, db: Any = None) -> None:
423353 self ._db_name , self ._username , self ._password , verify = True
424354 )
425355
426- def __set_graph_name (self , name : Any = None ) -> None :
427- if self .__db is None :
428- m = "Cannot set graph name without setting the database first"
429- raise DatabaseNotSet (m )
430-
431- if not name :
432- self .__graph_exists_in_db = False
433- logger .warning (f"**name** not set for { self .__class__ .__name__ } " )
434- return
435-
356+ def __set_graph (
357+ self ,
358+ name : Any ,
359+ default_node_type : str | None = None ,
360+ edge_type_func : Callable [[str , str ], str ] | None = None ,
361+ ) -> None :
436362 if not isinstance (name , str ):
437363 raise TypeError ("**name** must be a string" )
438364
365+ if self .db .has_graph (name ):
366+ logger .info (f"Graph '{ name } ' exists." )
367+
368+ if edge_type_func is not None :
369+ m = "Cannot pass **edge_type_func** if the graph already exists"
370+ raise NotImplementedError (m )
371+
372+ self .adb_graph = self .db .graph (name )
373+ vertex_collections = self .adb_graph .vertex_collections ()
374+ edge_definitions = self .adb_graph .edge_definitions ()
375+
376+ if default_node_type is None :
377+ default_node_type = list (vertex_collections )[0 ]
378+ logger .info (f"Default node type set to '{ default_node_type } '" )
379+
380+ elif default_node_type not in vertex_collections :
381+ m = f"Default node type '{ default_node_type } ' not found in graph '{ name } '" # noqa: E501
382+ raise InvalidDefaultNodeType (m )
383+
384+ node_types_to_edge_type_map : dict [tuple [str , str ], str ] = {}
385+ for e_d in edge_definitions :
386+ for f in e_d ["from_vertex_collections" ]:
387+ for t in e_d ["to_vertex_collections" ]:
388+ if (f , t ) in node_types_to_edge_type_map :
389+ # TODO: Should we log a warning at least?
390+ continue
391+
392+ node_types_to_edge_type_map [(f , t )] = e_d ["edge_collection" ]
393+
394+ def edge_type_func (u : str , v : str ) -> str :
395+ try :
396+ return node_types_to_edge_type_map [(u , v )]
397+ except KeyError :
398+ m = f"Edge type ambiguity between '{ u } ' and '{ v } '"
399+ raise EdgeTypeAmbiguity (m )
400+
401+ else :
402+ prefix = f"{ name } _" if name else ""
403+
404+ if default_node_type is None :
405+ default_node_type = f"{ prefix } node"
406+
407+ if edge_type_func is None :
408+ edge_type_func = lambda u , v : f"{ u } _to_{ v } " # noqa: E731
409+
410+ # TODO: Parameterize the edge definitions
411+ # How can we work with a heterogenous **incoming_graph_data**?
412+ default_edge_type = edge_type_func (default_node_type , default_node_type )
413+ edge_definitions = [
414+ {
415+ "edge_collection" : default_edge_type ,
416+ "from_vertex_collections" : [default_node_type ],
417+ "to_vertex_collections" : [default_node_type ],
418+ }
419+ ]
420+
421+ # Create a general graph if it doesn't exist
422+ self .adb_graph = self .db .create_graph (
423+ name = name ,
424+ edge_definitions = edge_definitions ,
425+ )
426+
427+ logger .info (f"Graph '{ name } ' created." )
428+
439429 self .__name = name
440- self .__graph_exists_in_db = self .db .has_graph (name )
430+ self .__graph_exists_in_db = True
431+ self .edge_type_func = edge_type_func
432+ self .default_node_type = default_node_type
433+
434+ properties = self .adb_graph .properties ()
435+ self .__is_smart : bool = properties .get ("smart" , False )
436+ self .__smart_field : str | None = properties .get ("smart_field" )
437+
438+ def _load_nx_graph (
439+ self , nx_graph : nx .Graph , write_batch_size : int , write_async : bool
440+ ) -> None :
441+ controller = ADBNX_Controller
442+
443+ if all ([self .is_smart , self .smart_field ]):
444+ smart_field = self .__smart_field
441445
442- logger .info (f"Graph '{ name } ' exists: { self .__graph_exists_in_db } " )
446+ class SmartController (ADBNX_Controller ):
447+ def _keyify_networkx_node (
448+ self , i : int , nx_node_id : NxId , nx_node : NxData , col : str
449+ ) -> str :
450+ if smart_field not in nx_node :
451+ m = f"Node { nx_node_id } missing smart field '{ smart_field } '" # noqa: E501
452+ raise KeyError (m )
453+
454+ return f"{ nx_node [smart_field ]} :{ str (i )} "
455+
456+ def _prepare_networkx_edge (self , nx_edge : NxData , col : str ) -> None :
457+ del nx_edge ["_key" ]
458+
459+ controller = SmartController
460+ logger .info (f"Using smart field '{ smart_field } ' for node keys" )
461+
462+ ADBNX_Adapter (self .db , controller ()).networkx_to_arangodb (
463+ self .adb_graph .name ,
464+ nx_graph ,
465+ batch_size = write_batch_size ,
466+ use_async = write_async ,
467+ )
443468
444469 ###########
445470 # Getters #
@@ -479,6 +504,14 @@ def graph_exists_in_db(self) -> bool:
479504 def edge_attributes (self ) -> set [str ]:
480505 return self ._edge_collections_attributes
481506
507+ @property
508+ def is_smart (self ) -> bool :
509+ return self .__is_smart
510+
511+ @property
512+ def smart_field (self ) -> str | None :
513+ return self .__smart_field
514+
482515 ###########
483516 # Setters #
484517 ###########
@@ -645,12 +678,24 @@ def add_node_override(self, node_for_adding, **attr):
645678 # attr_dict.update(attr)
646679
647680 # New:
648- self ._node [node_for_adding ] = self .node_attr_dict_factory ()
681+ node_attr_dict = self .node_attr_dict_factory ()
682+
683+ if self .is_smart :
684+ if self .smart_field not in attr :
685+ m = f"Node { node_for_adding } missing smart field '{ self .smart_field } '" # noqa: E501
686+ raise KeyError (m )
687+
688+ node_attr_dict .data [self .smart_field ] = attr [self .smart_field ]
689+
690+ self ._node [node_for_adding ] = node_attr_dict
649691 self ._node [node_for_adding ].update (attr )
650692
651693 # Reason:
652694 # Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set
653695 # i.e trying to update a node's attributes before we know _which_ node it is
696+ # Furthermore, support for ArangoDB Smart Graphs requires the smart field
697+ # to be set before adding the node to the graph. This is because the smart
698+ # field is used to generate the node's key.
654699
655700 ###########################
656701
0 commit comments