Skip to content

Commit db1693a

Browse files
community: fix issue #29429 in age_graph.py (#29506)
## Description: This PR addresses issue #29429 by fixing the _wrap_query method in langchain_community/graphs/age_graph.py. The method now correctly handles Cypher queries with UNION and EXCEPT operators, ensuring that the fields in the SQL query are ordered as they appear in the Cypher query. Additionally, the method now properly handles cases where RETURN * is not supported. ### Issue: #29429 ### Dependencies: None ### Add tests and docs: Added unit tests in tests/unit_tests/graphs/test_age_graph.py to validate the changes. No new integrations were added, so no example notebook is necessary. Lint and test: Ran make format, make lint, and make test to ensure code quality and functionality.
1 parent 2f97916 commit db1693a

File tree

2 files changed

+221
-64
lines changed

2 files changed

+221
-64
lines changed

libs/community/langchain_community/graphs/age_graph.py

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -473,71 +473,78 @@ def _get_col_name(field: str, idx: int) -> str:
473473
@staticmethod
474474
def _wrap_query(query: str, graph_name: str) -> str:
475475
"""
476-
Convert a cypher query to an Apache Age compatible
477-
sql query by wrapping the cypher query in ag_catalog.cypher,
478-
casting results to agtype and building a select statement
476+
Convert a Cyper query to an Apache Age compatible Sql Query.
477+
Handles combined queries with UNION/EXCEPT operators
479478
480479
Args:
481-
query (str): a valid cypher query
482-
graph_name (str): the name of the graph to query
480+
query (str) : A valid cypher query, can include UNION/EXCEPT operators
481+
graph_name (str) : The name of the graph to query
483482
484-
Returns:
485-
str: an equivalent pgsql query
483+
Returns :
484+
str : An equivalent pgSql query wrapped with ag_catalog.cypher
485+
486+
Raises:
487+
ValueError : If query is empty, contain RETURN *, or has invalid field names
486488
"""
487489

490+
if not query.strip():
491+
raise ValueError("Empty query provided")
492+
488493
# pgsql template
489494
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
490495
{query}
491496
$$) AS ({fields});"""
492497

493-
# if there are any returned fields they must be added to the pgsql query
494-
return_match = re.search(r'\breturn\b(?![^"]*")', query, re.IGNORECASE)
495-
if return_match:
496-
# Extract the part of the query after the RETURN keyword
497-
return_clause = query[return_match.end() :]
498-
499-
# parse return statement to identify returned fields
500-
fields = (
501-
return_clause.lower()
502-
.split("distinct")[-1]
503-
.split("order by")[0]
504-
.split("skip")[0]
505-
.split("limit")[0]
506-
.split(",")
507-
)
508-
509-
# raise exception if RETURN * is found as we can't resolve the fields
510-
if "*" in [x.strip() for x in fields]:
511-
raise ValueError(
512-
"AGE graph does not support 'RETURN *'"
513-
+ " statements in Cypher queries"
498+
# split the query into parts based on UNION and EXCEPT
499+
parts = re.split(r"\b(UNION\b|\bEXCEPT)\b", query, flags=re.IGNORECASE)
500+
501+
all_fields = []
502+
503+
for part in parts:
504+
if part.strip().upper() in ("UNION", "EXCEPT"):
505+
continue
506+
507+
# if there are any returned fields they must be added to the pgsql query
508+
return_match = re.search(r'\breturn\b(?![^"]*")', part, re.IGNORECASE)
509+
if return_match:
510+
# Extract the part of the query after the RETURN keyword
511+
return_clause = part[return_match.end() :]
512+
513+
# parse return statement to identify returned fields
514+
fields = (
515+
return_clause.lower()
516+
.split("distinct")[-1]
517+
.split("order by")[0]
518+
.split("skip")[0]
519+
.split("limit")[0]
520+
.split(",")
514521
)
515522

516-
# get pgsql formatted field names
517-
fields = [
518-
AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields)
519-
]
520-
521-
# build resulting pgsql relation
522-
fields_str = ", ".join(
523-
[
524-
field.split(".")[-1] + " agtype"
525-
for field in fields
526-
if field.split(".")[-1]
527-
]
528-
)
523+
# raise exception if RETURN * is found as we can't resolve the fields
524+
clean_fileds = [f.strip() for f in fields if f.strip()]
525+
if "*" in clean_fileds:
526+
raise ValueError(
527+
"Apache Age does not support RETURN * in Cypher queries"
528+
)
529529

530-
# if no return statement we still need to return a single field of type agtype
531-
else:
530+
# Format fields and maintain order of appearance
531+
for idx, field in enumerate(clean_fileds):
532+
field_name = AGEGraph._get_col_name(field, idx)
533+
if field_name not in all_fields:
534+
all_fields.append(field_name)
535+
536+
# if no return statements found in any part
537+
if not all_fields:
532538
fields_str = "a agtype"
533539

534-
select_str = "*"
540+
else:
541+
fields_str = ", ".join(f"{field} agtype" for field in all_fields)
535542

536543
return template.format(
537544
graph_name=graph_name,
538545
query=query,
539546
fields=fields_str,
540-
projection=select_str,
547+
projection="*",
541548
)
542549

543550
@staticmethod

libs/community/tests/unit_tests/graphs/test_age_graph.py

Lines changed: 169 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_get_col_name(self) -> None:
5353
self.assertEqual(AGEGraph._get_col_name(*value), expected[idx])
5454

5555
def test_wrap_query(self) -> None:
56+
"""Test basic query wrapping functionality."""
5657
inputs = [
5758
# Positive case: Simple return clause
5859
"""
@@ -76,46 +77,195 @@ def test_wrap_query(self) -> None:
7677

7778
expected = [
7879
# Expected output for the first positive case
80+
"""
81+
SELECT * FROM ag_catalog.cypher('test', $$
82+
MATCH (keanu:Person {name:'Keanu Reeves'})
83+
RETURN keanu.name AS name, keanu.born AS born
84+
$$) AS (name agtype, born agtype);
85+
""",
86+
# Second test case (no RETURN clause)
87+
"""
88+
SELECT * FROM ag_catalog.cypher('test', $$
89+
MERGE (n:a {id: 1})
90+
$$) AS (a agtype);
91+
""",
92+
# Expected output for the negative cases (no RETURN clause)
93+
"""
94+
SELECT * FROM ag_catalog.cypher('test', $$
95+
MATCH (n {description: "This will return a value"})
96+
MERGE (n)-[:RELATED]->(m)
97+
$$) AS (a agtype);
98+
""",
99+
"""
100+
SELECT * FROM ag_catalog.cypher('test', $$
101+
MATCH (n {returnValue: "some value"})
102+
MERGE (n)-[:RELATED]->(m)
103+
$$) AS (a agtype);
104+
""",
105+
]
106+
107+
for idx, value in enumerate(inputs):
108+
result = AGEGraph._wrap_query(value, "test")
109+
expected_result = expected[idx]
110+
self.assertEqual(
111+
re.sub(r"\s", "", result),
112+
re.sub(r"\s", "", expected_result),
113+
(
114+
f"Failed on test case {idx + 1}\n"
115+
f"Input:\n{value}\n"
116+
f"Expected:\n{expected_result}\n"
117+
f"Got:\n{result}"
118+
),
119+
)
120+
121+
def test_wrap_query_union_except(self) -> None:
122+
"""Test query wrapping with UNION and EXCEPT operators."""
123+
inputs = [
124+
# UNION case
125+
"""
126+
MATCH (n:Person)
127+
RETURN n.name AS name, n.age AS age
128+
UNION
129+
MATCH (n:Employee)
130+
RETURN n.name AS name, n.salary AS salary
131+
""",
132+
"""
133+
MATCH (a:Employee {name: "Alice"})
134+
RETURN a.name AS name
135+
UNION
136+
MATCH (b:Manager {name: "Bob"})
137+
RETURN b.name AS name
138+
""",
139+
# Complex UNION case
140+
"""
141+
MATCH (n)-[r]->(m)
142+
RETURN n.name AS source, type(r) AS relationship, m.name AS target
143+
UNION
144+
MATCH (m)-[r]->(n)
145+
RETURN m.name AS source, type(r) AS relationship, n.name AS target
146+
""",
147+
"""
148+
MATCH (a:Person)-[:FRIEND]->(b:Person)
149+
WHERE a.age > 30
150+
RETURN a.name AS name
151+
UNION
152+
MATCH (c:Person)-[:FRIEND]->(d:Person)
153+
WHERE c.age < 25
154+
RETURN c.name AS name
155+
""",
156+
# EXCEPT case
157+
"""
158+
MATCH (n:Person)
159+
RETURN n.name AS name
160+
EXCEPT
161+
MATCH (n:Employee)
162+
RETURN n.name AS name
163+
""",
164+
"""
165+
MATCH (a:Person)
166+
RETURN a.name AS name, a.age AS age
167+
EXCEPT
168+
MATCH (b:Person {name: "Alice", age: 30})
169+
RETURN b.name AS name, b.age AS age
170+
""",
171+
]
172+
173+
expected = [
79174
"""
80175
SELECT * FROM ag_catalog.cypher('test', $$
81-
MATCH (keanu:Person {name:'Keanu Reeves'})
82-
RETURN keanu.name AS name, keanu.born AS born
83-
$$) AS (name agtype, born agtype);
176+
MATCH (n:Person)
177+
RETURN n.name AS name, n.age AS age
178+
UNION
179+
MATCH (n:Employee)
180+
RETURN n.name AS name, n.salary AS salary
181+
$$) AS (name agtype, age agtype, salary agtype);
84182
""",
85183
"""
86184
SELECT * FROM ag_catalog.cypher('test', $$
87-
MERGE (n:a {id: 1})
88-
$$) AS (a agtype);
185+
MATCH (a:Employee {name: "Alice"})
186+
RETURN a.name AS name
187+
UNION
188+
MATCH (b:Manager {name: "Bob"})
189+
RETURN b.name AS name
190+
$$) AS (name agtype);
89191
""",
90-
# Expected output for the negative cases (no return clause)
91192
"""
92193
SELECT * FROM ag_catalog.cypher('test', $$
93-
MATCH (n {description: "This will return a value"})
94-
MERGE (n)-[:RELATED]->(m)
95-
$$) AS (a agtype);
194+
MATCH (n)-[r]->(m)
195+
RETURN n.name AS source, type(r) AS relationship, m.name AS target
196+
UNION
197+
MATCH (m)-[r]->(n)
198+
RETURN m.name AS source, type(r) AS relationship, n.name AS target
199+
$$) AS (source agtype, relationship agtype, target agtype);
96200
""",
97201
"""
98202
SELECT * FROM ag_catalog.cypher('test', $$
99-
MATCH (n {returnValue: "some value"})
100-
MERGE (n)-[:RELATED]->(m)
101-
$$) AS (a agtype);
203+
MATCH (a:Person)-[:FRIEND]->(b:Person)
204+
WHERE a.age > 30
205+
RETURN a.name AS name
206+
UNION
207+
MATCH (c:Person)-[:FRIEND]->(d:Person)
208+
WHERE c.age < 25
209+
RETURN c.name AS name
210+
$$) AS (name agtype);
211+
""",
212+
"""
213+
SELECT * FROM ag_catalog.cypher('test', $$
214+
MATCH (n:Person)
215+
RETURN n.name AS name
216+
EXCEPT
217+
MATCH (n:Employee)
218+
RETURN n.name AS name
219+
$$) AS (name agtype);
220+
""",
221+
"""
222+
SELECT * FROM ag_catalog.cypher('test', $$
223+
MATCH (a:Person)
224+
RETURN a.name AS name, a.age AS age
225+
EXCEPT
226+
MATCH (b:Person {name: "Alice", age: 30})
227+
RETURN b.name AS name, b.age AS age
228+
$$) AS (name agtype, age agtype);
102229
""",
103230
]
104231

105232
for idx, value in enumerate(inputs):
233+
result = AGEGraph._wrap_query(value, "test")
234+
expected_result = expected[idx]
106235
self.assertEqual(
107-
re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")),
108-
re.sub(r"\s", "", expected[idx]),
236+
re.sub(r"\s", "", result),
237+
re.sub(r"\s", "", expected_result),
238+
(
239+
f"Failed on test case {idx + 1}\n"
240+
f"Input:\n{value}\n"
241+
f"Expected:\n{expected_result}\n"
242+
f"Got:\n{result}"
243+
),
109244
)
110245

111-
with self.assertRaises(ValueError):
112-
AGEGraph._wrap_query(
113-
"""
246+
def test_wrap_query_errors(self) -> None:
247+
"""Test error cases for query wrapping."""
248+
error_cases = [
249+
# Empty query
250+
"",
251+
# Return * case
252+
"""
114253
MATCH ()
115254
RETURN *
116255
""",
117-
"test",
118-
)
256+
# Return * in UNION
257+
"""
258+
MATCH (n:Person)
259+
RETURN n.name
260+
UNION
261+
MATCH ()
262+
RETURN *
263+
""",
264+
]
265+
266+
for query in error_cases:
267+
with self.assertRaises(ValueError):
268+
AGEGraph._wrap_query(query, "test")
119269

120270
def test_format_properties(self) -> None:
121271
inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]

0 commit comments

Comments
 (0)