Skip to content

Commit f6eec46

Browse files
FastLeenfx
andauthored
Added default owner group selection to the installer (#3370)
related to #3111 --------- Co-authored-by: Serge Smertin <[email protected]>
1 parent 60718a7 commit f6eec46

File tree

8 files changed

+150
-96
lines changed

8 files changed

+150
-96
lines changed

src/databricks/labs/ucx/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from databricks.labs.ucx.hive_metastore.tables import What
2121
from databricks.labs.ucx.install import AccountInstaller
2222
from databricks.labs.ucx.source_code.linters.files import LocalCodeLinter
23+
from databricks.labs.ucx.workspace_access.groups import AccountGroupLookup
2324

2425
ucx = App(__file__)
2526
logger = get_logger(__file__)
@@ -657,7 +658,7 @@ def assign_owner_group(
657658
else:
658659
workspace_contexts = _get_workspace_contexts(w, a, run_as_collection)
659660

660-
owner_group = workspace_contexts[0].group_manager.pick_owner_group(prompts)
661+
owner_group = AccountGroupLookup(workspace_contexts[0].workspace_client).pick_owner_group(prompts)
661662
if not owner_group:
662663
return
663664
for workspace_context in workspace_contexts:

src/databricks/labs/ucx/hive_metastore/ownership.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _application_principal(self) -> str | None:
115115
@cached_property
116116
def _static_owner(self) -> str | None:
117117
# If the default owner group is not valid, fall back to the application principal
118-
if self._default_owner_group and self._group_manager.validate_owner_group(self._default_owner_group):
118+
if self._default_owner_group and self._group_manager.current_user_in_owner_group(self._default_owner_group):
119119
logger.warning("Default owner group is not valid, falling back to administrator ownership.")
120120
return self._default_owner_group
121121
return self._application_principal

src/databricks/labs/ucx/install.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,23 @@
2020
from databricks.labs.blueprint.parallel import ManyError, Threads
2121
from databricks.labs.blueprint.tui import Prompts
2222
from databricks.labs.blueprint.upgrades import Upgrades
23-
from databricks.labs.blueprint.wheels import (
24-
ProductInfo,
25-
Version,
26-
find_project_root,
27-
)
23+
from databricks.labs.blueprint.wheels import ProductInfo, Version, find_project_root
2824
from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend
2925
from databricks.labs.lsql.dashboards import DashboardMetadata, Dashboards
3026
from databricks.labs.lsql.deployment import SchemaDeployer
31-
from databricks.sdk import WorkspaceClient, AccountClient
32-
from databricks.sdk.useragent import with_extra
27+
from databricks.sdk import AccountClient, WorkspaceClient
3328
from databricks.sdk.errors import (
3429
AlreadyExists,
3530
BadRequest,
3631
DeadlineExceeded,
3732
InternalError,
3833
InvalidParameterValue,
3934
NotFound,
35+
OperationFailed,
4036
PermissionDenied,
4137
ResourceAlreadyExists,
4238
ResourceDoesNotExist,
4339
Unauthenticated,
44-
OperationFailed,
4540
)
4641
from databricks.sdk.retries import retried
4742
from databricks.sdk.service.dashboards import LifecycleState
@@ -51,7 +46,7 @@
5146
EndpointInfoWarehouseType,
5247
SpotInstancePolicy,
5348
)
54-
49+
from databricks.sdk.useragent import with_extra
5550
from databricks.labs.ucx.__about__ import __version__
5651
from databricks.labs.ucx.assessment.azure import AzureServicePrincipalInfo
5752
from databricks.labs.ucx.assessment.clusters import ClusterInfo, PolicyInfo
@@ -81,7 +76,7 @@
8176
from databricks.labs.ucx.source_code.queries import QueryProblem
8277
from databricks.labs.ucx.workspace_access.base import Permissions
8378
from databricks.labs.ucx.workspace_access.generic import WorkspaceObjectInfo
84-
from databricks.labs.ucx.workspace_access.groups import ConfigureGroups, MigratedGroup
79+
from databricks.labs.ucx.workspace_access.groups import AccountGroupLookup, ConfigureGroups, MigratedGroup
8580

8681
TAG_STEP = "step"
8782
WAREHOUSE_PREFIX = "Unity Catalog Migration"
@@ -245,6 +240,12 @@ def _prompt_for_new_installation(self) -> WorkspaceConfig:
245240
configure_groups = ConfigureGroups(self.prompts)
246241
configure_groups.run()
247242
include_databases = self._select_databases()
243+
244+
# Checking if the user wants to define a default owner group.
245+
default_owner_group = None
246+
if self.prompts.confirm("Do you want to define a default owner group for all tables and schemas? "):
247+
default_owner_group = AccountGroupLookup(self.workspace_client).pick_owner_group(self.prompts)
248+
248249
upload_dependencies = self.prompts.confirm(
249250
f"Does given workspace {self.workspace_client.get_workspace_id()} block Internet access?"
250251
)
@@ -267,6 +268,7 @@ def _prompt_for_new_installation(self) -> WorkspaceConfig:
267268
trigger_job=trigger_job,
268269
recon_tolerance_percent=recon_tolerance_percent,
269270
upload_dependencies=upload_dependencies,
271+
default_owner_group=default_owner_group,
270272
)
271273

272274
def _compare_remote_local_versions(self):

src/databricks/labs/ucx/workspace_access/groups.py

Lines changed: 88 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from collections.abc import Iterable, Collection
77
from dataclasses import dataclass
88
from datetime import timedelta
9-
from typing import ClassVar
109

1110
from databricks.labs.blueprint.limiter import rate_limited
1211
from databricks.labs.blueprint.parallel import ManyError, Threads
@@ -22,13 +21,15 @@
2221
)
2322
from databricks.sdk.retries import retried
2423
from databricks.sdk.service import iam
25-
from databricks.sdk.service.iam import Group
24+
from databricks.sdk.service.iam import Group, User
2625

2726
from databricks.labs.ucx.framework.crawlers import CrawlerBase
2827
from databricks.labs.ucx.framework.utils import escape_sql_identifier
2928

3029
logger = logging.getLogger(__name__)
3130

31+
SYSTEM_GROUPS: list[str] = ["users", "admins", "account users"]
32+
3233

3334
@dataclass
3435
class MigratedGroup:
@@ -402,7 +403,6 @@ def __init__(self, group_id: str, old_name: str, new_name: str) -> None:
402403

403404

404405
class GroupManager(CrawlerBase[MigratedGroup]):
405-
_SYSTEM_GROUPS: ClassVar[list[str]] = ["users", "admins", "account users"]
406406

407407
def __init__( # pylint: disable=too-many-arguments
408408
self,
@@ -430,6 +430,7 @@ def __init__( # pylint: disable=too-many-arguments
430430
self._account_group_regex = account_group_regex
431431
self._external_id_match = external_id_match
432432
self._verify_timeout = verify_timeout
433+
self._account_groups_lookup = AccountGroupLookup(ws)
433434

434435
def rename_groups(self):
435436
account_groups_in_workspace = self._account_groups_in_workspace()
@@ -484,6 +485,9 @@ def rename_groups(self):
484485
# Step 3: Wait for enumeration to also reflect the updated information.
485486
self._wait_for_renamed_groups(renamed_groups)
486487

488+
def current_user_in_owner_group(self, group_name: str) -> bool:
489+
return self._account_groups_lookup.user_in_group(group_name, self._ws.current_user.me())
490+
487491
def _rename_group(self, group_id: str, old_group_name: str, new_group_name: str) -> None:
488492
logger.debug(f"Renaming group: {old_group_name} (id={group_id}) -> {new_group_name}")
489493
self._rate_limited_rename_group_with_retry(group_id, new_group_name)
@@ -555,7 +559,7 @@ def _check_for_renamed_groups(self, expected_groups: Collection[tuple[str, str]]
555559

556560
def reflect_account_groups_on_workspace(self):
557561
tasks = []
558-
account_groups_in_account = self._account_groups_in_account()
562+
account_groups_in_account = self._account_groups_lookup.get_mapping()
559563
account_groups_in_workspace = self._account_groups_in_workspace()
560564
workspace_groups_in_workspace = self._workspace_groups_in_workspace()
561565
groups_to_migrate = self.get_migration_state().groups
@@ -624,50 +628,6 @@ def delete_original_workspace_groups(self):
624628
# Step 3: Confirm that enumeration no longer returns the deleted groups.
625629
self._wait_for_deleted_workspace_groups(deleted_groups)
626630

627-
def pick_owner_group(self, prompt: Prompts) -> str | None:
628-
# This method is used to select the group that will be used as the owner group.
629-
# The owner group will be assigned by default to all migrated tables/schemas
630-
user_id = self._ws.current_user.me().id
631-
if not user_id:
632-
logger.error("Couldn't find the user id of the current user.")
633-
return None
634-
groups = self._user_account_groups(user_id)
635-
if not groups:
636-
logger.warning("No account groups found for the current user.")
637-
return None
638-
if len(groups) == 1:
639-
return groups[0].display_name
640-
group_names = [group.display_name for group in groups]
641-
return prompt.choice("Select the group to be used as the owner group", group_names, max_attempts=3)
642-
643-
def validate_owner_group(self, group_name: str) -> bool:
644-
# This method is used to validate that the current owner is a member of the group
645-
user_id = self._ws.current_user.me().id
646-
if not user_id:
647-
logger.warning("No user found for the current session.")
648-
return False
649-
groups = self._user_account_groups(user_id)
650-
if not groups:
651-
logger.warning("No account groups found for the current user.")
652-
return False
653-
group_names = [group.display_name for group in groups]
654-
return group_name in group_names
655-
656-
def _user_account_groups(self, user_id: str) -> list[Group]:
657-
# This method is used to find all the account groups that a user is a member of.
658-
groups: list[Group] = []
659-
account_groups = self._list_account_groups("id,displayName,externalId,members")
660-
if not account_groups:
661-
return groups
662-
for group in account_groups:
663-
if not group.members:
664-
continue
665-
for member in group.members:
666-
if member.value == user_id:
667-
groups.append(group)
668-
break
669-
return groups
670-
671631
def _try_fetch(self) -> Iterable[MigratedGroup]:
672632
state = []
673633
for row in self._sql_backend.fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
@@ -690,13 +650,13 @@ def _try_fetch(self) -> Iterable[MigratedGroup]:
690650

691651
def _crawl(self) -> Iterable[MigratedGroup]:
692652
workspace_groups_in_workspace = self._workspace_groups_in_workspace()
693-
account_groups_in_account = self._account_groups_in_account()
653+
account_groups_in_account = self._account_groups_lookup.get_mapping()
694654
strategy = self._get_strategy(workspace_groups_in_workspace, account_groups_in_account)
695655
yield from strategy.generate_migrated_groups()
696656

697657
def validate_group_membership(self) -> list[dict]:
698658
workspace_groups_in_workspace = self._workspace_groups_in_workspace()
699-
account_groups_in_account = self._account_groups_in_account()
659+
account_groups_in_account = self._account_groups_lookup.get_mapping()
700660
strategy = self._get_strategy(workspace_groups_in_workspace, account_groups_in_account)
701661
migrated_groups = list(strategy.generate_migrated_groups())
702662
mismatch_group = []
@@ -762,17 +722,9 @@ def _account_groups_in_workspace(self) -> dict[str, Group]:
762722
groups[group.display_name] = group
763723
return groups
764724

765-
def _account_groups_in_account(self) -> dict[str, Group]:
766-
groups = {}
767-
for group in self._list_account_groups("id,displayName,externalId"):
768-
if not group.display_name:
769-
logger.debug(f"Ignoring account group in without name: {group.id}")
770-
continue
771-
groups[group.display_name] = group
772-
return groups
773-
774-
def _is_group_out_of_scope(self, group: iam.Group, resource_type: str) -> bool:
775-
if group.display_name in self._SYSTEM_GROUPS:
725+
@staticmethod
726+
def _is_group_out_of_scope(group: iam.Group, resource_type: str) -> bool:
727+
if group.display_name in SYSTEM_GROUPS:
776728
return True
777729
meta = group.meta
778730
if not meta:
@@ -826,23 +778,6 @@ def _get_account_group(self, group_id: str) -> Group | None:
826778
logger.warning(f"Group with ID {group_id} does not exist anymore in the Databricks account.")
827779
return None
828780

829-
def _list_account_groups(self, scim_attributes: str) -> list[iam.Group]:
830-
# TODO: we should avoid using this method, as it's not documented
831-
# get account-level groups even if they're not (yet) assigned to a workspace
832-
logger.info(f"Listing account groups with {scim_attributes}...")
833-
account_groups = []
834-
raw = self._ws.api_client.do("GET", "/api/2.0/account/scim/v2/Groups", query={"attributes": scim_attributes})
835-
for resource in raw.get("Resources", []): # type: ignore[union-attr]
836-
group = iam.Group.from_dict(resource)
837-
if group.display_name in self._SYSTEM_GROUPS:
838-
continue
839-
account_groups.append(group)
840-
logger.info(f"Found {len(account_groups)} account groups")
841-
sorted_groups: list[iam.Group] = sorted(
842-
account_groups, key=lambda _: _.display_name if _.display_name else ""
843-
) # type: ignore[arg-type,return-value]
844-
return sorted_groups
845-
846781
def _delete_workspace_group_and_wait_for_deletion(self, group_id: str, display_name: str) -> str:
847782
logger.debug(f"Deleting workspace group: {display_name} (id={group_id})")
848783
self._delete_workspace_group(group_id, display_name)
@@ -975,6 +910,81 @@ def _get_strategy(
975910
)
976911

977912

913+
class AccountGroupLookup:
914+
def __init__(self, ws: WorkspaceClient):
915+
self._ws = ws
916+
917+
def pick_owner_group(self, prompt: Prompts) -> str | None:
918+
# This method is used to select the group that will be used as the owner group.
919+
# The owner group will be assigned by default to all migrated tables/schemas
920+
user_id = self._ws.current_user.me().id
921+
if not user_id:
922+
logger.error("Couldn't find the user id of the current user.")
923+
return None
924+
groups = self._user_account_groups(user_id)
925+
if not groups:
926+
logger.warning("No account groups found for the current user.")
927+
return None
928+
if len(groups) == 1:
929+
return groups[0].display_name
930+
group_names = [group.display_name for group in groups]
931+
return prompt.choice("Select the group to be used as the owner group", group_names, max_attempts=3)
932+
933+
def user_in_group(self, group_name: str, user: User) -> bool:
934+
# This method is used to validate that the current user is a member of the group
935+
user_id = user.id
936+
if not user_id:
937+
logger.warning("No user found for the current session.")
938+
return False
939+
groups = self._user_account_groups(user_id)
940+
if not groups:
941+
logger.warning("No account groups found for the current user.")
942+
return False
943+
group_names = [group.display_name for group in groups]
944+
return group_name in group_names
945+
946+
def _user_account_groups(self, user_id: str) -> list[Group]:
947+
# This method is used to find all the account groups that a user is a member of.
948+
groups: list[Group] = []
949+
account_groups = self._list_account_groups("id,displayName,externalId,members")
950+
if not account_groups:
951+
return groups
952+
for group in account_groups:
953+
if not group.members:
954+
continue
955+
for member in group.members:
956+
if member.value == user_id:
957+
groups.append(group)
958+
break
959+
return groups
960+
961+
def _list_account_groups(self, scim_attributes: str) -> list[iam.Group]:
962+
# TODO: we should avoid using this method, as it's not documented
963+
# get account-level groups even if they're not (yet) assigned to a workspace
964+
logger.info(f"Listing account groups with {scim_attributes}...")
965+
account_groups = []
966+
raw = self._ws.api_client.do("GET", "/api/2.0/account/scim/v2/Groups", query={"attributes": scim_attributes})
967+
for resource in raw.get("Resources", []): # type: ignore[union-attr]
968+
group = iam.Group.from_dict(resource)
969+
if group.display_name in SYSTEM_GROUPS:
970+
continue
971+
account_groups.append(group)
972+
logger.info(f"Found {len(account_groups)} account groups")
973+
sorted_groups: list[iam.Group] = sorted(
974+
account_groups, key=lambda _: _.display_name if _.display_name else ""
975+
) # type: ignore[arg-type,return-value]
976+
return sorted_groups
977+
978+
def get_mapping(self) -> dict[str, Group]:
979+
groups: dict[str, Group] = {}
980+
for group in self._list_account_groups("id,displayName,externalId"):
981+
if not group.display_name:
982+
logger.debug(f"Ignoring account group in without name: {group.id}")
983+
continue
984+
groups[group.display_name] = group
985+
return groups
986+
987+
978988
class ConfigureGroups:
979989
renamed_group_prefix = "db-temp-"
980990
workspace_group_regex = None

tests/unit/hive_metastore/test_grants.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from databricks.labs.ucx.hive_metastore.tables import Table, TablesCrawler
1515
from databricks.labs.ucx.hive_metastore.udfs import UdfsCrawler
1616
from databricks.labs.ucx.progress.history import ProgressEncoder
17-
from databricks.labs.ucx.workspace_access.groups import GroupManager
17+
from databricks.labs.ucx.workspace_access.groups import GroupManager, AccountGroupLookup
1818
from tests.unit import mock_workspace_client
1919

2020

@@ -956,7 +956,6 @@ def test_grant_supports_history(mock_backend, grant_record: Grant, history_recor
956956
# Testing the validation in retrival of the default owner group. 666 is the current_user user_id.
957957
@pytest.mark.parametrize("user_id, expected", [("666", True), ("777", False)])
958958
def test_default_owner(user_id, expected) -> None:
959-
sql_backend = MockBackend()
960959
ws = mock_workspace_client()
961960

962961
account_admins_group = Group(
@@ -966,5 +965,5 @@ def test_default_owner(user_id, expected) -> None:
966965
"Resources": [account_admins_group.as_dict()],
967966
}
968967

969-
group_manager = GroupManager(sql_backend, ws, "ucx")
970-
assert group_manager.validate_owner_group("owners") == expected
968+
account_group_lookup = AccountGroupLookup(ws)
969+
assert account_group_lookup.user_in_group("owners", ws.current_user.me()) == expected

tests/unit/hive_metastore/test_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def test_default_securable_ownership(
836836
Table("main", "foo", "baz", "VIEW", "UNKNOWN", None, "select * from bar"),
837837
]
838838
group_manager = create_autospec(GroupManager)
839-
group_manager.validate_owner_group.return_value = valid_admin
839+
group_manager.current_user_in_owner_group.return_value = valid_admin
840840

841841
ownership = DefaultSecurableOwnership(
842842
admin_locator, table_crawler, group_manager, default_owner_group, lambda: cli_user

tests/unit/install/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def download(path: str) -> io.StringIO | io.BytesIO:
7777
workspace_client = create_autospec(WorkspaceClient)
7878

7979
workspace_client.current_user.me = lambda: iam.User(
80-
user_name="[email protected]", groups=[iam.ComplexValue(display="admins")]
80+
user_name="[email protected]", id="666", groups=[iam.ComplexValue(display="admins")]
8181
)
8282
workspace_client.config.host = "https://foo"
8383
workspace_client.config.is_aws = True

0 commit comments

Comments
 (0)