1+ from contextlib import contextmanager
12import unittest
23from unittest .mock import Mock , patch , MagicMock
34
45import requests
56
67import databricks .sql .cloudfetch .downloader as downloader
8+ from databricks .sql .common .http import DatabricksHttpClient
79from databricks .sql .exc import Error
810from databricks .sql .types import SSLOptions
911
@@ -12,6 +14,7 @@ def create_response(**kwargs) -> requests.Response:
1214 result = requests .Response ()
1315 for k , v in kwargs .items ():
1416 setattr (result , k , v )
17+ result .close = Mock ()
1518 return result
1619
1720
@@ -52,91 +55,94 @@ def test_run_link_past_expiry_buffer(self, mock_time):
5255
5356 mock_time .assert_called_once ()
5457
55- @patch ("requests.Session" , return_value = MagicMock (get = MagicMock (return_value = None )))
5658 @patch ("time.time" , return_value = 1000 )
57- def test_run_get_response_not_ok (self , mock_time , mock_session ):
58- mock_session .return_value .get .return_value = create_response (status_code = 404 )
59-
59+ def test_run_get_response_not_ok (self , mock_time ):
60+ http_client = DatabricksHttpClient .get_instance ()
6061 settings = Mock (link_expiry_buffer_secs = 0 , download_timeout = 0 )
6162 settings .download_timeout = 0
6263 settings .use_proxy = False
6364 result_link = Mock (expiryTime = 1001 )
6465
65- d = downloader .ResultSetDownloadHandler (
66- settings , result_link , ssl_options = SSLOptions ()
67- )
68- with self .assertRaises (requests .exceptions .HTTPError ) as context :
69- d .run ()
70- self .assertTrue ("404" in str (context .exception ))
66+ with patch .object (
67+ http_client ,
68+ "execute" ,
69+ return_value = create_response (status_code = 404 , _content = b"1234" ),
70+ ):
71+ d = downloader .ResultSetDownloadHandler (
72+ settings , result_link , ssl_options = SSLOptions ()
73+ )
74+ with self .assertRaises (requests .exceptions .HTTPError ) as context :
75+ d .run ()
76+ self .assertTrue ("404" in str (context .exception ))
7177
72- @patch ("requests.Session" , return_value = MagicMock (get = MagicMock (return_value = None )))
7378 @patch ("time.time" , return_value = 1000 )
74- def test_run_uncompressed_successful (self , mock_time , mock_session ):
79+ def test_run_uncompressed_successful (self , mock_time ):
80+ http_client = DatabricksHttpClient .get_instance ()
7581 file_bytes = b"1234567890" * 10
76- mock_session .return_value .get .return_value = create_response (
77- status_code = 200 , _content = file_bytes
78- )
79-
8082 settings = Mock (link_expiry_buffer_secs = 0 , download_timeout = 0 , use_proxy = False )
8183 settings .is_lz4_compressed = False
8284 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
8385
84- d = downloader .ResultSetDownloadHandler (
85- settings , result_link , ssl_options = SSLOptions ()
86- )
87- file = d .run ()
86+ with patch .object (
87+ http_client ,
88+ "execute" ,
89+ return_value = create_response (status_code = 200 , _content = file_bytes ),
90+ ):
91+ d = downloader .ResultSetDownloadHandler (
92+ settings , result_link , ssl_options = SSLOptions ()
93+ )
94+ file = d .run ()
8895
89- assert file .file_bytes == b"1234567890" * 10
96+ assert file .file_bytes == b"1234567890" * 10
9097
91- @patch (
92- "requests.Session" ,
93- return_value = MagicMock (get = MagicMock (return_value = MagicMock (ok = True ))),
94- )
9598 @patch ("time.time" , return_value = 1000 )
96- def test_run_compressed_successful (self , mock_time , mock_session ):
99+ def test_run_compressed_successful (self , mock_time ):
100+ http_client = DatabricksHttpClient .get_instance ()
97101 file_bytes = b"1234567890" * 10
98102 compressed_bytes = b'\x04 "M\x18 h@d\x00 \x00 \x00 \x00 \x00 \x00 \x00 #\x14 \x00 \x00 \x00 \xaf 1234567890\n \x00 BP67890\x00 \x00 \x00 \x00 '
99- mock_session .return_value .get .return_value = create_response (
100- status_code = 200 , _content = compressed_bytes
101- )
102103
103104 settings = Mock (link_expiry_buffer_secs = 0 , download_timeout = 0 , use_proxy = False )
104105 settings .is_lz4_compressed = True
105106 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
107+ with patch .object (
108+ http_client ,
109+ "execute" ,
110+ return_value = create_response (status_code = 200 , _content = compressed_bytes ),
111+ ):
112+ d = downloader .ResultSetDownloadHandler (
113+ settings , result_link , ssl_options = SSLOptions ()
114+ )
115+ file = d .run ()
116+
117+ assert file .file_bytes == b"1234567890" * 10
106118
107- d = downloader .ResultSetDownloadHandler (
108- settings , result_link , ssl_options = SSLOptions ()
109- )
110- file = d .run ()
111-
112- assert file .file_bytes == b"1234567890" * 10
113-
114- @patch ("requests.Session.get" , side_effect = ConnectionError ("foo" ))
115119 @patch ("time.time" , return_value = 1000 )
116- def test_download_connection_error (self , mock_time , mock_session ):
120+ def test_download_connection_error (self , mock_time ):
121+
122+ http_client = DatabricksHttpClient .get_instance ()
117123 settings = Mock (
118124 link_expiry_buffer_secs = 0 , use_proxy = False , is_lz4_compressed = True
119125 )
120126 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
121- mock_session .return_value .get .return_value .content = b'\x04 "M\x18 h@d\x00 \x00 \x00 \x00 \x00 \x00 \x00 #\x14 \x00 \x00 \x00 \xaf 1234567890\n \x00 BP67890\x00 \x00 \x00 \x00 '
122127
123- d = downloader .ResultSetDownloadHandler (
124- settings , result_link , ssl_options = SSLOptions ()
125- )
126- with self .assertRaises (ConnectionError ):
127- d .run ()
128+ with patch .object (http_client , "execute" , side_effect = ConnectionError ("foo" )):
129+ d = downloader .ResultSetDownloadHandler (
130+ settings , result_link , ssl_options = SSLOptions ()
131+ )
132+ with self .assertRaises (ConnectionError ):
133+ d .run ()
128134
129- @patch ("requests.Session.get" , side_effect = TimeoutError ("foo" ))
130135 @patch ("time.time" , return_value = 1000 )
131- def test_download_timeout (self , mock_time , mock_session ):
136+ def test_download_timeout (self , mock_time ):
137+ http_client = DatabricksHttpClient .get_instance ()
132138 settings = Mock (
133139 link_expiry_buffer_secs = 0 , use_proxy = False , is_lz4_compressed = True
134140 )
135141 result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
136- mock_session .return_value .get .return_value .content = b'\x04 "M\x18 h@d\x00 \x00 \x00 \x00 \x00 \x00 \x00 #\x14 \x00 \x00 \x00 \xaf 1234567890\n \x00 BP67890\x00 \x00 \x00 \x00 '
137142
138- d = downloader .ResultSetDownloadHandler (
139- settings , result_link , ssl_options = SSLOptions ()
140- )
141- with self .assertRaises (TimeoutError ):
142- d .run ()
143+ with patch .object (http_client , "execute" , side_effect = TimeoutError ("foo" )):
144+ d = downloader .ResultSetDownloadHandler (
145+ settings , result_link , ssl_options = SSLOptions ()
146+ )
147+ with self .assertRaises (TimeoutError ):
148+ d .run ()
0 commit comments