Skip to content

Commit e906fcc

Browse files
authored
Merge pull request #51 from mindsdb/identifier-tostring-fix
Keep escape symbols in identifier parts when dump to string
2 parents 7dbd454 + f602dec commit e906fcc

File tree

9 files changed

+76
-43
lines changed

9 files changed

+76
-43
lines changed

mindsdb_sql_parser/ast/select/identifier.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
from copy import copy, deepcopy
3-
from typing import List
3+
from typing import List, Optional
44

55
from mindsdb_sql_parser.ast.base import ASTNode
66
from mindsdb_sql_parser.utils import indent
@@ -27,19 +27,28 @@ def path_str_to_parts(path_str: str):
2727
}
2828

2929

30-
def get_reserved_words():
31-
from mindsdb_sql_parser.lexer import MindsDBLexer
30+
_reserved_keywords: set[str] = None
3231

33-
reserved = RESERVED_KEYWORDS
34-
for word in MindsDBLexer.tokens:
35-
if '_' not in word:
36-
# exclude combinations
37-
reserved.add(word)
38-
return reserved
32+
33+
def get_reserved_words() -> set[str]:
34+
global _reserved_keywords
35+
36+
if _reserved_keywords is None:
37+
from mindsdb_sql_parser.lexer import MindsDBLexer
38+
39+
_reserved_keywords = RESERVED_KEYWORDS
40+
for word in MindsDBLexer.tokens:
41+
if '_' not in word:
42+
# exclude combinations
43+
_reserved_keywords.add(word)
44+
return _reserved_keywords
3945

4046

4147
class Identifier(ASTNode):
42-
def __init__(self, path_str=None, parts=None, is_outer=False, with_rollup=False, *args, **kwargs):
48+
def __init__(
49+
self, path_str=None, parts=None, is_outer=False, with_rollup=False,
50+
is_quoted: Optional[List[bool]] = None, *args, **kwargs
51+
):
4352
super().__init__(*args, **kwargs)
4453
assert path_str or parts, "Either path_str or parts must be provided for an Identifier"
4554
assert not (path_str and parts), "Provide either path_str or parts, but not both"
@@ -48,7 +57,7 @@ def __init__(self, path_str=None, parts=None, is_outer=False, with_rollup=False,
4857

4958
if path_str and not parts:
5059
parts, is_quoted = path_str_to_parts(path_str)
51-
else:
60+
elif is_quoted is None:
5261
is_quoted = [False] * len(parts)
5362
assert isinstance(parts, list)
5463
self.parts = parts
@@ -63,22 +72,26 @@ def from_path_str(self, value, *args, **kwargs):
6372
parts, _ = path_str_to_parts(value)
6473
return Identifier(parts=parts, *args, **kwargs)
6574

66-
def parts_to_str(self):
67-
out_parts = []
75+
def append(self, other: "Identifier") -> None:
76+
self.parts += other.parts
77+
self.is_quoted += other.is_quoted
78+
79+
def iter_parts_str(self):
6880
reserved_words = get_reserved_words()
69-
for part in self.parts:
81+
for part, is_quoted in zip(self.parts, self.is_quoted):
7082
if isinstance(part, Star):
7183
part = str(part)
7284
else:
7385
if (
74-
not no_wrap_identifier_regex.fullmatch(part)
75-
or
76-
part.upper() in reserved_words
86+
is_quoted
87+
or not no_wrap_identifier_regex.fullmatch(part)
88+
or part.upper() in reserved_words
7789
):
7890
part = f'`{part}`'
91+
yield part
7992

80-
out_parts.append(part)
81-
return '.'.join(out_parts)
93+
def parts_to_str(self):
94+
return '.'.join(self.iter_parts_str())
8295

8396
def to_tree(self, *args, level=0, **kwargs):
8497
alias_str = f', alias={self.alias.to_tree()}' if self.alias else ''

mindsdb_sql_parser/ast/show.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def get_string(self, *args, **kwargs):
5353
from_str = ''
5454
if self.from_table:
5555
ar = [
56-
f'FROM {i}'
57-
for i in self.from_table.parts
56+
f'FROM {part}'
57+
for part in self.from_table.iter_parts_str()
5858
]
5959
ar.reverse()
6060
from_str = ' ' + ' '.join(ar)

mindsdb_sql_parser/parser.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def show(self, p):
538538
value0 = command.from_table
539539
value1 = p.identifier
540540
if value0 is not None:
541-
value1.parts = value1.parts + value0.parts
541+
value1.append(value0)
542542

543543
command.from_table = value1
544544
return command
@@ -549,7 +549,7 @@ def show(self, p):
549549
value0 = command.in_table
550550
value1 = p.identifier
551551
if value0 is not None:
552-
value1.parts = value1.parts + value0.parts
552+
value1.append(value0)
553553

554554
command.in_table = value1
555555
return command
@@ -1824,17 +1824,21 @@ def json_value(self, p):
18241824
'identifier DOT star')
18251825
def identifier(self, p):
18261826
node = p[0]
1827-
is_quoted = False
18281827
if isinstance(p[2], Star):
18291828
node.parts.append(p[2])
1829+
node.is_quoted.append(False)
18301830
elif isinstance(p[2], int):
18311831
node.parts.append(str(p[2]))
1832+
node.is_quoted.append(False)
18321833
elif isinstance(p[2], str):
18331834
node.parts.append(p[2])
1835+
node.is_quoted.append(False)
1836+
elif isinstance(p[2], Identifier):
1837+
node.append(p[2])
18341838
else:
1835-
node.parts += p[2].parts
1836-
is_quoted = p[2].is_quoted[0]
1837-
node.is_quoted.append(is_quoted)
1839+
# fallback, shouldn't happen
1840+
node.parts.append(str(p[2]))
1841+
node.is_quoted.append(False)
18381842
return node
18391843

18401844
@_('quote_string',

tests/test_base_sql/test_ast.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,21 @@ def test_identifier_deepcopy_is_quoted(self):
3131
ident = Identifier('`a`')
3232
ident2 = deepcopy(ident)
3333
assert ident2.is_quoted == [True]
34+
35+
def test_identifier_to_string(self):
36+
test_cases = [
37+
'test',
38+
'Test',
39+
'TEST',
40+
'`test`',
41+
'`Test`',
42+
'`TEST`'
43+
]
44+
45+
for test_case in test_cases:
46+
assert Identifier(test_case).to_string() == test_case
47+
48+
for i in range(len(test_cases)):
49+
for test_case in test_cases:
50+
test_str = f'{test_case}.{test_cases[i]}'
51+
assert Identifier(test_str).to_string() == test_str

tests/test_base_sql/test_select_structure.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -726,9 +726,10 @@ def test_backticks(self):
726726
sql = "SELECT `name`, `status` FROM `mindsdb`.`wow stuff predictors`.`even-dashes-work`.`nice`"
727727
ast = parse_sql(sql)
728728

729-
expected_ast = Select(targets=[Identifier(parts=['name']), Identifier(parts=['status'])],
730-
from_table=Identifier(parts=['mindsdb', 'wow stuff predictors', 'even-dashes-work', 'nice']),
731-
)
729+
expected_ast = Select(
730+
targets=[Identifier('`name`'), Identifier('`status`')],
731+
from_table=Identifier('`mindsdb`.`wow stuff predictors`.`even-dashes-work`.`nice`'),
732+
)
732733

733734
assert ast.to_tree() == expected_ast.to_tree()
734735
assert str(ast) == str(expected_ast)
@@ -1224,10 +1225,10 @@ def test_double_quote_render_skip(self):
12241225
sql = 'select `KEY_ID`, `a`.* from `Table1` where `id`=2'
12251226

12261227
expected_ast = Select(
1227-
targets=[Identifier('KEY_ID'), Identifier(parts=['a', Star()])],
1228-
from_table=Identifier(parts=['Table1']),
1228+
targets=[Identifier('`KEY_ID`'), Identifier(parts=['a', Star()], is_quoted=[True, False])],
1229+
from_table=Identifier(parts=['Table1'], is_quoted=[True]),
12291230
where=BinaryOperation(op='=', args=[
1230-
Identifier('id'), Constant(2)
1231+
Identifier('`id`'), Constant(2)
12311232
])
12321233
)
12331234

tests/test_base_sql/test_show.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ def test_from_where(self):
8888
assert ast.to_tree() == expected_ast.to_tree()
8989

9090
def test_full_columns(self):
91-
sql = "SHOW FULL COLUMNS FROM `concrete` FROM `files`"
91+
sql = "SHOW FULL COLUMNS FROM `ccc` FROM `fff`"
9292
ast = parse_sql(sql)
9393
expected_ast = Show(
9494
category='COLUMNS',
9595
modes=['FULL'],
96-
from_table=Identifier('files.concrete')
96+
from_table=Identifier('`fff`.`ccc`')
9797
)
9898

9999
assert str(ast) == str(expected_ast)

tests/test_mindsdb/test_create_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_create_predictor_quotes(self):
124124
ast = parse_sql(sql)
125125
expected_ast = CreatePredictor(
126126
name=Identifier('xxx'),
127-
integration_name=Identifier('yyy'),
127+
integration_name=Identifier('`yyy`'),
128128
query_str="SELECT * FROM zzz",
129129
targets=[Identifier('sss')],
130130
)

tests/test_mysql/test_mysql_parser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,23 @@ def test_select_varialbe_complex(self):
4242
assert str(ast) == str(expected_ast)
4343

4444
def test_show_index(self):
45-
sql = "SHOW INDEX FROM predictors"
45+
sql = "SHOW INDEX FROM `predictors`"
4646
ast = parse_sql(sql)
4747
expected_ast = Show(
4848
category='INDEX',
49-
from_table=Identifier('predictors')
49+
from_table=Identifier('`predictors`')
5050
)
5151

5252
assert str(ast).lower() == sql.lower()
5353
assert str(ast) == str(expected_ast)
5454
assert ast.to_tree() == expected_ast.to_tree()
5555

5656
def test_show_index_from_db(self):
57-
sql = "SHOW INDEX FROM predictors FROM db"
57+
sql = "SHOW INDEX FROM `predictors` FROM db"
5858
ast = parse_sql(sql)
5959
expected_ast = Show(
6060
category='INDEX',
61-
from_table=Identifier('db.predictors'),
61+
from_table=Identifier('db.`predictors`'),
6262
)
6363

6464
# assert str(ast).lower() == sql.lower()

tests/test_standard_render.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,4 @@ def test_standard_render():
8282

8383
# inject function
8484
module.parse_sql = parse_sql2
85-
8685
check_module(module)
87-
88-

0 commit comments

Comments
 (0)