11import logging
22from dataclasses import dataclass
3-
43import requests
54import lz4 .frame
65import threading
76import time
8-
7+ import os
8+ import re
99from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
1010
1111logger = logging .getLogger (__name__ )
1212
13+ DEFAULT_CLOUD_FILE_TIMEOUT = int (os .getenv ("DATABRICKS_CLOUD_FILE_TIMEOUT" , 60 ))
14+
1315
1416@dataclass
1517class DownloadableResultSettings :
@@ -20,13 +22,17 @@ class DownloadableResultSettings:
2022 is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
2123 link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
2224 download_timeout (int): Timeout for download requests. Default 60 secs.
23- max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
25+ download_max_retries (int): Number of consecutive download retries before shutting down.
26+ max_retries (int): Number of consecutive download retries before shutting down.
27+ backoff_factor (int): Factor to increase wait time between retries.
28+
2429 """
2530
2631 is_lz4_compressed : bool
2732 link_expiry_buffer_secs : int = 0
28- download_timeout : int = 60
29- max_consecutive_file_download_retries : int = 0
33+ download_timeout : int = DEFAULT_CLOUD_FILE_TIMEOUT
34+ max_retries : int = 5
35+ backoff_factor : int = 2
3036
3137
3238class ResultSetDownloadHandler (threading .Thread ):
@@ -57,16 +63,21 @@ def is_file_download_successful(self) -> bool:
5763 else None
5864 )
5965 try :
66+ logger .debug (
67+ f"waiting for at most { timeout } seconds for download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
68+ )
69+
6070 if not self .is_download_finished .wait (timeout = timeout ):
6171 self .is_download_timedout = True
62- logger .debug (
63- "Cloud fetch download timed out after {} seconds for link representing rows {} to {}" .format (
64- self .settings .download_timeout ,
65- self .result_link .startRowOffset ,
66- self .result_link .startRowOffset + self .result_link .rowCount ,
67- )
72+ logger .error (
73+ f"cloud fetch download timed out after { self .settings .download_timeout } seconds for link representing rows { self .result_link .startRowOffset } to { self .result_link .startRowOffset + self .result_link .rowCount } "
6874 )
69- return False
75+ # there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
76+ return self .is_file_downloaded_successfully
77+
78+ logger .debug (
79+ f"finish waiting for download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
80+ )
7081 except Exception as e :
7182 logger .error (e )
7283 return False
@@ -81,24 +92,36 @@ def run(self):
8192 """
8293 self ._reset ()
8394
84- # Check if link is already expired or is expiring
85- if ResultSetDownloadHandler .check_link_expired (
86- self .result_link , self .settings .link_expiry_buffer_secs
87- ):
88- self .is_link_expired = True
89- return
95+ try :
96+ # Check if link is already expired or is expiring
97+ if ResultSetDownloadHandler .check_link_expired (
98+ self .result_link , self .settings .link_expiry_buffer_secs
99+ ):
100+ self .is_link_expired = True
101+ return
90102
91- session = requests .Session ()
92- session .timeout = self .settings .download_timeout
103+ logger .debug (
104+ f"started to download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
105+ )
93106
94- try :
95107 # Get the file via HTTP request
96- response = session .get (self .result_link .fileLink )
108+ response = http_get_with_retry (
109+ url = self .result_link .fileLink ,
110+ max_retries = self .settings .max_retries ,
111+ backoff_factor = self .settings .backoff_factor ,
112+ download_timeout = self .settings .download_timeout ,
113+ )
97114
98- if not response .ok :
99- self .is_file_downloaded_successfully = False
115+ if not response :
116+ logger .error (
117+ f"failed downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
118+ )
100119 return
101120
121+ logger .debug (
122+ f"success downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
123+ )
124+
102125 # Save (and decompress if needed) the downloaded file
103126 compressed_data = response .content
104127 decompressed_data = (
@@ -109,15 +132,22 @@ def run(self):
109132 self .result_file = decompressed_data
110133
111134 # The size of the downloaded file should match the size specified from TSparkArrowResultLink
112- self .is_file_downloaded_successfully = (
113- len (self .result_file ) == self .result_link .bytesNum
135+ success = len (self .result_file ) == self .result_link .bytesNum
136+ logger .debug (
137+ f"download successful file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
114138 )
139+ self .is_file_downloaded_successfully = success
115140 except Exception as e :
141+ logger .error (
142+ f"exception downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
143+ )
116144 logger .error (e )
117145 self .is_file_downloaded_successfully = False
118146
119147 finally :
120- session and session .close ()
148+ logger .debug (
149+ f"signal finished file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
150+ )
121151 # Awaken threads waiting for this to be true which signals the run is complete
122152 self .is_download_finished .set ()
123153
@@ -145,6 +175,7 @@ def check_link_expired(
145175 link .expiryTime < current_time
146176 or link .expiryTime - current_time < expiry_buffer_secs
147177 ):
178+ logger .debug ("link expired" )
148179 return True
149180 return False
150181
@@ -171,3 +202,38 @@ def decompress_data(compressed_data: bytes) -> bytes:
171202 uncompressed_data += data
172203 start += num_bytes
173204 return uncompressed_data
205+
206+
207+ def http_get_with_retry (url , max_retries = 5 , backoff_factor = 2 , download_timeout = 60 ):
208+ attempts = 0
209+ pattern = re .compile (r"(\?|&)([\w-]+)=([^&\s]+)" )
210+ mask = r"\1\2=<REDACTED>"
211+
212+ # TODO: introduce connection pooling. I am seeing weird errors without it.
213+ while attempts < max_retries :
214+ try :
215+ session = requests .Session ()
216+ session .timeout = download_timeout
217+ response = session .get (url )
218+
219+ # Check if the response status code is in the 2xx range for success
220+ if response .status_code == 200 :
221+ return response
222+ else :
223+ logger .error (response )
224+ except requests .RequestException as e :
225+ # if this is not redacted, it will print the pre-signed URL
226+ logger .error (f"request failed with exception: { re .sub (pattern , mask , str (e ))} " )
227+ finally :
228+ session .close ()
229+ # Exponential backoff before the next attempt
230+ wait_time = backoff_factor ** attempts
231+ logger .info (f"retrying in { wait_time } seconds..." )
232+ time .sleep (wait_time )
233+
234+ attempts += 1
235+
236+ logger .error (
237+ f"exceeded maximum number of retries ({ max_retries } ) while downloading result."
238+ )
239+ return None
0 commit comments