11import logging
2+ from dataclasses import dataclass , field
23from typing import ClassVar
34
45from databricks .labs .blueprint .installation import Installation
1011logger = logging .getLogger (__name__ )
1112
1213
14+ @dataclass
15+ class AccountGroupDetails :
16+ id : str
17+ members : list [ComplexValue ] | None = None
18+
19+
20+ @dataclass
21+ class AccountGroupCreationContext :
22+ valid_workspace_groups : dict [str , Group ] = field (default_factory = dict )
23+ created_groups : dict [str , Group ] = field (default_factory = dict )
24+ renamed_groups : dict [str , str ] = field (default_factory = dict )
25+ preexisting_account_groups : dict [str , AccountGroupDetails ] = field (default_factory = dict )
26+
27+
1328class AccountWorkspaces :
1429 SYNC_FILE_NAME : ClassVar [str ] = "workspaces.json"
1530
@@ -76,21 +91,101 @@ def sync_workspace_info(self, workspaces: list[Workspace] | None = None):
7691 except (PermissionDenied , NotFound , ValueError ):
7792 logger .warning (f"Failed to save workspace info for { ws .config .host } " )
7893
79- def create_account_level_groups (self , prompts : Prompts ):
80- acc_groups = self ._get_account_groups ()
94+ def create_account_level_groups (self , prompts : Prompts ) -> None :
95+ """
96+ Create account level groups from workspace groups
97+
98+ The following approach is used:
99+ Get all valid worskpace groups from all workspaces
100+
101+ For each group:
102+ - Check if the group already exists in the account
103+ - If it does not exist, check if it is a nested group (users are added directly)
104+ - If its a nested group follow the same approach recursively
105+ - If it is a regular group, create the group in the account and add all members to the group
106+ """
107+ context = AccountGroupCreationContext ()
108+ context .preexisting_account_groups = self ._get_account_groups ()
81109 workspace_ids = [workspace .workspace_id for workspace in self ._workspaces ()]
82110 if not workspace_ids :
83111 raise ValueError ("The workspace ids provided are not found in the account, Please check and try again." )
84- all_valid_workspace_groups = self ._get_valid_workspaces_groups (prompts , workspace_ids )
112+ context . valid_workspace_groups = self ._get_valid_workspaces_groups (prompts , workspace_ids , context )
85113
86- for group_name , valid_group in all_valid_workspace_groups .items ():
87- acc_group = self ._try_create_account_groups (group_name , acc_groups )
114+ for group_name , valid_group in context . valid_workspace_groups .items ():
115+ self ._create_account_groups_recursively (group_name , valid_group , context )
88116
89- if not acc_group or not valid_group .members or not acc_group .id :
90- continue
91- if len (valid_group .members ) > 0 :
92- self ._add_members_to_acc_group (self ._ac , acc_group .id , group_name , valid_group )
93- logger .info (f"Group { group_name } created in the account" )
117+ def _create_account_groups_recursively (
118+ self , group_name : str , valid_group : Group , context : AccountGroupCreationContext
119+ ) -> None :
120+ """
121+ Function recursively crawls through all group and nested groups to create account level groups
122+ """
123+ if group_name in context .created_groups :
124+ logger .info (f"Group { group_name } already exist in the account, ignoring" )
125+ return
126+
127+ members_to_add = []
128+ assert valid_group .members is not None , "group members undefined"
129+ for member in valid_group .members :
130+ if member .ref and member .ref .startswith ("Users" ):
131+ members_to_add .append (member )
132+ elif member .ref and member .ref .startswith ("Groups" ):
133+ assert member .display is not None , "group name undefined"
134+ members_to_append = self ._handle_nested_group (member .display , context )
135+ if members_to_append :
136+ members_to_add .append (members_to_append )
137+ else :
138+ logger .warning (f"Member { member .ref } is not a user or group, skipping" )
139+
140+ acc_group = self ._try_create_account_groups (group_name , context .preexisting_account_groups )
141+ if acc_group :
142+ assert valid_group .display_name is not None , "group name undefined"
143+ logger .info (f"Successfully created account group { acc_group .display_name } " )
144+ if members_to_add and acc_group .id :
145+ self ._add_members_to_acc_group (self ._ac , acc_group .id , valid_group .display_name , members_to_add )
146+ created_acc_group = self ._safe_groups_get (self ._ac , acc_group .id )
147+ if not created_acc_group :
148+ logger .warning (f"Newly created group { valid_group .display_name } could not be fetched, skipping" )
149+ return
150+ context .created_groups [valid_group .display_name ] = created_acc_group
151+
152+ def _handle_nested_group (self , group_name : str , context : AccountGroupCreationContext ) -> ComplexValue | None :
153+ """
154+ Function to handle nested groups
155+ Checks if the group has already been created at account level
156+ If not, it creates the group by calling _create_account_groups_recursively
157+ """
158+ # check if group name is in the renamed groups
159+ if group_name in context .renamed_groups :
160+ group_name = context .renamed_groups [group_name ]
161+
162+ # check if account group was created before this run
163+ if group_name in context .preexisting_account_groups :
164+ logger .info (f"Group { group_name } already exist in the account, ignoring" )
165+ acc_group_id = context .preexisting_account_groups [group_name ].id
166+ full_account_group = self ._safe_groups_get (self ._ac , acc_group_id )
167+ if not full_account_group :
168+ logger .warning (f"Group { group_name } could not be fetched, skipping" )
169+ return None
170+ context .created_groups [group_name ] = full_account_group
171+
172+ # check if workspace group is already created at account level in current run
173+ if group_name not in context .created_groups :
174+ # if there is no account group created for the group, create one
175+ self ._create_account_groups_recursively (group_name , context .valid_workspace_groups [group_name ], context )
176+
177+ if group_name not in context .created_groups :
178+ logger .warning (f"Group { group_name } could not be fetched, skipping" )
179+ return None
180+
181+ created_acc_group = context .created_groups [group_name ]
182+
183+ # the AccountGroupsAPI expects the members to be in the form of ComplexValue
184+ return ComplexValue (
185+ display = created_acc_group .display_name ,
186+ ref = f"Groups/{ created_acc_group .id } " ,
187+ value = created_acc_group .id ,
188+ )
94189
95190 def get_accessible_workspaces (self ) -> list [Workspace ]:
96191 """
@@ -126,9 +221,7 @@ def can_administer(self, workspace: Workspace) -> bool:
126221 return False
127222 return True
128223
129- def _try_create_account_groups (
130- self , group_name : str , acc_groups : dict [str | None , list [ComplexValue ] | None ]
131- ) -> Group | None :
224+ def _try_create_account_groups (self , group_name : str , acc_groups : dict [str , AccountGroupDetails ]) -> Group | None :
132225 try :
133226 if group_name in acc_groups :
134227 logger .info (f"Group { group_name } already exist in the account, ignoring" )
@@ -139,9 +232,9 @@ def _try_create_account_groups(
139232 return None
140233
141234 def _add_members_to_acc_group (
142- self , acc_client : AccountClient , acc_group_id : str , group_name : str , valid_group : Group
235+ self , acc_client : AccountClient , acc_group_id : str , group_name : str , group_members : list [ ComplexValue ] | None
143236 ):
144- for chunk in self ._chunks (valid_group . members , 20 ):
237+ for chunk in self ._chunks (group_members , 20 ):
145238 logger .debug (f"Adding { len (chunk )} members to acc group { group_name } " )
146239 acc_client .groups .patch (
147240 acc_group_id ,
@@ -155,17 +248,25 @@ def _chunks(lst, chunk_size):
155248 for i in range (0 , len (lst ), chunk_size ):
156249 yield lst [i : i + chunk_size ]
157250
158- def _get_valid_workspaces_groups (self , prompts : Prompts , workspace_ids : list [int ]) -> dict [str , Group ]:
251+ def _get_valid_workspaces_groups (
252+ self , prompts : Prompts , workspace_ids : list [int ], context : AccountGroupCreationContext
253+ ) -> dict [str , Group ]:
159254 all_workspaces_groups : dict [str , Group ] = {}
160255
161256 for workspace in self ._workspaces ():
162257 if workspace .workspace_id not in workspace_ids :
163258 continue
164- self ._load_workspace_groups (prompts , workspace , all_workspaces_groups )
259+ self ._load_workspace_groups (prompts , workspace , all_workspaces_groups , context )
165260
166261 return all_workspaces_groups
167262
168- def _load_workspace_groups (self , prompts , workspace , all_workspaces_groups ):
263+ def _load_workspace_groups (
264+ self ,
265+ prompts : Prompts ,
266+ workspace : Workspace ,
267+ all_workspaces_groups : dict [str , Group ],
268+ context : AccountGroupCreationContext ,
269+ ) -> None :
169270 client = self .client_for (workspace )
170271 logger .info (f"Crawling groups in workspace { client .config .host } " )
171272 ws_group_ids = client .groups .list (attributes = "id" )
@@ -188,6 +289,7 @@ def _load_workspace_groups(self, prompts, workspace, all_workspaces_groups):
188289 f"it will be created at the account with name : { workspace .workspace_name } _{ group_name } "
189290 ):
190291 all_workspaces_groups [f"{ workspace .workspace_name } _{ group_name } " ] = full_workspace_group
292+ context .renamed_groups [group_name ] = f"{ workspace .workspace_name } _{ group_name } "
191293 continue
192294 logger .info (f"Found new group { group_name } " )
193295 all_workspaces_groups [group_name ] = full_workspace_group
@@ -212,7 +314,7 @@ def _has_same_members(group_1: Group, group_2: Group) -> bool:
212314 ws_members_set_2 = set ([m .display for m in group_2 .members ] if group_2 .members else [])
213315 return not bool ((ws_members_set_1 - ws_members_set_2 ).union (ws_members_set_2 - ws_members_set_1 ))
214316
215- def _get_account_groups (self ) -> dict [str | None , list [ ComplexValue ] | None ]:
317+ def _get_account_groups (self ) -> dict [str , AccountGroupDetails ]:
216318 logger .debug ("Listing groups in account" )
217319 acc_groups = {}
218320 for acc_grp_id in self ._ac .groups .list (attributes = "id" ):
@@ -222,7 +324,10 @@ def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]:
222324 if not full_account_group :
223325 continue
224326 logger .debug (f"Found account group { full_account_group .display_name } " )
225- acc_groups [full_account_group .display_name ] = full_account_group .members
327+ assert full_account_group .display_name is not None , "group name undefined"
328+ acc_groups [full_account_group .display_name ] = AccountGroupDetails (
329+ id = acc_grp_id .id , members = full_account_group .members
330+ )
226331
227332 logger .info (f"{ len (acc_groups )} account groups found" )
228333 return acc_groups
0 commit comments