Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/databricks/sql/backend/databricks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def execute_command(
parameters: List,
async_op: bool,
enforce_embedded_schema_correctness: bool,
row_limit: Optional[int] = None,
) -> Union["ResultSet", None]:
"""
Executes a SQL command or query within the specified session.
Expand All @@ -103,6 +104,7 @@ def execute_command(
parameters: List of parameters to bind to the query
async_op: Whether to execute the command asynchronously
enforce_embedded_schema_correctness: Whether to enforce schema correctness
row_limit: Maximum number of rows to fetch overall. Only supported for SEA protocol.

Returns:
If async_op is False, returns a ResultSet object containing the
Expand Down
3 changes: 2 additions & 1 deletion src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def execute_command(
parameters: List[Dict[str, Any]],
async_op: bool,
enforce_embedded_schema_correctness: bool,
row_limit: Optional[int] = None,
) -> Union["ResultSet", None]:
"""
Execute a SQL command using the SEA backend.
Expand Down Expand Up @@ -462,7 +463,7 @@ def execute_command(
format=format,
wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value,
on_wait_timeout="CONTINUE",
row_limit=max_rows,
row_limit=row_limit,
parameters=sea_parameters if sea_parameters else None,
result_compression=result_compression,
)
Expand Down
3 changes: 2 additions & 1 deletion src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import time
import threading
from typing import List, Union, Any, TYPE_CHECKING
from typing import List, Optional, Union, Any, TYPE_CHECKING

if TYPE_CHECKING:
from databricks.sql.client import Cursor
Expand Down Expand Up @@ -929,6 +929,7 @@ def execute_command(
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
row_limit: Optional[int] = None,
) -> Union["ResultSet", None]:
thrift_handle = session_id.to_thrift_handle()
if not thrift_handle:
Expand Down
26 changes: 18 additions & 8 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def cursor(
self,
arraysize: int = DEFAULT_ARRAY_SIZE,
buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
row_limit: Optional[int] = None,
) -> "Cursor":
"""
Return a new Cursor object using the connection.
Expand All @@ -355,6 +356,7 @@ def cursor(
self.session.backend,
arraysize=arraysize,
result_buffer_size_bytes=buffer_size_bytes,
row_limit=row_limit,
)
self._cursors.append(cursor)
return cursor
Expand Down Expand Up @@ -388,6 +390,7 @@ def __init__(
backend: DatabricksClient,
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
arraysize: int = DEFAULT_ARRAY_SIZE,
row_limit: Optional[int] = None,
) -> None:
"""
These objects represent a database cursor, which is used to manage the context of a fetch
Expand All @@ -397,16 +400,23 @@ def __init__(
visible by other cursors or connections.
"""

self.connection = connection
self.rowcount = -1 # Return -1 as this is not supported
self.buffer_size_bytes = result_buffer_size_bytes
self.connection: Connection = connection

if not connection.session.use_sea and row_limit is not None:
logger.warning(
"Row limit is only supported for SEA protocol. Ignoring row_limit."
)

self.rowcount: int = -1 # Return -1 as this is not supported
self.buffer_size_bytes: int = result_buffer_size_bytes
self.active_result_set: Union[ResultSet, None] = None
self.arraysize = arraysize
self.arraysize: int = arraysize
self.row_limit: Optional[int] = row_limit
# Note that Cursor closed => active result set closed, but not vice versa
self.open = True
self.executing_command_id = None
self.backend = backend
self.active_command_id = None
self.open: bool = True
self.executing_command_id: Optional[CommandId] = None
self.backend: DatabricksClient = backend
self.active_command_id: Optional[CommandId] = None
self.escaper = ParamEscaper()
self.lastrowid = None

Expand Down
5 changes: 3 additions & 2 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.use_sea = kwargs.get("use_sea", False)
self.backend = self._create_backend(
self.use_sea,
server_hostname,
http_path,
all_headers,
Expand All @@ -89,6 +91,7 @@ def __init__(

def _create_backend(
self,
use_sea: bool,
server_hostname: str,
http_path: str,
all_headers: List[Tuple[str, str]],
Expand All @@ -97,8 +100,6 @@ def _create_backend(
kwargs: dict,
) -> DatabricksClient:
"""Create and return the appropriate backend client."""
use_sea = kwargs.get("use_sea", False)

databricks_client_class: Type[DatabricksClient]
if use_sea:
logger.debug("Creating SEA backend client")
Expand Down
Loading