Skip to content

Commit e8f0e9a

Browse files
Merge branch 'master' into table_acls_in_chunks
2 parents 5e79ca3 + 0066952 commit e8f0e9a

File tree

4 files changed

+21
-10
lines changed

4 files changed

+21
-10
lines changed

dbclient/HiveClient.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def update_table_ddl(self, local_table_path, db_path):
102102
# check if the database location / path is the default DBFS path
103103
table_name = os.path.basename(local_table_path)
104104
is_db_default_path = db_path.startswith('dbfs:/user/hive/warehouse')
105-
if (not is_db_default_path) and (not self.is_table_location_defined(local_table_path)):
105+
ddl_statement = self.get_ddl_by_keyword_group(local_table_path)
106+
if (not is_db_default_path) and (not self.is_table_location_defined(local_table_path)) and (not self.is_ddl_a_view(ddl_statement)):
106107
# the LOCATION attribute is not defined and the Database has a custom location defined
107108
# therefore we need to add it to the DDL, e.g. dbfs:/db_path/table_name
108109
table_path = db_path + '/' + table_name
@@ -656,4 +657,4 @@ def get_or_launch_cluster(self, cluster_name=None):
656657
# is not persisted and checkpointed on system crash.
657658
def _persist_to_disk(self, fp):
658659
fp.flush()
659-
os.fsync(fp.fileno())
660+
os.fsync(fp.fileno())

dbclient/dbclient.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, configs):
6767
self._local = threading.local()
6868
self._retry_total = configs['retry_total']
6969
self._retry_backoff = configs['retry_backoff']
70+
self._timeout = configs['timeout']
7071
if configs['debug']:
7172
logging.getLogger("urllib3").setLevel(logging.DEBUG)
7273
if self._verify_ssl:
@@ -100,6 +101,9 @@ def is_skip_failed(self):
100101
def get_file_format(self):
101102
return self._file_format
102103

104+
def get_timeout(self):
105+
return self._timeout
106+
103107
def is_source_file_format(self):
104108
if self._file_format == 'SOURCE':
105109
return True
@@ -190,7 +194,7 @@ def req_session(self):
190194
self._local.session = session
191195
return self._local.session
192196

193-
def get(self, endpoint, json_params=None, version='2.0', print_json=False, do_not_throw=False, timeout=10):
197+
def get(self, endpoint, json_params=None, version='2.0', print_json=False, do_not_throw=False):
194198
if version:
195199
ver = version
196200
while True:
@@ -199,11 +203,12 @@ def get(self, endpoint, json_params=None, version='2.0', print_json=False, do_no
199203
print("Get: {0}".format(full_endpoint))
200204
if json_params:
201205
raw_results = self.req_session().get(
202-
full_endpoint, headers=self._token, params=json_params, verify=self._verify_ssl, timeout=timeout
206+
full_endpoint, headers=self._token, params=json_params, verify=self._verify_ssl,
207+
timeout=self.get_timeout()
203208
)
204209
else:
205210
raw_results = self.req_session().get(
206-
full_endpoint, headers=self._token, verify=self._verify_ssl, timeout=timeout
211+
full_endpoint, headers=self._token, verify=self._verify_ssl, timeout=self.get_timeout()
207212
)
208213

209214
if self._should_retry_with_new_token(raw_results):
@@ -222,7 +227,7 @@ def get(self, endpoint, json_params=None, version='2.0', print_json=False, do_no
222227
results['http_status_code'] = http_status_code
223228
return results
224229

225-
def http_req(self, http_type, endpoint, json_params, version='2.0', print_json=False, files_json=None, timeout=10):
230+
def http_req(self, http_type, endpoint, json_params, version='2.0', print_json=False, files_json=None):
226231
if version:
227232
ver = version
228233
while True:
@@ -234,22 +239,22 @@ def http_req(self, http_type, endpoint, json_params, version='2.0', print_json=F
234239
if files_json:
235240
raw_results = self.req_session().post(
236241
full_endpoint, headers=self._token, data=json_params, files=files_json,
237-
verify=self._verify_ssl, timeout=timeout
242+
verify=self._verify_ssl, timeout=self.get_timeout()
238243
)
239244
else:
240245
raw_results = self.req_session().post(
241246
full_endpoint, headers=self._token, json=json_params, verify=self._verify_ssl,
242-
timeout=timeout
247+
timeout=self.get_timeout()
243248
)
244249
if http_type == 'put':
245250
raw_results = self.req_session().put(
246251
full_endpoint, headers=self._token, json=json_params, verify=self._verify_ssl,
247-
timeout=timeout
252+
timeout=self.get_timeout()
248253
)
249254
if http_type == 'patch':
250255
raw_results = self.req_session().patch(
251256
full_endpoint, headers=self._token, json=json_params, verify=self._verify_ssl,
252-
timeout=timeout
257+
timeout=self.get_timeout()
253258
)
254259
else:
255260
print("Must have a payload in json_args param.")

dbclient/parser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,4 +554,7 @@ def get_pipeline_parser() -> argparse.ArgumentParser:
554554
parser.add_argument('--groups-to-keep', nargs='+', type=str, default=[],
555555
help='List of groups (and therefore users/notebooks) to keep if specified')
556556

557+
parser.add_argument('--timeout', type=float, default=10.0,
558+
help='Timeout for the calls to Databricks\' REST API, in seconds, defaults to 10.0 --use float e.g. 100.0 to make it bigger')
559+
557560
return parser

migration_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def build_pipeline(args) -> Pipeline:
5757

5858
client_config['verbose'] = args.verbose
5959

60+
client_config['timeout'] = args.timeout
61+
6062
if not args.dry_run:
6163
os.makedirs(client_config['export_dir'], exist_ok=True)
6264

0 commit comments

Comments
 (0)