2626 Column ,
2727 ParameterType ,
2828 RawColType ,
29+ SetParameter ,
2930 parse_type ,
3031 parse_value ,
3132 split_format_sql ,
@@ -100,6 +101,7 @@ class BaseCursor:
100101 "_idx_lock" ,
101102 "_row_sets" ,
102103 "_next_set_idx" ,
104+ "_set_parameters" ,
103105 )
104106
105107 default_arraysize = 1
@@ -114,6 +116,7 @@ def __init__(self, client: AsyncClient, connection: Connection):
114116 self ._row_sets : List [
115117 Tuple [int , Optional [List [Column ]], Optional [List [List [RawColType ]]]]
116118 ] = []
119+ self ._set_parameters : Dict [str , Any ] = dict ()
117120 self ._rowcount = - 1
118121 self ._idx = 0
119122 self ._next_set_idx = 0
@@ -172,37 +175,6 @@ def close(self) -> None:
172175 # remove typecheck skip after connection is implemented
173176 self .connection ._remove_cursor (self ) # type: ignore
174177
175- def _append_query_data (self , response : Response ) -> None :
176- """Store information about executed query from httpx response."""
177-
178- row_set : Tuple [
179- int , Optional [List [Column ]], Optional [List [List [RawColType ]]]
180- ] = (- 1 , None , None )
181-
182- # Empty response is returned for insert query
183- if response .headers .get ("content-length" , "" ) != "0" :
184- try :
185- # Skip parsing floats to properly parse them later
186- query_data = response .json (parse_float = str )
187- rowcount = int (query_data ["rows" ])
188- descriptions = [
189- Column (
190- d ["name" ], parse_type (d ["type" ]), None , None , None , None , None
191- )
192- for d in query_data ["meta" ]
193- ]
194-
195- # Parse data during fetch
196- rows = query_data ["data" ]
197- row_set = (rowcount , descriptions , rows )
198- except (KeyError , ValueError ) as err :
199- raise DataError (f"Invalid query data format: { str (err )} " )
200-
201- self ._row_sets .append (row_set )
202- if self ._next_set_idx == 0 :
203- # Populate values for first set
204- self ._pop_next_set ()
205-
206178 @check_not_closed
207179 @check_query_executed
208180 def nextset (self ) -> Optional [bool ]:
@@ -227,6 +199,9 @@ def _pop_next_set(self) -> Optional[bool]:
227199 self ._next_set_idx += 1
228200 return True
229201
202+ def flush_parameters (self ) -> None :
203+ self ._set_parameters = dict ()
204+
230205 async def _raise_if_error (self , resp : Response ) -> None :
231206 """Raise a proper error if any"""
232207 if resp .status_code == codes .INTERNAL_SERVER_ERROR :
@@ -260,39 +235,105 @@ def _reset(self) -> None:
260235 self ._row_sets = []
261236 self ._next_set_idx = 0
262237
263- async def _do_execute_request (
238+ def _row_set_from_response (
239+ self , response : Response
240+ ) -> Tuple [int , Optional [List [Column ]], Optional [List [List [RawColType ]]]]:
241+ """Fetch information about executed query from http response"""
242+
243+ # Empty response is returned for insert query
244+ if response .headers .get ("content-length" , "" ) == "0" :
245+ return (- 1 , None , None )
246+
247+ try :
248+ # Skip parsing floats to properly parse them later
249+ query_data = response .json (parse_float = str )
250+ rowcount = int (query_data ["rows" ])
251+ descriptions = [
252+ Column (d ["name" ], parse_type (d ["type" ]), None , None , None , None , None )
253+ for d in query_data ["meta" ]
254+ ]
255+
256+ # Parse data during fetch
257+ rows = query_data ["data" ]
258+ return (rowcount , descriptions , rows )
259+ except (KeyError , ValueError ) as err :
260+ raise DataError (f"Invalid query data format: { str (err )} " )
261+
262+ def _append_row_set (
264263 self ,
265- query : str ,
264+ row_set : Tuple [int , Optional [List [Column ]], Optional [List [List [RawColType ]]]],
265+ ) -> None :
266+ """Store information about executed query."""
267+ self ._row_sets .append (row_set )
268+ if self ._next_set_idx == 0 :
269+ # Populate values for first set
270+ self ._pop_next_set ()
271+
272+ async def _api_request (
273+ self , query : str , set_parameters : Optional [dict ]
274+ ) -> Response :
275+ return await self ._client .request (
276+ url = "/" ,
277+ method = "POST" ,
278+ params = {
279+ "database" : self .connection .database ,
280+ "output_format" : JSON_OUTPUT_FORMAT ,
281+ ** self ._set_parameters ,
282+ ** (set_parameters or dict ()),
283+ },
284+ content = query ,
285+ )
286+
287+ async def _do_execute (
288+ self ,
289+ raw_query : str ,
266290 parameters : Sequence [Sequence [ParameterType ]],
267291 set_parameters : Optional [Dict ] = None ,
268292 ) -> None :
269293 self ._reset ()
294+ if set_parameters is not None :
295+ logger .warning (
296+ "Passing set parameters as an argument is deprecated. Please run "
297+ "a query 'SET <param> = <value>'"
298+ )
270299 try :
271300
272- queries = split_format_sql (query , parameters )
301+ queries = split_format_sql (raw_query , parameters )
273302
274303 for query in queries :
275304
276305 start_time = time .time ()
277306 # our CREATE EXTERNAL TABLE queries currently require credentials,
278307 # so we will skip logging those queries.
279308 # https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table
280- if not re .search ("aws_key_id|credentials" , query , flags = re .IGNORECASE ):
309+ if isinstance (query , SetParameter ) or not re .search (
310+ "aws_key_id|credentials" , query , flags = re .IGNORECASE
311+ ):
281312 logger .debug (f"Running query: { query } " )
282313
283- resp = await self ._client .request (
284- url = "/" ,
285- method = "POST" ,
286- params = {
287- "database" : self .connection .database ,
288- "output_format" : JSON_OUTPUT_FORMAT ,
289- ** (set_parameters or dict ()),
290- },
291- content = query ,
292- )
314+ # Define type for mypy
315+ row_set : Tuple [
316+ int , Optional [List [Column ]], Optional [List [List [RawColType ]]]
317+ ] = (- 1 , None , None )
318+ if isinstance (query , SetParameter ):
319+ # Validate parameter by executing simple query with it
320+ resp = await self ._api_request (
321+ "select 1" , {query .name : query .value }
322+ )
323+ # Handle invalid set parameter
324+ if resp .status_code == codes .BAD_REQUEST :
325+ raise OperationalError (resp .text )
326+ await self ._raise_if_error (resp )
327+
328+ # set parameter passed validation
329+ self ._set_parameters [query .name ] = query .value
330+ else :
331+ resp = await self ._api_request (query , set_parameters )
332+ await self ._raise_if_error (resp )
333+ row_set = self ._row_set_from_response (resp )
334+
335+ self ._append_row_set (row_set )
293336
294- await self ._raise_if_error (resp )
295- self ._append_query_data (resp )
296337 logger .info (
297338 f"Query fetched { self .rowcount } rows in"
298339 f" { time .time () - start_time } seconds"
@@ -314,7 +355,7 @@ async def execute(
314355 """Prepare and execute a database query. Return row count."""
315356
316357 params_list = [parameters ] if parameters else []
317- await self ._do_execute_request (query , params_list , set_parameters )
358+ await self ._do_execute (query , params_list , set_parameters )
318359 return self .rowcount
319360
320361 @check_not_closed
@@ -325,7 +366,7 @@ async def executemany(
325366 Prepare and execute a database query against all parameter
326367 sequences provided. Return last query row count.
327368 """
328- await self ._do_execute_request (query , parameters_seq )
369+ await self ._do_execute (query , parameters_seq )
329370 return self .rowcount
330371
331372 def _parse_row (self , row : List [RawColType ]) -> List [ColType ]:
0 commit comments