Skip to content

Commit 714e959

Browse files
fix: add robust handling for getting nodes or edges in kuzudb (#124)
* fix: add robust handling for getting nodes or edges in kuzudb * fix: add safe_json_loads
1 parent c69483b commit 714e959

File tree

1 file changed

+63
-26
lines changed

1 file changed

+63
-26
lines changed

graphgen/models/storage/graph/kuzu_storage.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import os
3-
import shutil
43
from dataclasses import dataclass
54
from typing import Any
65

@@ -69,6 +68,16 @@ def _init_schema(self):
6968
def index_done_callback(self):
7069
"""KuzuDB is ACID, changes are immediate, but we can verify generic persistence here."""
7170

71+
@staticmethod
72+
def _safe_json_loads(data_str: str) -> dict:
73+
if not isinstance(data_str, str) or not data_str.strip():
74+
return {}
75+
try:
76+
return json.loads(data_str)
77+
except json.JSONDecodeError as e:
78+
print(f"Error decoding JSON: {e}")
79+
return {}
80+
7281
def has_node(self, node_id: str) -> bool:
7382
result = self._conn.execute(
7483
"MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
@@ -111,10 +120,11 @@ def get_node(self, node_id: str) -> Any:
111120
result = self._conn.execute(
112121
"MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id}
113122
)
114-
if result.has_next():
115-
data_str = result.get_next()[0]
116-
return json.loads(data_str) if data_str else {}
117-
return None
123+
if not result.has_next():
124+
return None
125+
126+
data_str = result.get_next()[0]
127+
return self._safe_json_loads(data_str)
118128

119129
def update_node(self, node_id: str, node_data: dict[str, str]):
120130
current_data = self.get_node(node_id)
@@ -124,7 +134,11 @@ def update_node(self, node_id: str, node_data: dict[str, str]):
124134

125135
# Merge existing data with new data
126136
current_data.update(node_data)
127-
json_data = json.dumps(current_data, ensure_ascii=False)
137+
try:
138+
json_data = json.dumps(current_data, ensure_ascii=False)
139+
except (TypeError, ValueError) as e:
140+
print(f"Error serializing JSON for node {node_id}: {e}")
141+
return
128142

129143
self._conn.execute(
130144
"MATCH (a:Entity {id: $id}) SET a.data = $data",
@@ -137,7 +151,11 @@ def get_all_nodes(self) -> Any:
137151
nodes = []
138152
while result.has_next():
139153
row = result.get_next()
140-
nodes.append((row[0], json.loads(row[1])))
154+
if row is None or len(row) < 2:
155+
continue
156+
node_id, data_str = row[0], row[1]
157+
data = self._safe_json_loads(data_str)
158+
nodes.append((node_id, data))
141159
return nodes
142160

143161
def get_edge(self, source_node_id: str, target_node_id: str):
@@ -149,10 +167,11 @@ def get_edge(self, source_node_id: str, target_node_id: str):
149167
result = self._conn.execute(
150168
query, {"src": source_node_id, "dst": target_node_id}
151169
)
152-
if result.has_next():
153-
data_str = result.get_next()[0]
154-
return json.loads(data_str) if data_str else {}
155-
return None
170+
if not result.has_next():
171+
return None
172+
173+
data_str = result.get_next()[0]
174+
return self._safe_json_loads(data_str)
156175

157176
def update_edge(
158177
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
@@ -163,14 +182,20 @@ def update_edge(
163182
return
164183

165184
current_data.update(edge_data)
166-
json_data = json.dumps(current_data, ensure_ascii=False)
185+
try:
186+
json_data = json.dumps(current_data, ensure_ascii=False)
187+
except (TypeError, ValueError) as e:
188+
print(
189+
f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}"
190+
)
191+
return
167192

168-
query = """
193+
self._conn.execute(
194+
"""
169195
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
170196
SET e.data = $data
171-
"""
172-
self._conn.execute(
173-
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
197+
""",
198+
{"src": source_node_id, "dst": target_node_id, "data": json_data},
174199
)
175200

176201
def get_all_edges(self) -> Any:
@@ -180,7 +205,11 @@ def get_all_edges(self) -> Any:
180205
edges = []
181206
while result.has_next():
182207
row = result.get_next()
183-
edges.append((row[0], row[1], json.loads(row[2])))
208+
if row is None or len(row) < 3:
209+
continue
210+
src, dst, data_str = row[0], row[1], row[2]
211+
data = self._safe_json_loads(data_str)
212+
edges.append((src, dst, data))
184213
return edges
185214

186215
def get_node_edges(self, source_node_id: str) -> Any:
@@ -193,15 +222,23 @@ def get_node_edges(self, source_node_id: str) -> Any:
193222
edges = []
194223
while result.has_next():
195224
row = result.get_next()
196-
edges.append((row[0], row[1], json.loads(row[2])))
225+
if row is None or len(row) < 3:
226+
continue
227+
src, dst, data_str = row[0], row[1], row[2]
228+
data = self._safe_json_loads(data_str)
229+
edges.append((src, dst, data))
197230
return edges
198231

199232
def upsert_node(self, node_id: str, node_data: dict[str, str]):
200233
"""
201234
Insert or Update node.
202235
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
203236
"""
204-
json_data = json.dumps(node_data, ensure_ascii=False)
237+
try:
238+
json_data = json.dumps(node_data, ensure_ascii=False)
239+
except (TypeError, ValueError) as e:
240+
print(f"Error serializing JSON for node {node_id}: {e}")
241+
return
205242
query = """
206243
MERGE (a:Entity {id: $id})
207244
ON MATCH SET a.data = $data
@@ -224,7 +261,13 @@ def upsert_edge(
224261
if not self.has_node(target_node_id):
225262
self.upsert_node(target_node_id, {})
226263

227-
json_data = json.dumps(edge_data, ensure_ascii=False)
264+
try:
265+
json_data = json.dumps(edge_data, ensure_ascii=False)
266+
except (TypeError, ValueError) as e:
267+
print(
268+
f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}"
269+
)
270+
return
228271
query = """
229272
MATCH (a:Entity {id: $src}), (b:Entity {id: $dst})
230273
MERGE (a)-[e:Relation]->(b)
@@ -248,9 +291,3 @@ def clear(self):
248291

249292
def reload(self):
250293
"""For databases that need reloading, KuzuDB auto-manages this."""
251-
252-
def drop(self):
253-
"""Completely remove the database folder."""
254-
if self.db_path and os.path.exists(self.db_path):
255-
shutil.rmtree(self.db_path)
256-
print(f"Dropped KuzuDB at {self.db_path}")

0 commit comments

Comments
 (0)