Skip to content

Commit 152fe62

Browse files
committed
smart graph support | initial commit
1 parent 7999151 commit 152fe62

File tree

3 files changed

+161
-100
lines changed

3 files changed

+161
-100
lines changed

nx_arangodb/classes/digraph.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
and not self._loaded_incoming_graph_data
195195
):
196196
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)
197+
self._loaded_incoming_graph_data = True
197198

198199
#######################
199200
# nx.DiGraph Overides #
@@ -241,12 +242,25 @@ def add_node_override(self, node_for_adding, **attr):
241242
# attr_dict.update(attr)
242243

243244
# New:
245+
246+
node_attr_dict = self.node_attr_dict_factory()
247+
248+
if self.is_smart:
249+
if self.smart_field not in attr:
250+
m = f"Node {node_for_adding} missing smart field '{self.smart_field}'" # noqa: E501
251+
raise KeyError(m)
252+
253+
node_attr_dict.data[self.smart_field] = attr[self.smart_field]
254+
244255
self._node[node_for_adding] = self.node_attr_dict_factory()
245256
self._node[node_for_adding].update(attr)
246257

247258
# Reason:
248259
# Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set
249260
# i.e trying to update a node's attributes before we know _which_ node it is
261+
# Furthermore, support for ArangoDB Smart Graphs requires the smart field
262+
# to be set before adding the node to the graph. This is because the smart
263+
# field is used to generate the node's key.
250264

251265
###########################
252266

nx_arangodb/classes/graph.py

Lines changed: 145 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from typing import Any, Callable, ClassVar
44

55
import networkx as nx
6-
from adbnx_adapter import ADBNX_Adapter
6+
from adbnx_adapter import ADBNX_Adapter, ADBNX_Controller
7+
from adbnx_adapter.typings import NxData, NxId
78
from arango import ArangoClient
89
from arango.cursor import Cursor
910
from arango.database import StandardDatabase
@@ -186,19 +187,18 @@ def __init__(
186187
write_async: bool = True,
187188
symmetrize_edges: bool = False,
188189
use_arango_views: bool = False,
190+
overwrite_graph: bool = False,
189191
*args: Any,
190192
**kwargs: Any,
191193
):
192194
self.__db = None
193-
self.__name = None
194195
self.__use_arango_views = use_arango_views
195196
self.__graph_exists_in_db = False
196197

197198
self.__set_db(db)
198-
if self.__db is not None:
199-
self.__set_graph_name(name)
200-
201-
self.__set_edge_collections_attributes(edge_collections_attributes)
199+
if all([self.__db, name]):
200+
self.__set_graph(name, default_node_type, edge_type_func)
201+
self.__set_edge_collections_attributes(edge_collections_attributes)
202202

203203
# NOTE: Need to revisit these...
204204
# self.maintain_node_dict_cache = False
@@ -219,96 +219,25 @@ def __init__(
219219
# raise ValueError(m)
220220

221221
self._loaded_incoming_graph_data = False
222-
223-
if self.__graph_exists_in_db:
224-
if incoming_graph_data is not None:
225-
m = "Cannot pass both **incoming_graph_data** and **name** yet if the already graph exists" # noqa: E501
226-
raise NotImplementedError(m)
227-
228-
if edge_type_func is not None:
229-
m = "Cannot pass **edge_type_func** if the graph already exists"
230-
raise NotImplementedError(m)
231-
232-
self.adb_graph = self.db.graph(self.__name)
233-
vertex_collections = self.adb_graph.vertex_collections()
234-
edge_definitions = self.adb_graph.edge_definitions()
235-
236-
if default_node_type is None:
237-
default_node_type = list(vertex_collections)[0]
238-
logger.info(f"Default node type set to '{default_node_type}'")
239-
elif default_node_type not in vertex_collections:
240-
m = f"Default node type '{default_node_type}' not found in graph '{name}'" # noqa: E501
241-
raise InvalidDefaultNodeType(m)
242-
243-
node_types_to_edge_type_map: dict[tuple[str, str], str] = {}
244-
for e_d in edge_definitions:
245-
for f in e_d["from_vertex_collections"]:
246-
for t in e_d["to_vertex_collections"]:
247-
if (f, t) in node_types_to_edge_type_map:
248-
# TODO: Should we log a warning at least?
249-
continue
250-
251-
node_types_to_edge_type_map[(f, t)] = e_d["edge_collection"]
252-
253-
def edge_type_func(u: str, v: str) -> str:
254-
try:
255-
return node_types_to_edge_type_map[(u, v)]
256-
except KeyError:
257-
m = f"Edge type ambiguity between '{u}' and '{v}'"
258-
raise EdgeTypeAmbiguity(m)
259-
260-
self.edge_type_func = edge_type_func
261-
self.default_node_type = default_node_type
262-
222+
if self.graph_exists_in_db:
263223
self._set_factory_methods()
264224
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
265225

266-
elif self.__name:
226+
if overwrite_graph:
227+
logger.info("Truncating graph collections...")
267228

268-
prefix = f"{name}_" if name else ""
269-
if default_node_type is None:
270-
default_node_type = f"{prefix}node"
271-
if edge_type_func is None:
272-
edge_type_func = lambda u, v: f"{u}_to_{v}" # noqa: E731
229+
for col in self.adb_graph.vertex_collections():
230+
self.db.collection(col).truncate()
273231

274-
self.edge_type_func = edge_type_func
275-
self.default_node_type = default_node_type
276-
277-
# TODO: Parameterize the edge definitions
278-
# How can we work with a heterogenous **incoming_graph_data**?
279-
default_edge_type = edge_type_func(default_node_type, default_node_type)
280-
edge_definitions = [
281-
{
282-
"edge_collection": default_edge_type,
283-
"from_vertex_collections": [default_node_type],
284-
"to_vertex_collections": [default_node_type],
285-
}
286-
]
232+
for col in self.adb_graph.edge_definitions():
233+
self.db.collection(col["edge_collection"]).truncate()
287234

288235
if isinstance(incoming_graph_data, nx.Graph):
289-
self.adb_graph = ADBNX_Adapter(self.db).networkx_to_arangodb(
290-
self.__name,
291-
incoming_graph_data,
292-
edge_definitions=edge_definitions,
293-
batch_size=write_batch_size,
294-
use_async=write_async,
295-
)
296-
236+
self._load_nx_graph(incoming_graph_data, write_batch_size, write_async)
297237
self._loaded_incoming_graph_data = True
298238

299-
else:
300-
self.adb_graph = self.db.create_graph(
301-
self.__name,
302-
edge_definitions=edge_definitions,
303-
)
304-
305-
self._set_factory_methods()
306-
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
307-
logger.info(f"Graph '{name}' created.")
308-
self.__graph_exists_in_db = True
309-
310-
if self.__name is not None:
311-
kwargs["name"] = self.__name
239+
if name is not None:
240+
kwargs["name"] = name
312241

313242
super().__init__(*args, **kwargs)
314243

@@ -333,6 +262,7 @@ def edge_type_func(u: str, v: str) -> str:
333262
and not self._loaded_incoming_graph_data
334263
):
335264
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)
265+
self._loaded_incoming_graph_data = True
336266

337267
#######################
338268
# Init helper methods #
@@ -423,23 +353,118 @@ def __set_db(self, db: Any = None) -> None:
423353
self._db_name, self._username, self._password, verify=True
424354
)
425355

426-
def __set_graph_name(self, name: Any = None) -> None:
427-
if self.__db is None:
428-
m = "Cannot set graph name without setting the database first"
429-
raise DatabaseNotSet(m)
430-
431-
if not name:
432-
self.__graph_exists_in_db = False
433-
logger.warning(f"**name** not set for {self.__class__.__name__}")
434-
return
435-
356+
def __set_graph(
357+
self,
358+
name: Any,
359+
default_node_type: str | None = None,
360+
edge_type_func: Callable[[str, str], str] | None = None,
361+
) -> None:
436362
if not isinstance(name, str):
437363
raise TypeError("**name** must be a string")
438364

365+
if self.db.has_graph(name):
366+
logger.info(f"Graph '{name}' exists.")
367+
368+
if edge_type_func is not None:
369+
m = "Cannot pass **edge_type_func** if the graph already exists"
370+
raise NotImplementedError(m)
371+
372+
self.adb_graph = self.db.graph(name)
373+
vertex_collections = self.adb_graph.vertex_collections()
374+
edge_definitions = self.adb_graph.edge_definitions()
375+
376+
if default_node_type is None:
377+
default_node_type = list(vertex_collections)[0]
378+
logger.info(f"Default node type set to '{default_node_type}'")
379+
380+
elif default_node_type not in vertex_collections:
381+
m = f"Default node type '{default_node_type}' not found in graph '{name}'" # noqa: E501
382+
raise InvalidDefaultNodeType(m)
383+
384+
node_types_to_edge_type_map: dict[tuple[str, str], str] = {}
385+
for e_d in edge_definitions:
386+
for f in e_d["from_vertex_collections"]:
387+
for t in e_d["to_vertex_collections"]:
388+
if (f, t) in node_types_to_edge_type_map:
389+
# TODO: Should we log a warning at least?
390+
continue
391+
392+
node_types_to_edge_type_map[(f, t)] = e_d["edge_collection"]
393+
394+
def edge_type_func(u: str, v: str) -> str:
395+
try:
396+
return node_types_to_edge_type_map[(u, v)]
397+
except KeyError:
398+
m = f"Edge type ambiguity between '{u}' and '{v}'"
399+
raise EdgeTypeAmbiguity(m)
400+
401+
else:
402+
prefix = f"{name}_" if name else ""
403+
404+
if default_node_type is None:
405+
default_node_type = f"{prefix}node"
406+
407+
if edge_type_func is None:
408+
edge_type_func = lambda u, v: f"{u}_to_{v}" # noqa: E731
409+
410+
# TODO: Parameterize the edge definitions
411+
# How can we work with a heterogenous **incoming_graph_data**?
412+
default_edge_type = edge_type_func(default_node_type, default_node_type)
413+
edge_definitions = [
414+
{
415+
"edge_collection": default_edge_type,
416+
"from_vertex_collections": [default_node_type],
417+
"to_vertex_collections": [default_node_type],
418+
}
419+
]
420+
421+
# Create a general graph if it doesn't exist
422+
self.adb_graph = self.db.create_graph(
423+
name=name,
424+
edge_definitions=edge_definitions,
425+
)
426+
427+
logger.info(f"Graph '{name}' created.")
428+
439429
self.__name = name
440-
self.__graph_exists_in_db = self.db.has_graph(name)
430+
self.__graph_exists_in_db = True
431+
self.edge_type_func = edge_type_func
432+
self.default_node_type = default_node_type
433+
434+
properties = self.adb_graph.properties()
435+
self.__is_smart: bool = properties.get("smart", False)
436+
self.__smart_field: str | None = properties.get("smart_field")
437+
438+
def _load_nx_graph(
439+
self, nx_graph: nx.Graph, write_batch_size: int, write_async: bool
440+
) -> None:
441+
controller = ADBNX_Controller
442+
443+
if all([self.is_smart, self.smart_field]):
444+
smart_field = self.__smart_field
441445

442-
logger.info(f"Graph '{name}' exists: {self.__graph_exists_in_db}")
446+
class SmartController(ADBNX_Controller):
447+
def _keyify_networkx_node(
448+
self, i: int, nx_node_id: NxId, nx_node: NxData, col: str
449+
) -> str:
450+
if smart_field not in nx_node:
451+
m = f"Node {nx_node_id} missing smart field '{smart_field}'" # noqa: E501
452+
raise KeyError(m)
453+
454+
return f"{nx_node[smart_field]}:{str(i)}"
455+
456+
def _prepare_networkx_edge(self, nx_edge: NxData, col: str) -> None:
457+
del nx_edge["_key"]
458+
459+
controller = SmartController
460+
logger.info(f"Using smart field '{smart_field}' for node keys")
461+
462+
ADBNX_Adapter(self.db, controller()).networkx_to_arangodb(
463+
self.adb_graph.name,
464+
nx_graph,
465+
batch_size=write_batch_size,
466+
use_async=write_async,
467+
)
443468

444469
###########
445470
# Getters #
@@ -479,6 +504,14 @@ def graph_exists_in_db(self) -> bool:
479504
def edge_attributes(self) -> set[str]:
480505
return self._edge_collections_attributes
481506

507+
@property
508+
def is_smart(self) -> bool:
509+
return self.__is_smart
510+
511+
@property
512+
def smart_field(self) -> str | None:
513+
return self.__smart_field
514+
482515
###########
483516
# Setters #
484517
###########
@@ -645,12 +678,24 @@ def add_node_override(self, node_for_adding, **attr):
645678
# attr_dict.update(attr)
646679

647680
# New:
648-
self._node[node_for_adding] = self.node_attr_dict_factory()
681+
node_attr_dict = self.node_attr_dict_factory()
682+
683+
if self.is_smart:
684+
if self.smart_field not in attr:
685+
m = f"Node {node_for_adding} missing smart field '{self.smart_field}'" # noqa: E501
686+
raise KeyError(m)
687+
688+
node_attr_dict.data[self.smart_field] = attr[self.smart_field]
689+
690+
self._node[node_for_adding] = node_attr_dict
649691
self._node[node_for_adding].update(attr)
650692

651693
# Reason:
652694
# Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set
653695
# i.e trying to update a node's attributes before we know _which_ node it is
696+
# Furthermore, support for ArangoDB Smart Graphs requires the smart field
697+
# to be set before adding the node to the graph. This is because the smart
698+
# field is used to generate the node's key.
654699

655700
###########################
656701

nx_arangodb/classes/multigraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def __init__(
215215
else:
216216
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)
217217

218+
self._loaded_incoming_graph_data = True
219+
218220
#######################
219221
# Init helper methods #
220222
#######################

0 commit comments

Comments
 (0)