66from collections .abc import Iterable , Collection
77from dataclasses import dataclass
88from datetime import timedelta
9- from typing import ClassVar
109
1110from databricks .labs .blueprint .limiter import rate_limited
1211from databricks .labs .blueprint .parallel import ManyError , Threads
2221)
2322from databricks .sdk .retries import retried
2423from databricks .sdk .service import iam
25- from databricks .sdk .service .iam import Group
24+ from databricks .sdk .service .iam import Group , User
2625
2726from databricks .labs .ucx .framework .crawlers import CrawlerBase
2827from databricks .labs .ucx .framework .utils import escape_sql_identifier
2928
3029logger = logging .getLogger (__name__ )
3130
31+ SYSTEM_GROUPS : list [str ] = ["users" , "admins" , "account users" ]
32+
3233
3334@dataclass
3435class MigratedGroup :
@@ -402,7 +403,6 @@ def __init__(self, group_id: str, old_name: str, new_name: str) -> None:
402403
403404
404405class GroupManager (CrawlerBase [MigratedGroup ]):
405- _SYSTEM_GROUPS : ClassVar [list [str ]] = ["users" , "admins" , "account users" ]
406406
407407 def __init__ ( # pylint: disable=too-many-arguments
408408 self ,
@@ -430,6 +430,7 @@ def __init__( # pylint: disable=too-many-arguments
430430 self ._account_group_regex = account_group_regex
431431 self ._external_id_match = external_id_match
432432 self ._verify_timeout = verify_timeout
433+ self ._account_groups_lookup = AccountGroupLookup (ws )
433434
434435 def rename_groups (self ):
435436 account_groups_in_workspace = self ._account_groups_in_workspace ()
@@ -484,6 +485,9 @@ def rename_groups(self):
484485 # Step 3: Wait for enumeration to also reflect the updated information.
485486 self ._wait_for_renamed_groups (renamed_groups )
486487
488+ def current_user_in_owner_group (self , group_name : str ) -> bool :
489+ return self ._account_groups_lookup .user_in_group (group_name , self ._ws .current_user .me ())
490+
487491 def _rename_group (self , group_id : str , old_group_name : str , new_group_name : str ) -> None :
488492 logger .debug (f"Renaming group: { old_group_name } (id={ group_id } ) -> { new_group_name } " )
489493 self ._rate_limited_rename_group_with_retry (group_id , new_group_name )
@@ -555,7 +559,7 @@ def _check_for_renamed_groups(self, expected_groups: Collection[tuple[str, str]]
555559
556560 def reflect_account_groups_on_workspace (self ):
557561 tasks = []
558- account_groups_in_account = self ._account_groups_in_account ()
562+ account_groups_in_account = self ._account_groups_lookup . get_mapping ()
559563 account_groups_in_workspace = self ._account_groups_in_workspace ()
560564 workspace_groups_in_workspace = self ._workspace_groups_in_workspace ()
561565 groups_to_migrate = self .get_migration_state ().groups
@@ -624,50 +628,6 @@ def delete_original_workspace_groups(self):
624628 # Step 3: Confirm that enumeration no longer returns the deleted groups.
625629 self ._wait_for_deleted_workspace_groups (deleted_groups )
626630
627- def pick_owner_group (self , prompt : Prompts ) -> str | None :
628- # This method is used to select the group that will be used as the owner group.
629- # The owner group will be assigned by default to all migrated tables/schemas
630- user_id = self ._ws .current_user .me ().id
631- if not user_id :
632- logger .error ("Couldn't find the user id of the current user." )
633- return None
634- groups = self ._user_account_groups (user_id )
635- if not groups :
636- logger .warning ("No account groups found for the current user." )
637- return None
638- if len (groups ) == 1 :
639- return groups [0 ].display_name
640- group_names = [group .display_name for group in groups ]
641- return prompt .choice ("Select the group to be used as the owner group" , group_names , max_attempts = 3 )
642-
643- def validate_owner_group (self , group_name : str ) -> bool :
644- # This method is used to validate that the current owner is a member of the group
645- user_id = self ._ws .current_user .me ().id
646- if not user_id :
647- logger .warning ("No user found for the current session." )
648- return False
649- groups = self ._user_account_groups (user_id )
650- if not groups :
651- logger .warning ("No account groups found for the current user." )
652- return False
653- group_names = [group .display_name for group in groups ]
654- return group_name in group_names
655-
656- def _user_account_groups (self , user_id : str ) -> list [Group ]:
657- # This method is used to find all the account groups that a user is a member of.
658- groups : list [Group ] = []
659- account_groups = self ._list_account_groups ("id,displayName,externalId,members" )
660- if not account_groups :
661- return groups
662- for group in account_groups :
663- if not group .members :
664- continue
665- for member in group .members :
666- if member .value == user_id :
667- groups .append (group )
668- break
669- return groups
670-
671631 def _try_fetch (self ) -> Iterable [MigratedGroup ]:
672632 state = []
673633 for row in self ._sql_backend .fetch (f"SELECT * FROM { escape_sql_identifier (self .full_name )} " ):
@@ -690,13 +650,13 @@ def _try_fetch(self) -> Iterable[MigratedGroup]:
690650
691651 def _crawl (self ) -> Iterable [MigratedGroup ]:
692652 workspace_groups_in_workspace = self ._workspace_groups_in_workspace ()
693- account_groups_in_account = self ._account_groups_in_account ()
653+ account_groups_in_account = self ._account_groups_lookup . get_mapping ()
694654 strategy = self ._get_strategy (workspace_groups_in_workspace , account_groups_in_account )
695655 yield from strategy .generate_migrated_groups ()
696656
697657 def validate_group_membership (self ) -> list [dict ]:
698658 workspace_groups_in_workspace = self ._workspace_groups_in_workspace ()
699- account_groups_in_account = self ._account_groups_in_account ()
659+ account_groups_in_account = self ._account_groups_lookup . get_mapping ()
700660 strategy = self ._get_strategy (workspace_groups_in_workspace , account_groups_in_account )
701661 migrated_groups = list (strategy .generate_migrated_groups ())
702662 mismatch_group = []
@@ -762,17 +722,9 @@ def _account_groups_in_workspace(self) -> dict[str, Group]:
762722 groups [group .display_name ] = group
763723 return groups
764724
765- def _account_groups_in_account (self ) -> dict [str , Group ]:
766- groups = {}
767- for group in self ._list_account_groups ("id,displayName,externalId" ):
768- if not group .display_name :
769- logger .debug (f"Ignoring account group in without name: { group .id } " )
770- continue
771- groups [group .display_name ] = group
772- return groups
773-
774- def _is_group_out_of_scope (self , group : iam .Group , resource_type : str ) -> bool :
775- if group .display_name in self ._SYSTEM_GROUPS :
725+ @staticmethod
726+ def _is_group_out_of_scope (group : iam .Group , resource_type : str ) -> bool :
727+ if group .display_name in SYSTEM_GROUPS :
776728 return True
777729 meta = group .meta
778730 if not meta :
@@ -826,23 +778,6 @@ def _get_account_group(self, group_id: str) -> Group | None:
826778 logger .warning (f"Group with ID { group_id } does not exist anymore in the Databricks account." )
827779 return None
828780
829- def _list_account_groups (self , scim_attributes : str ) -> list [iam .Group ]:
830- # TODO: we should avoid using this method, as it's not documented
831- # get account-level groups even if they're not (yet) assigned to a workspace
832- logger .info (f"Listing account groups with { scim_attributes } ..." )
833- account_groups = []
834- raw = self ._ws .api_client .do ("GET" , "/api/2.0/account/scim/v2/Groups" , query = {"attributes" : scim_attributes })
835- for resource in raw .get ("Resources" , []): # type: ignore[union-attr]
836- group = iam .Group .from_dict (resource )
837- if group .display_name in self ._SYSTEM_GROUPS :
838- continue
839- account_groups .append (group )
840- logger .info (f"Found { len (account_groups )} account groups" )
841- sorted_groups : list [iam .Group ] = sorted (
842- account_groups , key = lambda _ : _ .display_name if _ .display_name else ""
843- ) # type: ignore[arg-type,return-value]
844- return sorted_groups
845-
846781 def _delete_workspace_group_and_wait_for_deletion (self , group_id : str , display_name : str ) -> str :
847782 logger .debug (f"Deleting workspace group: { display_name } (id={ group_id } )" )
848783 self ._delete_workspace_group (group_id , display_name )
@@ -975,6 +910,81 @@ def _get_strategy(
975910 )
976911
977912
913+ class AccountGroupLookup :
914+ def __init__ (self , ws : WorkspaceClient ):
915+ self ._ws = ws
916+
917+ def pick_owner_group (self , prompt : Prompts ) -> str | None :
918+ # This method is used to select the group that will be used as the owner group.
919+ # The owner group will be assigned by default to all migrated tables/schemas
920+ user_id = self ._ws .current_user .me ().id
921+ if not user_id :
922+ logger .error ("Couldn't find the user id of the current user." )
923+ return None
924+ groups = self ._user_account_groups (user_id )
925+ if not groups :
926+ logger .warning ("No account groups found for the current user." )
927+ return None
928+ if len (groups ) == 1 :
929+ return groups [0 ].display_name
930+ group_names = [group .display_name for group in groups ]
931+ return prompt .choice ("Select the group to be used as the owner group" , group_names , max_attempts = 3 )
932+
933+ def user_in_group (self , group_name : str , user : User ) -> bool :
934+ # This method is used to validate that the current user is a member of the group
935+ user_id = user .id
936+ if not user_id :
937+ logger .warning ("No user found for the current session." )
938+ return False
939+ groups = self ._user_account_groups (user_id )
940+ if not groups :
941+ logger .warning ("No account groups found for the current user." )
942+ return False
943+ group_names = [group .display_name for group in groups ]
944+ return group_name in group_names
945+
946+ def _user_account_groups (self , user_id : str ) -> list [Group ]:
947+ # This method is used to find all the account groups that a user is a member of.
948+ groups : list [Group ] = []
949+ account_groups = self ._list_account_groups ("id,displayName,externalId,members" )
950+ if not account_groups :
951+ return groups
952+ for group in account_groups :
953+ if not group .members :
954+ continue
955+ for member in group .members :
956+ if member .value == user_id :
957+ groups .append (group )
958+ break
959+ return groups
960+
961+ def _list_account_groups (self , scim_attributes : str ) -> list [iam .Group ]:
962+ # TODO: we should avoid using this method, as it's not documented
963+ # get account-level groups even if they're not (yet) assigned to a workspace
964+ logger .info (f"Listing account groups with { scim_attributes } ..." )
965+ account_groups = []
966+ raw = self ._ws .api_client .do ("GET" , "/api/2.0/account/scim/v2/Groups" , query = {"attributes" : scim_attributes })
967+ for resource in raw .get ("Resources" , []): # type: ignore[union-attr]
968+ group = iam .Group .from_dict (resource )
969+ if group .display_name in SYSTEM_GROUPS :
970+ continue
971+ account_groups .append (group )
972+ logger .info (f"Found { len (account_groups )} account groups" )
973+ sorted_groups : list [iam .Group ] = sorted (
974+ account_groups , key = lambda _ : _ .display_name if _ .display_name else ""
975+ ) # type: ignore[arg-type,return-value]
976+ return sorted_groups
977+
978+ def get_mapping (self ) -> dict [str , Group ]:
979+ groups : dict [str , Group ] = {}
980+ for group in self ._list_account_groups ("id,displayName,externalId" ):
981+ if not group .display_name :
982+ logger .debug (f"Ignoring account group in without name: { group .id } " )
983+ continue
984+ groups [group .display_name ] = group
985+ return groups
986+
987+
978988class ConfigureGroups :
979989 renamed_group_prefix = "db-temp-"
980990 workspace_group_regex = None
0 commit comments