Skip to content

Commit d26aa71

Browse files
authored
feat(FIR-46457): Add full transaction support (#471)
1 parent 0a85eaa commit d26aa71

File tree

23 files changed

+2316
-321
lines changed

23 files changed

+2316
-321
lines changed

src/firebolt/async_db/connection.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
import logging
34
from ssl import SSLContext
45
from types import TracebackType
56
from typing import Any, Dict, List, Optional, Type, Union
67
from uuid import uuid4
78

8-
from httpx import Timeout, codes
9+
import trio
10+
from httpx import Request, Response, Timeout, codes
911

1012
from firebolt.async_db.cursor import Cursor, CursorV1, CursorV2
1113
from firebolt.client import DEFAULT_API_URL
@@ -45,6 +47,8 @@
4547
validate_engine_name_and_url_v1,
4648
)
4749

50+
logger = logging.getLogger(__name__)
51+
4852

4953
class Connection(BaseConnection):
5054
"""
@@ -78,9 +82,13 @@ class Connection(BaseConnection):
7882
"engine_url",
7983
"api_endpoint",
8084
"_is_closed",
85+
"_transaction_id",
86+
"_transaction_sequence_id",
87+
"_transaction_lock",
8188
"client_class",
8289
"cursor_type",
8390
"id",
91+
"_autocommit",
8492
)
8593

8694
def __init__(
@@ -92,14 +100,17 @@ def __init__(
92100
api_endpoint: str,
93101
init_parameters: Optional[Dict[str, Any]] = None,
94102
id: str = uuid4().hex,
103+
autocommit: bool = True,
95104
):
96105
super().__init__(cursor_type)
97106
self.api_endpoint = api_endpoint
98107
self.engine_url = engine_url
99108
self._cursors: List[Cursor] = []
100109
self._client = client
101110
self.id = id
111+
self._transaction_lock: trio.Lock = trio.Lock()
102112
self.init_parameters = init_parameters or {}
113+
self._autocommit = autocommit
103114
if database:
104115
self.init_parameters["database"] = database
105116

@@ -192,6 +203,44 @@ async def cancel_async_query(self, token: str) -> None:
192203
cursor = self.cursor()
193204
await cursor.execute(ASYNC_QUERY_CANCEL, [async_query_info[0].query_id])
194205

206+
async def _execute_query_impl(self, request: Request) -> Response:
207+
self._add_transaction_params(request)
208+
response = await self._client.send(request, stream=True)
209+
if not self.autocommit:
210+
self._handle_transaction_updates(response.headers)
211+
return response
212+
213+
async def _begin_nolock(self, request: Request) -> None:
214+
"""Begin a transaction without a lock. Used internally."""
215+
# Create a copy of the request with "BEGIN" as the body content
216+
begin_request = self._client.build_request(
217+
request.method, request.url, content="BEGIN"
218+
)
219+
response = await self._client.send(begin_request, stream=True)
220+
self._handle_transaction_updates(response.headers)
221+
222+
async def _execute_query(self, request: Request) -> Response:
223+
if self.in_transaction or not self.autocommit:
224+
async with self._transaction_lock:
225+
# If autocommit is off we need to explicitly begin a transaction
226+
if not self.in_transaction:
227+
await self._begin_nolock(request)
228+
return await self._execute_query_impl(request)
229+
else:
230+
return await self._execute_query_impl(request)
231+
232+
async def commit(self) -> None:
233+
if self.closed:
234+
raise ConnectionClosedError("Unable to commit: Connection closed.")
235+
# Commit is a no-op for V1
236+
if self.cursor_type != CursorV1:
237+
await self.cursor().execute("COMMIT")
238+
239+
async def rollback(self) -> None:
240+
if self.closed:
241+
raise ConnectionClosedError("Unable to rollback: Connection closed.")
242+
await self.cursor().execute("ROLLBACK")
243+
195244
# Context manager support
196245
async def __aenter__(self) -> Connection:
197246
if self.closed:
@@ -203,6 +252,14 @@ async def aclose(self) -> None:
203252
if self.closed:
204253
return
205254

255+
# Only rollback if we have a transaction and autocommit is off
256+
if self.in_transaction and not self.autocommit:
257+
try:
258+
await self.rollback()
259+
except Exception:
260+
# If rollback fails during close, continue closing
261+
logger.warning("Rollback failed during close")
262+
206263
# self._cursors is going to be changed during closing cursors
207264
# after this point no cursors would be added to _cursors, only removed since
208265
# closing lock is held, and later connection will be marked as closed
@@ -217,6 +274,10 @@ async def aclose(self) -> None:
217274
async def __aexit__(
218275
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
219276
) -> None:
277+
# If exiting normally (no exception) and we have a transaction with
278+
# autocommit=False, commit the transaction before closing
279+
if exc_type is None and not self.autocommit and self.in_transaction:
280+
await self.commit()
220281
await self.aclose()
221282

222283

@@ -229,6 +290,7 @@ async def connect(
229290
api_endpoint: str = DEFAULT_API_URL,
230291
disable_cache: bool = False,
231292
url: Optional[str] = None,
293+
autocommit: bool = True,
232294
additional_parameters: Dict[str, Any] = {},
233295
) -> Connection:
234296
# auth parameter is optional in function signature
@@ -256,6 +318,7 @@ async def connect(
256318
user_agent_header=user_agent_header,
257319
database=database,
258320
connection_url=url,
321+
autocommit=autocommit,
259322
)
260323
elif auth_version == FireboltAuthVersion.V2:
261324
assert account_name is not None
@@ -268,6 +331,7 @@ async def connect(
268331
api_endpoint=api_endpoint,
269332
connection_id=connection_id,
270333
disable_cache=disable_cache,
334+
autocommit=autocommit,
271335
)
272336
elif auth_version == FireboltAuthVersion.V1:
273337
return await connect_v1(
@@ -293,6 +357,7 @@ async def connect_v2(
293357
engine_name: Optional[str] = None,
294358
api_endpoint: str = DEFAULT_API_URL,
295359
disable_cache: bool = False,
360+
autocommit: bool = True,
296361
) -> Connection:
297362
"""Connect to Firebolt.
298363
@@ -356,6 +421,7 @@ async def connect_v2(
356421
api_endpoint,
357422
cursor.parameters | cursor._set_parameters,
358423
connection_id,
424+
autocommit,
359425
)
360426

361427

@@ -423,6 +489,7 @@ def connect_core(
423489
user_agent_header: str,
424490
database: Optional[str] = None,
425491
connection_url: Optional[str] = None,
492+
autocommit: bool = True,
426493
) -> Connection:
427494
"""Connect to Firebolt Core.
428495
@@ -460,6 +527,7 @@ def connect_core(
460527
client=client,
461528
cursor_type=CursorV2,
462529
api_endpoint=verified_url,
530+
autocommit=autocommit,
463531
)
464532

465533

src/firebolt/async_db/cursor.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,13 @@
88
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
99
from urllib.parse import urljoin
1010

11-
from httpx import (
12-
URL,
13-
USE_CLIENT_DEFAULT,
14-
Headers,
15-
Response,
16-
TimeoutException,
17-
codes,
18-
)
11+
from httpx import URL, USE_CLIENT_DEFAULT, Response, TimeoutException, codes
1912

2013
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
2114
from firebolt.common._types import ColType, ParameterType, SetParameter
22-
from firebolt.common.constants import (
23-
JSON_OUTPUT_FORMAT,
24-
REMOVE_PARAMETERS_HEADER,
25-
RESET_SESSION_HEADER,
26-
UPDATE_ENDPOINT_HEADER,
27-
UPDATE_PARAMETERS_HEADER,
28-
CursorState,
29-
)
15+
from firebolt.common.constants import JSON_OUTPUT_FORMAT, CursorState
3016
from firebolt.common.cursor.base_cursor import (
3117
BaseCursor,
32-
_parse_remove_parameters,
33-
_parse_update_endpoint,
34-
_parse_update_parameters,
3518
_raise_if_internal_set_parameter,
3619
)
3720
from firebolt.common.cursor.decorators import (
@@ -135,7 +118,7 @@ async def _api_request(
135118
content=query,
136119
timeout=timeout if timeout is not None else USE_CLIENT_DEFAULT,
137120
)
138-
return await self._client.send(req, stream=True)
121+
return await self.connection._execute_query(req)
139122
except TimeoutException:
140123
raise QueryTimeoutError()
141124

@@ -181,25 +164,6 @@ async def _validate_set_parameter(
181164
# append empty result set
182165
await self._append_row_set_from_response(None)
183166

184-
async def _parse_response_headers(self, headers: Headers) -> None:
185-
if headers.get(UPDATE_ENDPOINT_HEADER):
186-
endpoint, params = _parse_update_endpoint(
187-
headers.get(UPDATE_ENDPOINT_HEADER)
188-
)
189-
self._update_set_parameters(params)
190-
self.engine_url = endpoint
191-
192-
if headers.get(RESET_SESSION_HEADER):
193-
self.flush_parameters()
194-
195-
if headers.get(UPDATE_PARAMETERS_HEADER):
196-
param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER))
197-
self._update_set_parameters(param_dict)
198-
199-
if headers.get(REMOVE_PARAMETERS_HEADER):
200-
param_list = _parse_remove_parameters(headers.get(REMOVE_PARAMETERS_HEADER))
201-
self._remove_set_parameters(param_list)
202-
203167
async def _close_rowset_and_reset(self) -> None:
204168
"""Reset cursor state."""
205169
if self._row_set is not None:
@@ -305,7 +269,7 @@ async def _execute_single_query(
305269
await resp.aread()
306270
self._parse_async_response(resp)
307271
else:
308-
await self._parse_response_headers(resp.headers)
272+
self._parse_response_headers(resp.headers)
309273
await self._append_row_set_from_response(resp)
310274

311275
if not async_execution:

src/firebolt/common/base_connection.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
11
from collections import namedtuple
22
from typing import Any, Dict, List, Optional, Tuple, Type
33

4+
from httpx import Headers, Request
5+
46
from firebolt.client.auth.base import Auth
57
from firebolt.common._types import ColType
8+
from firebolt.common.constants import (
9+
REMOVE_PARAMETERS_HEADER,
10+
RESET_SESSION_HEADER,
11+
TRANSACTION_ID_SETTING,
12+
TRANSACTION_SEQUENCE_ID_SETTING,
13+
UPDATE_PARAMETERS_HEADER,
14+
)
615
from firebolt.utils.cache import (
716
ConnectionInfo,
817
EngineInfo,
918
SecureCacheKey,
1019
_firebolt_cache,
1120
)
12-
from firebolt.utils.exception import ConnectionClosedError, FireboltError
21+
from firebolt.utils.exception import FireboltError
1322
from firebolt.utils.usage_tracker import (
1423
get_cache_tracking_params,
1524
get_user_agent_header,
1625
)
26+
from firebolt.utils.util import (
27+
_parse_remove_parameters,
28+
_parse_update_parameters,
29+
)
1730

1831
ASYNC_QUERY_STATUS_RUNNING = "RUNNING"
1932
ASYNC_QUERY_STATUS_SUCCESSFUL = "ENDED_SUCCESSFULLY"
@@ -68,6 +81,9 @@ def __init__(self, cursor_type: Type) -> None:
6881
self.cursor_type = cursor_type
6982
self._cursors: List[Any] = []
7083
self._is_closed = False
84+
self._transaction_id: Optional[str] = None
85+
self._transaction_sequence_id: Optional[str] = None
86+
self._autocommit: bool = True
7187

7288
def _remove_cursor(self, cursor: Any) -> None:
7389
# This way it's atomic
@@ -76,17 +92,70 @@ def _remove_cursor(self, cursor: Any) -> None:
7692
except ValueError:
7793
pass
7894

95+
@property
96+
def in_transaction(self) -> bool:
97+
"""`True` if connection is in a transaction; `False` otherwise."""
98+
return self._transaction_id is not None
99+
100+
def _parse_response_headers_transaction(self, headers: Headers) -> None:
101+
parameters_header = headers.get(UPDATE_PARAMETERS_HEADER)
102+
if not parameters_header:
103+
return
104+
parameters = _parse_update_parameters(parameters_header)
105+
transaction_id = parameters.get(TRANSACTION_ID_SETTING)
106+
if transaction_id:
107+
self._transaction_id = transaction_id
108+
sequence_id = parameters.get(TRANSACTION_SEQUENCE_ID_SETTING)
109+
if sequence_id:
110+
self._transaction_sequence_id = sequence_id
111+
112+
def _parse_remove_headers_transaction(self, headers: Headers) -> None:
113+
parameters_header = headers.get(REMOVE_PARAMETERS_HEADER)
114+
if not parameters_header:
115+
return
116+
parameters = _parse_remove_parameters(parameters_header)
117+
for param in parameters:
118+
if param == TRANSACTION_ID_SETTING:
119+
self._transaction_id = None
120+
elif param == TRANSACTION_SEQUENCE_ID_SETTING:
121+
self._transaction_sequence_id = None
122+
123+
def _reset_transaction_state(self) -> None:
124+
self._transaction_id = None
125+
self._transaction_sequence_id = None
126+
127+
def create_transaction_params(self) -> Dict[str, str]:
128+
params: Dict[str, str] = {}
129+
if self._transaction_id:
130+
params[TRANSACTION_ID_SETTING] = self._transaction_id
131+
if self._transaction_sequence_id is not None:
132+
params[TRANSACTION_SEQUENCE_ID_SETTING] = str(self._transaction_sequence_id)
133+
return params
134+
135+
def _add_transaction_params(self, request: Request) -> None:
136+
transaction_params = self.create_transaction_params()
137+
for key, value in transaction_params.items():
138+
request.url = request.url.copy_add_param(key, value)
139+
140+
def _handle_transaction_updates(self, headers: Headers) -> None:
141+
self._parse_response_headers_transaction(headers)
142+
if headers.get(RESET_SESSION_HEADER):
143+
self._reset_transaction_state()
144+
if headers.get(REMOVE_PARAMETERS_HEADER):
145+
self._parse_remove_headers_transaction(headers)
146+
147+
@property
148+
def autocommit(self) -> bool:
149+
"""
150+
`True` if connection is in autocommit mode; `False` otherwise.
151+
"""
152+
return self._autocommit
153+
79154
@property
80155
def closed(self) -> bool:
81156
"""`True` if connection is closed; `False` otherwise."""
82157
return self._is_closed
83158

84-
def commit(self) -> None:
85-
"""Does nothing since Firebolt doesn't have transactions."""
86-
87-
if self.closed:
88-
raise ConnectionClosedError("Unable to commit: Connection closed.")
89-
90159

91160
def get_cached_system_engine_info(
92161
auth: Auth,

src/firebolt/common/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@ class ParameterStyle(Enum):
2525
FB_NUMERIC = "fb_numeric" # $1, $2, ... as placeholders (server-side)
2626

2727

28+
TRANSACTION_ID_SETTING = "transaction_id"
29+
TRANSACTION_SEQUENCE_ID_SETTING = "transaction_sequence_id"
30+
2831
# Parameters that should be set using USE instead of SET
2932
USE_PARAMETER_LIST = ["database", "engine"]
3033
# parameters that can only be set by the backend
3134
DISALLOWED_PARAMETER_LIST = ["output_format"]
35+
# Connection level transaction management
36+
TRANSACTION_PARAMETER_LIST = [TRANSACTION_ID_SETTING, TRANSACTION_SEQUENCE_ID_SETTING]
3237
# parameters that are set by the backend and should not be set by the user
3338
IMMUTABLE_PARAMETER_LIST = USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST
3439
UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint"

0 commit comments

Comments
 (0)