11import pytest
2-
32from mindsdb_sql_parser import parse_sql
43from mindsdb_sql_parser .ast import *
54from mindsdb_sql_parser .ast .mindsdb .evaluate import Evaluate
65from mindsdb_sql_parser .lexer import MindsDBLexer
7-
6+ from mindsdb_sql_parser . utils import to_single_line
87
98class TestEvaluate :
109 def test_evaluate_lexer (self ):
11- sql = "EVALUATE balanced_accuracy_score FROM (SELECT true , pred FROM table_1)"
10+ sql = "EVALUATE balanced_accuracy_score FROM (SELECT ground_truth , pred FROM table_1)"
1211 tokens = list (MindsDBLexer ().tokenize (sql ))
1312 assert tokens [0 ].type == 'EVALUATE'
1413 assert tokens [1 ].type == 'ID'
1514 assert tokens [1 ].value == 'balanced_accuracy_score'
1615
1716 def test_evaluate_full_1 (self ):
18- sql = "EVALUATE balanced_accuracy_score FROM (SELECT ground_truth, pred FROM table_1) USING adjusted=1, param2=2;" # noqa
17+ sql = "EVALUATE balanced_accuracy_score FROM (SELECT ground_truth, pred FROM table_1) USING adjusted=1, param2=2;"
1918 ast = parse_sql (sql )
2019 expected_ast = Evaluate (
2120 name = Identifier ('balanced_accuracy_score' ),
2221 query_str = "SELECT ground_truth, pred FROM table_1" ,
2322 using = {'adjusted' : 1 , 'param2' : 2 },
2423 )
25- assert ' ' .join (str (ast ).split ()).lower () == sql .lower ()
24+ assert to_single_line (str (ast )).lower () == to_single_line (sql ).lower () # Added .lower()
25+ assert to_single_line (str (ast )).lower () == to_single_line (str (expected_ast )).lower () # Added .lower()
2626 assert ast .to_tree () == expected_ast .to_tree ()
27- assert str (ast ) == str (expected_ast )
2827
2928 def test_evaluate_full_2 (self ):
30- query_str = """SELECT t.rental_price as ground_truth, m.rental_price as prediction FROM example_db.demo_data.home_rentals as t JOIN mindsdb.home_rentals_model as m limit 100""" # noqa
29+ query_str = """SELECT t.rental_price as ground_truth, m.rental_price as prediction FROM example_db.demo_data.home_rentals as t JOIN mindsdb.home_rentals_model as m limit 100"""
3130 sql = f"""EVALUATE r2_score FROM ({ query_str } );"""
3231 ast = parse_sql (sql )
3332 expected_ast = Evaluate (
3433 name = Identifier ('r2_score' ),
3534 query_str = query_str ,
3635 )
37- assert ' ' . join (str (ast ). split ()) .lower () == sql .lower ()
38- assert ast . to_tree () == expected_ast . to_tree ()
39- assert str ( ast ). lower () == str ( expected_ast ). lower ()
36+ assert to_single_line (str (ast )) .lower () == to_single_line ( sql ). lower () # Added .lower()
37+ assert to_single_line ( str ( ast )). lower () == to_single_line ( str ( expected_ast )). lower () # Added .lower ()
38+ assert ast . to_tree () == expected_ast . to_tree ()
0 commit comments