11import logging
22
3- from concurrent .futures import ThreadPoolExecutor
4- from dataclasses import dataclass
3+ from concurrent .futures import ThreadPoolExecutor , Future
54from typing import List , Union
65
76from databricks .sql .cloudfetch .downloader import (
87 ResultSetDownloadHandler ,
98 DownloadableResultSettings ,
9+ DownloadedFile ,
1010)
1111from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
1212
1313logger = logging .getLogger (__name__ )
1414
1515
16- @dataclass
17- class DownloadedFile :
18- """
19- Class for the result file and metadata.
20-
21- Attributes:
22- file_bytes (bytes): Downloaded file in bytes.
23- start_row_offset (int): The offset of the starting row in relation to the full result.
24- row_count (int): Number of rows the file represents in the result.
25- """
26-
27- file_bytes : bytes
28- start_row_offset : int
29- row_count : int
30-
31-
3216class ResultFileDownloadManager :
33- def __init__ (self , max_download_threads : int , lz4_compressed : bool ):
34- self .download_handlers : List [ResultSetDownloadHandler ] = []
35- self .thread_pool = ThreadPoolExecutor (max_workers = max_download_threads + 1 )
36- self .downloadable_result_settings = DownloadableResultSettings (lz4_compressed )
37- self .fetch_need_retry = False
38- self .num_consecutive_result_file_download_retries = 0
39-
40- def add_file_links (
41- self , t_spark_arrow_result_links : List [TSparkArrowResultLink ]
42- ) -> None :
43- """
44- Create download handler for each cloud fetch link.
45-
46- Args:
47- t_spark_arrow_result_links: List of cloud fetch links consisting of file URL and metadata.
48- """
49- for link in t_spark_arrow_result_links :
17+ def __init__ (
18+ self ,
19+ links : List [TSparkArrowResultLink ],
20+ max_download_threads : int ,
21+ lz4_compressed : bool ,
22+ ):
23+ self ._pending_links : List [TSparkArrowResultLink ] = []
24+ for link in links :
5025 if link .rowCount <= 0 :
5126 continue
5227 logger .debug (
53- "ResultFileDownloadManager.add_file_links: start offset {}, row count: {}" .format (
28+ "ResultFileDownloadManager: adding file link, start offset {}, row count: {}" .format (
5429 link .startRowOffset , link .rowCount
5530 )
5631 )
57- self .download_handlers .append (
58- ResultSetDownloadHandler (self .downloadable_result_settings , link )
59- )
32+ self ._pending_links .append (link )
33+
34+ self ._download_tasks : List [Future [DownloadedFile ]] = []
35+ self ._max_download_threads : int = max_download_threads
36+ self ._thread_pool = ThreadPoolExecutor (max_workers = self ._max_download_threads )
37+
38+ self ._downloadable_result_settings = DownloadableResultSettings (lz4_compressed )
6039
6140 def get_next_downloaded_file (
6241 self , next_row_offset : int
@@ -73,143 +52,49 @@ def get_next_downloaded_file(
7352 Args:
7453 next_row_offset (int): The offset of the starting row of the next file we want data from.
7554 """
76- # No more files to download from this batch of links
77- if not self .download_handlers :
78- self ._shutdown_manager ()
79- return None
80-
81- # Remove handlers we don't need anymore
82- self ._remove_past_handlers (next_row_offset )
8355
84- # Schedule the downloads
56+ # Make sure the download queue is always full
8557 self ._schedule_downloads ()
8658
87- # Find next file
88- idx = self ._find_next_file_index (next_row_offset )
89- if idx is None :
59+ # No more files to download from this batch of links
60+ if len (self ._download_tasks ) == 0 :
9061 self ._shutdown_manager ()
9162 return None
92- handler = self .download_handlers [idx ]
9363
94- # Check (and wait) for download status
95- if self ._check_if_download_successful (handler ):
96- link = handler .result_link
97- logger .debug (
98- "ResultFileDownloadManager: file found for row index {}: start {}, row count: {}" .format (
99- next_row_offset , link .startRowOffset , link .rowCount
100- )
101- )
102- # Buffer should be empty so set buffer to new ArrowQueue with result_file
103- result = DownloadedFile (
104- handler .result_file ,
105- handler .result_link .startRowOffset ,
106- handler .result_link .rowCount ,
107- )
108- self .download_handlers .pop (idx )
109- # Return True upon successful download to continue loop and not force a retry
110- return result
111- else :
64+ task = self ._download_tasks .pop (0 )
65+ # Future's `result()` method will wait for the call to complete, and return
66+ # the value returned by the call. If the call throws an exception - `result()`
67+ # will throw the same exception
68+ file = task .result ()
69+ if (next_row_offset < file .start_row_offset ) or (
70+ next_row_offset > file .start_row_offset + file .row_count
71+ ):
11272 logger .debug (
113- "ResultFileDownloadManager: cannot find file for row index {}" .format (
114- next_row_offset
73+ "ResultFileDownloadManager: file does not contain row {}, start {}, row count {}" .format (
74+ next_row_offset , file . start_row_offset , file . row_count
11575 )
11676 )
11777
118- # Download was not successful for next download item, force a retry
119- self ._shutdown_manager ()
120- return None
121-
122- def _remove_past_handlers (self , next_row_offset : int ):
123- logger .debug (
124- "ResultFileDownloadManager: removing past handlers, current offset: {}" .format (
125- next_row_offset
126- )
127- )
128- # Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
129- i = 0
130- while i < len (self .download_handlers ):
131- result_link = self .download_handlers [i ].result_link
132- logger .debug (
133- "- checking result link: start {}, row count: {}, current offset: {}" .format (
134- result_link .startRowOffset , result_link .rowCount , next_row_offset
135- )
136- )
137- if result_link .startRowOffset + result_link .rowCount > next_row_offset :
138- i += 1
139- continue
140- self .download_handlers .pop (i )
78+ return file
14179
14280 def _schedule_downloads (self ):
143- # Schedule downloads for all download handlers if not already scheduled.
81+ """
82+ While download queue has a capacity, peek pending links and submit them to thread pool.
83+ """
14484 logger .debug ("ResultFileDownloadManager: schedule downloads" )
145- for handler in self .download_handlers :
146- if handler .is_download_scheduled :
147- continue
148- try :
149- logger .debug (
150- "- start: {}, row count: {}" .format (
151- handler .result_link .startRowOffset , handler .result_link .rowCount
152- )
153- )
154- self .thread_pool .submit (handler .run )
155- except Exception as e :
156- logger .error (e )
157- break
158- handler .is_download_scheduled = True
159-
160- def _find_next_file_index (self , next_row_offset : int ):
161- logger .debug (
162- "ResultFileDownloadManager: trying to find file for row {}" .format (
163- next_row_offset
164- )
165- )
166- # Get the handler index of the next file in order
167- next_indices = [
168- i
169- for i , handler in enumerate (self .download_handlers )
170- if handler .is_download_scheduled
171- # TODO: shouldn't `next_row_offset` be tested against the range, not just start row offset?
172- and handler .result_link .startRowOffset == next_row_offset
173- ]
174-
175- for i in next_indices :
176- link = self .download_handlers [i ].result_link
85+ while (len (self ._download_tasks ) < self ._max_download_threads ) and (
86+ len (self ._pending_links ) > 0
87+ ):
88+ link = self ._pending_links .pop (0 )
17789 logger .debug (
178- "- found file: start {}, row count {}" .format (
179- link .startRowOffset , link .rowCount
180- )
90+ "- start: {}, row count: {}" .format (link .startRowOffset , link .rowCount )
18191 )
182-
183- return next_indices [0 ] if len (next_indices ) > 0 else None
184-
185- def _check_if_download_successful (self , handler : ResultSetDownloadHandler ):
186- # Check (and wait until download finishes) if download was successful
187- if not handler .is_file_download_successful ():
188- if handler .is_link_expired :
189- self .fetch_need_retry = True
190- return False
191- elif handler .is_download_timedout :
192- # Consecutive file retries should not exceed threshold in settings
193- if (
194- self .num_consecutive_result_file_download_retries
195- >= self .downloadable_result_settings .max_consecutive_file_download_retries
196- ):
197- self .fetch_need_retry = True
198- return False
199- self .num_consecutive_result_file_download_retries += 1
200-
201- # Re-submit handler run to thread pool and recursively check download status
202- self .thread_pool .submit (handler .run )
203- return self ._check_if_download_successful (handler )
204- else :
205- self .fetch_need_retry = True
206- return False
207-
208- self .num_consecutive_result_file_download_retries = 0
209- self .fetch_need_retry = False
210- return True
92+ handler = ResultSetDownloadHandler (self ._downloadable_result_settings , link )
93+ task = self ._thread_pool .submit (handler .run )
94+ self ._download_tasks .append (task )
21195
21296 def _shutdown_manager (self ):
21397 # Clear download handlers and shutdown the thread pool
214- self .download_handlers = []
215- self .thread_pool .shutdown (wait = False )
98+ self ._pending_links = []
99+ self ._download_tasks = []
100+ self ._thread_pool .shutdown (wait = False )
0 commit comments