Skip to content

Commit 18c93ed

Browse files
authored
feat(FIR-46457): Support remove parameters backend header (#468)
1 parent e6f70a6 commit 18c93ed

File tree

12 files changed

+449
-34
lines changed

12 files changed

+449
-34
lines changed

src/firebolt/async_db/cursor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
from firebolt.common._types import ColType, ParameterType, SetParameter
2222
from firebolt.common.constants import (
2323
JSON_OUTPUT_FORMAT,
24+
REMOVE_PARAMETERS_HEADER,
2425
RESET_SESSION_HEADER,
2526
UPDATE_ENDPOINT_HEADER,
2627
UPDATE_PARAMETERS_HEADER,
2728
CursorState,
2829
)
2930
from firebolt.common.cursor.base_cursor import (
3031
BaseCursor,
32+
_parse_remove_parameters,
3133
_parse_update_endpoint,
3234
_parse_update_parameters,
3335
_raise_if_internal_set_parameter,
@@ -194,6 +196,10 @@ async def _parse_response_headers(self, headers: Headers) -> None:
194196
param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER))
195197
self._update_set_parameters(param_dict)
196198

199+
if headers.get(REMOVE_PARAMETERS_HEADER):
200+
param_list = _parse_remove_parameters(headers.get(REMOVE_PARAMETERS_HEADER))
201+
self._remove_set_parameters(param_list)
202+
197203
async def _close_rowset_and_reset(self) -> None:
198204
"""Reset cursor state."""
199205
if self._row_set is not None:

src/firebolt/client/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
DEFAULT_API_URL: str = "api.app.firebolt.io"
77
PROTOCOL_VERSION_HEADER_NAME = "Firebolt-Protocol-Version"
8-
PROTOCOL_VERSION: str = "2.3"
8+
PROTOCOL_VERSION: str = "2.4"
99
_REQUEST_ERRORS: Tuple[Type, ...] = (
1010
HTTPError,
1111
InvalidURL,

src/firebolt/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ class ParameterStyle(Enum):
3434
UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint"
3535
UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters"
3636
RESET_SESSION_HEADER = "Firebolt-Reset-Session"
37+
REMOVE_PARAMETERS_HEADER = "Firebolt-Remove-Parameters"

src/firebolt/common/cursor/base_cursor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def _parse_update_parameters(parameter_header: str) -> Dict[str, str]:
4242
return param_dict
4343

4444

45+
def _parse_remove_parameters(parameter_header: str) -> List[str]:
46+
"""Parse remove parameters header and return list of parameter names to remove."""
47+
# parse key1,key2,key3 comma separated string into list
48+
param_list = [item.strip() for item in parameter_header.split(",")]
49+
return param_list
50+
51+
4552
def _parse_update_endpoint(
4653
new_engine_endpoint_header: str,
4754
) -> Tuple[str, Dict[str, str]]:
@@ -223,6 +230,14 @@ def _update_set_parameters(self, parameters: Dict[str, Any]) -> None:
223230

224231
self._set_parameters.update(user_parameters)
225232

233+
def _remove_set_parameters(self, parameter_names: List[str]) -> None:
234+
"""Remove parameters from both user and immutable parameter collections."""
235+
for param_name in parameter_names:
236+
# Remove from user parameters
237+
self._set_parameters.pop(param_name, None)
238+
# Remove from immutable parameters
239+
self.parameters.pop(param_name, None)
240+
226241
def _update_server_parameters(self, parameters: Dict[str, Any]) -> None:
227242
for key, value in parameters.items():
228243
self.parameters[key] = value

src/firebolt/db/cursor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
from firebolt.common._types import ColType, ParameterType, SetParameter
3030
from firebolt.common.constants import (
3131
JSON_OUTPUT_FORMAT,
32+
REMOVE_PARAMETERS_HEADER,
3233
RESET_SESSION_HEADER,
3334
UPDATE_ENDPOINT_HEADER,
3435
UPDATE_PARAMETERS_HEADER,
3536
CursorState,
3637
)
3738
from firebolt.common.cursor.base_cursor import (
3839
BaseCursor,
40+
_parse_remove_parameters,
3941
_parse_update_endpoint,
4042
_parse_update_parameters,
4143
_raise_if_internal_set_parameter,
@@ -200,6 +202,10 @@ def _parse_response_headers(self, headers: Headers) -> None:
200202
param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER))
201203
self._update_set_parameters(param_dict)
202204

205+
if headers.get(REMOVE_PARAMETERS_HEADER):
206+
param_list = _parse_remove_parameters(headers.get(REMOVE_PARAMETERS_HEADER))
207+
self._remove_set_parameters(param_list)
208+
203209
def _close_rowset_and_reset(self) -> None:
204210
"""Reset the cursor state."""
205211
if self._row_set is not None:

tests/integration/dbapi/async/V2/test_queries_async.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,12 @@ async def test_parameterized_query_with_special_chars(connection: Connection) ->
289289
[
290290
(
291291
"fb_numeric",
292-
'INSERT INTO "test_tbl" VALUES ($1, $2)',
292+
'INSERT INTO "{table}" VALUES ($1, $2)',
293293
[(1, "alice"), (2, "bob"), (3, "charlie")],
294294
),
295295
(
296296
"qmark",
297-
'INSERT INTO "test_tbl" VALUES (?, ?)',
297+
'INSERT INTO "{table}" VALUES (?, ?)',
298298
[(4, "david"), (5, "eve"), (6, "frank")],
299299
),
300300
],
@@ -312,7 +312,7 @@ async def test_executemany_bulk_insert_paramstyles(
312312
firebolt.async_db.paramstyle = paramstyle
313313
# Generate a unique label for this test execution
314314
unique_label = f"test_bulk_insert_async_{paramstyle}_{randint(100000, 999999)}"
315-
table_name = "test_tbl"
315+
table_name = create_drop_test_table_setup_teardown_async
316316

317317
try:
318318
c = connection.cursor()
@@ -323,7 +323,7 @@ async def test_executemany_bulk_insert_paramstyles(
323323

324324
# Execute bulk insert
325325
await c.executemany(
326-
query,
326+
query.format(table=table_name),
327327
test_data,
328328
bulk_insert=True,
329329
)
@@ -767,3 +767,90 @@ async def test_select_quoted_bigint(
767767
assert result[0][0] == int(
768768
long_bigint_value
769769
), "Invalid data returned by fetchall"
770+
771+
772+
async def test_transaction_commit(
773+
connection: Connection, create_drop_test_table_setup_teardown_async: Callable
774+
) -> None:
775+
"""Test transaction SQL statements with COMMIT."""
776+
table_name = create_drop_test_table_setup_teardown_async
777+
async with connection.cursor() as c:
778+
# Test successful transaction with COMMIT
779+
result = await c.execute("BEGIN TRANSACTION")
780+
assert result == 0, "BEGIN TRANSACTION should return 0 rows"
781+
782+
await c.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'committed')")
783+
784+
result = await c.execute("COMMIT TRANSACTION")
785+
assert result == 0, "COMMIT TRANSACTION should return 0 rows"
786+
787+
# Verify the data was committed
788+
await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
789+
data = await c.fetchall()
790+
assert len(data) == 1, "Committed data should be present"
791+
assert data[0] == [
792+
1,
793+
"committed",
794+
], "Committed data should match inserted values"
795+
796+
797+
async def test_transaction_rollback(
798+
connection: Connection, create_drop_test_table_setup_teardown_async: Callable
799+
) -> None:
800+
"""Test transaction SQL statements with ROLLBACK."""
801+
table_name = create_drop_test_table_setup_teardown_async
802+
async with connection.cursor() as c:
803+
# Test transaction with ROLLBACK
804+
result = await c.execute("BEGIN") # Test short form
805+
assert result == 0, "BEGIN should return 0 rows"
806+
807+
await c.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'rolled_back')")
808+
809+
# Verify data is visible within transaction
810+
await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
811+
data = await c.fetchall()
812+
assert len(data) == 1, "Data should be visible within transaction"
813+
814+
result = await c.execute("ROLLBACK") # Test short form
815+
assert result == 0, "ROLLBACK should return 0 rows"
816+
817+
# Verify the data was rolled back
818+
await c.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
819+
data = await c.fetchall()
820+
assert len(data) == 0, "Rolled back data should not be present"
821+
822+
823+
async def test_transaction_cursor_isolation(
824+
connection: Connection, create_drop_test_table_setup_teardown_async: Callable
825+
) -> None:
826+
"""Test that one cursor can't see another's data until it commits."""
827+
table_name = create_drop_test_table_setup_teardown_async
828+
cursor1 = connection.cursor()
829+
cursor2 = connection.cursor()
830+
831+
# Start transaction in cursor1 and insert data
832+
result = await cursor1.execute("BEGIN TRANSACTION")
833+
assert result == 0, "BEGIN TRANSACTION should return 0 rows"
834+
835+
await cursor1.execute(f"INSERT INTO \"{table_name}\" VALUES (1, 'isolated_data')")
836+
837+
# Verify cursor1 can see its own uncommitted data
838+
await cursor1.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
839+
data1 = await cursor1.fetchall()
840+
assert len(data1) == 1, "Cursor1 should see its own uncommitted data"
841+
assert data1[0] == [1, "isolated_data"], "Cursor1 data should match inserted values"
842+
843+
# Verify cursor2 cannot see cursor1's uncommitted data
844+
await cursor2.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
845+
data2 = await cursor2.fetchall()
846+
assert len(data2) == 0, "Cursor2 should not see cursor1's uncommitted data"
847+
848+
# Commit the transaction in cursor1
849+
result = await cursor1.execute("COMMIT TRANSACTION")
850+
assert result == 0, "COMMIT TRANSACTION should return 0 rows"
851+
852+
# Now cursor2 should be able to see the committed data
853+
await cursor2.execute(f'SELECT * FROM "{table_name}" WHERE id = 1')
854+
data2 = await cursor2.fetchall()
855+
assert len(data2) == 1, "Cursor2 should see committed data after commit"
856+
assert data2[0] == [1, "isolated_data"], "Cursor2 should see the committed data"

tests/integration/dbapi/conftest.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import uuid
23
from datetime import date, datetime, timedelta, timezone
34
from decimal import Decimal
45
from logging import getLogger
@@ -12,8 +13,8 @@
1213

1314
LOGGER = getLogger(__name__)
1415

15-
CREATE_TEST_TABLE = 'CREATE TABLE IF NOT EXISTS "test_tbl" (id int, name string)'
16-
DROP_TEST_TABLE = 'DROP TABLE IF EXISTS "test_tbl" CASCADE'
16+
CREATE_TEST_TABLE = 'CREATE TABLE IF NOT EXISTS "{table}" (id int, name string)'
17+
DROP_TEST_TABLE = 'DROP TABLE IF EXISTS "{table}" CASCADE'
1718

1819
LONG_SELECT_DEFAULT_V1 = 250000000000
1920
LONG_SELECT_DEFAULT_V2 = 350000000000
@@ -29,38 +30,27 @@ def long_test_value_with_default(default: int = 0) -> int:
2930
return long_test_value_with_default
3031

3132

32-
@fixture
33-
def create_drop_test_table_setup_teardown(connection: Connection) -> None:
34-
with connection.cursor() as c:
35-
c.execute(CREATE_TEST_TABLE)
36-
yield c
37-
c.execute(DROP_TEST_TABLE)
38-
39-
40-
@fixture
41-
async def create_server_side_test_table_setup_teardown_async(
42-
connection: Connection,
43-
) -> None:
44-
with connection.cursor() as c:
45-
await c.execute(CREATE_TEST_TABLE)
46-
yield c
47-
await c.execute(DROP_TEST_TABLE)
33+
def generate_unique_table_name() -> str:
34+
"""Generate a unique table name for testing purposes."""
35+
return f"test_table_{uuid.uuid4().hex}"
4836

4937

5038
@fixture
5139
def create_drop_test_table_setup_teardown(connection: Connection) -> None:
40+
table = generate_unique_table_name()
5241
with connection.cursor() as c:
53-
c.execute(CREATE_TEST_TABLE)
54-
yield c
55-
c.execute(DROP_TEST_TABLE)
42+
c.execute(CREATE_TEST_TABLE.format(table=table))
43+
yield table
44+
c.execute(DROP_TEST_TABLE.format(table=table))
5645

5746

5847
@fixture
5948
async def create_drop_test_table_setup_teardown_async(connection: Connection) -> None:
49+
table = generate_unique_table_name()
6050
async with connection.cursor() as c:
61-
await c.execute(CREATE_TEST_TABLE)
62-
yield c
63-
await c.execute(DROP_TEST_TABLE)
51+
await c.execute(CREATE_TEST_TABLE.format(table=table))
52+
yield table
53+
await c.execute(DROP_TEST_TABLE.format(table=table))
6454

6555

6656
@fixture

0 commit comments

Comments
 (0)