Skip to content

Commit 07f03f7

Browse files
authored
Improved validating groups membership cli command (#816)
1 parent 2ab1321 commit 07f03f7

File tree

4 files changed

+213
-22
lines changed

4 files changed

+213
-22
lines changed

labs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ commands:
107107
- name: validate-groups-membership
108108
description: Validate the groups to see if the groups at account level and workspace level have different membership
109109
table_template: |-
110-
Workflow Group Name\tAccount Group Name
111-
{{range .}}{{.name_in_workspace}}\t{{.name_in_group}}
110+
Workspace Group Name\tMembers Count\tAccount Group Name\tMembers Count
111+
{{range .}}{{.wf_group_name}}\t{{.wf_group_members_count}}\t{{.acc_group_name}}\t{{.acc_group_members_count}}
112112
{{end}}
113113
114114
- name: save-aws-iam-profiles

src/databricks/labs/ucx/workspace_access/groups.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -409,19 +409,24 @@ def validate_group_membership(self) -> list[dict]:
409409
strategy = self._get_strategy(workspace_groups_in_workspace, account_groups_in_account)
410410
migrated_groups = strategy.generate_migrated_groups()
411411
mismatch_group = []
412-
for groups in migrated_groups:
413-
ws_members_set = set([m.get("display") for m in json.loads(groups.members)] if groups.members else [])
414-
account_group = account_groups_in_account[groups.name_in_account]
415-
ac_members_set = set(
416-
[a.as_dict().get("display") for a in account_group.members] if account_group.members else []
417-
)
418-
set_diff = (ws_members_set - ac_members_set).union(ac_members_set - ws_members_set)
412+
retry_on_internal_error = retried(on=[InternalError], timeout=self._verify_timeout)
413+
get_account_group = retry_on_internal_error(self._get_account_group)
414+
for ws_group in migrated_groups:
415+
# Users with the same display name but different email will be deduplicated!
416+
ws_members_set = {m.get("display") for m in json.loads(ws_group.members)} if ws_group.members else set()
417+
acc_group = get_account_group(account_groups_in_account[ws_group.name_in_account].id)
418+
if not acc_group:
419+
continue # group not present anymore
420+
acc_members_set = {a.as_dict().get("display") for a in acc_group.members} if acc_group.members else set()
421+
set_diff = (ws_members_set - acc_members_set).union(acc_members_set - ws_members_set)
419422
if not set_diff:
420423
continue
421424
mismatch_group.append(
422425
{
423-
"wf_group_name": groups.name_in_workspace,
424-
"ac_group_name": groups.name_in_account,
426+
"wf_group_name": ws_group.name_in_workspace,
427+
"wf_group_members_count": len(ws_members_set),
428+
"acc_group_name": ws_group.name_in_account,
429+
"acc_group_members_count": len(acc_members_set),
425430
}
426431
)
427432
if not mismatch_group:
@@ -500,6 +505,16 @@ def _get_group(self, group_id: str) -> iam.Group | None:
500505
# which will cause timeout errors because of groups no longer there.
501506
return None
502507

508+
@rate_limited(max_requests=20)
509+
def _get_account_group(self, group_id: str) -> Group | None:
510+
try:
511+
raw = self._ws.api_client.do("GET", f"/api/2.0/account/scim/v2/Groups/{group_id}")
512+
return iam.Group.from_dict(raw) # type: ignore[arg-type]
513+
except NotFound:
514+
# the given group has been removed from the account after getting the group and before running this method
515+
logger.warning("Group with ID: %s does not exist anymore in the Databricks account.", group_id)
516+
return None
517+
503518
def _list_account_groups(self, scim_attributes: str) -> list[iam.Group]:
504519
# TODO: we should avoid using this method, as it's not documented
505520
# get account-level groups even if they're not (yet) assigned to a workspace

tests/integration/workspace_access/test_groups.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,36 @@ def test_group_matching_names_with_diff_users(
238238
assert len(t) > 0
239239

240240

241+
@retried(on=[NotFound], timeout=timedelta(minutes=2))
242+
def test_group_matching_names_with_same_users(
243+
ws, sql_backend, inventory_schema, make_random, make_user, make_group, make_acc_group
244+
):
245+
rand_elem = make_random(4)
246+
workspace_group_name = f"test_group_{rand_elem}"
247+
account_group_name = f"same_group_[{rand_elem}]"
248+
user1 = make_user()
249+
members1 = [user1.id]
250+
members2 = [user1.id]
251+
ws_group = make_group(display_name=workspace_group_name, members=members1, entitlements=["allow-cluster-create"])
252+
acc_group = make_acc_group(display_name=account_group_name, members=members2)
253+
254+
logger.info(
255+
f"Attempting Mapping From Workspace Group {ws_group.display_name} to Account Group {acc_group.display_name}"
256+
)
257+
group_manager = GroupManager(
258+
sql_backend,
259+
ws,
260+
inventory_schema,
261+
[ws_group.display_name],
262+
"ucx-temp-",
263+
workspace_group_regex=r"([0-9a-zA-Z]*)$",
264+
account_group_regex=r"\[([0-9a-zA-Z]*)\]",
265+
)
266+
267+
t = group_manager.validate_group_membership()
268+
assert len(t) == 0
269+
270+
241271
# average runtime is 100 seconds
242272
@retried(on=[NotFound], timeout=timedelta(minutes=3))
243273
def test_replace_workspace_groups_with_account_groups(

tests/unit/workspace_access/test_groups.py

Lines changed: 157 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from databricks.labs.blueprint.parallel import ManyError
77
from databricks.labs.blueprint.tui import MockPrompts
88
from databricks.sdk import WorkspaceClient
9-
from databricks.sdk.errors import DatabricksError, ResourceDoesNotExist
9+
from databricks.sdk.errors import DatabricksError, NotFound, ResourceDoesNotExist
1010
from databricks.sdk.service import iam
1111
from databricks.sdk.service.iam import ComplexValue, Group, ResourceMeta
1212

@@ -836,23 +836,121 @@ def test_validate_group_diff_membership():
836836
external_id="1234",
837837
display_name="test_(1234)",
838838
meta=ResourceMeta(resource_type="WorkspaceGroup"),
839-
members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")],
839+
members=[ComplexValue(display="test-user-1", value="1"), ComplexValue(display="test-user-2", value="2")],
840840
roles=[
841841
ComplexValue(value="arn:aws:iam::123456789098:instance-profile/ip1"),
842842
ComplexValue(value="arn:aws:iam::123456789098:instance-profile/ip2"),
843843
],
844844
entitlements=[ComplexValue(value="allow-cluster-create"), ComplexValue(value="allow-instance-pool-create")],
845845
)
846846
wsclient.groups.list.return_value = [group]
847-
wsclient.groups.get.return_value = group
848-
account_admins_group = Group(id="1234", external_id="1234", display_name="ac_test_1234")
849-
wsclient.api_client.do.return_value = {
850-
"Resources": [g.as_dict() for g in [account_admins_group]],
851-
}
847+
account_admins_group = Group(
848+
id="1234",
849+
external_id="1234",
850+
display_name="ac_test_1234",
851+
members=[ComplexValue(display="test-user-3", value="3")],
852+
)
853+
854+
def do_api_side_effect(*args, **kwargs):
855+
if args[0] == "GET":
856+
if args[1] == "/api/2.0/account/scim/v2/Groups":
857+
return {"Resources": [g.as_dict() for g in [account_admins_group]]}
858+
else:
859+
return account_admins_group.as_dict()
860+
else:
861+
raise RuntimeError()
862+
863+
wsclient.api_client.do.side_effect = do_api_side_effect
864+
wsclient.groups.get.side_effect = lambda group_id: group if group_id == "1" else account_admins_group
865+
grp_membership = GroupManager(
866+
backend, wsclient, inventory_database="inv", workspace_group_regex=r"\(([1-9]+)\)", account_group_regex="[1-9]+"
867+
).validate_group_membership()
868+
assert grp_membership == [
869+
{
870+
"wf_group_name": "test_(1234)",
871+
"wf_group_members_count": 2,
872+
"acc_group_name": "ac_test_1234",
873+
"acc_group_members_count": 1,
874+
}
875+
]
876+
877+
878+
def test_validate_group_diff_membership_no_members():
879+
backend = create_autospec(SqlBackend)
880+
wsclient = create_autospec(WorkspaceClient)
881+
group = Group(
882+
id="1",
883+
external_id="1234",
884+
display_name="test_(1234)",
885+
meta=ResourceMeta(resource_type="WorkspaceGroup"),
886+
members=None,
887+
roles=[
888+
ComplexValue(value="arn:aws:iam::123456789098:instance-profile/ip1"),
889+
],
890+
entitlements=[ComplexValue(value="allow-cluster-create"), ComplexValue(value="allow-instance-pool-create")],
891+
)
892+
wsclient.groups.list.return_value = [group]
893+
account_admins_group = Group(
894+
id="1234",
895+
external_id="1234",
896+
display_name="ac_test_1234",
897+
members=None,
898+
)
899+
900+
def do_api_side_effect(*args, **kwargs):
901+
if args[0] == "GET":
902+
if args[1] == "/api/2.0/account/scim/v2/Groups":
903+
return {"Resources": [g.as_dict() for g in [account_admins_group]]}
904+
else:
905+
return account_admins_group.as_dict()
906+
else:
907+
raise RuntimeError()
908+
909+
wsclient.api_client.do.side_effect = do_api_side_effect
910+
wsclient.groups.get.side_effect = lambda group_id: group if group_id == "1" else account_admins_group
852911
grp_membership = GroupManager(
853912
backend, wsclient, inventory_database="inv", workspace_group_regex=r"\(([1-9]+)\)", account_group_regex="[1-9]+"
854913
).validate_group_membership()
855-
assert grp_membership == [{"wf_group_name": "test_(1234)", "ac_group_name": "ac_test_1234"}]
914+
assert grp_membership == []
915+
916+
917+
def test_validate_group_diff_membership_no_account_group_found():
918+
backend = create_autospec(SqlBackend)
919+
wsclient = create_autospec(WorkspaceClient)
920+
group = Group(
921+
id="1",
922+
external_id="1234",
923+
display_name="test_(1234)",
924+
meta=ResourceMeta(resource_type="WorkspaceGroup"),
925+
members=None,
926+
roles=[
927+
ComplexValue(value="arn:aws:iam::123456789098:instance-profile/ip1"),
928+
],
929+
entitlements=[ComplexValue(value="allow-cluster-create"), ComplexValue(value="allow-instance-pool-create")],
930+
)
931+
wsclient.groups.list.return_value = [group]
932+
account_admins_group = Group(
933+
id="1234",
934+
external_id="1234",
935+
display_name="ac_test_1234",
936+
members=None,
937+
)
938+
939+
def do_api_side_effect(*args, **kwargs):
940+
if args[0] == "GET":
941+
if args[1] == "/api/2.0/account/scim/v2/Groups":
942+
return {"Resources": [g.as_dict() for g in [account_admins_group]]}
943+
else:
944+
return account_admins_group.as_dict()
945+
else:
946+
raise RuntimeError()
947+
948+
wsclient.api_client.do.side_effect = do_api_side_effect
949+
wsclient.groups.get.side_effect = lambda group_id: group if group_id == "1" else None
950+
grp_membership = GroupManager(
951+
backend, wsclient, inventory_database="inv", workspace_group_regex=r"\(([1-9]+)\)", account_group_regex="[1-9]+"
952+
).validate_group_membership()
953+
assert grp_membership == []
856954

857955

858956
def test_validate_group_same_membership():
@@ -878,9 +976,57 @@ def test_validate_group_same_membership():
878976
display_name="ac_test_1234",
879977
members=[ComplexValue(display="test-user-1", value="01"), ComplexValue(display="test-user-2", value="02")],
880978
)
881-
wsclient.api_client.do.return_value = {
882-
"Resources": [g.as_dict() for g in [account_admins_group]],
883-
}
979+
980+
def do_api_side_effect(*args, **kwargs):
981+
if args[0] == "GET":
982+
if args[1] == "/api/2.0/account/scim/v2/Groups":
983+
return {"Resources": [g.as_dict() for g in [account_admins_group]]}
984+
else:
985+
return account_admins_group.as_dict()
986+
else:
987+
raise RuntimeError()
988+
989+
wsclient.api_client.do.side_effect = do_api_side_effect
990+
grp_membership = GroupManager(
991+
backend, wsclient, inventory_database="inv", workspace_group_regex=r"\(([1-9]+)\)", account_group_regex="[1-9]+"
992+
).validate_group_membership()
993+
assert grp_membership == []
994+
995+
996+
def test_validate_acc_group_removed_after_listing():
997+
backend = MockBackend()
998+
wsclient = MagicMock()
999+
group = Group(
1000+
id="1",
1001+
external_id="1234",
1002+
display_name="test_(1234)",
1003+
meta=ResourceMeta(resource_type="WorkspaceGroup"),
1004+
members=[ComplexValue(display="test-user-1", value="01"), ComplexValue(display="test-user-2", value="02")],
1005+
roles=[
1006+
ComplexValue(value="arn:aws:iam::123456789098:instance-profile/test_ip1"),
1007+
ComplexValue(value="arn:aws:iam::123456789098:instance-profile/test_ip2"),
1008+
],
1009+
entitlements=[ComplexValue(value="allow-cluster-create"), ComplexValue(value="allow-instance-pool-create")],
1010+
)
1011+
wsclient.groups.list.return_value = [group]
1012+
wsclient.groups.get.return_value = group
1013+
account_admins_group = Group(
1014+
id="1234",
1015+
external_id="1234",
1016+
display_name="ac_test_1234",
1017+
members=[ComplexValue(display="test-user-1", value="01"), ComplexValue(display="test-user-2", value="02")],
1018+
)
1019+
1020+
def do_api_side_effect(*args, **kwargs):
1021+
if args[0] == "GET":
1022+
if args[1] == "/api/2.0/account/scim/v2/Groups":
1023+
return {"Resources": [g.as_dict() for g in [account_admins_group]]}
1024+
else:
1025+
raise NotFound()
1026+
else:
1027+
raise RuntimeError()
1028+
1029+
wsclient.api_client.do.side_effect = do_api_side_effect
8841030
grp_membership = GroupManager(
8851031
backend, wsclient, inventory_database="inv", workspace_group_regex=r"\(([1-9]+)\)", account_group_regex="[1-9]+"
8861032
).validate_group_membership()

0 commit comments

Comments
 (0)