Skip to content

Commit 8b2fa01

Browse files
committed
fix to_tree for queries with ROLLUP
1 parent 398f179 commit 8b2fa01

File tree

3 files changed

+40
-30
lines changed

3 files changed

+40
-30
lines changed

mindsdb_sql_parser/ast/select/identifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def parts_to_str(self):
9595

9696
def to_tree(self, *args, level=0, **kwargs):
9797
alias_str = f', alias={self.alias.to_tree()}' if self.alias else ''
98-
return indent(level) + f'Identifier(parts={[str(i) for i in self.parts]}{alias_str})'
98+
with_rollup_str = ', with_rollup=True' if self.with_rollup else ''
99+
return indent(level) + f'Identifier(parts={[str(i) for i in self.parts]}{alias_str}{with_rollup_str})'
99100

100101
def get_string(self, *args, **kwargs):
101102
return self.parts_to_str()

mindsdb_sql_parser/ast/show.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ def to_tree(self, *args, level=0, **kwargs):
3131
ind = indent(level)
3232
ind1 = indent(level+1)
3333
category_str = f'{ind1}category={repr(self.category)},'
34-
from_str = f'\n{ind1}from={self.from_table.to_string()},' if self.from_table else ''
35-
in_str = f'\n{ind1}in={self.in_table.to_tree(level=level + 2)},' if self.in_table else ''
34+
from_str = f'\n{ind1}from_table={self.from_table.to_tree(level=level + 2)},' if self.from_table else ''
35+
in_str = f'\n{ind1}in_table={self.in_table.to_tree(level=level + 2)},' if self.in_table else ''
3636
where_str = f'\n{ind1}where=\n{self.where.to_tree(level=level+2)},' if self.where else ''
37-
name_str = f'\n{ind1}name={self.name},' if self.name else ''
38-
like_str = f'\n{ind1}like={self.like},' if self.like else ''
39-
modes_str = f'\n{ind1}modes=[{",".join(self.modes)}],' if self.modes else ''
37+
name_str = f'\n{ind1}name={repr(self.name)},' if self.name else ''
38+
like_str = f'\n{ind1}like={repr(self.like)},' if self.like else ''
39+
modes_str = f'\n{ind1}modes=[{",".join([repr(m) for m in self.modes])}],' if self.modes else ''
4040
out_str = f'{ind}Show(' \
4141
f'{category_str}' \
4242
f'{name_str}' \
Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
from mindsdb_sql_parser import parse_sql
2-
from mindsdb_sql_parser.ast import Select, Identifier, BinaryOperation, Star
3-
from mindsdb_sql_parser.ast import Variable, Function
42
from mindsdb_sql_parser.parser import Show
3+
from mindsdb_sql_parser.ast import Select, Identifier, BinaryOperation, Star, Variable, Function, ASTNode
4+
5+
6+
def compare_ast(parsed: ASTNode, expected: ASTNode, sql: str) -> None:
7+
assert parsed.to_tree() == expected.to_tree()
8+
assert str(parsed).lower() == sql.lower()
9+
assert str(parsed) == str(expected)
10+
assert str(eval(parsed.to_tree())) == str(parsed)
511

612

713
class TestMySQLParser:
814
def test_select_variable(self):
915
sql = 'SELECT @version'
1016
ast = parse_sql(sql)
1117
expected_ast = Select(targets=[Variable('version')])
12-
assert ast.to_tree() == expected_ast.to_tree()
13-
assert str(ast).lower() == sql.lower()
14-
assert str(ast) == str(expected_ast)
18+
compare_ast(ast, expected_ast, sql)
1519

1620
sql = 'SELECT @@version'
1721
ast = parse_sql(sql)
1822
expected_ast = Select(targets=[Variable('version', is_system_var=True)])
19-
assert ast.to_tree() == expected_ast.to_tree()
20-
assert str(ast).lower() == sql.lower()
21-
assert str(ast) == str(expected_ast)
23+
compare_ast(ast, expected_ast, sql)
2224

2325
def test_select_varialbe_complex(self):
24-
sql = f"""SELECT * FROM tab1 WHERE column1 in (SELECT column2 + @variable FROM t2)"""
26+
sql = """SELECT * FROM tab1 WHERE column1 in (SELECT column2 + @variable FROM t2)"""
2527
ast = parse_sql(sql)
2628
expected_ast = Select(targets=[Star()],
2729
from_table=Identifier('tab1'),
@@ -36,10 +38,7 @@ def test_select_varialbe_complex(self):
3638
parentheses=True)
3739
)
3840
))
39-
40-
assert ast.to_tree() == expected_ast.to_tree()
41-
assert str(ast).lower() == sql.lower()
42-
assert str(ast) == str(expected_ast)
41+
compare_ast(ast, expected_ast, sql)
4342

4443
def test_show_index(self):
4544
sql = "SHOW INDEX FROM `predictors`"
@@ -48,10 +47,7 @@ def test_show_index(self):
4847
category='INDEX',
4948
from_table=Identifier('`predictors`')
5049
)
51-
52-
assert str(ast).lower() == sql.lower()
53-
assert str(ast) == str(expected_ast)
54-
assert ast.to_tree() == expected_ast.to_tree()
50+
compare_ast(ast, expected_ast, sql)
5551

5652
def test_show_index_from_db(self):
5753
sql = "SHOW INDEX FROM `predictors` FROM db"
@@ -60,10 +56,7 @@ def test_show_index_from_db(self):
6056
category='INDEX',
6157
from_table=Identifier('db.`predictors`'),
6258
)
63-
64-
# assert str(ast).lower() == sql.lower()
65-
assert str(ast) == str(expected_ast)
66-
assert ast.to_tree() == expected_ast.to_tree()
59+
compare_ast(ast, expected_ast, sql)
6760

6861
def test_with_rollup(self):
6962
sql = "SELECT country, SUM(sales) FROM booksales GROUP BY country WITH ROLLUP"
@@ -77,7 +70,23 @@ def test_with_rollup(self):
7770
from_table=Identifier('booksales'),
7871
group_by=[Identifier('country', with_rollup=True)]
7972
)
73+
compare_ast(ast, expected_ast, sql)
74+
75+
def test_with_rollup_multiple_columns(self):
76+
"""Test WITH ROLLUP with multiple GROUP BY columns"""
77+
sql = "SELECT year, country, SUM(sales) FROM booksales GROUP BY year, country WITH ROLLUP"
8078

81-
assert str(ast).lower() == sql.lower()
82-
assert str(ast) == str(expected_ast)
83-
assert ast.to_tree() == expected_ast.to_tree()
79+
ast = parse_sql(sql)
80+
expected_ast = Select(
81+
targets=[
82+
Identifier('year'),
83+
Identifier('country'),
84+
Function(op='SUM', args=[Identifier('sales')])
85+
],
86+
from_table=Identifier('booksales'),
87+
group_by=[
88+
Identifier('year'),
89+
Identifier('country', with_rollup=True)
90+
]
91+
)
92+
compare_ast(ast, expected_ast, sql)

0 commit comments

Comments
 (0)