Skip to content

Commit 4bd290e

Browse files
simplify download process: no pre-fetching
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 811205e commit 4bd290e

File tree

4 files changed

+36
-110
lines changed

4 files changed

+36
-110
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ def __init__(
170170

171171
# Track the current chunk we're processing
172172
self._current_chunk_link: Optional["ExternalLink"] = initial_link
173-
self._download_current_link()
174173

175174
# Initialize table and position
176175
self.table = self._create_next_table()
@@ -188,18 +187,6 @@ def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink
188187
httpHeaders=link.http_headers or {},
189188
)
190189

191-
def _download_current_link(self):
192-
"""Download the current chunk link."""
193-
if not self._current_chunk_link:
194-
return None
195-
196-
if not self.download_manager:
197-
logger.debug("SeaCloudFetchQueue: No download manager, returning")
198-
return None
199-
200-
thrift_link = self._convert_to_thrift_link(self._current_chunk_link)
201-
self.download_manager.add_link(thrift_link)
202-
203190
def _progress_chunk_link(self):
204191
"""Progress to the next chunk link."""
205192
if not self._current_chunk_link:
@@ -221,19 +208,26 @@ def _progress_chunk_link(self):
221208
next_chunk_index, e
222209
)
223210
)
211+
self._current_chunk_link = None
224212
return None
225213

226214
logger.debug(
227215
f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}"
228216
)
229-
self._download_current_link()
230217

231218
def _create_next_table(self) -> Union["pyarrow.Table", None]:
232219
"""Create next table by retrieving the logical next downloaded file."""
233220
if not self._current_chunk_link:
234221
logger.debug("SeaCloudFetchQueue: No current chunk link, returning")
235222
return None
236223

224+
if not self.download_manager:
225+
logger.debug("SeaCloudFetchQueue: No download manager, returning")
226+
return None
227+
228+
thrift_link = self._convert_to_thrift_link(self._current_chunk_link)
229+
self.download_manager.add_link(thrift_link)
230+
237231
row_offset = self._current_chunk_link.row_offset
238232
arrow_table = self._create_table_at_offset(row_offset)
239233

src/databricks/sql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def __init__(
237237
self.table = None
238238
self.table_row_index = 0
239239

240-
# Initialize download manager - will be set by subclasses
240+
# Initialize download manager
241241
self.download_manager: Optional["ResultFileDownloadManager"] = None
242242

243243
def remaining_rows(self) -> "pyarrow.Table":

tests/unit/test_sea_queue.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -487,76 +487,6 @@ def test_init_non_zero_chunk_index(
487487
# Verify download manager wasn't created (no chunk 0)
488488
mock_download_manager_class.assert_not_called()
489489

490-
@patch("databricks.sql.backend.sea.queue.logger")
491-
def test_download_current_link_no_current_link(self, mock_logger):
492-
"""Test _download_current_link with no current link."""
493-
# Create a queue instance without initializing
494-
queue = Mock(spec=SeaCloudFetchQueue)
495-
queue._current_chunk_link = None
496-
497-
# Call the method directly
498-
result = SeaCloudFetchQueue._download_current_link(queue)
499-
500-
# Verify the result is None
501-
assert result is None
502-
503-
@patch("databricks.sql.backend.sea.queue.logger")
504-
def test_download_current_link_no_download_manager(
505-
self, mock_logger, mock_sea_client, ssl_options
506-
):
507-
"""Test _download_current_link with no download manager."""
508-
# Create a queue instance without initializing
509-
queue = Mock(spec=SeaCloudFetchQueue)
510-
queue._current_chunk_link = ExternalLink(
511-
external_link="https://example.com/data/chunk0",
512-
expiration="2025-07-03T05:51:18.118009",
513-
row_count=100,
514-
byte_count=1024,
515-
row_offset=0,
516-
chunk_index=0,
517-
next_chunk_index=1,
518-
http_headers={"Authorization": "Bearer token123"},
519-
)
520-
queue.download_manager = None
521-
522-
# Call the method directly
523-
result = SeaCloudFetchQueue._download_current_link(queue)
524-
525-
# Verify debug message was logged
526-
mock_logger.debug.assert_called_with(
527-
"SeaCloudFetchQueue: No download manager, returning"
528-
)
529-
530-
# Verify the result is None
531-
assert result is None
532-
533-
@patch("databricks.sql.backend.sea.queue.logger")
534-
def test_download_current_link_success(self, mock_logger):
535-
"""Test _download_current_link with successful download."""
536-
# Create a queue instance without initializing
537-
queue = Mock(spec=SeaCloudFetchQueue)
538-
queue._current_chunk_link = ExternalLink(
539-
external_link="https://example.com/data/chunk0",
540-
expiration="2025-07-03T05:51:18.118009",
541-
row_count=100,
542-
byte_count=1024,
543-
row_offset=0,
544-
chunk_index=0,
545-
next_chunk_index=1,
546-
http_headers={"Authorization": "Bearer token123"},
547-
)
548-
queue.download_manager = Mock()
549-
550-
# Mock the _convert_to_thrift_link method
551-
mock_thrift_link = Mock()
552-
queue._convert_to_thrift_link = Mock(return_value=mock_thrift_link)
553-
554-
# Call the method directly
555-
SeaCloudFetchQueue._download_current_link(queue)
556-
557-
# Verify the download manager add_link was called
558-
queue.download_manager.add_link.assert_called_once_with(mock_thrift_link)
559-
560490
@patch("databricks.sql.backend.sea.queue.logger")
561491
def test_progress_chunk_link_no_current_link(self, mock_logger):
562492
"""Test _progress_chunk_link with no current link."""
@@ -610,7 +540,6 @@ def test_progress_chunk_link_success(self, mock_logger, mock_sea_client):
610540
)
611541
queue._sea_client = mock_sea_client
612542
queue._statement_id = "test-statement-123"
613-
queue._download_current_link = Mock()
614543

615544
# Setup the mock client to return a new link
616545
next_link = ExternalLink(
@@ -636,9 +565,6 @@ def test_progress_chunk_link_success(self, mock_logger, mock_sea_client):
636565
f"SeaCloudFetchQueue: Progressed to link for chunk 1: {next_link}"
637566
)
638567

639-
# Verify _download_current_link was called
640-
queue._download_current_link.assert_called_once()
641-
642568
@patch("databricks.sql.backend.sea.queue.logger")
643569
def test_progress_chunk_link_error(self, mock_logger, mock_sea_client):
644570
"""Test _progress_chunk_link with error during chunk fetch."""
@@ -710,6 +636,7 @@ def test_create_next_table_success(self, mock_logger):
710636
next_chunk_index=1,
711637
http_headers={"Authorization": "Bearer token123"},
712638
)
639+
queue.download_manager = Mock()
713640

714641
# Mock the dependencies
715642
mock_table = Mock()

tests/unit/test_thrift_field_ids.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,72 +16,77 @@ class TestThriftFieldIds:
1616

1717
# Known exceptions that exceed the field ID limit
1818
KNOWN_EXCEPTIONS = {
19-
('TExecuteStatementReq', 'enforceEmbeddedSchemaCorrectness'): 3353,
20-
('TSessionHandle', 'serverProtocolVersion'): 3329,
19+
("TExecuteStatementReq", "enforceEmbeddedSchemaCorrectness"): 3353,
20+
("TSessionHandle", "serverProtocolVersion"): 3329,
2121
}
2222

2323
def test_all_thrift_field_ids_are_within_allowed_range(self):
2424
"""
2525
Validates that all field IDs in Thrift-generated classes are within the allowed range.
26-
26+
2727
This test prevents field ID conflicts and ensures compatibility with different
2828
Thrift implementations and protocols.
2929
"""
3030
violations = []
31-
31+
3232
# Get all classes from the ttypes module
3333
for name, obj in inspect.getmembers(ttypes):
34-
if (inspect.isclass(obj) and
35-
hasattr(obj, 'thrift_spec') and
36-
obj.thrift_spec is not None):
37-
34+
if (
35+
inspect.isclass(obj)
36+
and hasattr(obj, "thrift_spec")
37+
and obj.thrift_spec is not None
38+
):
39+
3840
self._check_class_field_ids(obj, name, violations)
39-
41+
4042
if violations:
4143
error_message = self._build_error_message(violations)
4244
pytest.fail(error_message)
4345

4446
def _check_class_field_ids(self, cls, class_name, violations):
4547
"""
4648
Checks all field IDs in a Thrift class and reports violations.
47-
49+
4850
Args:
4951
cls: The Thrift class to check
5052
class_name: Name of the class for error reporting
5153
violations: List to append violation messages to
5254
"""
5355
thrift_spec = cls.thrift_spec
54-
56+
5557
if not isinstance(thrift_spec, (tuple, list)):
5658
return
57-
59+
5860
for spec_entry in thrift_spec:
5961
if spec_entry is None:
6062
continue
61-
63+
6264
# Thrift spec format: (field_id, field_type, field_name, ...)
6365
if isinstance(spec_entry, (tuple, list)) and len(spec_entry) >= 3:
6466
field_id = spec_entry[0]
6567
field_name = spec_entry[2]
66-
68+
6769
# Skip known exceptions
6870
if (class_name, field_name) in self.KNOWN_EXCEPTIONS:
6971
continue
70-
72+
7173
if isinstance(field_id, int) and field_id >= self.MAX_ALLOWED_FIELD_ID:
7274
violations.append(
7375
"{} field '{}' has field ID {} (exceeds maximum of {})".format(
74-
class_name, field_name, field_id, self.MAX_ALLOWED_FIELD_ID - 1
76+
class_name,
77+
field_name,
78+
field_id,
79+
self.MAX_ALLOWED_FIELD_ID - 1,
7580
)
7681
)
7782

7883
def _build_error_message(self, violations):
7984
"""
8085
Builds a comprehensive error message for field ID violations.
81-
86+
8287
Args:
8388
violations: List of violation messages
84-
89+
8590
Returns:
8691
Formatted error message
8792
"""
@@ -90,8 +95,8 @@ def _build_error_message(self, violations):
9095
"This can cause compatibility issues and conflicts with reserved ID ranges.\n"
9196
"Violations found:\n".format(self.MAX_ALLOWED_FIELD_ID - 1)
9297
)
93-
98+
9499
for violation in violations:
95100
error_message += " - {}\n".format(violation)
96-
97-
return error_message
101+
102+
return error_message

0 commit comments

Comments
 (0)