|
4 | 4 | from typing import Any, Dict, Optional, Tuple
|
5 | 5 |
|
6 | 6 | import aiohttp
|
7 |
| -import sqlparse |
| 7 | +import sqlglot |
8 | 8 | from pydantic import BaseModel
|
9 |
| -from sqlglot.tokens import Token, Tokenizer, TokenType |
10 | 9 |
|
11 | 10 | logger = logging.getLogger("wren-ai-service")
|
12 | 11 |
|
@@ -50,206 +49,17 @@ def remove_limit_statement(sql: str) -> str:
|
50 | 49 | return modified_sql
|
51 | 50 |
|
52 | 51 |
|
53 |
| -def squish_sql(sql: str) -> str: |
54 |
| - return ( |
55 |
| - sqlparse.format( |
56 |
| - sql, |
57 |
| - strip_comments=False, |
58 |
| - reindent=False, # don't add newlines/indent |
59 |
| - keyword_case=None, # don't change case |
60 |
| - ) |
61 |
| - .replace("\n", " ") |
62 |
| - .replace("\r", " ") |
63 |
| - .strip() |
64 |
| - ) |
65 |
| - |
66 |
| - |
67 | 52 | def add_quotes(sql: str) -> Tuple[str, str]:
|
68 |
| - def _quote_sql_identifiers_by_tokens(sql: str, quote_char: str = '"') -> str: |
69 |
| - """ |
70 |
| - Add quotes around identifiers using SQLGlot's tokenizer positions. |
71 |
| - """ |
72 |
| - |
73 |
| - def is_sql_keyword(text: str) -> bool: |
74 |
| - """Check if the text is a SQL keyword that should not be quoted.""" |
75 |
| - # Common SQL keywords that should never be quoted |
76 |
| - sql_keywords = { |
77 |
| - # Basic SQL keywords |
78 |
| - "SELECT", |
79 |
| - "FROM", |
80 |
| - "WHERE", |
81 |
| - "JOIN", |
82 |
| - "LEFT", |
83 |
| - "RIGHT", |
84 |
| - "INNER", |
85 |
| - "OUTER", |
86 |
| - "ON", |
87 |
| - "AND", |
88 |
| - "OR", |
89 |
| - "NOT", |
90 |
| - "IN", |
91 |
| - "EXISTS", |
92 |
| - "BETWEEN", |
93 |
| - "LIKE", |
94 |
| - "IS", |
95 |
| - "NULL", |
96 |
| - "ORDER", |
97 |
| - "BY", |
98 |
| - "GROUP", |
99 |
| - "HAVING", |
100 |
| - "LIMIT", |
101 |
| - "OFFSET", |
102 |
| - "UNION", |
103 |
| - "INTERSECT", |
104 |
| - "EXCEPT", |
105 |
| - "AS", |
106 |
| - "DISTINCT", |
107 |
| - "ALL", |
108 |
| - "TOP", |
109 |
| - "WITH", |
110 |
| - "RECURSIVE", |
111 |
| - # Data types |
112 |
| - "INTEGER", |
113 |
| - "INT", |
114 |
| - "BIGINT", |
115 |
| - "SMALLINT", |
116 |
| - "DECIMAL", |
117 |
| - "NUMERIC", |
118 |
| - "FLOAT", |
119 |
| - "REAL", |
120 |
| - "DOUBLE", |
121 |
| - "PRECISION", |
122 |
| - "VARCHAR", |
123 |
| - "CHAR", |
124 |
| - "TEXT", |
125 |
| - "BOOLEAN", |
126 |
| - "BOOL", |
127 |
| - "DATE", |
128 |
| - "TIME", |
129 |
| - "TIMESTAMP", |
130 |
| - "TIMESTAMPTZ", |
131 |
| - "INTERVAL", |
132 |
| - "WITH", |
133 |
| - "WITHOUT", |
134 |
| - # Time/date keywords |
135 |
| - "YEAR", |
136 |
| - "MONTH", |
137 |
| - "DAY", |
138 |
| - "HOUR", |
139 |
| - "MINUTE", |
140 |
| - "SECOND", |
141 |
| - "TIMEZONE", |
142 |
| - "EPOCH", |
143 |
| - "AT", |
144 |
| - "ZONE", |
145 |
| - "CURRENT_DATE", |
146 |
| - "CURRENT_TIME", |
147 |
| - "CURRENT_TIMESTAMP", |
148 |
| - # Other common keywords |
149 |
| - "CASE", |
150 |
| - "WHEN", |
151 |
| - "THEN", |
152 |
| - "ELSE", |
153 |
| - "END", |
154 |
| - "DESC", |
155 |
| - "ASC", |
156 |
| - "TRUE", |
157 |
| - "FALSE", |
158 |
| - } |
159 |
| - return text.upper() in sql_keywords |
160 |
| - |
161 |
| - def is_ident(tok: Token): |
162 |
| - # SQLGlot uses VAR for identifiers, but also treats SQL keywords as identifiers in some contexts |
163 |
| - if tok.token_type not in ( |
164 |
| - TokenType.VAR, |
165 |
| - TokenType.SCHEMA, |
166 |
| - TokenType.TABLE, |
167 |
| - TokenType.COLUMN, |
168 |
| - TokenType.DATABASE, |
169 |
| - TokenType.INDEX, |
170 |
| - TokenType.VIEW, |
171 |
| - ): |
172 |
| - return False |
173 |
| - |
174 |
| - # Don't quote SQL keywords |
175 |
| - token_text = sql[tok.start : tok.end + 1] |
176 |
| - if is_sql_keyword(token_text): |
177 |
| - return False |
178 |
| - |
179 |
| - return True |
180 |
| - |
181 |
| - def is_already_quoted_text(text: str) -> bool: |
182 |
| - text = text.strip() |
183 |
| - return ( |
184 |
| - (len(text) >= 2 and text[0] == '"' and text[-1] == '"') |
185 |
| - or (len(text) >= 2 and text[0] == "`" and text[-1] == "`") |
186 |
| - or (len(text) >= 2 and text[0] == "[" and text[-1] == "]") |
187 |
| - ) |
188 |
| - |
189 |
| - toks = Tokenizer().tokenize(sql) |
190 |
| - n = len(toks) |
191 |
| - edits = [] # (start, end_exclusive, replacement) |
192 |
| - |
193 |
| - i = 0 |
194 |
| - while i < n: |
195 |
| - t = toks[i] |
196 |
| - |
197 |
| - if not is_ident(t): |
198 |
| - i += 1 |
199 |
| - continue |
200 |
| - |
201 |
| - # Check for wildcard pattern: IDENT DOT STAR (e.g., t.*) |
202 |
| - if ( |
203 |
| - i + 2 < n |
204 |
| - and toks[i + 1].token_type == TokenType.DOT |
205 |
| - and toks[i + 2].token_type == TokenType.STAR |
206 |
| - ): |
207 |
| - i += 3 # Skip the entire wildcard pattern |
208 |
| - continue |
209 |
| - |
210 |
| - # Check if this is part of a dotted chain |
211 |
| - j = i |
212 |
| - chain_tokens = [t] # Start with current identifier |
213 |
| - |
214 |
| - # Collect all tokens in the dotted chain: IDENT (DOT IDENT)* |
215 |
| - while ( |
216 |
| - j + 2 < n |
217 |
| - and toks[j + 1].token_type == TokenType.DOT |
218 |
| - and is_ident(toks[j + 2]) |
219 |
| - ): |
220 |
| - chain_tokens.append(toks[j + 1]) # DOT |
221 |
| - chain_tokens.append(toks[j + 2]) # IDENT |
222 |
| - j += 2 |
223 |
| - |
224 |
| - # If the next token after the chain is '(', it's a function call -> skip |
225 |
| - if j + 1 < n and toks[j + 1].token_type == TokenType.L_PAREN: |
226 |
| - i = j + 1 |
227 |
| - continue |
228 |
| - |
229 |
| - # Process each identifier in the chain separately to ensure all are quoted |
230 |
| - for k in range( |
231 |
| - 0, len(chain_tokens), 2 |
232 |
| - ): # Process only identifiers (skip dots) |
233 |
| - ident_token = chain_tokens[k] |
234 |
| - token_text = sql[ident_token.start : ident_token.end + 1] |
235 |
| - |
236 |
| - if not is_already_quoted_text(token_text): |
237 |
| - replacement = f"{quote_char}{token_text}{quote_char}" |
238 |
| - edits.append((ident_token.start, ident_token.end + 1, replacement)) |
239 |
| - |
240 |
| - i = j + 1 |
241 |
| - |
242 |
| - # Apply edits right-to-left to keep offsets valid |
243 |
| - out = sql |
244 |
| - for start, end, repl in sorted(edits, key=lambda x: x[0], reverse=True): |
245 |
| - out = out[:start] + repl + out[end:] |
246 |
| - return out |
247 |
| - |
248 | 53 | try:
|
249 |
| - sql = squish_sql(sql) |
250 |
| - quoted_sql = _quote_sql_identifiers_by_tokens(sql) |
| 54 | + quoted_sql = sqlglot.transpile( |
| 55 | + sql, |
| 56 | + read=None, |
| 57 | + identify=True, |
| 58 | + error_level=sqlglot.ErrorLevel.RAISE, |
| 59 | + unsupported_level=sqlglot.ErrorLevel.RAISE, |
| 60 | + )[0] |
251 | 61 | except Exception as e:
|
252 |
| - logger.exception(f"Error in adding quotes to {sql}: {e}") |
| 62 | + logger.exception(f"Error in sqlglot.transpile to {sql}: {e}") |
253 | 63 |
|
254 | 64 | return "", str(e)
|
255 | 65 |
|
|
0 commit comments