Skip to content

Commit bd12648

Browse files
authored
Merge pull request #9 from databricks/BugFixes
Bug fixes
2 parents 7206358 + a601983 commit bd12648

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

GroupMigration/WSGroupMigration.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,21 @@ def __init__(self, groupL : list, cloud : str, inventoryTableName : str, workspa
5757
self.verbose = verbose
5858
self.numThreads = numThreads
5959

60+
6061
self.lastInventoryRun = None
6162
self.checkAllDB = False
6263
print(f'Clearing inventory table {self.inventoryTableName}')
6364
spark.sql(f"drop table if exists {self.inventoryTableName}")
6465
spark.sql(f"drop table if exists {self.inventoryTableName+'TableACL'}")
66+
6567

6668
#Check if we should automatically generate list, and do it immediately.
6769
#Implementers Note: Could change this section to a lazy calculation by setting groupL to nil or some sentinel value and adding checks before use.
70+
res=requests.get(f"{self.workspace_url}/api/2.0/preview/scim/v2/Me", headers=self.headers)
71+
#print(res.text)
72+
if res.status_code == 403:
73+
print("token not valid.")
74+
return
6875
if(autoGenerateList) :
6976
print("autoGenerateList parameter is set to TRUE. Ignoring groupL parameter and instead will automatically generate list of migraiton groups.")
7077
self.groupL = self.findMigrationEligibleGroups()
@@ -837,7 +844,7 @@ def getSingleFolderList(self, path:str, depth:int) -> dict:
837844
return (path, subFolders, notebooks, files)
838845

839846
for c in resFolderJson['objects']:
840-
if c['object_type']=="DIRECTORY" and c['path'].startswith('/Repos') == False and c['path'].startswith('/Shared') == False and c['path'].endswith('/Trash') == False:
847+
if c['object_type']=="DIRECTORY" and c['path'].startswith('/Shared') == False and c['path'].endswith('/Trash') == False:
841848
subFolders[c['object_id']] = c['path']
842849
elif c['object_type']=="NOTEBOOK" and c['path'].startswith('/Repos') == False and c['path'].startswith('/Shared') == False:
843850
notebooks[c['object_id']] = c['path']
@@ -1117,27 +1124,22 @@ def updateGroup2Permission(self, object:str, groupPermission : dict, level:str):
11171124
dataAcl=[]
11181125
for acl in aclList:
11191126
try:
1120-
if 'user_name' in acl.keys():
1121-
if acl['user_name']==self.userName:
1122-
addUser=False
11231127
gName=acl['group_name']
1128+
if gName=="ADMIN" and acl['permission_level']!='CAN_MANAGE':
1129+
dataAcl.append({'group_name': gName, 'permission_level': "CAN_MANAGE"})
11241130
if level=="Workspace":
11251131
if acl['group_name'] in self.WorkspaceGroupNames:
11261132
gName="db-temp-"+acl['group_name']
11271133
dataAcl.append({'group_name': gName, 'permission_level': acl['permission_level']})
11281134
elif level=="Account":
11291135
if acl['group_name'] in self.TempGroupNames:
11301136
gName=acl['group_name'][8:]
1131-
#dataAcl.append({'group_name': gName, 'permission_level': acl['permission_level']})
11321137
else:
11331138
gName=acl['group_name']
1134-
#acl['group_name']=gName
11351139
dataAcl.append(acl)
11361140
except KeyError:
11371141
dataAcl.append(acl)
11381142
continue
1139-
if addUser:
1140-
dataAcl.append({"user_name": self.userName,"permission_level": "CAN_MANAGE"})
11411143
data={"access_control_list":dataAcl}
11421144
resAppPerm=requests.post(f"{self.workspace_url}/api/2.0/preview/sql/permissions/{object}/{object_id}", headers=self.headers, data=json.dumps(data))
11431145
except Exception as e:
@@ -1187,15 +1189,16 @@ def getDBACL(self, db: str):
11871189
try:
11881190
aclList=[]
11891191
dbdf=self.getGrantsOnObjects(db, "DATABASE", db)
1192+
aclList+=dbdf.collect()
11901193
if not self.checkAllDB:
1191-
userListCollect=dbdf.filter(col('ObjectType')=="DATABASE").filter(array_contains(col('ActionTypes'),"USAGE")).select(col('Principal')).collect()
1194+
userListCollect=dbdf.filter(col('ObjectType')=="DATABASE").filter((array_contains(col('ActionTypes'),"USAGE") | array_contains(col('ActionTypes'),"OWN"))).select(col('Principal')).collect()
11921195
userList=[ p.Principal for p in userListCollect]
11931196
userList=list(set(userList))
11941197
if not self.checkPrincipalInGroupOrMember(userList, db):
1195-
#print(f'selected groups or members of the groups have no USAGE permission on database level. Skipping object level permission check for database {db}.')
1198+
#print(f'selected groups or members of the groups have no USAGE or OWN permission on database level. Skipping object level permission check for database {db}.')
11961199
return []
11971200

1198-
aclList+=dbdf.collect()
1201+
11991202
tables = self.runVerboseSql("show tables in spark_catalog.{}".format(db)).filter(col("isTemporary") == False)
12001203
for table in tables.collect():
12011204
try:
@@ -1221,15 +1224,15 @@ def getDBACL(self, db: str):
12211224
def checkPrincipalInGroupOrMember(self, principalList: str, name: str)->bool:
12221225
for p in principalList:
12231226
if p in self.groupGroupList:
1224-
print(f'Group {p} is given USAGE permission for {name}.')
1227+
print(f'Group {p} is given USAGE or OWN permission for {name}.')
12251228
return True
12261229
for p in principalList:
12271230
if p in self.groupUserList:
1228-
print(f'User {p} is given USAGE permission for {name}.')
1231+
print(f'User {p} is given USAGE or OWN permission for {name}.')
12291232
return True
12301233
for p in principalList:
12311234
if p in self.groupSPList:
1232-
print(f'SP {p} is given USAGE permission for {name}.')
1235+
print(f'SP {p} is given USAGE or OWN permission for {name}.')
12331236
return True
12341237
return False
12351238

@@ -1247,6 +1250,8 @@ def getTableACLs(self)-> list:
12471250
common_df = common_df.unionAll(self.getGrantsOnObjects(None, "ANY FILE", None))
12481251
# CATALOG
12491252
common_df = common_df.unionAll(self.getGrantsOnObjects(None, "CATALOG", None))
1253+
aclList = []
1254+
aclList = common_df.collect()
12501255
#check if any group is given permission at catalog level
12511256
userListCollect=common_df.filter(col('ObjectType')=="CATALOG$").filter(array_contains(col('ActionTypes'),"USAGE")).select(col('Principal')).collect()
12521257
userList=[ p.Principal for p in userListCollect]
@@ -1263,7 +1268,7 @@ def getTableACLs(self)-> list:
12631268
#database_names=['aaron_binns','hsdb']
12641269
currentCount=0
12651270
try:
1266-
aclList = []
1271+
#aclList = []
12671272
aclFinalList = []
12681273
with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor:
12691274
future_db = [executor.submit(self.getDBACL, f"`{databaseName}`" ) for databaseName in database_names]
@@ -1280,12 +1285,14 @@ def getTableACLs(self)-> list:
12801285

12811286
def generate_table_acls_command(self, action_types, object_type, object_key, groupName):
12821287
lines = []
1283-
grant_privs = [ x for x in action_types if not x.startswith("DENIED_") ]
1284-
deny_privs = [ x[len("DENIED_"):] for x in action_types if x.startswith("DENIED_") ]
1288+
grant_privs = [ x for x in action_types if not x.startswith("DENIED_") and x != "OWN"]
1289+
deny_privs = [ x[len("DENIED_"):] for x in action_types if x.startswith("DENIED_") and x != "OWN"]
12851290
if grant_privs:
12861291
lines.append(f"GRANT {', '.join(grant_privs)} ON {object_type} {object_key} TO `{groupName}`;")
12871292
if deny_privs:
12881293
lines.append(f"DENY {', '.join(deny_privs)} ON {object_type} {object_key} TO `{groupName}`;")
1294+
if "OWN" in action_types:
1295+
lines.append(f"ALTER {object_type} {object_key} OWNER TO `{groupName}`;")
12891296
return lines
12901297

12911298
def updateDataObjectsPermission(self, aclList : list, level:str):
@@ -1297,7 +1304,16 @@ def updateDataObjectsPermission(self, aclList : list, level:str):
12971304
gName="db-temp-"+acl.Principal
12981305
elif level=="Account":
12991306
gName=acl.Principal[8:]
1300-
lines.extend(self.generate_table_acls_command(acl.ActionTypes, acl.ObjectType, acl.ObjectKey, gName))
1307+
if acl.ObjectType == "ANONYMOUS_FUNCTION":
1308+
lines.extend(self.generate_table_acls_command(acl.ActionTypes, 'ANONYMOUS FUNCTION', '', gName))
1309+
elif acl.ObjectType == "ANY_FILE":
1310+
lines.extend(self.generate_table_acls_command(acl.ActionTypes, 'ANY FILE', '', gName))
1311+
elif acl.ObjectType == "CATALOG$":
1312+
lines.extend(self.generate_table_acls_command(acl.ActionTypes, 'CATALOG', '', gName))
1313+
elif acl.ObjectType in ["DATABASE", "TABLE"]:
1314+
# DATABASE, TABLE, VIEW (view's seem to show up as tables)
1315+
lines.extend(self.generate_table_acls_command(acl.ActionTypes, acl.ObjectType, acl.ObjectKey, gName))
1316+
#lines.extend(self.generate_table_acls_command(acl.ActionTypes, acl.ObjectType, acl.ObjectKey, gName))
13011317
for aclQuery in lines:
13021318
#print(aclQuery)
13031319
self.runVerboseSql(aclQuery)

0 commit comments

Comments
 (0)