11from mindsdb_sql_parser import parse_sql
2- from mindsdb_sql_parser .ast import Select , Identifier , BinaryOperation , Star
3- from mindsdb_sql_parser .ast import Variable , Function
42from mindsdb_sql_parser .parser import Show
3+ from mindsdb_sql_parser .ast import Select , Identifier , BinaryOperation , Star , Variable , Function , ASTNode
4+
5+
6+ def compare_ast (parsed : ASTNode , expected : ASTNode , sql : str ) -> None :
7+ assert parsed .to_tree () == expected .to_tree ()
8+ assert str (parsed ).lower () == sql .lower ()
9+ assert str (parsed ) == str (expected )
10+ assert str (eval (parsed .to_tree ())) == str (parsed )
511
612
713class TestMySQLParser :
814 def test_select_variable (self ):
915 sql = 'SELECT @version'
1016 ast = parse_sql (sql )
1117 expected_ast = Select (targets = [Variable ('version' )])
12- assert ast .to_tree () == expected_ast .to_tree ()
13- assert str (ast ).lower () == sql .lower ()
14- assert str (ast ) == str (expected_ast )
18+ compare_ast (ast , expected_ast , sql )
1519
1620 sql = 'SELECT @@version'
1721 ast = parse_sql (sql )
1822 expected_ast = Select (targets = [Variable ('version' , is_system_var = True )])
19- assert ast .to_tree () == expected_ast .to_tree ()
20- assert str (ast ).lower () == sql .lower ()
21- assert str (ast ) == str (expected_ast )
23+ compare_ast (ast , expected_ast , sql )
2224
2325 def test_select_varialbe_complex (self ):
24- sql = f """SELECT * FROM tab1 WHERE column1 in (SELECT column2 + @variable FROM t2)"""
26+ sql = """SELECT * FROM tab1 WHERE column1 in (SELECT column2 + @variable FROM t2)"""
2527 ast = parse_sql (sql )
2628 expected_ast = Select (targets = [Star ()],
2729 from_table = Identifier ('tab1' ),
@@ -36,10 +38,7 @@ def test_select_varialbe_complex(self):
3638 parentheses = True )
3739 )
3840 ))
39-
40- assert ast .to_tree () == expected_ast .to_tree ()
41- assert str (ast ).lower () == sql .lower ()
42- assert str (ast ) == str (expected_ast )
41+ compare_ast (ast , expected_ast , sql )
4342
4443 def test_show_index (self ):
4544 sql = "SHOW INDEX FROM `predictors`"
@@ -48,10 +47,7 @@ def test_show_index(self):
4847 category = 'INDEX' ,
4948 from_table = Identifier ('`predictors`' )
5049 )
51-
52- assert str (ast ).lower () == sql .lower ()
53- assert str (ast ) == str (expected_ast )
54- assert ast .to_tree () == expected_ast .to_tree ()
50+ compare_ast (ast , expected_ast , sql )
5551
5652 def test_show_index_from_db (self ):
5753 sql = "SHOW INDEX FROM `predictors` FROM db"
@@ -60,10 +56,7 @@ def test_show_index_from_db(self):
6056 category = 'INDEX' ,
6157 from_table = Identifier ('db.`predictors`' ),
6258 )
63-
64- # assert str(ast).lower() == sql.lower()
65- assert str (ast ) == str (expected_ast )
66- assert ast .to_tree () == expected_ast .to_tree ()
59+ compare_ast (ast , expected_ast , sql )
6760
6861 def test_with_rollup (self ):
6962 sql = "SELECT country, SUM(sales) FROM booksales GROUP BY country WITH ROLLUP"
@@ -77,7 +70,23 @@ def test_with_rollup(self):
7770 from_table = Identifier ('booksales' ),
7871 group_by = [Identifier ('country' , with_rollup = True )]
7972 )
73+ compare_ast (ast , expected_ast , sql )
74+
75+ def test_with_rollup_multiple_columns (self ):
76+ """Test WITH ROLLUP with multiple GROUP BY columns"""
77+ sql = "SELECT year, country, SUM(sales) FROM booksales GROUP BY year, country WITH ROLLUP"
8078
81- assert str (ast ).lower () == sql .lower ()
82- assert str (ast ) == str (expected_ast )
83- assert ast .to_tree () == expected_ast .to_tree ()
79+ ast = parse_sql (sql )
80+ expected_ast = Select (
81+ targets = [
82+ Identifier ('year' ),
83+ Identifier ('country' ),
84+ Function (op = 'SUM' , args = [Identifier ('sales' )])
85+ ],
86+ from_table = Identifier ('booksales' ),
87+ group_by = [
88+ Identifier ('year' ),
89+ Identifier ('country' , with_rollup = True )
90+ ]
91+ )
92+ compare_ast (ast , expected_ast , sql )
0 commit comments