11from __future__ import annotations
22
3+ import logging
34from ssl import SSLContext
45from types import TracebackType
56from typing import Any , Dict , List , Optional , Type , Union
67from uuid import uuid4
78
8- from httpx import Timeout , codes
9+ import trio
10+ from httpx import Request , Response , Timeout , codes
911
1012from firebolt .async_db .cursor import Cursor , CursorV1 , CursorV2
1113from firebolt .client import DEFAULT_API_URL
4547 validate_engine_name_and_url_v1 ,
4648)
4749
50+ logger = logging .getLogger (__name__ )
51+
4852
4953class Connection (BaseConnection ):
5054 """
@@ -78,9 +82,13 @@ class Connection(BaseConnection):
7882 "engine_url" ,
7983 "api_endpoint" ,
8084 "_is_closed" ,
85+ "_transaction_id" ,
86+ "_transaction_sequence_id" ,
87+ "_transaction_lock" ,
8188 "client_class" ,
8289 "cursor_type" ,
8390 "id" ,
91+ "_autocommit" ,
8492 )
8593
8694 def __init__ (
@@ -92,14 +100,17 @@ def __init__(
92100 api_endpoint : str ,
93101 init_parameters : Optional [Dict [str , Any ]] = None ,
94102 id : str = uuid4 ().hex ,
103+ autocommit : bool = True ,
95104 ):
96105 super ().__init__ (cursor_type )
97106 self .api_endpoint = api_endpoint
98107 self .engine_url = engine_url
99108 self ._cursors : List [Cursor ] = []
100109 self ._client = client
101110 self .id = id
111+ self ._transaction_lock : trio .Lock = trio .Lock ()
102112 self .init_parameters = init_parameters or {}
113+ self ._autocommit = autocommit
103114 if database :
104115 self .init_parameters ["database" ] = database
105116
@@ -192,6 +203,44 @@ async def cancel_async_query(self, token: str) -> None:
192203 cursor = self .cursor ()
193204 await cursor .execute (ASYNC_QUERY_CANCEL , [async_query_info [0 ].query_id ])
194205
206+ async def _execute_query_impl (self , request : Request ) -> Response :
207+ self ._add_transaction_params (request )
208+ response = await self ._client .send (request , stream = True )
209+ if not self .autocommit :
210+ self ._handle_transaction_updates (response .headers )
211+ return response
212+
213+ async def _begin_nolock (self , request : Request ) -> None :
214+ """Begin a transaction without a lock. Used internally."""
215+ # Create a copy of the request with "BEGIN" as the body content
216+ begin_request = self ._client .build_request (
217+ request .method , request .url , content = "BEGIN"
218+ )
219+ response = await self ._client .send (begin_request , stream = True )
220+ self ._handle_transaction_updates (response .headers )
221+
222+ async def _execute_query (self , request : Request ) -> Response :
223+ if self .in_transaction or not self .autocommit :
224+ async with self ._transaction_lock :
225+ # If autocommit is off we need to explicitly begin a transaction
226+ if not self .in_transaction :
227+ await self ._begin_nolock (request )
228+ return await self ._execute_query_impl (request )
229+ else :
230+ return await self ._execute_query_impl (request )
231+
232+ async def commit (self ) -> None :
233+ if self .closed :
234+ raise ConnectionClosedError ("Unable to commit: Connection closed." )
235+ # Commit is a no-op for V1
236+ if self .cursor_type != CursorV1 :
237+ await self .cursor ().execute ("COMMIT" )
238+
239+ async def rollback (self ) -> None :
240+ if self .closed :
241+ raise ConnectionClosedError ("Unable to rollback: Connection closed." )
242+ await self .cursor ().execute ("ROLLBACK" )
243+
195244 # Context manager support
196245 async def __aenter__ (self ) -> Connection :
197246 if self .closed :
@@ -203,6 +252,14 @@ async def aclose(self) -> None:
203252 if self .closed :
204253 return
205254
255+ # Only rollback if we have a transaction and autocommit is off
256+ if self .in_transaction and not self .autocommit :
257+ try :
258+ await self .rollback ()
259+ except Exception :
260+ # If rollback fails during close, continue closing
261+ logger .warning ("Rollback failed during close" )
262+
206263 # self._cursors is going to be changed during closing cursors
207264 # after this point no cursors would be added to _cursors, only removed since
208265 # closing lock is held, and later connection will be marked as closed
@@ -217,6 +274,10 @@ async def aclose(self) -> None:
217274 async def __aexit__ (
218275 self , exc_type : type , exc_val : Exception , exc_tb : TracebackType
219276 ) -> None :
277+ # If exiting normally (no exception) and we have a transaction with
278+ # autocommit=False, commit the transaction before closing
279+ if exc_type is None and not self .autocommit and self .in_transaction :
280+ await self .commit ()
220281 await self .aclose ()
221282
222283
@@ -229,6 +290,7 @@ async def connect(
229290 api_endpoint : str = DEFAULT_API_URL ,
230291 disable_cache : bool = False ,
231292 url : Optional [str ] = None ,
293+ autocommit : bool = True ,
232294 additional_parameters : Dict [str , Any ] = {},
233295) -> Connection :
234296 # auth parameter is optional in function signature
@@ -256,6 +318,7 @@ async def connect(
256318 user_agent_header = user_agent_header ,
257319 database = database ,
258320 connection_url = url ,
321+ autocommit = autocommit ,
259322 )
260323 elif auth_version == FireboltAuthVersion .V2 :
261324 assert account_name is not None
@@ -268,6 +331,7 @@ async def connect(
268331 api_endpoint = api_endpoint ,
269332 connection_id = connection_id ,
270333 disable_cache = disable_cache ,
334+ autocommit = autocommit ,
271335 )
272336 elif auth_version == FireboltAuthVersion .V1 :
273337 return await connect_v1 (
@@ -293,6 +357,7 @@ async def connect_v2(
293357 engine_name : Optional [str ] = None ,
294358 api_endpoint : str = DEFAULT_API_URL ,
295359 disable_cache : bool = False ,
360+ autocommit : bool = True ,
296361) -> Connection :
297362 """Connect to Firebolt.
298363
@@ -356,6 +421,7 @@ async def connect_v2(
356421 api_endpoint ,
357422 cursor .parameters | cursor ._set_parameters ,
358423 connection_id ,
424+ autocommit ,
359425 )
360426
361427
@@ -423,6 +489,7 @@ def connect_core(
423489 user_agent_header : str ,
424490 database : Optional [str ] = None ,
425491 connection_url : Optional [str ] = None ,
492+ autocommit : bool = True ,
426493) -> Connection :
427494 """Connect to Firebolt Core.
428495
@@ -460,6 +527,7 @@ def connect_core(
460527 client = client ,
461528 cursor_type = CursorV2 ,
462529 api_endpoint = verified_url ,
530+ autocommit = autocommit ,
463531 )
464532
465533
0 commit comments