Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindsdb_sql_parser/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = 'mindsdb_sql_parser'
__package_name__ = 'mindsdb_sql_parser'
__version__ = '0.2.0'
__version__ = '0.3.0'
__description__ = "Mindsdb SQL parser"
__email__ = "[email protected]"
__author__ = 'MindsDB Inc'
Expand Down
19 changes: 14 additions & 5 deletions mindsdb_sql_parser/ast/select/identifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from copy import copy, deepcopy
from typing import List

from mindsdb_sql_parser.ast.base import ASTNode
from mindsdb_sql_parser.utils import indent
Expand All @@ -11,9 +12,13 @@


def path_str_to_parts(path_str: str):
match = re.finditer(path_str_parts_regex, path_str)
parts = [x[0].strip('`') for x in match]
return parts
parts, is_quoted = [], []
for x in re.finditer(path_str_parts_regex, path_str):
part = x[0].strip('`')
parts.append(part)
is_quoted.append(x[0] != part)

return parts, is_quoted


RESERVED_KEYWORDS = {
Expand Down Expand Up @@ -42,13 +47,17 @@ def __init__(self, path_str=None, parts=None, *args, **kwargs):
parts = [Star()]

if path_str and not parts:
parts = path_str_to_parts(path_str)
parts, is_quoted = path_str_to_parts(path_str)
else:
is_quoted = [False] * len(parts)
assert isinstance(parts, list)
self.parts = parts
# parts which were quoted
self.is_quoted: List[bool] = is_quoted

@classmethod
def from_path_str(self, value, *args, **kwargs):
parts = path_str_to_parts(value)
parts, _ = path_str_to_parts(value)
return Identifier(parts=parts, *args, **kwargs)

def parts_to_str(self):
Expand Down
5 changes: 3 additions & 2 deletions mindsdb_sql_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,7 @@ def from_table(self, p):
@_('PLUGINS',
'ENGINES')
def from_table(self, p):
return Identifier.from_path_str(p[0])
return Identifier(p[0])

@_('identifier')
def from_table(self, p):
Expand Down Expand Up @@ -1739,6 +1739,7 @@ def identifier(self, p):
node.parts.append(p[2])
else:
node.parts += p[2].parts
node.is_quoted.append(p[2].is_quoted[0])
return node

@_('quote_string',
Expand All @@ -1749,7 +1750,7 @@ def string(self, p):
@_('id', 'dquote_string')
def identifier(self, p):
value = p[0]
return Identifier.from_path_str(value)
return Identifier(value)

@_('PARAMETER')
def parameter(self, p):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_base_sql/test_select_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,43 @@ def test_table_double_quote(self):
ast = parse_sql(sql)
assert str(ast) == str(expected_ast)

def test_double_quote_render_skip(self):
sql = 'select `KEY_ID` from `Table1` where `id`=2'

expected_ast = Select(
targets=[Identifier('KEY_ID')],
from_table=Identifier(parts=['Table1']),
where=BinaryOperation(op='=', args=[
Identifier('id'), Constant(2)
])
)

ast = parse_sql(sql)
assert str(ast) == str(expected_ast)

# check is quoted
assert ast.targets[0].is_quoted == [True]
assert ast.from_table.is_quoted == [True]
assert ast.where.args[0].is_quoted == [True]

sql = 'select KEY_ID from Table1 where id=2'

expected_ast = Select(
targets=[Identifier('KEY_ID')],
from_table=Identifier(parts=['Table1']),
where=BinaryOperation(op='=', args=[
Identifier('id'), Constant(2)
])
)

ast = parse_sql(sql)
assert str(ast) == str(expected_ast)

# check is not quoted
assert ast.targets[0].is_quoted == [False]
assert ast.from_table.is_quoted == [False]
assert ast.where.args[0].is_quoted == [False]

def test_window_function_mindsdb(self):

# modifier
Expand Down
2 changes: 1 addition & 1 deletion tests/test_standard_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def check_module(module):

tests = klass()
for test_name, test_method in inspect.getmembers(tests, predicate=inspect.ismethod):
if not test_name.startswith('test_') or test_name.endswith('_error'):
if not test_name.startswith('test_') or test_name.endswith('_error') or test_name.endswith('_render_skip'):
# skip tests that expected error
continue
sig = inspect.signature(test_method)
Expand Down