diff --git a/mindsdb_sql_parser/__about__.py b/mindsdb_sql_parser/__about__.py index ae5d36a..82b92ee 100644 --- a/mindsdb_sql_parser/__about__.py +++ b/mindsdb_sql_parser/__about__.py @@ -1,6 +1,6 @@ __title__ = 'mindsdb_sql_parser' __package_name__ = 'mindsdb_sql_parser' -__version__ = '0.2.0' +__version__ = '0.3.0' __description__ = "Mindsdb SQL parser" __email__ = "jorge@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/mindsdb_sql_parser/ast/select/identifier.py b/mindsdb_sql_parser/ast/select/identifier.py index 1cac18f..37558bb 100644 --- a/mindsdb_sql_parser/ast/select/identifier.py +++ b/mindsdb_sql_parser/ast/select/identifier.py @@ -1,5 +1,6 @@ import re from copy import copy, deepcopy +from typing import List from mindsdb_sql_parser.ast.base import ASTNode from mindsdb_sql_parser.utils import indent @@ -11,9 +12,13 @@ def path_str_to_parts(path_str: str): - match = re.finditer(path_str_parts_regex, path_str) - parts = [x[0].strip('`') for x in match] - return parts + parts, is_quoted = [], [] + for x in re.finditer(path_str_parts_regex, path_str): + part = x[0].strip('`') + parts.append(part) + is_quoted.append(x[0] != part) + + return parts, is_quoted RESERVED_KEYWORDS = { @@ -42,13 +47,17 @@ def __init__(self, path_str=None, parts=None, *args, **kwargs): parts = [Star()] if path_str and not parts: - parts = path_str_to_parts(path_str) + parts, is_quoted = path_str_to_parts(path_str) + else: + is_quoted = [False] * len(parts) assert isinstance(parts, list) self.parts = parts + # parts which were quoted + self.is_quoted: List[bool] = is_quoted @classmethod def from_path_str(self, value, *args, **kwargs): - parts = path_str_to_parts(value) + parts, _ = path_str_to_parts(value) return Identifier(parts=parts, *args, **kwargs) def parts_to_str(self): diff --git a/mindsdb_sql_parser/parser.py b/mindsdb_sql_parser/parser.py index 7d8d6ae..1757c6c 100644 --- a/mindsdb_sql_parser/parser.py +++ b/mindsdb_sql_parser/parser.py @@ -1264,7 +1264,7 @@ def from_table(self, p): @_('PLUGINS', 'ENGINES') def from_table(self, p): - return Identifier.from_path_str(p[0]) + return Identifier(p[0]) @_('identifier') def from_table(self, p): @@ -1739,6 +1739,7 @@ def identifier(self, p): node.parts.append(p[2]) else: node.parts += p[2].parts + node.is_quoted.append(p[2].is_quoted[0]) return node @_('quote_string', @@ -1749,7 +1750,7 @@ def string(self, p): @_('id', 'dquote_string') def identifier(self, p): value = p[0] - return Identifier.from_path_str(value) + return Identifier(value) @_('PARAMETER') def parameter(self, p): diff --git a/tests/test_base_sql/test_select_structure.py b/tests/test_base_sql/test_select_structure.py index 6bb334d..daf26b7 100644 --- a/tests/test_base_sql/test_select_structure.py +++ b/tests/test_base_sql/test_select_structure.py @@ -1182,6 +1182,43 @@ def test_table_double_quote(self): ast = parse_sql(sql) assert str(ast) == str(expected_ast) + def test_double_quote_render_skip(self): + sql = 'select `KEY_ID` from `Table1` where `id`=2' + + expected_ast = Select( + targets=[Identifier('KEY_ID')], + from_table=Identifier(parts=['Table1']), + where=BinaryOperation(op='=', args=[ + Identifier('id'), Constant(2) + ]) + ) + + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + + # check is quoted + assert ast.targets[0].is_quoted == [True] + assert ast.from_table.is_quoted == [True] + assert ast.where.args[0].is_quoted == [True] + + sql = 'select KEY_ID from Table1 where id=2' + + expected_ast = Select( + targets=[Identifier('KEY_ID')], + from_table=Identifier(parts=['Table1']), + where=BinaryOperation(op='=', args=[ + Identifier('id'), Constant(2) + ]) + ) + + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + + # check is not quoted + assert ast.targets[0].is_quoted == [False] + assert ast.from_table.is_quoted == [False] + assert ast.where.args[0].is_quoted == [False] + def test_window_function_mindsdb(self): # modifier diff --git a/tests/test_standard_render.py b/tests/test_standard_render.py index bdce825..611adbe 100644 --- a/tests/test_standard_render.py +++ b/tests/test_standard_render.py @@ -28,7 +28,7 @@ def check_module(module): tests = klass() for test_name, test_method in inspect.getmembers(tests, predicate=inspect.ismethod): - if not test_name.startswith('test_') or test_name.endswith('_error'): + if not test_name.startswith('test_') or test_name.endswith('_error') or test_name.endswith('_render_skip'): # skip tests that expected error continue sig = inspect.signature(test_method)