Skip to content

Commit 0066952

Browse files
authored
Merge pull request #230 from lorenzorubi-db/master
configurable timeout for REST API
2 parents 33257e7 + f91b6a1 commit 0066952

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

dbclient/dbclient.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, configs):
6767
self._local = threading.local()
6868
self._retry_total = configs['retry_total']
6969
self._retry_backoff = configs['retry_backoff']
70+
self._timeout = configs['timeout']
7071
if configs['debug']:
7172
logging.getLogger("urllib3").setLevel(logging.DEBUG)
7273
if self._verify_ssl:
@@ -100,6 +101,9 @@ def is_skip_failed(self):
100101
def get_file_format(self):
101102
return self._file_format
102103

104+
def get_timeout(self):
105+
return self._timeout
106+
103107
def is_source_file_format(self):
104108
if self._file_format == 'SOURCE':
105109
return True
@@ -190,7 +194,7 @@ def req_session(self):
190194
self._local.session = session
191195
return self._local.session
192196

193-
def get(self, endpoint, json_params=None, version='2.0', print_json=False, do_not_throw=False, timeout=10):
197+
def get(self, endpoint, json_params=None, version='2.0', print_json=False, do_not_throw=False):
194198
if version:
195199
ver = version
196200
while True:
@@ -199,11 +203,12 @@ def get(self, endpoint, json_params=None, version='2.0', print_json=False, do_no
199203
print("Get: {0}".format(full_endpoint))
200204
if json_params:
201205
raw_results = self.req_session().get(
202-
full_endpoint, headers=self._token, params=json_params, verify=self._verify_ssl, timeout=timeout
206+
full_endpoint, headers=self._token, params=json_params, verify=self._verify_ssl,
207+
timeout=self.get_timeout()
203208
)
204209
else:
205210
raw_results = self.req_session().get(
206-
full_endpoint, headers=self._token, verify=self._verify_ssl, timeout=timeout
211+
full_endpoint, headers=self._token, verify=self._verify_ssl, timeout=self.get_timeout()
207212
)
208213

209214
if self._should_retry_with_new_token(raw_results):
@@ -222,7 +227,7 @@ def get(self, endpoint, json_params=None, version='2.0', print_json=False, do_no
222227
results['http_status_code'] = http_status_code
223228
return results
224229

225-
def http_req(self, http_type, endpoint, json_params, version='2.0', print_json=False, files_json=None, timeout=10):
230+
def http_req(self, http_type, endpoint, json_params, version='2.0', print_json=False, files_json=None):
226231
if version:
227232
ver = version
228233
while True:
@@ -234,22 +239,22 @@ def http_req(self, http_type, endpoint, json_params, version='2.0', print_json=F
234239
if files_json:
235240
raw_results = self.req_session().post(
236241
full_endpoint, headers=self._token, data=json_params, files=files_json,
237-
verify=self._verify_ssl, timeout=timeout
242+
verify=self._verify_ssl, timeout=self.get_timeout()
238243
)
239244
else:
240245
raw_results = self.req_session().post(
241246
full_endpoint, headers=self._token, json=json_params, verify=self._verify_ssl,
242-
timeout=timeout
247+
timeout=self.get_timeout()
243248
)
244249
if http_type == 'put':
245250
raw_results = self.req_session().put(
246251
full_endpoint, headers=self._token, json=json_params, verify=self._verify_ssl,
247-
timeout=timeout
252+
timeout=self.get_timeout()
248253
)
249254
if http_type == 'patch':
250255
raw_results = self.req_session().patch(
251256
full_endpoint, headers=self._token, json=json_params, verify=self._verify_ssl,
252-
timeout=timeout
257+
timeout=self.get_timeout()
253258
)
254259
else:
255260
print("Must have a payload in json_args param.")

dbclient/parser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,4 +554,7 @@ def get_pipeline_parser() -> argparse.ArgumentParser:
554554
parser.add_argument('--groups-to-keep', nargs='+', type=str, default=[],
555555
help='List of groups (and therefore users/notebooks) to keep if specified')
556556

557+
parser.add_argument('--timeout', type=float, default=10.0,
558+
help='Timeout for the calls to Databricks\' REST API, in seconds, defaults to 10.0 --use float e.g. 100.0 to make it bigger')
559+
557560
return parser

migration_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def build_pipeline(args) -> Pipeline:
5757

5858
client_config['verbose'] = args.verbose
5959

60+
client_config['timeout'] = args.timeout
61+
6062
if not args.dry_run:
6163
os.makedirs(client_config['export_dir'], exist_ok=True)
6264

0 commit comments

Comments
 (0)