Skip to content

Commit 75a3660

Browse files
authored
GA-163 | test_digraph (#40)
* GA-163 | initial commit will fail * unlock adbnx * fix: `incoming_graph_data` * fix: incoming_graph_data * fix: off-by-one IDs * checkpoint * checkpoint: `BaseGraphTester` is passing * checkpoint: BaseGraphAttrTester * cleanup: `aql_fetch_data`, `aql_fetch_data_edge` * use pytest skip for failing tests * checkpoint: optimize `__iter__` * checkpoint: run `test_graph` * add comment * checkpoint * attempt: slleep * fix: lint * cleanup: getitem * cleanup: copy * attempt: shorten sleep * fix: `__set_adj_elements` * fix: mypy * attempt: decrease sleep * GA-163 | `test_digraph` * checkpoint lots of failures... * fix: set `self.Graph` * add type ignore * fix: graph name * fix: graph name * adjust assertions to exclude _rev, set `use_experimental_views` * Revert "adjust assertions to exclude _rev, set `use_experimental_views`" This reverts commit b805419. * fix: `_rev`, `use_experimental_views` * set `use_experimental_views` * fix: lint * new: `nbunch_iter` override * set experimental views to false * set experimental views to false * cleanup * fix: `function.py` * cleanup: `graph`, `digraph` * fix: `test_data_input` * attempt: wait for CircleCI * fix: nx graph * remove sleep * new: `override` suffix * enable more tests * fix: lint
1 parent 428dba2 commit 75a3660

File tree

7 files changed

+691
-64
lines changed

7 files changed

+691
-64
lines changed

nx_arangodb/classes/dict/adj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1557,7 +1557,7 @@ def keys(self) -> Any:
15571557

15581558
@logger_debug
15591559
def clear(self) -> None:
1560-
"""g._node.clear()"""
1560+
"""g._adj.clear()"""
15611561
self.data.clear()
15621562
self.FETCHED_ALL_DATA = False
15631563
self.FETCHED_ALL_IDS = False

nx_arangodb/classes/digraph.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
import nx_arangodb as nxadb
77
from nx_arangodb.classes.graph import Graph
8+
from nx_arangodb.logger import logger
89

910
from .dict.adj import AdjListOuterDict
1011
from .enum import TraversalDirection
12+
from .function import get_node_id
1113

1214
networkx_api = nxadb.utils.decorators.networkx_class(nx.DiGraph) # type: ignore
1315

@@ -59,12 +61,9 @@ def __init__(
5961
)
6062

6163
if self.graph_exists_in_db:
62-
assert isinstance(self._succ, AdjListOuterDict)
63-
assert isinstance(self._pred, AdjListOuterDict)
64-
self._succ.mirror = self._pred
65-
self._pred.mirror = self._succ
66-
self._succ.traversal_direction = TraversalDirection.OUTBOUND
67-
self._pred.traversal_direction = TraversalDirection.INBOUND
64+
self.clear_edges = self.clear_edges_override
65+
self.add_node = self.add_node_override
66+
self.remove_node = self.remove_node_override
6867

6968
#######################
7069
# nx.DiGraph Overides #
@@ -80,7 +79,14 @@ def __init__(
8079
# def out_edges(self):
8180
# pass
8281

83-
def add_node(self, node_for_adding, **attr):
82+
def clear_edges_override(self):
83+
logger.info("Note that clearing edges ony erases the edges in the local cache")
84+
for predecessor_dict in self._pred.data.values():
85+
predecessor_dict.clear()
86+
87+
super().clear_edges()
88+
89+
def add_node_override(self, node_for_adding, **attr):
8490
if node_for_adding not in self._succ:
8591
if node_for_adding is None:
8692
raise ValueError("None cannot be a node")
@@ -111,7 +117,10 @@ def add_node(self, node_for_adding, **attr):
111117

112118
nx._clear_cache(self)
113119

114-
def remove_node(self, n):
120+
def remove_node_override(self, n):
121+
if isinstance(n, (str, int)):
122+
n = get_node_id(str(n), self.default_node_type)
123+
115124
try:
116125

117126
######################
@@ -138,6 +147,22 @@ def remove_node(self, n):
138147
del self._pred[u][n] # remove all edges n-u in digraph
139148
del self._succ[n] # remove node from succ
140149
for u in nbrs_pred:
150+
######################
151+
# NOTE: Monkey patch #
152+
######################
153+
154+
# Old: Nothing
155+
156+
# New:
157+
if u == n:
158+
continue # skip self loops
159+
160+
# Reason: We need to skip self loops, as they are
161+
# already taken care of in the previous step. This
162+
# avoids getting a KeyError on the next line.
163+
164+
###########################
165+
141166
del self._succ[u][n] # remove all edges n-u in digraph
142167
del self._pred[n] # remove node from pred
143168
nx._clear_cache(self)

nx_arangodb/classes/graph.py

Lines changed: 65 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
node_attr_dict_factory,
2828
node_dict_factory,
2929
)
30+
from .dict.adj import AdjListOuterDict
31+
from .enum import TraversalDirection
3032
from .function import get_node_id
3133
from .reportviews import CustomEdgeView, CustomNodeView
3234

@@ -96,6 +98,7 @@ def __init__(
9698
# m = "Must set **graph_name** if passing **incoming_graph_data**"
9799
# raise ValueError(m)
98100

101+
loaded_incoming_graph_data = False
99102
if self._graph_exists_in_db:
100103
if incoming_graph_data is not None:
101104
m = "Cannot pass both **incoming_graph_data** and **graph_name** yet if the already graph exists" # noqa: E501
@@ -170,29 +173,44 @@ def edge_type_func(u: str, v: str) -> str:
170173
use_async=write_async,
171174
)
172175

176+
loaded_incoming_graph_data = True
177+
173178
else:
174179
self.adb_graph = self.db.create_graph(
175180
self.__name,
176181
edge_definitions=edge_definitions,
177182
)
178183

179-
# Let the parent class handle the incoming graph data
180-
# if it is not a networkx.Graph object
181-
kwargs["incoming_graph_data"] = incoming_graph_data
182-
183184
self._set_factory_methods()
184185
self._set_arangodb_backend_config()
185186
logger.info(f"Graph '{name}' created.")
186187
self._graph_exists_in_db = True
187188

188-
else:
189-
kwargs["incoming_graph_data"] = incoming_graph_data
190-
191-
if name is not None:
192-
kwargs["name"] = name
189+
if self.__name is not None:
190+
kwargs["name"] = self.__name
193191

194192
super().__init__(*args, **kwargs)
195193

194+
if self.is_directed() and self.graph_exists_in_db:
195+
assert isinstance(self._succ, AdjListOuterDict)
196+
assert isinstance(self._pred, AdjListOuterDict)
197+
self._succ.mirror = self._pred
198+
self._pred.mirror = self._succ
199+
self._succ.traversal_direction = TraversalDirection.OUTBOUND
200+
self._pred.traversal_direction = TraversalDirection.INBOUND
201+
202+
if incoming_graph_data is not None and not loaded_incoming_graph_data:
203+
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)
204+
205+
if self.graph_exists_in_db:
206+
self.copy = self.copy_override
207+
self.subgraph = self.subgraph_override
208+
self.clear = self.clear_override
209+
self.clear_edges = self.clear_edges_override
210+
self.add_node = self.add_node_override
211+
self.number_of_edges = self.number_of_edges_override
212+
self.nbunch_iter = self.nbunch_iter_override
213+
196214
#######################
197215
# Init helper methods #
198216
#######################
@@ -345,6 +363,9 @@ def _set_graph_name(self, graph_name: str | None = None) -> None:
345363
# ArangoDB Methods #
346364
####################
347365

366+
def clear_nxcg_cache(self):
367+
self.nxcg_graph = None
368+
348369
def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor:
349370
return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)
350371

@@ -355,7 +376,7 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs
355376
# NOTE: OUT OF SERVICE
356377
# def chat(self, prompt: str) -> str:
357378
# if self.__qa_chain is None:
358-
# if not self.__graph_exists_in_db:
379+
# if not self.graph_exists_in_db:
359380
# return "Could not initialize QA chain: Graph does not exist"
360381

361382
# # try:
@@ -381,30 +402,6 @@ def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Curs
381402
# nx.Graph Overides #
382403
#####################
383404

384-
def copy(self, *args, **kwargs):
385-
logger.warning("Note that copying a graph loses the connection to the database")
386-
G = super().copy(*args, **kwargs)
387-
G.node_dict_factory = nx.Graph.node_dict_factory
388-
G.node_attr_dict_factory = nx.Graph.node_attr_dict_factory
389-
G.edge_attr_dict_factory = nx.Graph.edge_attr_dict_factory
390-
G.adjlist_inner_dict_factory = nx.Graph.adjlist_inner_dict_factory
391-
G.adjlist_outer_dict_factory = nx.Graph.adjlist_outer_dict_factory
392-
return G
393-
394-
def subgraph(self, nbunch):
395-
raise NotImplementedError("Subgraphing is not yet implemented")
396-
397-
def clear(self):
398-
logger.info("Note that clearing only erases the local cache")
399-
super().clear()
400-
401-
def clear_edges(self):
402-
logger.info("Note that clearing edges ony erases the edges in the local cache")
403-
super().clear_edges()
404-
405-
def clear_nxcg_cache(self):
406-
self.nxcg_graph = None
407-
408405
@cached_property
409406
def nodes(self):
410407
if self.__use_experimental_views and self.graph_exists_in_db:
@@ -437,7 +434,30 @@ def edges(self):
437434

438435
return super().edges
439436

440-
def add_node(self, node_for_adding, **attr):
437+
def copy_override(self, *args, **kwargs):
438+
logger.warning("Note that copying a graph loses the connection to the database")
439+
G = super().copy(*args, **kwargs)
440+
G.node_dict_factory = nx.Graph.node_dict_factory
441+
G.node_attr_dict_factory = nx.Graph.node_attr_dict_factory
442+
G.edge_attr_dict_factory = nx.Graph.edge_attr_dict_factory
443+
G.adjlist_inner_dict_factory = nx.Graph.adjlist_inner_dict_factory
444+
G.adjlist_outer_dict_factory = nx.Graph.adjlist_outer_dict_factory
445+
return G
446+
447+
def subgraph_override(self, nbunch):
448+
raise NotImplementedError("Subgraphing is not yet implemented")
449+
450+
def clear_override(self):
451+
logger.info("Note that clearing only erases the local cache")
452+
super().clear()
453+
454+
def clear_edges_override(self):
455+
logger.info("Note that clearing edges ony erases the edges in the local cache")
456+
for nbr_dict in self._adj.data.values():
457+
nbr_dict.clear()
458+
nx._clear_cache(self)
459+
460+
def add_node_override(self, node_for_adding, **attr):
441461
if node_for_adding not in self._node:
442462
if node_for_adding is None:
443463
raise ValueError("None cannot be a node")
@@ -467,10 +487,7 @@ def add_node(self, node_for_adding, **attr):
467487

468488
nx._clear_cache(self)
469489

470-
def number_of_edges(self, u=None, v=None):
471-
if not self.graph_exists_in_db:
472-
return super().number_of_edges(u, v)
473-
490+
def number_of_edges_override(self, u=None, v=None):
474491
if u is not None:
475492
return super().number_of_edges(u, v)
476493

@@ -494,10 +511,7 @@ def number_of_edges(self, u=None, v=None):
494511
# It is more efficient to count the number of edges in the edge collections
495512
# compared to relying on the DegreeView.
496513

497-
def nbunch_iter(self, nbunch=None):
498-
if not self._graph_exists_in_db:
499-
return super().nbunch_iter(nbunch)
500-
514+
def nbunch_iter_override(self, nbunch=None):
501515
if nbunch is None:
502516
bunch = iter(self._adj)
503517
elif nbunch in self:
@@ -508,12 +522,13 @@ def nbunch_iter(self, nbunch=None):
508522
# Old: Nothing
509523

510524
# New:
511-
if isinstance(nbunch, int):
525+
if isinstance(nbunch, (str, int)):
512526
nbunch = get_node_id(str(nbunch), self.default_node_type)
513527

514528
# Reason:
515529
# ArangoDB only uses strings as node IDs. Therefore, we need to convert
516-
# the integer node ID to a string before using it in an iterator.
530+
# the non-prefixed node ID to an ArangoDB ID before
531+
# using it in an iterator.
517532

518533
bunch = iter([nbunch])
519534
else:
@@ -528,13 +543,15 @@ def bunch_iter(nlist, adj):
528543
# Old: Nothing
529544

530545
# New:
531-
if isinstance(n, int):
546+
if isinstance(n, (str, int)):
532547
n = get_node_id(str(n), self.default_node_type)
533548

534549
# Reason:
535550
# ArangoDB only uses strings as node IDs. Therefore,
536-
# we need to convert the integer node ID to a
537-
# string before using it in an iterator.
551+
# we need to convert non-prefixed node IDs to an
552+
# ArangoDB ID before using it in an iterator.
553+
554+
######################
538555

539556
if n in adj:
540557
yield n

nx_arangodb/classes/multigraph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def __init__(
5858
**kwargs,
5959
)
6060

61+
if self.graph_exists_in_db:
62+
self.add_edge = self.add_edge_override
63+
6164
#######################
6265
# Init helper methods #
6366
#######################
@@ -76,10 +79,7 @@ def _set_factory_methods(self) -> None:
7679
# nx.MultiGraph Overides #
7780
##########################
7881

79-
def add_edge(self, u_for_edge, v_for_edge, key=None, **attr):
80-
if not self.graph_exists_in_db:
81-
return super().add_edge(u_for_edge, v_for_edge, key=key, **attr)
82-
82+
def add_edge_override(self, u_for_edge, v_for_edge, key=None, **attr):
8383
if key is not None:
8484
m = "ArangoDB MultiGraph does not support custom edge keys yet."
8585
logger.warning(m)

tests/test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from .conftest import Capturing, create_grid_graph, create_line_graph, db, run_gpu_tests
1919

2020
G_NX = nx.karate_club_graph()
21+
G_NX_digraph = nx.DiGraph(G_NX)
22+
G_NX_multigraph = nx.MultiGraph(G_NX)
23+
G_NX_multidigraph = nx.MultiDiGraph(G_NX)
2124

2225

2326
def assert_remote_dict(G: nxadb.Graph) -> None:
@@ -1834,3 +1837,58 @@ def test_incoming_graph_data_not_nx_graph(
18341837
)
18351838
assert has_club == ("club" in G.nodes["0"])
18361839
assert has_weight == ("weight" in G.adj["0"]["1"])
1840+
1841+
1842+
@pytest.mark.parametrize(
1843+
"data_type, incoming_graph_data, has_club, has_weight",
1844+
[
1845+
("dict of dicts", G_NX_digraph._adj, False, True),
1846+
(
1847+
"dict of lists",
1848+
{k: list(v) for k, v in G_NX_digraph._adj.items()},
1849+
False,
1850+
False,
1851+
),
1852+
("container of edges", list(G_NX_digraph.edges), False, False),
1853+
("iterator of edges", iter(G_NX_digraph.edges), False, False),
1854+
("generator of edges", (e for e in G_NX_digraph.edges), False, False),
1855+
("2D numpy array", nx.to_numpy_array(G_NX_digraph), False, True),
1856+
(
1857+
"scipy sparse array",
1858+
nx.to_scipy_sparse_array(G_NX_digraph),
1859+
False,
1860+
True,
1861+
),
1862+
("Pandas EdgeList", nx.to_pandas_edgelist(G_NX_digraph), False, True),
1863+
("Pandas Adjacency", nx.to_pandas_adjacency(G_NX_digraph), False, True),
1864+
],
1865+
)
1866+
def test_incoming_graph_data_not_nx_graph_digraph(
1867+
data_type: str, incoming_graph_data: Any, has_club: bool, has_weight: bool
1868+
) -> None:
1869+
# See nx.convert.to_networkx_graph for the official supported types
1870+
name = "KarateGraph"
1871+
db.delete_graph(name, drop_collections=True, ignore_missing=True)
1872+
1873+
G = nxadb.DiGraph(incoming_graph_data=incoming_graph_data, name=name)
1874+
1875+
assert (
1876+
len(G.adj)
1877+
== len(G_NX_digraph.adj)
1878+
== db.collection(G.default_node_type).count()
1879+
)
1880+
assert (
1881+
len(G.nodes)
1882+
== len(G_NX_digraph.nodes)
1883+
== db.collection(G.default_node_type).count()
1884+
== G.number_of_nodes()
1885+
)
1886+
edge_col = G.edge_type_func(G.default_node_type, G.default_node_type)
1887+
assert (
1888+
len(G.edges)
1889+
== len(G_NX_digraph.edges)
1890+
== db.collection(edge_col).count()
1891+
== G.number_of_edges()
1892+
)
1893+
assert has_club == ("club" in G.nodes["0"])
1894+
assert has_weight == ("weight" in G.adj["0"]["1"])

0 commit comments

Comments
 (0)