Skip to content

Commit eb9c23f

Browse files
author
Dylan Storey
committed
Fix multiple Cypher parser and executor bugs
- Fix keys() and properties() returning empty results by changing EXISTS (... UNION ALL ...) to EXISTS (...) OR EXISTS (...) - Allow 'end' as identifier in Cypher queries (variable names, property access) - only reserved in CASE...END context - Fix list functions (range, tail, split) to return proper arrays instead of expanding into multiple rows - Fix UNWIND with list literals and expressions - Fix WITH clause variable scoping and expression handling - Update Rust and Python bindings to handle new array return format with extract_algo_array() helpers Includes regression tests for all fixes.
1 parent 0036e58 commit eb9c23f

27 files changed

+920
-179
lines changed

bindings/python/src/graphqlite/algorithms/_parsing.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,39 @@
11
"""Shared parsing helpers for algorithm results."""
22

3-
from typing import Any, Optional
3+
from typing import Any, List, Optional
4+
5+
6+
# Known column names for graph algorithm results
7+
ALGO_COLUMN_NAMES = [
8+
"column_0", "wcc()", "scc()", "pagerank()", "degree_centrality()",
9+
"betweenness_centrality()", "closeness_centrality()", "eigenvector_centrality()",
10+
"labelPropagation()", "louvain()"
11+
]
12+
13+
14+
def extract_algo_array(result: List[dict]) -> List[dict]:
15+
"""Extract wrapped array results from graph algorithms.
16+
17+
Graph algorithms return results in one of two formats:
18+
1. Old format: Multiple rows with fields directly accessible
19+
2. New format: Single row with a column containing an array of objects
20+
21+
This function detects the new format and extracts the array elements.
22+
"""
23+
# If multiple rows, assume old format - return as-is
24+
if len(result) != 1:
25+
return result
26+
27+
# Single row - check if it has an array column
28+
row = result[0]
29+
30+
# Try common column names for wrapped array results
31+
for col_name in ALGO_COLUMN_NAMES:
32+
if col_name in row and isinstance(row[col_name], list):
33+
return row[col_name]
34+
35+
# No array column found, return original result
36+
return result
437

538

639
def parse_score_result(row: dict, score_key: str = "score") -> Optional[dict]:

bindings/python/src/graphqlite/algorithms/centrality.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any
44

55
from ..graph._base import BaseMixin
6-
from ._parsing import safe_float, safe_int
6+
from ._parsing import extract_algo_array, safe_float, safe_int
77

88

99
class CentralityMixin(BaseMixin):
@@ -28,9 +28,10 @@ def pagerank(
2828
result = self._conn.cypher(
2929
f"RETURN pageRank({damping}, {iterations})"
3030
)
31+
rows = extract_algo_array(result)
3132

3233
ranks = []
33-
for row in result:
34+
for row in rows:
3435
node_id = row.get("node_id")
3536
user_id = row.get("user_id")
3637
score = row.get("score")
@@ -54,9 +55,10 @@ def degree_centrality(self) -> list[dict]:
5455
'out_degree', 'degree'
5556
"""
5657
result = self._conn.cypher("RETURN degreeCentrality()")
58+
rows = extract_algo_array(result)
5759

5860
degrees = []
59-
for row in result:
61+
for row in rows:
6062
node_id = row.get("node_id")
6163
user_id = row.get("user_id")
6264
in_degree = row.get("in_degree")
@@ -86,9 +88,10 @@ def betweenness_centrality(self) -> list[dict]:
8688
where score is the betweenness centrality value
8789
"""
8890
result = self._conn.cypher("RETURN betweennessCentrality()")
91+
rows = extract_algo_array(result)
8992

9093
scores = []
91-
for row in result:
94+
for row in rows:
9295
node_id = row.get("node_id")
9396
user_id = row.get("user_id")
9497
score = row.get("score")
@@ -118,9 +121,10 @@ def closeness_centrality(self) -> list[dict]:
118121
where score is the closeness centrality value (0 to 1)
119122
"""
120123
result = self._conn.cypher("RETURN closenessCentrality()")
124+
rows = extract_algo_array(result)
121125

122126
scores = []
123-
for row in result:
127+
for row in rows:
124128
node_id = row.get("node_id")
125129
user_id = row.get("user_id")
126130
score = row.get("score")
@@ -156,9 +160,10 @@ def eigenvector_centrality(self, iterations: int = 100) -> list[dict]:
156160
"""
157161
query = f"RETURN eigenvectorCentrality({iterations})"
158162
result = self._conn.cypher(query)
163+
rows = extract_algo_array(result)
159164

160165
scores = []
161-
for row in result:
166+
for row in rows:
162167
node_id = row.get("node_id")
163168
user_id = row.get("user_id")
164169
score = row.get("score")

bindings/python/src/graphqlite/algorithms/community.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional
44

55
from ..graph._base import BaseMixin
6-
from ._parsing import safe_int
6+
from ._parsing import extract_algo_array, safe_int
77

88

99
class CommunityMixin(BaseMixin):
@@ -20,9 +20,10 @@ def community_detection(self, iterations: int = 10) -> list[dict]:
2020
List of dicts with 'node_id', 'user_id', 'community'
2121
"""
2222
result = self._conn.cypher(f"RETURN labelPropagation({iterations})")
23+
rows = extract_algo_array(result)
2324

2425
communities = []
25-
for row in result:
26+
for row in rows:
2627
node_id = row.get("node_id")
2728
user_id = row.get("user_id")
2829
community = row.get("community")
@@ -50,9 +51,10 @@ def louvain(self, resolution: float = 1.0) -> list[dict]:
5051
List of dicts with 'node_id', 'user_id', 'community'
5152
"""
5253
result = self._conn.cypher(f"RETURN louvain({resolution})")
54+
rows = extract_algo_array(result)
5355

5456
communities = []
55-
for row in result:
57+
for row in rows:
5658
node_id = row.get("node_id")
5759
user_id = row.get("user_id")
5860
community = row.get("community")

bindings/python/src/graphqlite/algorithms/components.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Connected components algorithms mixin."""
22

33
from ..graph._base import BaseMixin
4-
from ._parsing import safe_int
4+
from ._parsing import extract_algo_array, safe_int
55

66

77
class ComponentsMixin(BaseMixin):
@@ -19,9 +19,10 @@ def weakly_connected_components(self) -> list[dict]:
1919
where nodes in the same component share the same component number
2020
"""
2121
result = self._conn.cypher("RETURN wcc()")
22+
rows = extract_algo_array(result)
2223

2324
components = []
24-
for row in result:
25+
for row in rows:
2526
node_id = row.get("node_id")
2627
user_id = row.get("user_id")
2728
component = row.get("component")
@@ -52,9 +53,10 @@ def strongly_connected_components(self) -> list[dict]:
5253
where nodes in the same SCC share the same component number
5354
"""
5455
result = self._conn.cypher("RETURN scc()")
56+
rows = extract_algo_array(result)
5557

5658
components = []
57-
for row in result:
59+
for row in rows:
5860
node_id = row.get("node_id")
5961
user_id = row.get("user_id")
6062
component = row.get("component")

bindings/python/tests/test_connection.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -769,9 +769,11 @@ def test_tail_function(db):
769769
"""Test tail() function."""
770770
results = db.cypher("RETURN tail([1, 2, 3]) AS result")
771771
assert len(results) == 1
772-
# tail returns [2, 3] as JSON
773-
import json
774-
tail = json.loads(results[0]["result"])
772+
# tail returns [2, 3] as a list (may be native list or JSON string)
773+
tail = results[0]["result"]
774+
if isinstance(tail, str):
775+
import json
776+
tail = json.loads(tail)
775777
assert tail == [2, 3]
776778

777779

@@ -786,8 +788,11 @@ def test_range_function(db):
786788
"""Test range() function."""
787789
results = db.cypher("RETURN range(1, 5) AS result")
788790
assert len(results) == 1
789-
import json
790-
r = json.loads(results[0]["result"])
791+
# range returns a list (may be native list or JSON string)
792+
r = results[0]["result"]
793+
if isinstance(r, str):
794+
import json
795+
r = json.loads(r)
791796
assert r == [1, 2, 3, 4, 5]
792797

793798

@@ -1078,9 +1083,11 @@ def test_path_nodes(db):
10781083
""")
10791084
assert len(results) == 1
10801085
# Column alias may not be applied, check both possibilities
1081-
import json
10821086
col_name = "path_nodes" if "path_nodes" in results[0] else "result"
1083-
nodes = json.loads(results[0][col_name])
1087+
nodes = results[0][col_name]
1088+
if isinstance(nodes, str):
1089+
import json
1090+
nodes = json.loads(nodes)
10841091
assert len(nodes) == 2
10851092

10861093

@@ -1093,9 +1100,11 @@ def test_path_relationships(db):
10931100
""")
10941101
assert len(results) == 1
10951102
# Column alias may not be applied, check both possibilities
1096-
import json
10971103
col_name = "rels" if "rels" in results[0] else "result"
1098-
rels = json.loads(results[0][col_name])
1104+
rels = results[0][col_name]
1105+
if isinstance(rels, str):
1106+
import json
1107+
rels = json.loads(rels)
10991108
assert len(rels) == 1
11001109

11011110

@@ -1171,10 +1180,12 @@ def test_labels_function(db):
11711180
db.cypher("CREATE (n:Person:Employee:Manager {name: 'Alice'})")
11721181
results = db.cypher("MATCH (n:Person {name: 'Alice'}) RETURN labels(n) AS lbls")
11731182
assert len(results) == 1
1174-
import json
11751183
# Column alias may not be applied, check both possibilities
11761184
col_name = "lbls" if "lbls" in results[0] else "result"
1177-
labels = json.loads(results[0][col_name])
1185+
labels = results[0][col_name]
1186+
if isinstance(labels, str):
1187+
import json
1188+
labels = json.loads(labels)
11781189
assert "Person" in labels
11791190
assert "Employee" in labels
11801191
assert "Manager" in labels

bindings/rust/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

bindings/rust/src/algorithms/centrality.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::Result;
55
use super::{
66
PageRankResult, DegreeCentralityResult, BetweennessCentralityResult,
77
ClosenessCentralityResult, EigenvectorCentralityResult,
8-
parsing::{extract_node_id, extract_user_id, extract_float, extract_int},
8+
parsing::{extract_algo_array, extract_node_id, extract_user_id, extract_float, extract_int},
99
};
1010

1111
impl Graph {
@@ -18,9 +18,10 @@ impl Graph {
1818
pub fn pagerank(&self, damping: f64, iterations: i32) -> Result<Vec<PageRankResult>> {
1919
let query = format!("RETURN pageRank({}, {})", damping, iterations);
2020
let result = self.connection().cypher(&query)?;
21+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
2122

2223
let mut ranks = Vec::new();
23-
for row in result.iter() {
24+
for row in rows.iter() {
2425
if let Some(node_id) = extract_node_id(row) {
2526
ranks.push(PageRankResult {
2627
node_id,
@@ -35,9 +36,10 @@ impl Graph {
3536
/// Calculate degree centrality for all nodes.
3637
pub fn degree_centrality(&self) -> Result<Vec<DegreeCentralityResult>> {
3738
let result = self.connection().cypher("RETURN degreeCentrality()")?;
39+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
3840

3941
let mut degrees = Vec::new();
40-
for row in result.iter() {
42+
for row in rows.iter() {
4143
if let Some(node_id) = extract_node_id(row) {
4244
degrees.push(DegreeCentralityResult {
4345
node_id,
@@ -54,9 +56,10 @@ impl Graph {
5456
/// Calculate betweenness centrality for all nodes.
5557
pub fn betweenness_centrality(&self) -> Result<Vec<BetweennessCentralityResult>> {
5658
let result = self.connection().cypher("RETURN betweennessCentrality()")?;
59+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
5760

5861
let mut scores = Vec::new();
59-
for row in result.iter() {
62+
for row in rows.iter() {
6063
if let Some(node_id) = extract_node_id(row) {
6164
scores.push(BetweennessCentralityResult {
6265
node_id,
@@ -71,9 +74,10 @@ impl Graph {
7174
/// Calculate closeness centrality for all nodes.
7275
pub fn closeness_centrality(&self) -> Result<Vec<ClosenessCentralityResult>> {
7376
let result = self.connection().cypher("RETURN closenessCentrality()")?;
77+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
7478

7579
let mut scores = Vec::new();
76-
for row in result.iter() {
80+
for row in rows.iter() {
7781
if let Some(node_id) = extract_node_id(row) {
7882
scores.push(ClosenessCentralityResult {
7983
node_id,
@@ -93,9 +97,10 @@ impl Graph {
9397
pub fn eigenvector_centrality(&self, iterations: i32) -> Result<Vec<EigenvectorCentralityResult>> {
9498
let query = format!("RETURN eigenvectorCentrality({})", iterations);
9599
let result = self.connection().cypher(&query)?;
100+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
96101

97102
let mut scores = Vec::new();
98-
for row in result.iter() {
103+
for row in rows.iter() {
99104
if let Some(node_id) = extract_node_id(row) {
100105
scores.push(EigenvectorCentralityResult {
101106
node_id,

bindings/rust/src/algorithms/community.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::graph::Graph;
44
use crate::Result;
55
use super::{
66
CommunityResult,
7-
parsing::{extract_node_id, extract_user_id, extract_int},
7+
parsing::{extract_algo_array, extract_node_id, extract_user_id, extract_int},
88
};
99

1010
impl Graph {
@@ -16,9 +16,10 @@ impl Graph {
1616
pub fn community_detection(&self, iterations: i32) -> Result<Vec<CommunityResult>> {
1717
let query = format!("RETURN labelPropagation({})", iterations);
1818
let result = self.connection().cypher(&query)?;
19+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
1920

2021
let mut communities = Vec::new();
21-
for row in result.iter() {
22+
for row in rows.iter() {
2223
if let Some(node_id) = extract_node_id(row) {
2324
communities.push(CommunityResult {
2425
node_id,
@@ -38,9 +39,10 @@ impl Graph {
3839
pub fn louvain(&self, resolution: f64) -> Result<Vec<CommunityResult>> {
3940
let query = format!("RETURN louvain({})", resolution);
4041
let result = self.connection().cypher(&query)?;
42+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
4143

4244
let mut communities = Vec::new();
43-
for row in result.iter() {
45+
for row in rows.iter() {
4446
if let Some(node_id) = extract_node_id(row) {
4547
communities.push(CommunityResult {
4648
node_id,

bindings/rust/src/algorithms/components.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::graph::Graph;
44
use crate::Result;
55
use super::{
66
ComponentResult,
7-
parsing::{extract_node_id, extract_user_id, extract_int},
7+
parsing::{extract_algo_array, extract_node_id, extract_user_id, extract_int},
88
};
99

1010
impl Graph {
@@ -13,9 +13,10 @@ impl Graph {
1313
/// Treats the graph as undirected and finds connected components.
1414
pub fn wcc(&self) -> Result<Vec<ComponentResult>> {
1515
let result = self.connection().cypher("RETURN wcc()")?;
16+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
1617

1718
let mut components = Vec::new();
18-
for row in result.iter() {
19+
for row in rows.iter() {
1920
if let Some(node_id) = extract_node_id(row) {
2021
components.push(ComponentResult {
2122
node_id,
@@ -33,9 +34,10 @@ impl Graph {
3334
/// other node following edge directions. Uses Tarjan's algorithm.
3435
pub fn scc(&self) -> Result<Vec<ComponentResult>> {
3536
let result = self.connection().cypher("RETURN scc()")?;
37+
let rows = extract_algo_array(result.iter().collect::<Vec<_>>().as_slice());
3638

3739
let mut components = Vec::new();
38-
for row in result.iter() {
40+
for row in rows.iter() {
3941
if let Some(node_id) = extract_node_id(row) {
4042
components.push(ComponentResult {
4143
node_id,

0 commit comments

Comments
 (0)