Skip to content

Commit f623d71

Browse files
authored
chore(wren-ai-service): revert add quotes (#1925)
1 parent 531eb7e commit f623d71

File tree

2 files changed

+9
-724
lines changed

2 files changed

+9
-724
lines changed

wren-ai-service/src/core/engine.py

Lines changed: 9 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from typing import Any, Dict, Optional, Tuple
55

66
import aiohttp
7-
import sqlparse
7+
import sqlglot
88
from pydantic import BaseModel
9-
from sqlglot.tokens import Token, Tokenizer, TokenType
109

1110
logger = logging.getLogger("wren-ai-service")
1211

@@ -50,206 +49,17 @@ def remove_limit_statement(sql: str) -> str:
5049
return modified_sql
5150

5251

53-
def squish_sql(sql: str) -> str:
54-
return (
55-
sqlparse.format(
56-
sql,
57-
strip_comments=False,
58-
reindent=False, # don't add newlines/indent
59-
keyword_case=None, # don't change case
60-
)
61-
.replace("\n", " ")
62-
.replace("\r", " ")
63-
.strip()
64-
)
65-
66-
6752
def add_quotes(sql: str) -> Tuple[str, str]:
68-
def _quote_sql_identifiers_by_tokens(sql: str, quote_char: str = '"') -> str:
69-
"""
70-
Add quotes around identifiers using SQLGlot's tokenizer positions.
71-
"""
72-
73-
def is_sql_keyword(text: str) -> bool:
74-
"""Check if the text is a SQL keyword that should not be quoted."""
75-
# Common SQL keywords that should never be quoted
76-
sql_keywords = {
77-
# Basic SQL keywords
78-
"SELECT",
79-
"FROM",
80-
"WHERE",
81-
"JOIN",
82-
"LEFT",
83-
"RIGHT",
84-
"INNER",
85-
"OUTER",
86-
"ON",
87-
"AND",
88-
"OR",
89-
"NOT",
90-
"IN",
91-
"EXISTS",
92-
"BETWEEN",
93-
"LIKE",
94-
"IS",
95-
"NULL",
96-
"ORDER",
97-
"BY",
98-
"GROUP",
99-
"HAVING",
100-
"LIMIT",
101-
"OFFSET",
102-
"UNION",
103-
"INTERSECT",
104-
"EXCEPT",
105-
"AS",
106-
"DISTINCT",
107-
"ALL",
108-
"TOP",
109-
"WITH",
110-
"RECURSIVE",
111-
# Data types
112-
"INTEGER",
113-
"INT",
114-
"BIGINT",
115-
"SMALLINT",
116-
"DECIMAL",
117-
"NUMERIC",
118-
"FLOAT",
119-
"REAL",
120-
"DOUBLE",
121-
"PRECISION",
122-
"VARCHAR",
123-
"CHAR",
124-
"TEXT",
125-
"BOOLEAN",
126-
"BOOL",
127-
"DATE",
128-
"TIME",
129-
"TIMESTAMP",
130-
"TIMESTAMPTZ",
131-
"INTERVAL",
132-
"WITH",
133-
"WITHOUT",
134-
# Time/date keywords
135-
"YEAR",
136-
"MONTH",
137-
"DAY",
138-
"HOUR",
139-
"MINUTE",
140-
"SECOND",
141-
"TIMEZONE",
142-
"EPOCH",
143-
"AT",
144-
"ZONE",
145-
"CURRENT_DATE",
146-
"CURRENT_TIME",
147-
"CURRENT_TIMESTAMP",
148-
# Other common keywords
149-
"CASE",
150-
"WHEN",
151-
"THEN",
152-
"ELSE",
153-
"END",
154-
"DESC",
155-
"ASC",
156-
"TRUE",
157-
"FALSE",
158-
}
159-
return text.upper() in sql_keywords
160-
161-
def is_ident(tok: Token):
162-
# SQLGlot uses VAR for identifiers, but also treats SQL keywords as identifiers in some contexts
163-
if tok.token_type not in (
164-
TokenType.VAR,
165-
TokenType.SCHEMA,
166-
TokenType.TABLE,
167-
TokenType.COLUMN,
168-
TokenType.DATABASE,
169-
TokenType.INDEX,
170-
TokenType.VIEW,
171-
):
172-
return False
173-
174-
# Don't quote SQL keywords
175-
token_text = sql[tok.start : tok.end + 1]
176-
if is_sql_keyword(token_text):
177-
return False
178-
179-
return True
180-
181-
def is_already_quoted_text(text: str) -> bool:
182-
text = text.strip()
183-
return (
184-
(len(text) >= 2 and text[0] == '"' and text[-1] == '"')
185-
or (len(text) >= 2 and text[0] == "`" and text[-1] == "`")
186-
or (len(text) >= 2 and text[0] == "[" and text[-1] == "]")
187-
)
188-
189-
toks = Tokenizer().tokenize(sql)
190-
n = len(toks)
191-
edits = [] # (start, end_exclusive, replacement)
192-
193-
i = 0
194-
while i < n:
195-
t = toks[i]
196-
197-
if not is_ident(t):
198-
i += 1
199-
continue
200-
201-
# Check for wildcard pattern: IDENT DOT STAR (e.g., t.*)
202-
if (
203-
i + 2 < n
204-
and toks[i + 1].token_type == TokenType.DOT
205-
and toks[i + 2].token_type == TokenType.STAR
206-
):
207-
i += 3 # Skip the entire wildcard pattern
208-
continue
209-
210-
# Check if this is part of a dotted chain
211-
j = i
212-
chain_tokens = [t] # Start with current identifier
213-
214-
# Collect all tokens in the dotted chain: IDENT (DOT IDENT)*
215-
while (
216-
j + 2 < n
217-
and toks[j + 1].token_type == TokenType.DOT
218-
and is_ident(toks[j + 2])
219-
):
220-
chain_tokens.append(toks[j + 1]) # DOT
221-
chain_tokens.append(toks[j + 2]) # IDENT
222-
j += 2
223-
224-
# If the next token after the chain is '(', it's a function call -> skip
225-
if j + 1 < n and toks[j + 1].token_type == TokenType.L_PAREN:
226-
i = j + 1
227-
continue
228-
229-
# Process each identifier in the chain separately to ensure all are quoted
230-
for k in range(
231-
0, len(chain_tokens), 2
232-
): # Process only identifiers (skip dots)
233-
ident_token = chain_tokens[k]
234-
token_text = sql[ident_token.start : ident_token.end + 1]
235-
236-
if not is_already_quoted_text(token_text):
237-
replacement = f"{quote_char}{token_text}{quote_char}"
238-
edits.append((ident_token.start, ident_token.end + 1, replacement))
239-
240-
i = j + 1
241-
242-
# Apply edits right-to-left to keep offsets valid
243-
out = sql
244-
for start, end, repl in sorted(edits, key=lambda x: x[0], reverse=True):
245-
out = out[:start] + repl + out[end:]
246-
return out
247-
24853
try:
249-
sql = squish_sql(sql)
250-
quoted_sql = _quote_sql_identifiers_by_tokens(sql)
54+
quoted_sql = sqlglot.transpile(
55+
sql,
56+
read=None,
57+
identify=True,
58+
error_level=sqlglot.ErrorLevel.RAISE,
59+
unsupported_level=sqlglot.ErrorLevel.RAISE,
60+
)[0]
25161
except Exception as e:
252-
logger.exception(f"Error in adding quotes to {sql}: {e}")
62+
logger.exception(f"Error in sqlglot.transpile to {sql}: {e}")
25363

25464
return "", str(e)
25565

0 commit comments

Comments
 (0)