Skip to content

Commit 4583670

Browse files
Merge remote-tracking branch 'upstream/master'
2 parents d7bb929 + 3de69f9 commit 4583670

File tree

3 files changed

+50
-31
lines changed

3 files changed

+50
-31
lines changed

data/notebooks/Export_Table_ACLs.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,7 @@ def create_grants_df(database_name: str,object_type: str, object_key: str):
134134
return grants_df
135135

136136

137-
def create_table_ACLSs_df_for_databases(database_names: List[str]):
138-
139-
# TODO check Catalog heuristic:
140-
# if all databases are exported, we include the Catalog grants as well
141-
#. if only a few databases are exported: we exclude the Catalog
142-
if database_names is None or database_names == '':
143-
database_names = get_database_names()
144-
include_catalog = True
145-
else:
146-
include_catalog = False
147-
137+
def create_table_ACLSs_df_for_databases(database_names: List[str], include_catalog: bool):
148138
num_databases_processed = len(database_names)
149139
num_tables_or_views_processed = 0
150140

@@ -201,35 +191,53 @@ def create_table_ACLSs_df_for_databases(database_names: List[str]):
201191
# COMMAND ----------
202192

203193
# DBTITLE 1,Run Export
194+
def chunks(lst, n):
195+
"""Yield successive n-sized chunks from lst."""
196+
for i in range(0, len(lst), n):
197+
yield lst[i:i + n]
198+
199+
204200
databases_raw = dbutils.widgets.get("Databases")
205201
output_path = dbutils.widgets.get("OutputPath")
206202

207203
if databases_raw.rstrip() == '':
208-
databases = None
204+
# TODO check Catalog heuristic:
205+
# if all databases are exported, we include the Catalog grants as well
206+
databases = get_database_names()
207+
include_catalog = True
209208
print(f"Exporting all databases")
210209
else:
210+
#. if only a few databases are exported: we exclude the Catalog
211211
databases = [x.rstrip().lstrip() for x in databases_raw.split(",")]
212+
include_catalog = False
212213
print(f"Exporting the following databases: {databases}")
213214

215+
counter = 1
216+
for databases_chunks in chunks(databases, 1):
217+
table_ACLs_df, num_databases_processed, num_tables_or_views_processed = create_table_ACLSs_df_for_databases(
218+
databases_chunks, include_catalog
219+
)
220+
221+
print(
222+
f"{datetime.datetime.now()} total number processed chunk {counter}: databases: {num_databases_processed}, tables or views: {num_tables_or_views_processed}")
223+
print(f"{datetime.datetime.now()} writing table ACLs to {output_path}")
224+
225+
# with table ACLS active, I direct write to DBFS is not allowed, so we store
226+
# the dateframe as a table for single zipped JSON file sorted, for consitent file diffs
227+
(
228+
table_ACLs_df
229+
# .coalesce(1)
230+
.selectExpr("Database", "Principal", "ActionTypes", "ObjectType", "ObjectKey", "ExportTimestamp")
231+
# .sort("Database","Principal","ObjectType","ObjectKey")
232+
.write
233+
.format("JSON")
234+
.option("compression", "gzip")
235+
.mode("append" if counter > 1 else "overwrite")
236+
.save(output_path)
237+
)
214238

215-
table_ACLs_df,num_databases_processed, num_tables_or_views_processed = create_table_ACLSs_df_for_databases(databases)
216-
217-
print(f"{datetime.datetime.now()} total number processed: databases: {num_databases_processed}, tables or views: {num_tables_or_views_processed}")
218-
print(f"{datetime.datetime.now()} writing table ACLs to {output_path}")
219-
220-
# with table ACLS active, I direct write to DBFS is not allowed, so we store
221-
# the dateframe as a table for single zipped JSON file sorted, for consitent file diffs
222-
(
223-
table_ACLs_df
224-
.coalesce(1)
225-
.selectExpr("Database","Principal","ActionTypes","ObjectType","ObjectKey","ExportTimestamp")
226-
.sort("Database","Principal","ObjectType","ObjectKey")
227-
.write
228-
.format("JSON")
229-
.option("compression","gzip")
230-
.mode("overwrite")
231-
.save(output_path)
232-
)
239+
counter += 1
240+
include_catalog = False
233241

234242

235243
# COMMAND ----------

data/notebooks/Import_Table_ACLs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ def execute_sql_statements(sqls):
190190
l = [ str(o) for o in error_causing_sqls ]
191191
print("\n".join(l))
192192

193+
# COMMAND ----------
194+
195+
# DBTITLE 1,Nicer error output
196+
if len(error_causing_sqls) != 0:
197+
l = [ {'sql': str(o.get('sql')), 'error': str(o.get('error'))} for o in error_causing_sqls ]
198+
display(spark.createDataFrame(l))
193199

194200
# COMMAND ----------
195201

dbclient/JobsClient.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ def import_job_configs(self, log_file='jobs.log', acl_file='acl_jobs.log', job_m
166166
wmconstants.WM_IMPORT, wmconstants.JOB_OBJECT)
167167

168168
def adjust_ids_for_cluster(settings): #job_settings or task_settings
169+
"""
170+
The task setting may have existing_cluster_id/new_cluster/job_cluster_key for cluster settings.
171+
The job level setting may have existing_cluster_id/new_cluster for cluster settings.
172+
Adjust cluster settings for existing_cluster_id and new_cluster scenario.
173+
"""
169174
if 'existing_cluster_id' in settings:
170175
old_cid = settings['existing_cluster_id']
171176
# set new cluster id for existing cluster attribute
@@ -176,7 +181,7 @@ def adjust_ids_for_cluster(settings): #job_settings or task_settings
176181
settings['new_cluster'] = self.get_jobs_default_cluster_conf()
177182
else:
178183
settings['existing_cluster_id'] = new_cid
179-
else: # new cluster config
184+
elif 'new_cluster' in settings: # new cluster config
180185
cluster_conf = settings['new_cluster']
181186
if 'policy_id' in cluster_conf:
182187
old_policy_id = cluster_conf['policy_id']

0 commit comments

Comments
 (0)