Skip to content

Commit 514f69b

Browse files
authored
cleanup function.py (#47)
* cleanup `function.py` * fix: typo, set `write_async` to False
1 parent 1bc28a4 commit 514f69b

File tree

3 files changed

+52
-66
lines changed

3 files changed

+52
-66
lines changed

nx_arangodb/classes/dict/adj.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,15 @@
3939
edge_link,
4040
get_arangodb_graph,
4141
get_node_id,
42+
get_node_type,
4243
get_node_type_and_id,
4344
get_update_dict,
44-
is_arangodb_id,
4545
json_serializable,
4646
key_is_adb_id_or_int,
4747
key_is_not_reserved,
4848
key_is_string,
4949
keys_are_not_reserved,
5050
keys_are_strings,
51-
read_collection_name_from_local_id,
5251
separate_edges_by_collections,
5352
upsert_collection_edges,
5453
)
@@ -1180,11 +1179,10 @@ def copy(self) -> Any:
11801179
return {key: value.copy() for key, value in self.data.items()}
11811180

11821181
@keys_are_strings
1183-
def update(self, edges: Any) -> None:
1182+
def update(self, edges: dict[str, dict[str, Any]]) -> None:
11841183
"""g._adj['node/1'].update({'node/2': {'foo': 'bar'}})"""
1185-
from_col_name = read_collection_name_from_local_id(
1186-
self.src_node_id, self.default_node_type
1187-
)
1184+
assert self.src_node_id
1185+
from_col_name = get_node_type(self.src_node_id, self.default_node_type)
11881186

11891187
to_upsert: Dict[str, List[Dict[str, Any]]] = {from_col_name: []}
11901188

@@ -1194,10 +1192,10 @@ def update(self, edges: Any) -> None:
11941192
edge_doc["_to"] = edge_id
11951193

11961194
edge_doc_id = edge_data.get("_id")
1197-
assert is_arangodb_id(edge_doc_id)
1198-
edge_col_name = read_collection_name_from_local_id(
1199-
edge_doc_id, self.default_node_type
1200-
)
1195+
if not edge_doc_id:
1196+
raise ValueError("Edge _id field is required for update.")
1197+
1198+
edge_col_name = get_node_type(edge_doc_id, self.default_node_type)
12011199

12021200
if to_upsert.get(edge_col_name) is None:
12031201
to_upsert[edge_col_name] = [edge_doc]

nx_arangodb/classes/function.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -637,17 +637,36 @@ def edge_link(
637637
return edge
638638

639639

640+
def is_arangodb_id(key):
641+
return "/" in key
642+
643+
644+
def get_node_type(key: str, default_node_type: str) -> str:
645+
"""Gets the node type."""
646+
return key.split("/")[0] if is_arangodb_id(key) else default_node_type
647+
648+
640649
def get_node_id(key: str, default_node_type: str) -> str:
641650
"""Gets the node ID."""
642-
return key if "/" in key else f"{default_node_type}/{key}"
651+
return key if is_arangodb_id(key) else f"{default_node_type}/{key}"
643652

644653

645654
def get_node_type_and_id(key: str, default_node_type: str) -> tuple[str, str]:
646655
"""Gets the node type and ID."""
647-
if "/" in key:
648-
return key.split("/")[0], key
656+
return (
657+
(key.split("/")[0], key)
658+
if is_arangodb_id(key)
659+
else (default_node_type, f"{default_node_type}/{key}")
660+
)
661+
662+
663+
def get_node_type_and_key(key: str, default_node_type: str) -> tuple[str, str]:
664+
"""Gets the node type and key."""
665+
if is_arangodb_id(key):
666+
col, key = key.split("/", 1)
667+
return col, key
649668

650-
return default_node_type, f"{default_node_type}/{key}"
669+
return default_node_type, key
651670

652671

653672
def get_update_dict(
@@ -683,38 +702,9 @@ def check_list_for_errors(lst):
683702
return True
684703

685704

686-
def is_arangodb_id(key):
687-
return "/" in key
688-
689-
690-
def get_arangodb_collection_key_tuple(key):
691-
if not is_arangodb_id(key):
692-
raise ValueError(f"Invalid ArangoDB key: {key}")
693-
return key.split("/", 1)
694-
695-
696-
def extract_arangodb_collection_name(arangodb_id: str) -> str:
697-
if not is_arangodb_id(arangodb_id):
698-
raise ValueError(f"Invalid ArangoDB key: {arangodb_id}")
699-
return arangodb_id.split("/")[0]
700-
701-
702-
def read_collection_name_from_local_id(
703-
local_id: Optional[str], default_collection: str
704-
) -> str:
705-
if local_id is None:
706-
print("local_id is None, cannot read collection name.")
707-
return ""
708-
709-
if is_arangodb_id(local_id):
710-
return extract_arangodb_collection_name(local_id)
711-
712-
assert default_collection is not None
713-
assert default_collection != ""
714-
return default_collection
715-
716-
717-
def separate_nodes_by_collections(nodes: Any, default_collection: str) -> Any:
705+
def separate_nodes_by_collections(
706+
nodes: dict[str, Any], default_collection: str
707+
) -> Any:
718708
"""
719709
Separate the dictionary into collections based on whether keys contain '/'.
720710
:param nodes:
@@ -728,15 +718,12 @@ def separate_nodes_by_collections(nodes: Any, default_collection: str) -> Any:
728718
separated: Any = {}
729719

730720
for key, value in nodes.items():
731-
if is_arangodb_id(key):
732-
collection, doc_key = get_arangodb_collection_key_tuple(key)
733-
if collection not in separated:
734-
separated[collection] = {}
735-
separated[collection][doc_key] = value
736-
else:
737-
if default_collection not in separated:
738-
separated[default_collection] = {}
739-
separated[default_collection][key] = value
721+
collection, doc_key = get_node_type_and_key(key, default_collection)
722+
723+
if collection not in separated:
724+
separated[collection] = {}
725+
726+
separated[collection][doc_key] = value
740727

741728
return separated
742729

tests/test.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def test_load_graph_from_nxadb():
113113
name=graph_name,
114114
incoming_graph_data=G_NX,
115115
default_node_type="person",
116+
write_async=False,
116117
)
117118

118119
assert db.has_graph(graph_name)
@@ -134,6 +135,7 @@ def test_load_graph_from_nxadb_w_specific_edge_attribute():
134135
incoming_graph_data=G_NX,
135136
default_node_type="person",
136137
edge_collections_attributes={"weight"},
138+
write_async=False,
137139
)
138140
# TODO: re-enable this line as soon as CPU based data caching is implemented
139141
# graph._adj._fetch_all()
@@ -163,6 +165,7 @@ def test_load_graph_from_nxadb_w_not_available_edge_attribute():
163165
default_node_type="person",
164166
# This will lead to weight not being loaded into the edge data
165167
edge_collections_attributes={"_id"},
168+
write_async=False,
166169
)
167170

168171
# Should just succeed without any errors (fallback to weight: 1 as above)
@@ -1592,15 +1595,9 @@ def test_graph_dict_clear_will_not_remove_remote_data(load_karate_graph: Any) ->
15921595

15931596

15941597
def test_graph_dict_set_item(load_karate_graph: Any) -> None:
1595-
try:
1596-
db.collection("nxadb_graphs").delete("KarateGraph")
1597-
except DocumentDeleteError:
1598-
pass
1599-
except Exception as e:
1600-
print(f"An unexpected error occurred: {e}")
1601-
raise
1602-
1603-
G = nxadb.Graph(name="KarateGraph", default_node_type="person")
1598+
name = "KarateGraph"
1599+
db.collection("nxadb_graphs").delete(name, ignore_missing=True)
1600+
G = nxadb.Graph(name=name, default_node_type="person")
16041601

16051602
json_values = [
16061603
"aString",
@@ -1819,7 +1816,9 @@ def test_incoming_graph_data_not_nx_graph(
18191816
name = "KarateGraph"
18201817
db.delete_graph(name, drop_collections=True, ignore_missing=True)
18211818

1822-
G = nxadb.Graph(incoming_graph_data=incoming_graph_data, name=name)
1819+
G = nxadb.Graph(
1820+
incoming_graph_data=incoming_graph_data, name=name, write_async=False
1821+
)
18231822

18241823
assert len(G.adj) == len(G_NX.adj) == db.collection(G.default_node_type).count()
18251824
assert (
@@ -1870,7 +1869,9 @@ def test_incoming_graph_data_not_nx_graph_digraph(
18701869
name = "KarateGraph"
18711870
db.delete_graph(name, drop_collections=True, ignore_missing=True)
18721871

1873-
G = nxadb.DiGraph(incoming_graph_data=incoming_graph_data, name=name)
1872+
G = nxadb.DiGraph(
1873+
incoming_graph_data=incoming_graph_data, name=name, write_async=False
1874+
)
18741875

18751876
assert (
18761877
len(G.adj)

0 commit comments

Comments
 (0)