Skip to content

Commit b086c48

Browse files
authored
Merge pull request #244 from lyliyu/table_acl_par
add notebooks for exporting table acls from a large scale of databases
2 parents 16bb436 + 0b38d5a commit b086c48

File tree

2 files changed

+532
-0
lines changed

2 files changed

+532
-0
lines changed
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
# Databricks notebook source
2+
# MAGIC %md #Export Table ACLs
3+
# MAGIC
4+
# MAGIC Exports Table ACLS to a JSON file on DBFS, which can be imported using the Import_Table_ACLs script
5+
# MAGIC
6+
# MAGIC Parameters:
7+
# MAGIC - Databases: [Optional] comma separated list of databases to be exported, if empty all databases will be exported
8+
# MAGIC - OutputPath: Path to write the exported file to
9+
# MAGIC
10+
# MAGIC Returns: (`dbutils.notebook.exit(exit_JSON_string)`)
11+
# MAGIC - `{ "total_num_acls": <int>, "num_errors": <int> }`
12+
# MAGIC - total_num_acls : valid ACL entries int the exported JSON
13+
# MAGIC - num_errors : error entries in the exported JSON, principal is set to `ERROR_!!!` and object_key and object_value are prefixed with `ERROR_!!!`
14+
# MAGIC
15+
# MAGIC Execution: **Run the notebook on a cluster with Table ACL's enabled as a user who is an admin**
16+
# MAGIC
17+
# MAGIC Supported object types:
18+
# MAGIC - Catalog: included if all databases are exported, not included if databases to be exported are specified
19+
# MAGIC - Database: included
20+
# MAGIC - Table: included
21+
# MAGIC - View: included (they are treated as tables with ObjectType `TABLE`)
22+
# MAGIC - Anonymous Function: included (testing pending)
23+
# MAGIC - Any File: included
24+
# MAGIC
25+
# MAGIC Unsupported object types:
26+
# MAGIC - User Function: Currently in Databricks SQL not supported - will add support later
27+
# MAGIC
28+
# MAGIC JSON File format: Line of JSON objects, gzipped
29+
# MAGIC
30+
# MAGIC - written as `.coalesce(1).format("JSON").option("compression","gzip")`
31+
# MAGIC - each line contains a JSON object with the keys:
32+
# MAGIC - `Database`: string
33+
# MAGIC - `Principal`: string
34+
# MAGIC - `ActionTypes`: list of action strings:
35+
# MAGIC - `ObjectType`: `(ANONYMOUS_FUNCTION|ANY_FILE|CATALOG$|DATABASE|TABLE|ERROR_!!!_<type>)` (view are treated as tables)
36+
# MAGIC - `ObjectKey`: string
37+
# MAGIC - `ExportTimestamp`: string
38+
# MAGIC - error lines contains:
39+
# MAGIC - the special `Principal` `ERROR_!!!`
40+
# MAGIC - `ActionTypes` contains one element: the error message, starting with `ERROR!!! :`
41+
# MAGIC - `Database`, `ObjectType`, `ObjectKey` are all prefixed with `ERROR_!!!_`
42+
# MAGIC - error lines are ignored by the Import_Table_ACLs
43+
# MAGIC
44+
# MAGIC The JSON file is written as table, because on a cluster with Table ACLS activated, files cannot be written directly.
45+
# MAGIC The output path will contain other files and diretories, starting with `_`, which can be ignored.
46+
# MAGIC
47+
# MAGIC
48+
# MAGIC What to do if exported JSON contains errors (the notebook returns `num_errors` > 0):
49+
# MAGIC - If there are only a few errors ( e.g. less then 1% or less then dozen)
50+
# MAGIC - proceed with the import (any error lines will be ignored)
51+
# MAGIC - For each error, the object type, name and error message is included so the cause can be investigated
52+
# MAGIC - in most cases, it turns out that those are broken or not used tables or views
53+
# MAGIC - If there are many errors
54+
# MAGIC - Try executing some `SHOW GRANT` commands on the same cluster using the same user, there might be a underlying problem
55+
# MAGIC - review the errors and investiage
56+
57+
# COMMAND ----------
58+
59+
# DBTITLE 1,Declare Parameters
60+
#dbutils.widgets.removeAll()
61+
dbutils.widgets.text("Databases","db_acl_test,db_acl_test_restricted","1: Databases (opt)")
62+
dbutils.widgets.text("OutputPath","dbfs:/tmp/migrate/test_table_acls.json.gz","2: Output Path")
63+
64+
# COMMAND ----------
65+
66+
# DBTITLE 1,Check Parameters
67+
68+
if not dbutils.widgets.get("OutputPath").startswith("dbfs:/"):
69+
raise Exception(f"Unexpected value for notebook parameter 'InputPath', got <{dbutils.widgets.get('OutputPath')}>, but it must start with <dbfs:/........>")
70+
71+
72+
# COMMAND ----------
73+
74+
# DBTITLE 1,Params
75+
WRITE_TABLES_PER_BATCH = 100
76+
MAX_WORKERS = 8
77+
78+
# COMMAND ----------
79+
80+
# DBTITLE 1,Define Export Logic
81+
import pyspark.sql.functions as sf
82+
from typing import Callable, Iterator, Union, Optional, List
83+
import datetime
84+
import sys
85+
from functools import reduce
86+
from pyspark.sql import DataFrame
87+
88+
def create_error_grants_df(sys_exec_info_res, database_name: str,object_type: str, object_key: str):
89+
msg_context = f"context: database_name: {database_name}, object_type: {object_type}, object_key: {object_key}"
90+
91+
msg_lines = str(sys_exec_info_res[1]).split("\n")
92+
if len(msg_lines) <= 2:
93+
short_message = " ".join(msg_lines)
94+
else:
95+
short_message = " ".join(msg_lines[:2])
96+
error_message = f"ERROR!!! : exception class {sys_exec_info_res[0]}, message: {short_message}, {msg_context}".replace('"',"'")
97+
98+
print(error_message)
99+
100+
database_value = f"'ERROR_!!!_{database_name}'".replace('"',"'") if database_name else "NULL"
101+
object_key_value = f"'ERROR_!!!_{object_key}'".replace('"',"'") if object_key else "NULL"
102+
object_type_value = f"'ERROR_!!!_{object_type}'".replace('"',"'") if object_type else "NULL" # Import ignores this object type
103+
104+
grants_df = spark.sql(f"""SELECT
105+
{database_value} AS Database,
106+
'ERROR_!!!' AS Principal,
107+
array("{error_message}") AS ActionTypes,
108+
{object_type_value} AS ObjectType,
109+
{object_key_value} AS ObjectKey,
110+
Now() AS ExportTimestamp
111+
""")
112+
113+
return grants_df
114+
115+
def get_database_names():
116+
database_names = []
117+
for db in spark.sql("show databases").collect():
118+
if hasattr(db,"databaseName"): #Angela has this fallback ...
119+
database_names.append(db.databaseName)
120+
else:
121+
database_names.append(db.namespace)
122+
return database_names
123+
124+
def write_grants_df(df, output_path, writeMode="append"):
125+
try:
126+
(
127+
df
128+
.selectExpr("Database","Principal","ActionTypes","ObjectType","ObjectKey","ExportTimestamp")
129+
.sort("Database","Principal","ObjectType","ObjectKey")
130+
.write
131+
.format("delta")
132+
.mode(writeMode)
133+
.save(output_path)
134+
)
135+
except Exception as ex:
136+
print(ex)
137+
138+
def write_grants_dfs(dfs, output_path, writeMode="append"):
139+
union_df = reduce(DataFrame.unionAll, dfs)
140+
write_grants_df(union_df, output_path, writeMode)
141+
write_dfs = []
142+
143+
def create_grants_df(database_name: str,object_type: str, object_key: str):
144+
try:
145+
if object_type in ["CATALOG", "ANY FILE", "ANONYMOUS FUNCTION"]: #without object key
146+
grants_df = (
147+
spark.sql(f"SHOW GRANT ON {object_type}")
148+
.groupBy("ObjectType","ObjectKey","Principal").agg(sf.collect_set("ActionType").alias("ActionTypes"))
149+
.selectExpr("CAST(NULL AS STRING) AS Database","Principal","ActionTypes","ObjectType","ObjectKey","Now() AS ExportTimestamp")
150+
)
151+
else:
152+
grants_df = (
153+
spark.sql(f"SHOW GRANT ON {object_type} {object_key}")
154+
.filter(sf.col("ObjectType") == f"{object_type}")
155+
.groupBy("ObjectType","ObjectKey","Principal").agg(sf.collect_set("ActionType").alias("ActionTypes"))
156+
.selectExpr(f"'{database_name}' AS Database","Principal","ActionTypes","ObjectType","ObjectKey","Now() AS ExportTimestamp")
157+
)
158+
except:
159+
grants_df = create_error_grants_df(sys.exc_info(), database_name, object_type, object_key)
160+
161+
return grants_df
162+
163+
164+
def create_table_ACLSs_df_for_databases(database_names: List[str]):
165+
166+
# TODO check Catalog heuristic:
167+
# if all databases are exported, we include the Catalog grants as well
168+
#. if only a few databases are exported: we exclude the Catalog
169+
# if database_names is None or database_names == '':
170+
# database_names = get_database_names()
171+
# include_catalog = True
172+
# else:
173+
# include_catalog = False
174+
175+
num_databases_processed = len(database_names)
176+
num_tables_or_views_processed = 0
177+
178+
# # ANONYMOUS FUNCTION
179+
# combined_grant_dfs = create_grants_df(None, "ANONYMOUS FUNCTION", None)
180+
181+
# # ANY FILE
182+
# combined_grant_dfs = combined_grant_dfs.unionAll(
183+
# create_grants_df(None, "ANY FILE", None)
184+
# )
185+
186+
# # CATALOG
187+
# if include_catalog:
188+
# combined_grant_dfs = combined_grant_dfs.unionAll(
189+
# create_grants_df(None, "CATALOG", None)
190+
# )
191+
#TODO ELSE: consider pushing catalog grants down to DB level in this case
192+
for database_name in database_names:
193+
print(f"processing database {database_name}")
194+
# DATABASE
195+
grant_df = create_grants_df(database_name, "DATABASE", database_name)
196+
print("writing out grant_df")
197+
write_grants_df(grant_df, output_path)
198+
print("finish writing")
199+
try:
200+
print("getting tables")
201+
tables_and_views_rows = spark.sql(
202+
f"SHOW TABLES IN {database_name}"
203+
).filter(sf.col("isTemporary") == False).collect()
204+
205+
print(f"{datetime.datetime.now()} working on database {database_name} with {len(tables_and_views_rows)} tables and views")
206+
num_tables_or_views_processed = num_tables_or_views_processed + len(tables_and_views_rows)
207+
208+
write_dfs = []
209+
for table_row in tables_and_views_rows:
210+
# TABLE, VIEW
211+
write_dfs.append(create_grants_df(database_name, "TABLE", f"{table_row.database}.{table_row.tableName}"))
212+
if len(write_dfs) == WRITE_TABLES_PER_BATCH:
213+
print("writing out write_dfs")
214+
write_grants_dfs(write_dfs, output_path)
215+
write_dfs = []
216+
if write_dfs:
217+
print("Flush out what's left in write_dfs.")
218+
write_grants_dfs(write_dfs, output_path)
219+
write_dfs = []
220+
221+
except:
222+
# error in SHOW TABLES IN database_name, errors in create_grants_df have already been catched
223+
error_df = create_error_grants_df(sys.exc_info(), database_name ,"DATABASE", database_name)
224+
write_grants_df(error_df, output_path)
225+
226+
#TODO ADD USER FUNCTION - not supported in SQL Analytics, so this can wait a bit
227+
# ... SHOW USER FUNCTIONS LIKE <my db>.`*`;
228+
#. function_row['function'] ... nah does not seem to work
229+
230+
return num_databases_processed, num_tables_or_views_processed
231+
232+
233+
# COMMAND ----------
234+
235+
output_path = dbutils.widgets.get("OutputPath")
236+
237+
# COMMAND ----------
238+
239+
# DBTITLE 1,Run Export
240+
from concurrent.futures import ThreadPoolExecutor
241+
from concurrent.futures import as_completed
242+
import os, fnmatch, shutil
243+
244+
databases_raw = dbutils.widgets.get("Databases")
245+
246+
def process_db(db: str, writeMode: str):
247+
print(f"processing db {db} with mode {writeMode}")
248+
num_databases_processed, num_tables_or_views_processed = create_table_ACLSs_df_for_databases([db])
249+
print(f"{datetime.datetime.now()} total number processed: databases: {num_databases_processed}, tables or views: {num_tables_or_views_processed}")
250+
251+
252+
def append_db_acls(db: str):
253+
process_db(db, "append")
254+
255+
if databases_raw.rstrip() == '':
256+
databases = get_database_names()
257+
include_catalog = True
258+
print(f"Exporting all databases")
259+
else:
260+
databases = [x.rstrip().lstrip() for x in databases_raw.split(",")]
261+
include_catalog = False
262+
print(f"Exporting the following databases: {databases}")
263+
264+
265+
266+
# ANONYMOUS FUNCTION
267+
combined_grant_dfs = create_grants_df(None, "ANONYMOUS FUNCTION", None)
268+
269+
# ANY FILE
270+
combined_grant_dfs = combined_grant_dfs.unionAll(
271+
create_grants_df(None, "ANY FILE", None)
272+
)
273+
274+
# CATALOG
275+
if include_catalog:
276+
combined_grant_dfs = combined_grant_dfs.unionAll(
277+
create_grants_df(None, "CATALOG", None)
278+
)
279+
280+
write_grants_df(combined_grant_dfs, output_path, "overwrite")
281+
282+
with ThreadPoolExecutor(max_workers = MAX_WORKERS) as executor:
283+
results = executor.map(append_db_acls, databases)
284+
print(results)
285+
286+
# print(f"{datetime.datetime.now()} total number processed: databases: {total_databases_processed}, tables or views: {total_tables_processed}")
287+
print(f"{datetime.datetime.now()} writing table ACLs to {output_path}")
288+
289+
# COMMAND ----------
290+
291+
# DBTITLE 1,Optimize the output data
292+
spark.sql(f"optimize delta.`{output_path}`")
293+
294+
# COMMAND ----------
295+
296+
# DBTITLE 1,Exported Table ACLs
297+
display(spark.read.format("delta").load(output_path))
298+
299+
# COMMAND ----------
300+
301+
# DBTITLE 1,Write total_num_acls and num_errors to the notebook exit value
302+
totals_df = (spark.read
303+
.format("delta")
304+
.load(output_path)
305+
.selectExpr(
306+
"sum(1) AS total_num_acls"
307+
,"sum(CASE WHEN Principal = 'ERROR_!!!' THEN 1 ELSE 0 END) AS num_errors")
308+
)
309+
310+
res_rows = totals_df.collect()
311+
312+
exit_JSON_string = '{ "total_num_acls": '+str(res_rows[0]["total_num_acls"])+', "num_errors": '+str(res_rows[0]["num_errors"])+' }'
313+
314+
print(exit_JSON_string)
315+
316+
dbutils.notebook.exit(exit_JSON_string)
317+
318+
# COMMAND ----------
319+
320+

0 commit comments

Comments
 (0)