Skip to content

Commit f585d73

Browse files
authored
Merge pull request #38 from mindsdb/intersect-distinct
`DISTINCT` keyword for `UNION/INTERSECT/EXCEPT`
2 parents 16c9401 + ba5a0e6 commit f585d73

File tree

3 files changed

+40
-21
lines changed

3 files changed

+40
-21
lines changed

mindsdb_sql_parser/ast/select/union.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ def __init__(self,
99
left,
1010
right,
1111
unique=True,
12+
distinct_key=False,
1213
*args, **kwargs):
1314
super().__init__(*args, **kwargs)
1415
self.left = left
1516
self.right = right
1617
self.unique = unique
18+
self.distinct_key = distinct_key
1719

1820
if self.alias:
1921
self.parentheses = True
@@ -26,7 +28,7 @@ def to_tree(self, *args, level=0, **kwargs):
2628
right_str = f'\n{ind1}right=\n{self.right.to_tree(level=level + 2)},'
2729

2830
cls_name = self.__class__.__name__
29-
out_str = f'{ind}{cls_name}(unique={repr(self.unique)},' \
31+
out_str = f'{ind}{cls_name}(unique={repr(self.unique)}, distinct_key={repr(self.distinct_key)}' \
3032
f'{left_str}' \
3133
f'{right_str}' \
3234
f'\n{ind})'
@@ -38,6 +40,8 @@ def get_string(self, *args, **kwargs):
3840
keyword = self.operation
3941
if not self.unique:
4042
keyword += ' ALL'
43+
if self.distinct_key:
44+
keyword += ' DISTINCT'
4145
out_str = f"""{left_str}\n{keyword}\n{right_str}"""
4246

4347
return out_str

mindsdb_sql_parser/parser.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,25 +1061,35 @@ def database_engine(self, p):
10611061
@_('select UNION select',
10621062
'union UNION select',
10631063
'select UNION ALL select',
1064-
'union UNION ALL select')
1064+
'union UNION ALL select',
1065+
'select UNION DISTINCT select',
1066+
'union UNION DISTINCT select')
10651067
def union(self, p):
10661068
unique = not hasattr(p, 'ALL')
1067-
return Union(left=p[0], right=p[2] if unique else p[3], unique=unique)
1069+
distinct_key = hasattr(p, 'DISTINCT')
1070+
return Union(left=p[0], right=p[-1], unique=unique, distinct_key=distinct_key)
10681071

10691072
@_('select INTERSECT select',
10701073
'union INTERSECT select',
10711074
'select INTERSECT ALL select',
1072-
'union INTERSECT ALL select')
1075+
'union INTERSECT ALL select',
1076+
'select INTERSECT DISTINCT select',
1077+
'union INTERSECT DISTINCT select')
10731078
def union(self, p):
10741079
unique = not hasattr(p, 'ALL')
1075-
return Intersect(left=p[0], right=p[2] if unique else p[3], unique=unique)
1080+
distinct_key = hasattr(p, 'DISTINCT')
1081+
return Intersect(left=p[0], right=p[-1], unique=unique, distinct_key=distinct_key)
1082+
10761083
@_('select EXCEPT select',
10771084
'union EXCEPT select',
10781085
'select EXCEPT ALL select',
1079-
'union EXCEPT ALL select')
1086+
'union EXCEPT ALL select',
1087+
'select EXCEPT DISTINCT select',
1088+
'union EXCEPT DISTINCT select')
10801089
def union(self, p):
10811090
unique = not hasattr(p, 'ALL')
1082-
return Except(left=p[0], right=p[2] if unique else p[3], unique=unique)
1091+
distinct_key = hasattr(p, 'DISTINCT')
1092+
return Except(left=p[0], right=p[-1], unique=unique, distinct_key=distinct_key)
10831093

10841094
# tableau
10851095
@_('LPAREN select RPAREN')

tests/test_base_sql/test_union.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,26 @@ def test_single_select_error(self):
1212

1313
def test_union_base(self):
1414
for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items():
15-
sql = f"""SELECT col1 FROM tab1
16-
{keyword}
17-
SELECT col1 FROM tab2"""
15+
for rule in ['', 'distinct']:
16+
sql = f"""SELECT col1 FROM tab1
17+
{keyword} {rule}
18+
SELECT col1 FROM tab2"""
1819

19-
ast = parse_sql(sql)
20-
expected_ast = cls(unique=True,
21-
left=Select(targets=[Identifier('col1')],
22-
from_table=Identifier(parts=['tab1']),
23-
),
24-
right=Select(targets=[Identifier('col1')],
25-
from_table=Identifier(parts=['tab2']),
26-
),
27-
)
28-
assert ast.to_tree() == expected_ast.to_tree()
29-
assert str(ast) == str(expected_ast)
20+
ast = parse_sql(sql)
21+
expected_ast = cls(
22+
unique=True,
23+
distinct_key=rule == 'distinct',
24+
left=Select(
25+
targets=[Identifier('col1')],
26+
from_table=Identifier(parts=['tab1']),
27+
),
28+
right=Select(
29+
targets=[Identifier('col1')],
30+
from_table=Identifier(parts=['tab2']),
31+
),
32+
)
33+
assert ast.to_tree() == expected_ast.to_tree()
34+
assert str(ast) == str(expected_ast)
3035

3136
def test_union_all(self):
3237
for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items():

0 commit comments

Comments
 (0)