Skip to content

Commit 414117a

Browse files
add e2e test for thrift resultRowLimit
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 304ef0e commit 414117a

File tree

3 files changed

+62
-10
lines changed

3 files changed

+62
-10
lines changed

src/databricks/sql/client.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,11 +402,6 @@ def __init__(
402402

403403
self.connection: Connection = connection
404404

405-
if not connection.session.use_sea and row_limit is not None:
406-
logger.warning(
407-
"Row limit is only supported for SEA protocol. Ignoring row_limit."
408-
)
409-
410405
self.rowcount: int = -1 # Return -1 as this is not supported
411406
self.buffer_size_bytes: int = result_buffer_size_bytes
412407
self.active_result_set: Union[ResultSet, None] = None
@@ -802,6 +797,7 @@ def execute(
802797
parameters=prepared_params,
803798
async_op=False,
804799
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
800+
row_limit=self.row_limit,
805801
)
806802

807803
if self.active_result_set and self.active_result_set.is_staging_operation:
@@ -858,6 +854,7 @@ def execute_async(
858854
parameters=prepared_params,
859855
async_op=True,
860856
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
857+
row_limit=self.row_limit,
861858
)
862859

863860
return self

src/databricks/sql/session.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def __init__(
7676
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
7777
)
7878

79-
self.use_sea = kwargs.get("use_sea", False)
8079
self.backend = self._create_backend(
81-
self.use_sea,
8280
server_hostname,
8381
http_path,
8482
all_headers,
@@ -91,7 +89,6 @@ def __init__(
9189

9290
def _create_backend(
9391
self,
94-
use_sea: bool,
9592
server_hostname: str,
9693
http_path: str,
9794
all_headers: List[Tuple[str, str]],
@@ -100,6 +97,8 @@ def _create_backend(
10097
kwargs: dict,
10198
) -> DatabricksClient:
10299
"""Create and return the appropriate backend client."""
100+
use_sea = kwargs.get("use_sea", False)
101+
103102
databricks_client_class: Type[DatabricksClient]
104103
if use_sea:
105104
logger.debug("Creating SEA backend client")

tests/e2e/test_driver.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ def connection(self, extra_params=()):
113113
conn.close()
114114

115115
@contextmanager
116-
def cursor(self, extra_params=()):
116+
def cursor(self, extra_params=(), extra_cursor_params=()):
117117
with self.connection(extra_params) as conn:
118118
cursor = conn.cursor(
119-
arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes
119+
arraysize=self.arraysize,
120+
buffer_size_bytes=self.buffer_size_bytes,
121+
**extra_cursor_params,
120122
)
121123
try:
122124
yield cursor
@@ -945,6 +947,60 @@ def test_result_set_close(self):
945947
finally:
946948
cursor.close()
947949

950+
def test_row_limit_with_larger_result(self):
951+
"""Test that row_limit properly constrains results when query would return more rows"""
952+
row_limit = 1000
953+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
954+
# Execute a query that returns more than row_limit rows
955+
cursor.execute("SELECT * FROM range(2000)")
956+
rows = cursor.fetchall()
957+
958+
# Check if the number of rows is limited to row_limit
959+
assert len(rows) == row_limit, f"Expected {row_limit} rows, got {len(rows)}"
960+
961+
def test_row_limit_with_smaller_result(self):
962+
"""Test that row_limit doesn't affect results when query returns fewer rows than limit"""
963+
row_limit = 100
964+
expected_rows = 50
965+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
966+
# Execute a query that returns fewer than row_limit rows
967+
cursor.execute(f"SELECT * FROM range({expected_rows})")
968+
rows = cursor.fetchall()
969+
970+
# Check if all rows are returned (not limited by row_limit)
971+
assert (
972+
len(rows) == expected_rows
973+
), f"Expected {expected_rows} rows, got {len(rows)}"
974+
975+
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
976+
def test_row_limit_with_arrow_larger_result(self):
977+
"""Test that row_limit properly constrains arrow results when query would return more rows"""
978+
row_limit = 800
979+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
980+
# Execute a query that returns more than row_limit rows
981+
cursor.execute("SELECT * FROM range(1500)")
982+
arrow_table = cursor.fetchall_arrow()
983+
984+
# Check if the number of rows in the arrow table is limited to row_limit
985+
assert (
986+
arrow_table.num_rows == row_limit
987+
), f"Expected {row_limit} rows, got {arrow_table.num_rows}"
988+
989+
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
990+
def test_row_limit_with_arrow_smaller_result(self):
991+
"""Test that row_limit doesn't affect arrow results when query returns fewer rows than limit"""
992+
row_limit = 200
993+
expected_rows = 100
994+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
995+
# Execute a query that returns fewer than row_limit rows
996+
cursor.execute(f"SELECT * FROM range({expected_rows})")
997+
arrow_table = cursor.fetchall_arrow()
998+
999+
# Check if all rows are returned (not limited by row_limit)
1000+
assert (
1001+
arrow_table.num_rows == expected_rows
1002+
), f"Expected {expected_rows} rows, got {arrow_table.num_rows}"
1003+
9481004

9491005
# use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep
9501006
# the 429/503 subsuites separate since they execute under different circumstances.

0 commit comments

Comments
 (0)