Skip to content

Commit 4c9f038

Browse files
authored
Merge pull request #62 from mindsdb/fix-CONN-1340
Fix CTE with UNION
2 parents a2b697f + 2e38e1a commit 4c9f038

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

mindsdb_sql_parser/parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,14 +1106,15 @@ def select(self, p):
11061106
select.cte = p.ctes
11071107
return select
11081108

1109-
@_('ctes COMMA identifier cte_columns_or_nothing AS LPAREN select RPAREN')
1109+
@_('ctes COMMA identifier cte_columns_or_nothing AS LPAREN select RPAREN',
1110+
'ctes COMMA identifier cte_columns_or_nothing AS LPAREN union RPAREN')
11101111
def ctes(self, p):
11111112
ctes = p.ctes
11121113
ctes = ctes + [
11131114
CommonTableExpression(
11141115
name=p.identifier,
11151116
columns=p.cte_columns_or_nothing,
1116-
query=p.select)
1117+
query=p[6])
11171118
]
11181119
return ctes
11191120

tests/test_base_sql/test_select_common_table_expression.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,48 @@ def test_cte_nested(self):
8787
assert str(ast).lower() == sql.lower()
8888
assert str(ast) == str(expected_ast)
8989
assert ast.to_tree() == expected_ast.to_tree()
90+
91+
def test_cte_union(self):
92+
sql = """
93+
WITH ta AS (
94+
SELECT 'a' AS a
95+
UNION
96+
SELECT 'b' AS a
97+
), tb AS (
98+
SELECT 'c' AS a
99+
UNION
100+
SELECT 'd' AS a
101+
)
102+
SELECT a FROM ta
103+
UNION
104+
SELECT a FROM tb
105+
"""
106+
ast = parse_sql(sql)
107+
108+
expected_ast = Union(
109+
left=Select(
110+
cte=[
111+
CommonTableExpression(
112+
name=Identifier('ta'),
113+
query=Union(
114+
left=Select(targets=[Constant('a', alias=Identifier('a'))]),
115+
right=Select(targets=[Constant('b', alias=Identifier('a'))])
116+
)
117+
),
118+
CommonTableExpression(
119+
name=Identifier('tb'),
120+
query=Union(
121+
left=Select(targets=[Constant('c', alias=Identifier('a'))]),
122+
right=Select(targets=[Constant('d', alias=Identifier('a'))])
123+
)
124+
),
125+
],
126+
targets=[Identifier('a')],
127+
from_table=Identifier('ta')
128+
),
129+
right=Select(targets=[Identifier('a')], from_table=Identifier('tb'))
130+
)
131+
132+
assert (' '.join(str(ast).split())).lower() == (' '.join(sql.split())).lower()
133+
assert str(ast) == str(expected_ast)
134+
assert ast.to_tree() == expected_ast.to_tree()

0 commit comments

Comments
 (0)