Skip to content

Commit 3c6e6eb

Browse files
authored
Merge pull request #3932 from allegro/add-auto-region-assignment
Add auto region assignment based on group
2 parents d7079ed + bed5c9c commit 3c6e6eb

File tree

5 files changed

+161
-84
lines changed

5 files changed

+161
-84
lines changed

src/ralph/accounts/ldap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def mirror_groups(self):
5454
)
5555
if target_group_names != current_group_names:
5656
logger.info(
57-
"Modifing user groups: current = {}, target = {}".format(
57+
"Modifying user groups: current = {}, target = {}".format(
5858
", ".join(current_group_names), ", ".join(target_group_names)
5959
)
6060
)

src/ralph/accounts/ldap_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ def handle_groups(groups_dns):
7171
handle_groups(nested_groups_dns)
7272
return group_map
7373

74-
def group_name_from_info(self, group_info):
74+
def group_name_from_info(self, group_info) -> str | None:
7575
"""Map ldap group names into ralph names if mapping defined."""
7676
if self.ldap_groups:
7777
for dn in group_info[1]["distinguishedname"]:
78-
mapped = self.ldap_groups.get(dn)
79-
if mapped:
78+
if mapped := self.ldap_groups.get(dn):
8079
return mapped
80+
return None
8181
# return original name if mapping not defined
8282
else:
8383
return super(MappedGroupOfNamesType, self).group_name_from_info(group_info)

src/ralph/accounts/management/commands/ldap_sync.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from functools import lru_cache
1212
from ldap.controls import SimplePagedResultsControl
1313

14+
from ralph.accounts.models import RalphUser, Region
1415
from ralph.helpers import cache
1516

1617
logger = logging.getLogger(__name__)
@@ -68,20 +69,20 @@ def __exit__(self, type, value, traceback):
6869

6970

7071
@cache(seconds=600)
71-
def get_nested_groups():
72+
def get_nested_groups() -> tuple[dict[str, set[str]], defaultdict[str, set[str]]]:
7273
"""
7374
Fetching users in nested group based on custom LDAP filter
7475
(AUTH_LDAP_NESTED_FILTER) e.g. (memberOf:{}). AUTH_LDAP_NESTED_FILTER
7576
is a simple dictonary where key is the name of group in DB, the value
7677
contains DN for nested group.
7778
"""
7879
# mapping from django group name to set of users (usernames) belonging to it
79-
group_users = {}
80+
group_name_to_usernames: dict[str, set[str]] = {}
8081
# mapping from user (username) to set of groups DNs to which he belongs to
81-
users_groups = defaultdict(set)
82+
username_to_group_names: defaultdict[str, set[str]] = defaultdict(set)
8283
nested_groups = getattr(settings, "AUTH_LDAP_NESTED_GROUPS", None)
8384
if not nested_groups:
84-
return group_users, users_groups
85+
return group_name_to_usernames, username_to_group_names
8586
nested_filter = getattr(settings, "AUTH_LDAP_NESTED_FILTER", "(memberOf:{})")
8687
logger.info("Fetching nested groups from LDAP")
8788
with LDAPConnectionManager() as conn:
@@ -99,7 +100,7 @@ def get_nested_groups():
99100
settings.AUTH_LDAP_QUERY_PAGE_SIZE,
100101
)
101102
logger.info("{} fetched".format(ralph_group_name))
102-
group_users[ralph_group_name] = set(
103+
group_name_to_usernames[ralph_group_name] = set(
103104
[
104105
u[1][settings.AUTH_LDAP_USER_USERNAME_ATTR][0]
105106
.decode("utf-8")
@@ -109,13 +110,13 @@ def get_nested_groups():
109110
)
110111
logger.info(
111112
"Users in nested group {}: {}".format(
112-
ralph_group_name, group_users[ralph_group_name]
113+
ralph_group_name, group_name_to_usernames[ralph_group_name]
113114
)
114115
)
115-
for username in group_users[ralph_group_name]:
116+
for username in group_name_to_usernames[ralph_group_name]:
116117
# notice group DN here, not Django group name!
117-
users_groups[username].add(ldap_group_name)
118-
return group_users, users_groups
118+
username_to_group_names[username].add(ldap_group_name)
119+
return group_name_to_usernames, username_to_group_names
119120

120121

121122
def _make_paged_query(conn, search_base, search_scope, ad_query, attr_list, page_size):
@@ -156,7 +157,29 @@ def _make_paged_query(conn, search_base, search_scope, ad_query, attr_list, page
156157
return result
157158

158159

159-
class NestedGroups(object):
160+
def _add_regions(user: RalphUser, region_names: list[str]):
161+
for region_name in region_names:
162+
try:
163+
user.regions.add(Region.objects.get(name=region_name))
164+
logger.info("Assigned {} to region {}".format(user.username, region_name))
165+
except Region.DoesNotExist:
166+
logger.warning(
167+
"Region {} does not exist, cannot assign to user {}".format(
168+
region_name,
169+
user.username,
170+
)
171+
)
172+
173+
174+
def assign_user_to_group(user: RalphUser, group: Group):
175+
user.groups.add(group)
176+
default_regions = settings.DEFAULT_REGIONS_FOR_GROUP.get(group.name, [])
177+
_add_regions(user, default_regions)
178+
179+
logger.info("Added {} to {}".format(user.username, group.name))
180+
181+
182+
class NestedGroups:
160183
"""
161184
Class fetch nested groups and mapping them to standard Django's
162185
group (get or create). django_auth_ldap and their class for nested
@@ -170,19 +193,17 @@ def __init__(self):
170193
def get_group_from_db(self, name):
171194
return Group.objects.get_or_create(name=name)[0]
172195

173-
def handle(self, user):
196+
def handle(self, user: RalphUser):
174197
"""
175198
Match user to group in fetched groups from LDAP and assign user
176199
to Django's group.
177200
"""
178-
179201
if not self.group_users:
180202
return
181203
for group_name, users in self.group_users.items():
182204
if user.username in users:
183205
group = self.get_group_from_db(group_name)
184-
user.groups.add(group)
185-
logger.info("Added {} to {}".format(user.username, group_name))
206+
assign_user_to_group(user, group)
186207

187208

188209
class Command(BaseCommand):

src/ralph/accounts/tests/tests.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44

55
from django.conf import settings
66
from django.contrib.auth.hashers import check_password
7-
from django.contrib.auth.models import Permission
7+
from django.contrib.auth.models import Permission, Group
88
from django.test import TestCase
99
from django.urls import reverse
1010
from rest_framework import status
1111

1212
from ralph.accounts.ldap import manager_country_attribute_populate
13-
from ralph.accounts.management.commands.ldap_sync import _truncate, ldap_module_exists
13+
from ralph.accounts.management.commands.ldap_sync import (
14+
_truncate,
15+
ldap_module_exists,
16+
assign_user_to_group,
17+
)
1418
from ralph.accounts.models import RalphUser, Region
1519
from ralph.api.tests._base import RalphAPITestCase
1620
from ralph.assets.tests.factories import (
@@ -22,6 +26,7 @@
2226
from ralph.licences.tests.factories import LicenceFactory
2327
from ralph.tests import factories
2428
from ralph.tests.mixins import ClientMixin
29+
from django.test.utils import override_settings
2530

2631
NO_LDAP_MODULE = not ldap_module_exists
2732

@@ -234,3 +239,42 @@ def make_request():
234239
# Check if password is actually changed
235240
self.admin.refresh_from_db()
236241
check_password(new_password, self.admin.password)
242+
243+
244+
class RalphUserRegionTests(TestCase):
245+
def test_user_region_str(self):
246+
region = Region.objects.create(name="PL")
247+
user = factories.UserFactory()
248+
user.regions.add(region)
249+
self.assertEqual(str(user.regions.first()), "PL")
250+
251+
def test_automatic_region_assignment(self):
252+
region_pl = Region.objects.create(name="PL")
253+
region_cz = Region.objects.create(name="CZ")
254+
multi_region_group = Group.objects.create(name="Multi region group")
255+
user = factories.UserFactory()
256+
257+
with override_settings(
258+
DEFAULT_REGIONS_FOR_GROUP={"Multi region group": ["PL", "CZ"]}
259+
):
260+
assign_user_to_group(user, multi_region_group) # noqa
261+
262+
self.assertIn(region_cz, user.regions.all())
263+
self.assertIn(region_pl, user.regions.all())
264+
265+
def test_automatic_region_assignment_when_user_already_in_group(self):
266+
region_pl = Region.objects.create(name="PL")
267+
region_cz = Region.objects.create(name="CZ")
268+
multi_region_group = Group.objects.create(name="Multi region group")
269+
user = factories.UserFactory()
270+
assign_user_to_group(user, multi_region_group) # noqa
271+
272+
self.assertEqual(user.regions.all().count(), 0)
273+
274+
with override_settings(
275+
DEFAULT_REGIONS_FOR_GROUP={"Multi region group": ["PL", "CZ"]}
276+
):
277+
assign_user_to_group(user, multi_region_group) # noqa
278+
279+
self.assertIn(region_cz, user.regions.all())
280+
self.assertIn(region_pl, user.regions.all())

0 commit comments

Comments
 (0)