diff --git a/src/backend/marsha/account/management/commands/dedupe_accounts.py b/src/backend/marsha/account/management/commands/dedupe_accounts.py index 1fe9ac7c5e..d1dc6cbf7a 100644 --- a/src/backend/marsha/account/management/commands/dedupe_accounts.py +++ b/src/backend/marsha/account/management/commands/dedupe_accounts.py @@ -1,114 +1,13 @@ """Command to dedupe accounts.""" -from difflib import SequenceMatcher import logging -from typing import Any -from django.contrib.auth import get_user_model from django.core.management.base import BaseCommand -from django.db import transaction -from django.db.models import Count +from marsha.account.utils.dedupe_accounts import dedupe_accounts -logger = logging.getLogger(__name__) - - -# pylint: disable=too-many-locals -def dedupe_accounts( - options: dict[str, Any], -) -> tuple[list[Any], dict[Any, Any], dict[Any, Any], list[Any], list[Any]]: - """Deduplicate accounts.""" - # pylint: disable=invalid-name - User = get_user_model() - - if options["email"]: - duplicates = [{"email": options["email"]}] - else: - duplicates = ( - User.objects.values("email") - .annotate(count=Count("id")) - .filter(count__gt=1) - .order_by("email") - ) - - accounts_to_delete = [] - duped_users = {} - organizations = {} - skipped_accounts = [] - users_to_delete = [] - for dup in duplicates: - email = dup["email"] - if not email: - continue - - logger.info("Deduping %s", email) - - users = list(User.objects.filter(email=email).order_by("date_joined")) - original_user, *duplicate_users = users - original_social = original_user.social_auth.first() - - for duplicate_user in duplicate_users: - new_social = duplicate_user.social_auth.first() - if not new_social: - continue - - old_account_email = original_social.uid.split(":")[1] - new_account_email = new_social.uid.split(":")[1] - - old_organization_uid = original_social.uid.split(":")[0] - new_organization_uid = new_social.uid.split(":")[0] - - account_email_ratio = SequenceMatcher( - None, old_account_email, new_account_email - ).ratio() - organization_ratio = SequenceMatcher( - None, old_organization_uid, new_organization_uid - ).ratio() - if old_account_email != new_account_email: - skipped_accounts.append( - [ - email, - [ - original_social.uid, - new_social.uid, - str(organization_ratio), - str(account_email_ratio), - ], - ] - ) - continue - - if old_organization_uid not in organizations: - organizations[old_organization_uid] = [new_organization_uid] - else: - if new_organization_uid not in organizations[old_organization_uid]: - organizations[old_organization_uid].append(new_organization_uid) - - if original_user.email not in duped_users: - duped_users[original_user.email] = [original_social.uid, new_social.uid] - else: - duped_users[original_user.email].append(new_social.uid) - users_to_delete.append(duplicate_user.email) - accounts_to_delete.append(original_social.uid) - - if not options["dry_run"]: - with transaction.atomic(): - original_user.social_auth.first().delete() - original_user.social_auth.set([new_social]) - for playlist in duplicate_user.playlists.exclude( - id__in=original_user.playlists.values_list("id", flat=True) - ): - original_user.playlists.add(playlist) - duplicate_user.delete() - - return ( - accounts_to_delete, - duped_users, - organizations, - skipped_accounts, - users_to_delete, - ) +logger = logging.getLogger(__name__) class Command(BaseCommand): @@ -118,52 +17,13 @@ class Command(BaseCommand): def add_arguments(self, parser): """Add arguments to the command.""" + parser.add_argument("--email", type=str, help="Email to dedupe") parser.add_argument("--dry-run", action="store_true") - parser.add_argument( - "--email", type=str, help="Email to dedupe (for testing purposes)" - ) def handle(self, *args, **options): """Handle command.""" - if options["dry_run"]: + dry_run = options["dry_run"] + if dry_run: logger.info("[DRY-RUN] No changes will be made.") - ( - accounts_to_delete, - duped_users, - organizations, - skipped_accounts, - users_to_delete, - ) = dedupe_accounts(options) - - logger.info("-" * 80) - logger.info( - "Deduping complete. %d SSO accounts deleted, %d users deleted", - len(accounts_to_delete), - len(users_to_delete), - ) - logger.info("- " * 40) - - logger.info("%d accounts skipped:", len(skipped_accounts)) - for email, accounts in skipped_accounts: - logger.info(" - %s | %s", email, " | ".join(accounts)) - logger.info("- " * 40) - - logger.info("%d organizations impacted:", len(organizations)) - for org_id, new_orgs in organizations.items(): - logger.info(" - %s -> %s", org_id, " -> ".join(new_orgs)) - logger.info("- " * 40) - - logger.info("%d users impacted:", len(duped_users)) - for email, accounts in duped_users.items(): - logger.info(" - %s -> %s", email, " -> ".join(accounts)) - logger.info("- " * 40) - - logger.info("Summary:") - logger.info(" %d organizations impacted", len(organizations)) - logger.info(" %d users processed", len(duped_users)) - logger.info(" %d users deleted", len(users_to_delete)) - logger.info(" %d SSO accounts deleted", len(accounts_to_delete)) - - if options["dry_run"]: - logger.info("[DRY-RUN] No changes made.") + dedupe_accounts(options["email"], dry_run) diff --git a/src/backend/marsha/account/tests/dedupe_accounts/__init__.py b/src/backend/marsha/account/tests/dedupe_accounts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/backend/marsha/account/tests/dedupe_accounts/test_dedupe_accounts_command.py b/src/backend/marsha/account/tests/dedupe_accounts/test_dedupe_accounts_command.py new file mode 100644 index 0000000000..4d04413fd3 --- /dev/null +++ b/src/backend/marsha/account/tests/dedupe_accounts/test_dedupe_accounts_command.py @@ -0,0 +1,503 @@ +"""Tests for the `dedupe_accounts` management command.""" + +from io import StringIO + +from django.contrib.auth import get_user_model +from django.core.management import call_command +from django.test import TestCase + +from social_django.models import UserSocialAuth + +from marsha.core.factories import ( + LtiUserAssociationFactory, + OrganizationAccessFactory, + PlaylistAccessFactory, +) +from marsha.core.models import ADMINISTRATOR, LtiUserAssociation, OrganizationAccess + + +# pylint: disable=invalid-name +User = get_user_model() + + +class DedupeAccountsCommandTest(TestCase): + """Test suite for dedupe_accounts command.""" + + maxDiff = None + + def setUp(self): + """ + Set up the test case. + """ + self.stdout = StringIO() + self.command_log_prefix = ( + "INFO:marsha.account.management.commands.dedupe_accounts:" + ) + self.log_prefix = "INFO:marsha.account.utils.dedupe_accounts.dedupe_tracker:" + + @staticmethod + def _add_social(user, uid="uid-xxx"): + return UserSocialAuth.objects.create(user=user, uid=uid) + + def _build_expected_logs( + self, expected_data, dry_run=False, include_command_prefix=False + ): + """Build expected log output from structured data.""" + logs = [] + + if include_command_prefix: + logs.append(f"{self.command_log_prefix}[DRY-RUN] No changes will be made.") + + # Header + logs.extend( + [ + f"{self.log_prefix}{'-' * 80}", + f"{self.log_prefix}Deduping complete", + f"{self.log_prefix}{'- ' * 40}", + ] + ) + + # Define sections in order + detail_sections = [ + ("organization UID migrations", "items"), + ("same account migrations", "items"), + ("different account merges", "items"), + ("organizations transferred", "items"), + ("LTI associations transferred", "items"), + ("LTI passports transferred", "items"), + ("consumer site accesses transferred", "items"), + ("playlist accesses transferred", "items"), + ("created playlists transferred", "items"), + ("created videos transferred", "items"), + ("created documents transferred", "items"), + ("created markdown documents transferred", "items"), + ("portability requests transferred", "items"), + ("user accounts deleted", "items"), + ("SSO accounts deleted", "items"), + ] + + # Detail sections + for section_name, _ in detail_sections: + section_data = expected_data.get(section_name, {}) + count = section_data.get("count", 0) + items = section_data.get("items", []) + + logs.append(f"{self.log_prefix}{count} {section_name}:") + for item in items: + logs.append(f"{self.log_prefix} - {item}") + logs.append(f"{self.log_prefix}{'- ' * 40}") + + # Summary section + logs.append(f"{self.log_prefix}Summary:") + for section_name, _ in detail_sections: + section_data = expected_data.get(section_name, {}) + count = section_data.get("count", 0) + logs.append(f"{self.log_prefix} {count} {section_name}") + + if dry_run: + logs.append(f"{self.log_prefix}[DRY-RUN] No changes made.") + + return logs + + def test_dedupe_accounts_dry_run(self): + """ + With the --dry-run flag, no changes should be made. + """ + playlist_access = PlaylistAccessFactory() + original_user = playlist_access.user + original_organization = OrganizationAccessFactory(user=original_user) + original_lti_association = LtiUserAssociationFactory(user=original_user) + original_social = self._add_social( + original_user, uid=f"old-uid:{original_user.email}" + ) + + playlist_access_duplicate = PlaylistAccessFactory( + user__email=original_user.email + ) + duplicate_user = playlist_access_duplicate.user + duplicate_organization = OrganizationAccessFactory(user=duplicate_user) + duplicate_lti_association = LtiUserAssociationFactory(user=duplicate_user) + duplicate_social = self._add_social( + duplicate_user, uid=f"new-uid:{original_user.email}" + ) + + with self.assertLogs("marsha.account", "INFO") as logs: + call_command("dedupe_accounts", "--dry-run", stdout=self.stdout) + + original_user.refresh_from_db() + self.assertEqual(original_user.social_auth.first(), original_social) + duplicate_user.refresh_from_db() + self.assertEqual(duplicate_user.social_auth.first(), duplicate_social) + self.assertTrue( + original_user.playlists.filter(id=playlist_access.playlist.id).exists() + ) + self.assertTrue( + duplicate_user.playlists.filter( + id=playlist_access_duplicate.playlist.id + ).exists() + ) + self.assertTrue( + OrganizationAccess.objects.filter( + id=original_organization.id, user=original_user + ).exists() + ) + self.assertTrue( + OrganizationAccess.objects.filter( + id=duplicate_organization.id, user=duplicate_user + ).exists() + ) + self.assertTrue( + LtiUserAssociation.objects.filter( + id=original_lti_association.id, user=original_user + ).exists() + ) + self.assertTrue( + LtiUserAssociation.objects.filter( + id=duplicate_lti_association.id, user=duplicate_user + ).exists() + ) + + lti_sites = " + ".join( + [ + original_lti_association.consumer_site.name, + duplicate_lti_association.consumer_site.name, + ] + ) + + expected_data = { + "organization UID migrations": { + "count": 1, + "items": ["old-uid -> new-uid"], + }, + "same account migrations": { + "count": 1, + "items": [ + f"{original_user.email} -> {original_social.uid} -> {duplicate_social.uid}" + ], + }, + "organizations transferred": { + "count": 1, + "items": [ + f"{original_user.email} : {original_organization.organization.name}" + f" + {duplicate_organization.organization.name}" + ], + }, + "LTI associations transferred": { + "count": 1, + "items": [f"{original_user.email} : {lti_sites}"], + }, + "playlist accesses transferred": { + "count": 1, + "items": [f"{original_user.email} : 1"], + }, + "user accounts deleted": {"count": 1, "items": [duplicate_user.username]}, + "SSO accounts deleted": {"count": 1, "items": [original_social.uid]}, + } + + self.assertListEqual( + self._build_expected_logs( + expected_data, dry_run=True, include_command_prefix=True + ), + logs.output, + ) + + def test_dedupe_accounts_1_duplicate(self): + """ + Accounts with 2 duplicates should be transferred correctly. + """ + playlist_access = PlaylistAccessFactory() + user = playlist_access.user + original_social = self._add_social(user, uid=f"old-uid:{user.email}") + playlist_access_duplicate = PlaylistAccessFactory(user__email=user.email) + duplicate_user = playlist_access_duplicate.user + duplicate_social = self._add_social(duplicate_user, uid=f"new-uid:{user.email}") + + with self.assertLogs("marsha.account", "INFO") as logs: + call_command("dedupe_accounts", stdout=self.stdout) + + user.refresh_from_db() + self.assertTrue(user.social_auth.filter(id=duplicate_social.id).exists()) + self.assertTrue( + user.playlists.filter(id=playlist_access_duplicate.playlist.id).exists() + ) + self.assertFalse(UserSocialAuth.objects.filter(id=original_social.id).exists()) + self.assertFalse(User.objects.filter(id=duplicate_user.id).exists()) + + expected_data = { + "organization UID migrations": { + "count": 1, + "items": ["old-uid -> new-uid"], + }, + "same account migrations": { + "count": 1, + "items": [ + f"{user.email} -> {original_social.uid} -> {duplicate_social.uid}" + ], + }, + "user accounts deleted": {"count": 1, "items": [duplicate_user.username]}, + "SSO accounts deleted": {"count": 1, "items": [original_social.uid]}, + } + + self.assertListEqual( + self._build_expected_logs(expected_data), + logs.output, + ) + + def test_dedupe_accounts_2_duplicates(self): + """ + Accounts with 3 duplicates should be transferred correctly. + """ + playlist_access = PlaylistAccessFactory() + user = playlist_access.user + original_social = self._add_social(user, uid=f"old-uid:{user.email}") + + playlist_access_duplicate = PlaylistAccessFactory(user__email=user.email) + duplicate_user_1 = playlist_access_duplicate.user + duplicate_social_1 = self._add_social( + duplicate_user_1, uid=f"new-uid:{user.email}" + ) + + playlist_access_duplicate_2 = PlaylistAccessFactory(user__email=user.email) + duplicate_user_2 = playlist_access_duplicate_2.user + duplicate_social_2 = self._add_social( + duplicate_user_2, uid=f"new-uid-2:{user.email}" + ) + + with self.assertLogs("marsha.account", "INFO") as logs: + call_command("dedupe_accounts", stdout=self.stdout) + + user.refresh_from_db() + self.assertFalse(UserSocialAuth.objects.filter(id=original_social.id).exists()) + self.assertFalse( + UserSocialAuth.objects.filter(id=duplicate_social_1.id).exists() + ) + self.assertTrue(user.social_auth.filter(id=duplicate_social_2.id).exists()) + + self.assertTrue(user.playlists.filter(id=playlist_access.playlist.id).exists()) + self.assertTrue( + user.playlists.filter(id=playlist_access_duplicate.playlist.id).exists() + ) + self.assertTrue( + user.playlists.filter(id=playlist_access_duplicate_2.playlist.id).exists() + ) + + self.assertTrue(User.objects.filter(id=user.id).exists()) + self.assertFalse(User.objects.filter(id=duplicate_user_1.id).exists()) + self.assertFalse(User.objects.filter(id=duplicate_user_2.id).exists()) + + expected_data = { + "organization UID migrations": { + "count": 1, + "items": ["old-uid -> new-uid -> new-uid-2"], + }, + "same account migrations": { + "count": 1, + "items": [ + f"{user.email} -> {original_social.uid}" + f" -> {duplicate_social_1.uid} -> {duplicate_social_2.uid}" + ], + }, + "user accounts deleted": { + "count": 2, + "items": [duplicate_user_1.username, duplicate_user_2.username], + }, + "SSO accounts deleted": { + "count": 2, + "items": [original_social.uid, duplicate_social_1.uid], + }, + } + + self.assertListEqual( + self._build_expected_logs(expected_data), + logs.output, + ) + + # pylint: disable=too-many-locals,too-many-statements + def test_dedupe_accounts_2_duplicates_2_skips(self): + """ + Accounts with 3 duplicates should be transferred correctly. + 2 duplicates should be skipped. + """ + playlist_access = PlaylistAccessFactory( + user__email="user@example.com", + user__username="user@example.com", + ) + user = playlist_access.user + original_social = self._add_social(user, uid=f"old-uid:{user.email}") + original_orga = OrganizationAccessFactory( + user=user, role=ADMINISTRATOR + ).organization + original_lti_association = LtiUserAssociationFactory(user=user) + + playlist_access_duplicate = PlaylistAccessFactory( + user__email=user.email, + user__username=f"{user.email}-1", + ) + duplicate_user_1 = playlist_access_duplicate.user + duplicate_social_1 = self._add_social( + duplicate_user_1, uid=f"new-uid:{user.email}" + ) + duplicate_orga_1 = OrganizationAccessFactory( + user=duplicate_user_1, role=ADMINISTRATOR + ).organization + duplicate_lti_association_1 = LtiUserAssociationFactory(user=duplicate_user_1) + + playlist_access_duplicate_2 = PlaylistAccessFactory( + user__email=user.email, user__username=f"{user.email}-2" + ) + duplicate_user_2 = playlist_access_duplicate_2.user + duplicate_social_2 = self._add_social( + duplicate_user_2, uid=f"new-uid-2:{user.email}" + ) + duplicate_orga_2 = OrganizationAccessFactory( + user=duplicate_user_2, role=ADMINISTRATOR + ).organization + duplicate_lti_association_2 = LtiUserAssociationFactory(user=duplicate_user_2) + + playlist_access_skip = PlaylistAccessFactory( + user__email=user.email, user__username=f"{user.email}-3" + ) + skip_user_1 = playlist_access_skip.user + kept_social_1 = self._add_social(skip_user_1, uid=f"new-uid:sk_{user.email}") + kept_orga_1 = OrganizationAccessFactory( + user=skip_user_1, role=ADMINISTRATOR + ).organization + kept_lti_association_1 = LtiUserAssociationFactory(user=skip_user_1) + + playlist_access_skip_2 = PlaylistAccessFactory( + user__email=user.email, user__username=f"{user.email}-4" + ) + skip_user_2 = playlist_access_skip_2.user + kept_social_2 = self._add_social(skip_user_2, uid=f"new-uid-2:sk_{user.email}") + kept_orga_2 = OrganizationAccessFactory( + user=skip_user_2, role=ADMINISTRATOR + ).organization + kept_lti_association_2 = LtiUserAssociationFactory(user=skip_user_2) + + with self.assertLogs("marsha.account", "INFO") as logs: + call_command("dedupe_accounts", stdout=self.stdout) + + user.refresh_from_db() + self.assertFalse(UserSocialAuth.objects.filter(id=original_social.id).exists()) + self.assertFalse( + UserSocialAuth.objects.filter(id=duplicate_social_1.id).exists() + ) + self.assertTrue( + user.social_auth.filter(id=duplicate_social_2.id).exists(), + duplicate_social_2.uid, + ) + + self.assertTrue(user.playlists.filter(id=playlist_access.playlist.id).exists()) + self.assertTrue( + user.playlists.filter(id=playlist_access_duplicate.playlist.id).exists() + ) + self.assertTrue( + user.playlists.filter(id=playlist_access_duplicate_2.playlist.id).exists() + ) + self.assertTrue( + user.playlists.filter(id=playlist_access_skip.playlist.id).exists() + ) + self.assertTrue( + user.playlists.filter(id=playlist_access_skip_2.playlist.id).exists() + ) + + self.assertTrue(User.objects.filter(id=user.id).exists()) + self.assertFalse(User.objects.filter(id=duplicate_user_1.id).exists()) + self.assertFalse(User.objects.filter(id=duplicate_user_2.id).exists()) + self.assertFalse(User.objects.filter(id=skip_user_1.id).exists()) + self.assertFalse(User.objects.filter(id=skip_user_2.id).exists()) + + self.assertTrue(user.organization_set.filter(id=original_orga.id).exists()) + self.assertTrue(user.organization_set.filter(id=duplicate_orga_1.id).exists()) + self.assertTrue(user.organization_set.filter(id=duplicate_orga_2.id).exists()) + self.assertTrue(user.organization_set.filter(id=kept_orga_1.id).exists()) + self.assertTrue(user.organization_set.filter(id=kept_orga_2.id).exists()) + + self.assertEqual( + user.organization_accesses.get(organization=original_orga).role, + ADMINISTRATOR, + ) + self.assertEqual( + user.organization_accesses.get(organization=duplicate_orga_1).role, + ADMINISTRATOR, + ) + self.assertEqual( + user.organization_accesses.get(organization=duplicate_orga_2).role, + ADMINISTRATOR, + ) + self.assertEqual( + user.organization_accesses.get(organization=kept_orga_1).role, ADMINISTRATOR + ) + self.assertEqual( + user.organization_accesses.get(organization=kept_orga_2).role, ADMINISTRATOR + ) + + self.assertEqual(original_lti_association.user, user) + duplicate_lti_association_1.refresh_from_db() + self.assertEqual(duplicate_lti_association_1.user, user) + duplicate_lti_association_2.refresh_from_db() + self.assertEqual(duplicate_lti_association_2.user, user) + kept_lti_association_1.refresh_from_db() + self.assertEqual(kept_lti_association_1.user, user) + kept_lti_association_2.refresh_from_db() + self.assertEqual(kept_lti_association_2.user, user) + + lti_sites = " + ".join( + [ + original_lti_association.consumer_site.name, + duplicate_lti_association_1.consumer_site.name, + duplicate_lti_association_2.consumer_site.name, + kept_lti_association_1.consumer_site.name, + kept_lti_association_2.consumer_site.name, + ] + ) + + expected_data = { + "organization UID migrations": { + "count": 1, + "items": ["old-uid -> new-uid -> new-uid-2"], + }, + "same account migrations": { + "count": 1, + "items": [ + f"{user.email} -> {original_social.uid}" + f" -> {duplicate_social_1.uid} -> {duplicate_social_2.uid}" + ], + }, + "different account merges": { + "count": 1, + "items": [ + f"{user.email} : {duplicate_social_2.uid}" + f" + {kept_social_1.uid} + {kept_social_2.uid}" + ], + }, + "organizations transferred": { + "count": 1, + "items": [ + f"{user.email} : {original_orga.name} + {duplicate_orga_1.name}" + f" + {duplicate_orga_2.name} + {kept_orga_1.name} + {kept_orga_2.name}" + ], + }, + "LTI associations transferred": { + "count": 1, + "items": [f"{user.email} : {lti_sites}"], + }, + "user accounts deleted": { + "count": 4, + "items": [ + duplicate_user_1.username, + duplicate_user_2.username, + skip_user_1.username, + skip_user_2.username, + ], + }, + "SSO accounts deleted": { + "count": 2, + "items": [original_social.uid, duplicate_social_1.uid], + }, + } + + self.assertListEqual( + self._build_expected_logs(expected_data), + logs.output, + ) diff --git a/src/backend/marsha/account/tests/dedupe_accounts/test_dedupe_tracker.py b/src/backend/marsha/account/tests/dedupe_accounts/test_dedupe_tracker.py new file mode 100644 index 0000000000..353140408b --- /dev/null +++ b/src/backend/marsha/account/tests/dedupe_accounts/test_dedupe_tracker.py @@ -0,0 +1,439 @@ +"""Tests for the DedupeTracker class.""" + +from django.test import TestCase + +from marsha.account.utils.dedupe_accounts.dedupe_tracker import DedupeTracker +from marsha.core.factories import ( + LtiUserAssociationFactory, + OrganizationAccessFactory, + UserFactory, +) +from marsha.core.models import ADMINISTRATOR, LtiUserAssociation + + +class DedupeTrackerTest(TestCase): + """Test suite for DedupeTracker class.""" + + def setUp(self): + """Set up test case.""" + self.tracker = DedupeTracker() + + def test_initialization(self): + """Test tracker initializes with empty collections.""" + self.assertEqual(self.tracker.deleted_sso_accounts, []) + self.assertEqual(self.tracker.same_account_migrations, {}) + self.assertEqual(self.tracker.organization_uid_migrations, {}) + self.assertEqual(self.tracker.deleted_user_accounts, []) + self.assertEqual(self.tracker.different_account_merges, {}) + self.assertEqual(self.tracker.transferred_organizations, {}) + self.assertEqual(self.tracker.transferred_lti_associations, {}) + self.assertEqual(self.tracker.transferred_lti_passports, {}) + self.assertEqual(self.tracker.transferred_consumersite_accesses, {}) + self.assertEqual(self.tracker.transferred_playlist_accesses, {}) + self.assertEqual(self.tracker.transferred_created_playlists, {}) + self.assertEqual(self.tracker.transferred_created_videos, {}) + self.assertEqual(self.tracker.transferred_created_documents, {}) + self.assertEqual(self.tracker.transferred_created_markdown_documents, {}) + self.assertEqual(self.tracker.transferred_portability_requests, {}) + + def test_track_merged_account_first_entry(self): + """Test tracking merged account creates new entry.""" + self.tracker.track_different_account_merge( + "user@test.com", "original-uid", "new-uid" + ) + + self.assertEqual( + self.tracker.different_account_merges["user@test.com"], + ["original-uid", "new-uid"], + ) + + def test_track_merged_account_append(self): + """Test tracking merged account appends to existing entry.""" + self.tracker.track_different_account_merge("user@test.com", "uid-1", "uid-2") + self.tracker.track_different_account_merge("user@test.com", "uid-1", "uid-3") + + self.assertEqual( + self.tracker.different_account_merges["user@test.com"], + ["uid-1", "uid-2", "uid-3"], + ) + + def test_track_duplicate_account_first_entry(self): + """Test tracking duplicate account creates new entry.""" + self.tracker.track_same_account_migration( + "user@test.com", "original-uid", "new-uid" + ) + + self.assertEqual( + self.tracker.same_account_migrations["user@test.com"], + ["original-uid", "new-uid"], + ) + + def test_track_duplicate_account_append(self): + """Test tracking duplicate account appends to existing entry.""" + self.tracker.track_same_account_migration("user@test.com", "uid-1", "uid-2") + self.tracker.track_same_account_migration("user@test.com", "uid-1", "uid-3") + + self.assertEqual( + self.tracker.same_account_migrations["user@test.com"], + ["uid-1", "uid-2", "uid-3"], + ) + + def test_track_organization_duplicate_first_entry(self): + """Test tracking organization duplicate creates new entry.""" + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid") + + self.assertEqual( + self.tracker.organization_uid_migrations["old-org-uid"], ["new-org-uid"] + ) + + def test_track_organization_duplicate_append(self): + """Test tracking organization duplicate appends new UIDs.""" + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid-1") + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid-2") + + self.assertEqual( + self.tracker.organization_uid_migrations["old-org-uid"], + ["new-org-uid-1", "new-org-uid-2"], + ) + + def test_track_organization_duplicate_no_duplicates(self): + """Test tracking organization duplicate doesn't add duplicates.""" + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid") + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid") + + self.assertEqual( + self.tracker.organization_uid_migrations["old-org-uid"], ["new-org-uid"] + ) + + def test_track_organization_duplicate_no_consecutive_duplicates(self): + """Test tracking organization duplicate doesn't add consecutive duplicates.""" + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid-1") + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid-2") + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid-2") + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid-3") + self.tracker.track_organization_uid_migration("old-org-uid", "new-org-uid-2") + + self.assertEqual( + self.tracker.organization_uid_migrations["old-org-uid"], + ["new-org-uid-1", "new-org-uid-2", "new-org-uid-3", "new-org-uid-2"], + ) + + def test_mark_user_for_deletion(self): + """Test marking user for deletion.""" + self.tracker.mark_user_for_deletion("user1") + self.tracker.mark_user_for_deletion("user2") + + self.assertEqual(self.tracker.deleted_user_accounts, ["user1", "user2"]) + + def test_mark_account_for_deletion(self): + """Test marking account for deletion.""" + self.tracker.mark_account_for_deletion("account-uid-1") + self.tracker.mark_account_for_deletion("account-uid-2") + + self.assertEqual( + self.tracker.deleted_sso_accounts, ["account-uid-1", "account-uid-2"] + ) + + def test_get_or_init_transferred_orgs_creates_new(self): + """Test get_or_init_transferred_orgs creates new entry.""" + access_1 = OrganizationAccessFactory(role=ADMINISTRATOR) + access_2 = OrganizationAccessFactory(role=ADMINISTRATOR) + user = access_1.user + OrganizationAccessFactory( + user=user, organization=access_2.organization, role=ADMINISTRATOR + ) + + result = self.tracker.get_or_init_transferred_orgs(user) + + self.assertIn(access_1.organization.name, result) + self.assertIn(access_2.organization.name, result) + self.assertEqual(self.tracker.transferred_organizations[user.email], result) + + def test_get_or_init_transferred_orgs_returns_existing(self): + """Test get_or_init_transferred_orgs returns existing entry.""" + access = OrganizationAccessFactory(role=ADMINISTRATOR) + user = access.user + self.tracker.transferred_organizations[user.email] = [access.organization.name] + + result = self.tracker.get_or_init_transferred_orgs(user) + + self.assertEqual(result, [access.organization.name]) + + def test_mark_transferred_orgs_first_time(self): + """Test mark_transferred_orgs creates new entry and adds organizations.""" + access_1 = OrganizationAccessFactory(role=ADMINISTRATOR) + access_2 = OrganizationAccessFactory(role=ADMINISTRATOR) + user = access_1.user + OrganizationAccessFactory( + user=user, organization=access_2.organization, role=ADMINISTRATOR + ) + + # Create organizations to transfer + duplicate_org_access_1 = OrganizationAccessFactory(role=ADMINISTRATOR) + duplicate_org_access_2 = OrganizationAccessFactory(role=ADMINISTRATOR) + organizations_to_transfer = [ + duplicate_org_access_1.organization, + duplicate_org_access_2.organization, + ] + + self.tracker.mark_transferred_orgs(user, organizations_to_transfer) + + self.assertIn(user.email, self.tracker.transferred_organizations) + transferred = self.tracker.transferred_organizations[user.email] + # Should include original orgs plus transferred ones + self.assertIn(access_1.organization.name, transferred) + self.assertIn(access_2.organization.name, transferred) + self.assertIn(duplicate_org_access_1.organization.name, transferred) + self.assertIn(duplicate_org_access_2.organization.name, transferred) + + def test_mark_transferred_orgs_appends_to_existing(self): + """Test mark_transferred_orgs appends to existing entry.""" + access = OrganizationAccessFactory(role=ADMINISTRATOR) + user = access.user + + # Pre-populate with one organization + self.tracker.transferred_organizations[user.email] = [access.organization.name] + + # Add more organizations + duplicate_org_access = OrganizationAccessFactory(role=ADMINISTRATOR) + organizations_to_transfer = [duplicate_org_access.organization] + + self.tracker.mark_transferred_orgs(user, organizations_to_transfer) + + transferred = self.tracker.transferred_organizations[user.email] + self.assertEqual(len(transferred), 2) + self.assertIn(access.organization.name, transferred) + self.assertIn(duplicate_org_access.organization.name, transferred) + + def test_mark_transferred_orgs_empty_list(self): + """Test mark_transferred_orgs with empty list.""" + access = OrganizationAccessFactory(role=ADMINISTRATOR) + user = access.user + + self.tracker.mark_transferred_orgs(user, []) + + self.assertIn(user.email, self.tracker.transferred_organizations) + # Should only have the user's existing organization + transferred = self.tracker.transferred_organizations[user.email] + self.assertEqual(len(transferred), 1) + self.assertIn(access.organization.name, transferred) + + def test_get_or_init_transferred_lti_associations_creates_new(self): + """Test get_or_init_transferred_lti_associations creates new entry.""" + lti_association_1 = LtiUserAssociationFactory() + lti_association_2 = LtiUserAssociationFactory() + user = lti_association_1.user + LtiUserAssociationFactory( + user=user, consumer_site=lti_association_2.consumer_site + ) + + result = self.tracker.get_or_init_transferred_lti_associations(user) + + self.assertIn(lti_association_1.consumer_site.name, result) + self.assertIn(lti_association_2.consumer_site.name, result) + self.assertEqual(self.tracker.transferred_lti_associations[user.email], result) + + def test_get_or_init_transferred_lti_associations_returns_existing(self): + """Test get_or_init_transferred_lti_associations returns existing entry.""" + lti_association = LtiUserAssociationFactory() + user = lti_association.user + self.tracker.transferred_lti_associations[user.email] = [ + lti_association.consumer_site.name + ] + + result = self.tracker.get_or_init_transferred_lti_associations(user) + + self.assertEqual(result, [lti_association.consumer_site.name]) + + def test_mark_transferred_lti_association_first_time(self): + """Test mark_transferred_lti_association creates new entry and adds associations.""" + lti_association_1 = LtiUserAssociationFactory() + lti_association_2 = LtiUserAssociationFactory() + user = lti_association_1.user + LtiUserAssociationFactory( + user=user, consumer_site=lti_association_2.consumer_site + ) + + # Create LTI associations to transfer + duplicate_lti_association_1 = LtiUserAssociationFactory() + duplicate_lti_association_2 = LtiUserAssociationFactory() + lti_associations_to_transfer = LtiUserAssociation.objects.filter( + id__in=[duplicate_lti_association_1.id, duplicate_lti_association_2.id] + ) + + self.tracker.mark_transferred_lti_association( + user, lti_associations_to_transfer + ) + + self.assertIn(user.email, self.tracker.transferred_lti_associations) + transferred = self.tracker.transferred_lti_associations[user.email] + # Should include original associations plus transferred ones + self.assertIn(lti_association_1.consumer_site.name, transferred) + self.assertIn(lti_association_2.consumer_site.name, transferred) + self.assertIn(duplicate_lti_association_1.consumer_site.name, transferred) + self.assertIn(duplicate_lti_association_2.consumer_site.name, transferred) + + def test_mark_transferred_lti_association_appends_to_existing(self): + """Test mark_transferred_lti_association appends to existing entry.""" + lti_association = LtiUserAssociationFactory() + user = lti_association.user + + # Pre-populate with one association + self.tracker.transferred_lti_associations[user.email] = [ + lti_association.consumer_site.name + ] + + # Add more associations + duplicate_lti_association = LtiUserAssociationFactory() + lti_associations_to_transfer = LtiUserAssociation.objects.filter( + id=duplicate_lti_association.id + ) + + self.tracker.mark_transferred_lti_association( + user, lti_associations_to_transfer + ) + + transferred = self.tracker.transferred_lti_associations[user.email] + self.assertEqual(len(transferred), 2) + self.assertIn(lti_association.consumer_site.name, transferred) + self.assertIn(duplicate_lti_association.consumer_site.name, transferred) + + def test_mark_transferred_lti_association_empty_queryset(self): + """Test mark_transferred_lti_association with empty queryset.""" + lti_association = LtiUserAssociationFactory() + user = lti_association.user + + # Empty queryset + empty_queryset = LtiUserAssociation.objects.none() + + self.tracker.mark_transferred_lti_association(user, empty_queryset) + + self.assertIn(user.email, self.tracker.transferred_lti_associations) + # Should only have the user's existing association + transferred = self.tracker.transferred_lti_associations[user.email] + self.assertEqual(len(transferred), 1) + self.assertIn(lti_association.consumer_site.name, transferred) + + def test_mark_transferred_lti_passports_first_time(self): + """Test mark_transferred_lti_passports creates new entry.""" + user = UserFactory() + self.tracker.mark_transferred_lti_passports(user, 3) + + self.assertEqual(self.tracker.transferred_lti_passports[user.email], 3) + + def test_mark_transferred_lti_passports_accumulates(self): + """Test mark_transferred_lti_passports accumulates counts.""" + user = UserFactory() + self.tracker.mark_transferred_lti_passports(user, 2) + self.tracker.mark_transferred_lti_passports(user, 3) + + self.assertEqual(self.tracker.transferred_lti_passports[user.email], 5) + + def test_mark_transferred_consumersite_accesses_first_time(self): + """Test mark_transferred_consumersite_accesses creates new entry.""" + user = UserFactory() + self.tracker.mark_transferred_consumersite_accesses(user, 2) + + self.assertEqual(self.tracker.transferred_consumersite_accesses[user.email], 2) + + def test_mark_transferred_consumersite_accesses_accumulates(self): + """Test mark_transferred_consumersite_accesses accumulates counts.""" + user = UserFactory() + self.tracker.mark_transferred_consumersite_accesses(user, 1) + self.tracker.mark_transferred_consumersite_accesses(user, 2) + + self.assertEqual(self.tracker.transferred_consumersite_accesses[user.email], 3) + + def test_mark_transferred_playlist_accesses_first_time(self): + """Test mark_transferred_playlist_accesses creates new entry.""" + user = UserFactory() + self.tracker.mark_transferred_playlist_accesses(user, 4) + + self.assertEqual(self.tracker.transferred_playlist_accesses[user.email], 4) + + def test_mark_transferred_playlist_accesses_accumulates(self): + """Test mark_transferred_playlist_accesses accumulates counts.""" + user = UserFactory() + self.tracker.mark_transferred_playlist_accesses(user, 2) + self.tracker.mark_transferred_playlist_accesses(user, 1) + + self.assertEqual(self.tracker.transferred_playlist_accesses[user.email], 3) + + def test_mark_transferred_created_playlists_first_time(self): + """Test mark_transferred_created_playlists creates new entry.""" + user = UserFactory() + self.tracker.mark_transferred_created_playlists(user, 5) + + self.assertEqual(self.tracker.transferred_created_playlists[user.email], 5) + + def test_mark_transferred_created_playlists_accumulates(self): + """Test mark_transferred_created_playlists accumulates counts.""" + user = UserFactory() + self.tracker.mark_transferred_created_playlists(user, 3) + self.tracker.mark_transferred_created_playlists(user, 2) + + self.assertEqual(self.tracker.transferred_created_playlists[user.email], 5) + + def test_mark_transferred_created_videos_first_time(self): + """Test mark_transferred_created_videos creates new entry.""" + user = UserFactory() + self.tracker.mark_transferred_created_videos(user, 10) + + self.assertEqual(self.tracker.transferred_created_videos[user.email], 10) + + def test_mark_transferred_created_videos_accumulates(self): + """Test mark_transferred_created_videos accumulates counts.""" + user = UserFactory() + self.tracker.mark_transferred_created_videos(user, 7) + self.tracker.mark_transferred_created_videos(user, 3) + + self.assertEqual(self.tracker.transferred_created_videos[user.email], 10) + + def test_mark_transferred_created_documents_first_time(self): + """Test mark_transferred_created_documents creates new entry.""" + user = UserFactory() + self.tracker.mark_transferred_created_documents(user, 6) + + self.assertEqual(self.tracker.transferred_created_documents[user.email], 6) + + def test_mark_transferred_created_documents_accumulates(self): + """Test mark_transferred_created_documents accumulates counts.""" + user = UserFactory() + self.tracker.mark_transferred_created_documents(user, 4) + self.tracker.mark_transferred_created_documents(user, 2) + + self.assertEqual(self.tracker.transferred_created_documents[user.email], 6) + + def test_mark_transferred_created_markdown_documents_first_time(self): + """Test mark_transferred_created_markdown_documents creates new entry.""" + user = UserFactory() + self.tracker.mark_transferred_created_markdown_documents(user, 3) + + self.assertEqual( + self.tracker.transferred_created_markdown_documents[user.email], 3 + ) + + def test_mark_transferred_created_markdown_documents_accumulates(self): + """Test mark_transferred_created_markdown_documents accumulates counts.""" + user = UserFactory() + self.tracker.mark_transferred_created_markdown_documents(user, 1) + self.tracker.mark_transferred_created_markdown_documents(user, 2) + + self.assertEqual( + self.tracker.transferred_created_markdown_documents[user.email], 3 + ) + + def test_mark_transferred_portability_requests_first_time(self): + """Test mark_transferred_portability_requests creates new entry.""" + user = UserFactory() + self.tracker.mark_transferred_portability_requests(user, 2) + + self.assertEqual(self.tracker.transferred_portability_requests[user.email], 2) + + def test_mark_transferred_portability_requests_accumulates(self): + """Test mark_transferred_portability_requests accumulates counts.""" + user = UserFactory() + self.tracker.mark_transferred_portability_requests(user, 1) + self.tracker.mark_transferred_portability_requests(user, 3) + + self.assertEqual(self.tracker.transferred_portability_requests[user.email], 4) diff --git a/src/backend/marsha/account/tests/dedupe_accounts/test_user_deduplicator.py b/src/backend/marsha/account/tests/dedupe_accounts/test_user_deduplicator.py new file mode 100644 index 0000000000..6199d9ea4e --- /dev/null +++ b/src/backend/marsha/account/tests/dedupe_accounts/test_user_deduplicator.py @@ -0,0 +1,721 @@ +"""Tests for the UserDeduplicator class.""" + +from unittest.mock import Mock + +from django.contrib.auth import get_user_model +from django.test import TestCase + +from social_django.models import UserSocialAuth + +from marsha.account.utils.dedupe_accounts import UserDeduplicator +from marsha.account.utils.dedupe_accounts.dedupe_tracker import DedupeTracker +from marsha.core.factories import ( + ConsumerSiteAccessFactory, + ConsumerSiteFactory, + ConsumerSiteLTIPassportFactory, + DocumentFactory, + LtiUserAssociationFactory, + OrganizationAccessFactory, + PlaylistAccessFactory, + PlaylistFactory, + PortabilityRequestFactory, + UserFactory, + VideoFactory, +) +from marsha.core.models import ADMINISTRATOR, ConsumerSiteAccess, PlaylistAccess +from marsha.markdown.factories import MarkdownDocumentFactory + + +# pylint: disable=invalid-name +User = get_user_model() + + +class UserDeduplicatorTest(TestCase): + """Test suite for UserDeduplicator class.""" + + def setUp(self): + """Set up test case.""" + self.deduplicator = UserDeduplicator(dry_run=False) + + def test_initialization(self): + """Test deduplicator initializes correctly.""" + self.assertFalse(self.deduplicator.dry_run) + self.assertIsInstance(self.deduplicator.tracker, DedupeTracker) + + def test_initialization_dry_run(self): + """Test deduplicator initializes with dry_run.""" + deduplicator = UserDeduplicator(dry_run=True) + self.assertTrue(deduplicator.dry_run) + + def test_get_duplicate_emails_with_specific_email(self): + """Test get_duplicate_emails with specific email.""" + result = UserDeduplicator.get_duplicate_emails("test@example.com") + + self.assertEqual(result, [{"email": "test@example.com"}]) + + def test_get_duplicate_emails_without_specific_email(self): + """Test get_duplicate_emails queries duplicates.""" + PlaylistAccessFactory(user__email="dup@test.com") + PlaylistAccessFactory(user__email="dup@test.com") + PlaylistAccessFactory(user__email="unique@test.com") + + result = list(UserDeduplicator.get_duplicate_emails(None)) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["email"], "dup@test.com") + + def test_parse_social_uid_with_valid_uid(self): + """Test parse_social_uid with valid format.""" + mock_social = Mock(uid="org-123:user@example.com") + + org_uid, account_email = UserDeduplicator.parse_social_uid(mock_social) + + self.assertEqual(org_uid, "org-123") + self.assertEqual(account_email, "user@example.com") + + def test_parse_social_uid_with_no_colon(self): + """Test parse_social_uid with uid without colon.""" + mock_social = Mock(uid="org-123") + + org_uid, account_email = UserDeduplicator.parse_social_uid(mock_social) + + self.assertEqual(org_uid, "org-123") + self.assertIsNone(account_email) + + def test_parse_social_uid_with_none(self): + """Test parse_social_uid with None.""" + org_uid, account_email = UserDeduplicator.parse_social_uid(None) + + self.assertIsNone(org_uid) + self.assertIsNone(account_email) + + def test_merge_playlists_dry_run(self): + """Test merge_playlists in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_access = PlaylistAccessFactory() + duplicate_access = PlaylistAccessFactory() + + deduplicator.transfer_playlists(original_access.user, duplicate_access.user) + + self.assertFalse( + duplicate_access.user.playlists.filter( + id=original_access.playlist.id + ).exists() + ) + + def test_merge_playlists_regular_run(self): + """Test merge_playlists in regular mode.""" + original_access = PlaylistAccessFactory() + duplicate_access = PlaylistAccessFactory() + + self.deduplicator.transfer_playlists( + original_access.user, duplicate_access.user + ) + + self.assertTrue( + original_access.user.playlists.filter( + id=original_access.playlist.id + ).exists() + ) + + def test_merge_playlists_no_duplicates(self): + """Test merge_playlists doesn't add duplicate playlists.""" + original_access = PlaylistAccessFactory() + duplicate_access = PlaylistAccessFactory() + duplicate_access.user.playlists.add(original_access.playlist) + initial_count = duplicate_access.user.playlists.count() + + self.deduplicator.transfer_playlists( + original_access.user, duplicate_access.user + ) + + self.assertEqual(original_access.user.playlists.count(), initial_count) + + def test_transfer_organizations_dry_run(self): + """Test transfer_organizations in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_org = OrganizationAccessFactory(role=ADMINISTRATOR) + duplicate_org = OrganizationAccessFactory(role=ADMINISTRATOR) + + deduplicator.transfer_organizations(original_org.user, duplicate_org.user) + + original_org.refresh_from_db() + self.assertEqual(original_org.user.id, original_org.user.id) + + def test_transfer_organizations_regular_run(self): + """Test transfer_organizations in regular mode.""" + original_org = OrganizationAccessFactory(role=ADMINISTRATOR) + duplicate_org = OrganizationAccessFactory(role=ADMINISTRATOR) + + self.deduplicator.transfer_organizations(original_org.user, duplicate_org.user) + + duplicate_org.refresh_from_db() + self.assertEqual(original_org.user.id, duplicate_org.user.id) + + def test_transfer_organizations_no_duplicate_orgs(self): + """Test transfer_organizations with no new organizations.""" + org_access = OrganizationAccessFactory(role=ADMINISTRATOR) + user = org_access.user + + self.deduplicator.transfer_organizations(user, user) + + self.assertEqual(self.deduplicator.tracker.transferred_organizations, {}) + + def test_transfer_organizations_excludes_duplicate_orgs(self): + """ + Test transfer_organizations doesn't transfer access + to organizations user already has. + """ + original_user = UserFactory() + duplicate_user = UserFactory() + # Create two different organizations + org_access1 = OrganizationAccessFactory(user=original_user, role=ADMINISTRATOR) + shared_org = org_access1.organization + # Create access to the same organization for duplicate user + org_access2 = OrganizationAccessFactory( + user=duplicate_user, organization=shared_org, role=ADMINISTRATOR + ) + # Create access to a different organization for duplicate user + org_access3 = OrganizationAccessFactory(user=duplicate_user, role=ADMINISTRATOR) + different_org = org_access3.organization + + # Verify setup: both users should have access to shared_org + self.assertIn(shared_org, original_user.organization_set.all()) + self.assertIn(shared_org, duplicate_user.organization_set.all()) + self.assertIn(different_org, duplicate_user.organization_set.all()) + + self.deduplicator.transfer_organizations(original_user, duplicate_user) + + # org_access2 should not be transferred (original user already has access to shared_org) + org_access2.refresh_from_db() + self.assertEqual(org_access2.user, duplicate_user) + # org_access3 should be transferred (original user didn't have access to different_org) + org_access3.refresh_from_db() + self.assertEqual(org_access3.user, original_user) + + def test_transfer_lti_associations_dry_run(self): + """Test transfer_lti_associations in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_association = LtiUserAssociationFactory() + duplicate_association = LtiUserAssociationFactory() + user = original_association.user + + deduplicator.transfer_lti_associations(user, duplicate_association.user) + + # In dry run mode, no transfer should be made + user.refresh_from_db() + self.assertFalse( + user.lti_user_associations.filter(id=duplicate_association.id).exists() + ) + # But tracking should happen + self.assertIn(user.email, deduplicator.tracker.transferred_lti_associations) + transferred = deduplicator.tracker.transferred_lti_associations[user.email] + self.assertEqual(len(transferred), 2) + self.assertIn(original_association.consumer_site.name, transferred) + self.assertIn(duplicate_association.consumer_site.name, transferred) + + def test_transfer_lti_associations_regular_run(self): + """Test transfer_lti_associations in regular mode.""" + deduplicator = UserDeduplicator(dry_run=False) + original_association = LtiUserAssociationFactory() + duplicate_association = LtiUserAssociationFactory() + user = original_association.user + + deduplicator.transfer_lti_associations(user, duplicate_association.user) + + # In regular mode, the transfer should be made + user.refresh_from_db() + self.assertTrue( + user.lti_user_associations.filter(id=duplicate_association.id).exists() + ) + # And tracking should happen + self.assertIn(user.email, deduplicator.tracker.transferred_lti_associations) + transferred = deduplicator.tracker.transferred_lti_associations[user.email] + self.assertEqual(len(transferred), 2) + self.assertIn(original_association.consumer_site.name, transferred) + self.assertIn(duplicate_association.consumer_site.name, transferred) + + def test_transfer_lti_associations_no_associations(self): + """Test transfer_lti_associations when duplicate user has no LTI associations.""" + original_access = PlaylistAccessFactory() + duplicate_access = PlaylistAccessFactory() + original_user = original_access.user + duplicate_user = duplicate_access.user + + # Create mock empty LTI associations for duplicate user + mock_lti = Mock() + mock_lti.count.return_value = 0 + duplicate_user.lti_associations = mock_lti + + self.deduplicator.transfer_lti_associations(original_user, duplicate_user) + + # Should not call update or track when count is 0 + mock_lti.update.assert_not_called() + self.assertNotIn( + original_user.email, self.deduplicator.tracker.transferred_lti_associations + ) + + def test_transfer_lti_passports_regular_run(self): + """Test transfer_lti_passports in regular mode.""" + original_user = UserFactory() + duplicate_user = UserFactory() + passport1 = ConsumerSiteLTIPassportFactory(created_by=duplicate_user) + passport2 = ConsumerSiteLTIPassportFactory(created_by=duplicate_user) + + self.deduplicator.transfer_lti_passports(original_user, duplicate_user) + + passport1.refresh_from_db() + passport2.refresh_from_db() + self.assertEqual(passport1.created_by, original_user) + self.assertEqual(passport2.created_by, original_user) + self.assertEqual( + self.deduplicator.tracker.transferred_lti_passports[original_user.email], 2 + ) + + def test_transfer_lti_passports_dry_run(self): + """Test transfer_lti_passports in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_user = UserFactory() + duplicate_user = UserFactory() + passport = ConsumerSiteLTIPassportFactory(created_by=duplicate_user) + + deduplicator.transfer_lti_passports(original_user, duplicate_user) + + passport.refresh_from_db() + self.assertEqual(passport.created_by, duplicate_user) + self.assertEqual( + deduplicator.tracker.transferred_lti_passports[original_user.email], 1 + ) + + def test_transfer_lti_passports_no_passports(self): + """Test transfer_lti_passports when duplicate user has no passports.""" + original_user = UserFactory() + duplicate_user = UserFactory() + + self.deduplicator.transfer_lti_passports(original_user, duplicate_user) + + self.assertNotIn( + original_user.email, self.deduplicator.tracker.transferred_lti_passports + ) + + def test_transfer_consumersite_accesses_regular_run(self): + """Test transfer_consumersite_accesses in regular mode.""" + original_user = UserFactory() + duplicate_user = UserFactory() + consumer_site1 = ConsumerSiteFactory() + consumer_site2 = ConsumerSiteFactory() + ConsumerSiteAccessFactory(user=original_user, consumer_site=consumer_site1) + access2 = ConsumerSiteAccessFactory( + user=duplicate_user, consumer_site=consumer_site2 + ) + + self.deduplicator.transfer_consumersite_accesses(original_user, duplicate_user) + + access2.refresh_from_db() + self.assertEqual(access2.user, original_user) + self.assertEqual( + self.deduplicator.tracker.transferred_consumersite_accesses[ + original_user.email + ], + 1, + ) + + def test_transfer_consumersite_accesses_with_duplicates(self): + """Test transfer_consumersite_accesses excludes duplicate accesses.""" + original_user = UserFactory() + duplicate_user = UserFactory() + consumer_site = ConsumerSiteFactory() + ConsumerSiteAccessFactory(user=original_user, consumer_site=consumer_site) + duplicate_access = ConsumerSiteAccessFactory( + user=duplicate_user, consumer_site=consumer_site + ) + + self.deduplicator.transfer_consumersite_accesses(original_user, duplicate_user) + + self.assertNotIn( + original_user.email, + self.deduplicator.tracker.transferred_consumersite_accesses, + ) + self.assertFalse( + ConsumerSiteAccess.objects.filter(id=duplicate_access.id).exists() + ) + + def test_transfer_consumersite_accesses_dry_run(self): + """Test transfer_consumersite_accesses in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_user = UserFactory() + duplicate_user = UserFactory() + access = ConsumerSiteAccessFactory(user=duplicate_user) + + deduplicator.transfer_consumersite_accesses(original_user, duplicate_user) + + access.refresh_from_db() + self.assertEqual(access.user, duplicate_user) + self.assertEqual( + deduplicator.tracker.transferred_consumersite_accesses[original_user.email], + 1, + ) + + def test_transfer_playlist_accesses_regular_run(self): + """Test transfer_playlist_accesses in regular mode.""" + original_user = UserFactory() + duplicate_user = UserFactory() + playlist1 = PlaylistFactory() + playlist2 = PlaylistFactory() + PlaylistAccessFactory(user=original_user, playlist=playlist1) + access2 = PlaylistAccessFactory(user=duplicate_user, playlist=playlist2) + + self.deduplicator.transfer_playlist_accesses(original_user, duplicate_user) + + access2.refresh_from_db() + self.assertEqual(access2.user, original_user) + self.assertEqual( + self.deduplicator.tracker.transferred_playlist_accesses[ + original_user.email + ], + 1, + ) + + def test_transfer_playlist_accesses_with_duplicates(self): + """Test transfer_playlist_accesses excludes duplicate accesses.""" + original_user = UserFactory() + duplicate_user = UserFactory() + playlist = PlaylistFactory() + PlaylistAccessFactory(user=original_user, playlist=playlist) + duplicate_access = PlaylistAccessFactory(user=duplicate_user, playlist=playlist) + + self.deduplicator.transfer_playlist_accesses(original_user, duplicate_user) + + self.assertNotIn( + original_user.email, + self.deduplicator.tracker.transferred_playlist_accesses, + ) + self.assertFalse(PlaylistAccess.objects.filter(id=duplicate_access.id).exists()) + + def test_transfer_playlist_accesses_dry_run(self): + """Test transfer_playlist_accesses in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_user = UserFactory() + duplicate_user = UserFactory() + access = PlaylistAccessFactory(user=duplicate_user) + + deduplicator.transfer_playlist_accesses(original_user, duplicate_user) + + access.refresh_from_db() + self.assertEqual(access.user, duplicate_user) + self.assertEqual( + deduplicator.tracker.transferred_playlist_accesses[original_user.email], 1 + ) + + def test_transfer_created_playlists_regular_run(self): + """Test transfer_created_playlists in regular mode.""" + original_user = UserFactory() + duplicate_user = UserFactory() + playlist1 = PlaylistFactory(created_by=duplicate_user) + playlist2 = PlaylistFactory(created_by=duplicate_user) + + self.deduplicator.transfer_created_playlists(original_user, duplicate_user) + + playlist1.refresh_from_db() + playlist2.refresh_from_db() + self.assertEqual(playlist1.created_by, original_user) + self.assertEqual(playlist2.created_by, original_user) + self.assertEqual( + self.deduplicator.tracker.transferred_created_playlists[ + original_user.email + ], + 2, + ) + + def test_transfer_created_playlists_dry_run(self): + """Test transfer_created_playlists in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_user = UserFactory() + duplicate_user = UserFactory() + playlist = PlaylistFactory(created_by=duplicate_user) + + deduplicator.transfer_created_playlists(original_user, duplicate_user) + + playlist.refresh_from_db() + self.assertEqual(playlist.created_by, duplicate_user) + self.assertEqual( + deduplicator.tracker.transferred_created_playlists[original_user.email], 1 + ) + + def test_transfer_created_videos_regular_run(self): + """Test transfer_created_videos in regular mode.""" + original_user = UserFactory() + duplicate_user = UserFactory() + video1 = VideoFactory(created_by=duplicate_user) + video2 = VideoFactory(created_by=duplicate_user) + + self.deduplicator.transfer_created_videos(original_user, duplicate_user) + + video1.refresh_from_db() + video2.refresh_from_db() + self.assertEqual(video1.created_by, original_user) + self.assertEqual(video2.created_by, original_user) + self.assertEqual( + self.deduplicator.tracker.transferred_created_videos[original_user.email], 2 + ) + + def test_transfer_created_videos_dry_run(self): + """Test transfer_created_videos in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_user = UserFactory() + duplicate_user = UserFactory() + video = VideoFactory(created_by=duplicate_user) + + deduplicator.transfer_created_videos(original_user, duplicate_user) + + video.refresh_from_db() + self.assertEqual(video.created_by, duplicate_user) + self.assertEqual( + deduplicator.tracker.transferred_created_videos[original_user.email], 1 + ) + + def test_transfer_created_documents_regular_run(self): + """Test transfer_created_documents in regular mode.""" + original_user = UserFactory() + duplicate_user = UserFactory() + doc1 = DocumentFactory(created_by=duplicate_user) + doc2 = DocumentFactory(created_by=duplicate_user) + + self.deduplicator.transfer_created_documents(original_user, duplicate_user) + + doc1.refresh_from_db() + doc2.refresh_from_db() + self.assertEqual(doc1.created_by, original_user) + self.assertEqual(doc2.created_by, original_user) + self.assertEqual( + self.deduplicator.tracker.transferred_created_documents[ + original_user.email + ], + 2, + ) + + def test_transfer_created_documents_dry_run(self): + """Test transfer_created_documents in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_user = UserFactory() + duplicate_user = UserFactory() + doc = DocumentFactory(created_by=duplicate_user) + + deduplicator.transfer_created_documents(original_user, duplicate_user) + + doc.refresh_from_db() + self.assertEqual(doc.created_by, duplicate_user) + self.assertEqual( + deduplicator.tracker.transferred_created_documents[original_user.email], 1 + ) + + def test_transfer_created_markdown_documents_regular_run(self): + """Test transfer_created_markdown_documents in regular mode.""" + original_user = UserFactory() + duplicate_user = UserFactory() + md1 = MarkdownDocumentFactory(created_by=duplicate_user) + md2 = MarkdownDocumentFactory(created_by=duplicate_user) + + self.deduplicator.transfer_created_markdown_documents( + original_user, duplicate_user + ) + + md1.refresh_from_db() + md2.refresh_from_db() + self.assertEqual(md1.created_by, original_user) + self.assertEqual(md2.created_by, original_user) + self.assertEqual( + self.deduplicator.tracker.transferred_created_markdown_documents[ + original_user.email + ], + 2, + ) + + def test_transfer_created_markdown_documents_dry_run(self): + """Test transfer_created_markdown_documents in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_user = UserFactory() + duplicate_user = UserFactory() + md = MarkdownDocumentFactory(created_by=duplicate_user) + + deduplicator.transfer_created_markdown_documents(original_user, duplicate_user) + + md.refresh_from_db() + self.assertEqual(md.created_by, duplicate_user) + self.assertEqual( + deduplicator.tracker.transferred_created_markdown_documents[ + original_user.email + ], + 1, + ) + + def test_transfer_portability_requests_regular_run(self): + """Test transfer_portability_requests in regular mode.""" + original_user = UserFactory() + duplicate_user = UserFactory() + req1 = PortabilityRequestFactory(from_user=duplicate_user) + req2 = PortabilityRequestFactory(updated_by_user=duplicate_user) + + self.deduplicator.transfer_portability_requests(original_user, duplicate_user) + + req1.refresh_from_db() + req2.refresh_from_db() + self.assertEqual(req1.from_user, original_user) + self.assertEqual(req2.updated_by_user, original_user) + self.assertEqual( + self.deduplicator.tracker.transferred_portability_requests[ + original_user.email + ], + 2, + ) + + def test_transfer_portability_requests_dry_run(self): + """Test transfer_portability_requests in dry run mode.""" + deduplicator = UserDeduplicator(dry_run=True) + original_user = UserFactory() + duplicate_user = UserFactory() + req = PortabilityRequestFactory(from_user=duplicate_user) + + deduplicator.transfer_portability_requests(original_user, duplicate_user) + + req.refresh_from_db() + self.assertEqual(req.from_user, duplicate_user) + self.assertEqual( + deduplicator.tracker.transferred_portability_requests[original_user.email], + 1, + ) + + def test_handle_different_accounts(self): + """Test handle_different_accounts merges and deletes correctly.""" + original_access = PlaylistAccessFactory() + duplicate_access = PlaylistAccessFactory() + original_social = UserSocialAuth.objects.create( + user=original_access.user, uid="org-1:original@test.com" + ) + duplicate_social = UserSocialAuth.objects.create( + user=duplicate_access.user, uid="org-1:duplicate@test.com" + ) + + self.deduplicator.handle_different_accounts( + original_access.user, + duplicate_access.user, + original_social, + duplicate_social, + ) + + self.assertIn( + original_access.user.email, + self.deduplicator.tracker.different_account_merges, + ) + self.assertIn( + duplicate_access.user.username, + self.deduplicator.tracker.deleted_user_accounts, + ) + self.assertFalse(User.objects.filter(id=duplicate_access.user.id).exists()) + self.assertTrue( + original_access.user.social_auth.filter(id=original_social.id).exists() + ) + self.assertTrue( + original_access.user.social_auth.filter(id=duplicate_social.id).exists() + ) + + def test_handle_same_account(self): + """Test handle_same_account updates social auth correctly.""" + original_access = PlaylistAccessFactory() + duplicate_access = PlaylistAccessFactory(user__email=original_access.user.email) + original_social = UserSocialAuth.objects.create( + user=original_access.user, uid=f"org-1:{original_access.user.email}" + ) + duplicate_social = UserSocialAuth.objects.create( + user=duplicate_access.user, uid=f"org-2:{original_access.user.email}" + ) + + self.deduplicator.handle_same_account( + original_access.user, + duplicate_access.user, + original_social, + duplicate_social, + "org-1", + "org-2", + ) + + original_access.user.refresh_from_db() + self.assertTrue( + original_access.user.social_auth.filter(id=duplicate_social.id).exists() + ) + self.assertFalse(UserSocialAuth.objects.filter(id=original_social.id).exists()) + self.assertFalse(User.objects.filter(id=duplicate_access.user.id).exists()) + + def test_process_duplicate_user_no_social_auth(self): + """Test process_duplicate_user with no social auth transfers relations and deletes user.""" + original_access = PlaylistAccessFactory() + duplicate_access = PlaylistAccessFactory() + duplicate_user_id = duplicate_access.user.id + duplicate_username = duplicate_access.user.username + + self.deduplicator.process_duplicate_user( + original_access.user, duplicate_access.user + ) + + # Duplicate user should be deleted + self.assertFalse(User.objects.filter(id=duplicate_user_id).exists()) + # Duplicate's playlist access should be transferred to original user + duplicate_access.refresh_from_db() + self.assertEqual(duplicate_access.user, original_access.user) + # User should be marked for deletion in tracker + self.assertIn( + duplicate_username, + self.deduplicator.tracker.deleted_user_accounts, + ) + + def test_process_duplicate_user_different_accounts(self): + """Test process_duplicate_user with different account emails.""" + original_access = PlaylistAccessFactory() + duplicate_access = PlaylistAccessFactory() + UserSocialAuth.objects.create( + user=original_access.user, uid="org-1:original@test.com" + ) + UserSocialAuth.objects.create( + user=duplicate_access.user, uid="org-1:duplicate@test.com" + ) + + self.deduplicator.process_duplicate_user( + original_access.user, duplicate_access.user + ) + + self.assertFalse(User.objects.filter(id=duplicate_access.user.id).exists()) + + def test_process_duplicate_user_same_account(self): + """Test process_duplicate_user with same account email.""" + email = "test@example.com" + original_access = PlaylistAccessFactory(user__email=email) + duplicate_access = PlaylistAccessFactory(user__email=email) + UserSocialAuth.objects.create(user=original_access.user, uid=f"org-1:{email}") + UserSocialAuth.objects.create(user=duplicate_access.user, uid=f"org-2:{email}") + + self.deduplicator.process_duplicate_user( + original_access.user, duplicate_access.user + ) + + self.assertFalse(User.objects.filter(id=duplicate_access.user.id).exists()) + + def test_deduplicate_skips_empty_emails(self): + """Test deduplicate skips users with empty emails.""" + PlaylistAccessFactory(user__email="") + PlaylistAccessFactory(user__email="") + + self.deduplicator.deduplicate(None) + + self.assertEqual(len(self.deduplicator.tracker.deleted_user_accounts), 0) + + def test_deduplicate_with_specific_email(self): + """Test deduplicate with specific email.""" + email = "specific@test.com" + access1 = PlaylistAccessFactory(user__email=email) + access2 = PlaylistAccessFactory(user__email=email) + UserSocialAuth.objects.create(user=access1.user, uid=f"org-1:{email}") + UserSocialAuth.objects.create(user=access2.user, uid=f"org-2:{email}") + + self.deduplicator.deduplicate(email) + + self.assertEqual(len(self.deduplicator.tracker.deleted_user_accounts), 1) diff --git a/src/backend/marsha/account/tests/management_commands/test_dedupe_accounts.py b/src/backend/marsha/account/tests/management_commands/test_dedupe_accounts.py deleted file mode 100644 index 051cb0ebd0..0000000000 --- a/src/backend/marsha/account/tests/management_commands/test_dedupe_accounts.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Tests for the `dedupe_accounts` management command.""" - -from io import StringIO - -from django.contrib.auth import get_user_model -from django.core.management import call_command -from django.test import TestCase - -from social_django.models import UserSocialAuth - -from marsha.core.factories import PlaylistAccessFactory - - -class DedupeAccountsCommandTest(TestCase): - """Test suite for dedupe_accounts command.""" - - maxDiff = None - - def setUp(self): - """ - Set up the test case. - """ - self.stdout = StringIO() - self.log_prefix = "INFO:marsha.account.management.commands.dedupe_accounts:" - - @staticmethod - def _add_social(user, uid="uid-xxx"): - return UserSocialAuth.objects.create(user=user, uid=uid) - - def test_dedupe_accounts_dry_run(self): - """ - With the --dry-run flag, no changes should be made. - """ - playlist_access = PlaylistAccessFactory() - user = playlist_access.user - old_social = self._add_social(user, uid=f"old-uid:{user.email}") - playlist_access_duplicate = PlaylistAccessFactory(user__email=user.email) - duplicate_user = playlist_access_duplicate.user - new_social = self._add_social(duplicate_user, uid=f"new-uid:{user.email}") - - with self.assertLogs("marsha.account", "INFO") as logs: - call_command("dedupe_accounts", "--dry-run", stdout=self.stdout) - - user.refresh_from_db() - self.assertEqual(user.social_auth.first(), old_social) - duplicate_user.refresh_from_db() - self.assertEqual(duplicate_user.social_auth.first(), new_social) - - self.assertListEqual( - [ - f"{self.log_prefix}[DRY-RUN] No changes will be made.", - f"{self.log_prefix}Deduping {user.email}", - f"{self.log_prefix}{"-" * 80}", - f"{self.log_prefix}Deduping complete. 1 SSO accounts deleted, 1 users deleted", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}0 accounts skipped:", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}1 organizations impacted:", - f"{self.log_prefix} - old-uid -> new-uid", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}1 users impacted:", - f"{self.log_prefix} - {user.email} -> {old_social.uid} -> {new_social.uid}", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}Summary:", - f"{self.log_prefix} 1 organizations impacted", - f"{self.log_prefix} 1 users processed", - f"{self.log_prefix} 1 users deleted", - f"{self.log_prefix} 1 SSO accounts deleted", - f"{self.log_prefix}[DRY-RUN] No changes made.", - ], - logs.output, - ) - - def test_dedupe_accounts_1_duplicate(self): - """ - Accounts with 2 duplicates should be merged correctly. - """ - playlist_access = PlaylistAccessFactory() - user = playlist_access.user - old_social = self._add_social(user, uid=f"old-uid:{user.email}") - playlist_access_duplicate = PlaylistAccessFactory(user__email=user.email) - duplicate_user = playlist_access_duplicate.user - new_social = self._add_social(duplicate_user, uid=f"new-uid:{user.email}") - - with self.assertLogs("marsha.account", "INFO") as logs: - call_command("dedupe_accounts", stdout=self.stdout) - - user.refresh_from_db() - self.assertTrue(user.social_auth.filter(id=new_social.id).exists()) - self.assertTrue( - user.playlists.filter( - id__in=duplicate_user.playlists.values_list("id", flat=True) - ).exists() - ) - self.assertFalse(UserSocialAuth.objects.filter(id=old_social.id).exists()) - self.assertFalse(get_user_model().objects.filter(id=duplicate_user.id).exists()) - - self.assertListEqual( - [ - f"{self.log_prefix}Deduping {user.email}", - f"{self.log_prefix}{"-" * 80}", - f"{self.log_prefix}Deduping complete. 1 SSO accounts deleted, 1 users deleted", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}0 accounts skipped:", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}1 organizations impacted:", - f"{self.log_prefix} - old-uid -> new-uid", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}1 users impacted:", - f"{self.log_prefix} - {user.email} -> {old_social.uid} -> {new_social.uid}", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}Summary:", - f"{self.log_prefix} 1 organizations impacted", - f"{self.log_prefix} 1 users processed", - f"{self.log_prefix} 1 users deleted", - f"{self.log_prefix} 1 SSO accounts deleted", - ], - logs.output, - ) - - def test_dedupe_accounts_2_duplicates(self): - """ - Accounts with 3 duplicates should be merged correctly. - """ - playlist_access = PlaylistAccessFactory() - user = playlist_access.user - old_social = self._add_social(user, uid=f"old-uid:{user.email}") - - playlist_access_duplicate = PlaylistAccessFactory(user__email=user.email) - duplicate_user_1 = playlist_access_duplicate.user - new_social_1 = self._add_social(duplicate_user_1, uid=f"new-uid:{user.email}") - - playlist_access_duplicate_2 = PlaylistAccessFactory(user__email=user.email) - duplicate_user_2 = playlist_access_duplicate_2.user - new_social_2 = self._add_social(duplicate_user_2, uid=f"new-uid-2:{user.email}") - - with self.assertLogs("marsha.account", "INFO") as logs: - call_command("dedupe_accounts", stdout=self.stdout) - - user.refresh_from_db() - self.assertFalse(UserSocialAuth.objects.filter(id=old_social.id).exists()) - self.assertFalse(UserSocialAuth.objects.filter(id=new_social_1.id).exists()) - self.assertTrue(user.social_auth.filter(id=new_social_2.id).exists()) - - self.assertTrue(user.playlists.filter(id=playlist_access.playlist.id).exists()) - self.assertTrue( - user.playlists.filter(id=playlist_access_duplicate.playlist.id).exists() - ) - self.assertTrue( - user.playlists.filter(id=playlist_access_duplicate_2.playlist.id).exists() - ) - - self.assertTrue(get_user_model().objects.filter(id=user.id).exists()) - self.assertFalse( - get_user_model().objects.filter(id=duplicate_user_1.id).exists() - ) - self.assertFalse( - get_user_model().objects.filter(id=duplicate_user_2.id).exists() - ) - - self.assertListEqual( - [ - f"{self.log_prefix}Deduping {user.email}", - f"{self.log_prefix}{"-" * 80}", - f"{self.log_prefix}Deduping complete. 2 SSO accounts deleted, 2 users deleted", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}0 accounts skipped:", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}1 organizations impacted:", - f"{self.log_prefix} - old-uid -> new-uid -> new-uid-2", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}1 users impacted:", - f"{self.log_prefix} - {user.email} -> {old_social.uid}" - f" -> {new_social_1.uid} -> {new_social_2.uid}", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}Summary:", - f"{self.log_prefix} 1 organizations impacted", - f"{self.log_prefix} 1 users processed", - f"{self.log_prefix} 2 users deleted", - f"{self.log_prefix} 2 SSO accounts deleted", - ], - logs.output, - ) - - def test_dedupe_accounts_2_duplicates_2_skips(self): - """ - Accounts with 3 duplicates should be merged correctly. - 2 duplicates should be skipped. - """ - playlist_access = PlaylistAccessFactory(user__email="user@example.com") - user = playlist_access.user - old_social = self._add_social(user, uid=f"old-uid:{user.email}") - - playlist_access_duplicate = PlaylistAccessFactory(user__email=user.email) - duplicate_user_1 = playlist_access_duplicate.user - new_social_1 = self._add_social(duplicate_user_1, uid=f"new-uid:{user.email}") - - playlist_access_duplicate_2 = PlaylistAccessFactory(user__email=user.email) - duplicate_user_2 = playlist_access_duplicate_2.user - new_social_2 = self._add_social(duplicate_user_2, uid=f"new-uid-2:{user.email}") - - playlist_access_skip = PlaylistAccessFactory(user__email=user.email) - skip_user_1 = playlist_access_skip.user - kept_social_1 = self._add_social(skip_user_1, uid=f"new-uid:sk_{user.email}") - - playlist_access_skip_2 = PlaylistAccessFactory(user__email=user.email) - skip_user_2 = playlist_access_skip_2.user - kept_social_2 = self._add_social(skip_user_2, uid=f"new-uid-2:sk_{user.email}") - - with self.assertLogs("marsha.account", "INFO") as logs: - call_command("dedupe_accounts", stdout=self.stdout) - - user.refresh_from_db() - self.assertFalse(UserSocialAuth.objects.filter(id=old_social.id).exists()) - self.assertFalse(UserSocialAuth.objects.filter(id=new_social_1.id).exists()) - self.assertTrue(user.social_auth.filter(id=new_social_2.id).exists()) - - self.assertTrue(user.playlists.filter(id=playlist_access.playlist.id).exists()) - self.assertTrue( - user.playlists.filter(id=playlist_access_duplicate.playlist.id).exists() - ) - self.assertTrue( - user.playlists.filter(id=playlist_access_duplicate_2.playlist.id).exists() - ) - self.assertFalse( - user.playlists.filter(id=playlist_access_skip.playlist.id).exists() - ) - self.assertFalse( - user.playlists.filter(id=playlist_access_skip_2.playlist.id).exists() - ) - - self.assertTrue(get_user_model().objects.filter(id=user.id).exists()) - self.assertFalse( - get_user_model().objects.filter(id=duplicate_user_1.id).exists() - ) - self.assertFalse( - get_user_model().objects.filter(id=duplicate_user_2.id).exists() - ) - self.assertTrue(get_user_model().objects.filter(id=skip_user_1.id).exists()) - self.assertTrue(get_user_model().objects.filter(id=skip_user_2.id).exists()) - - self.assertListEqual( - [ - f"{self.log_prefix}Deduping {user.email}", - f"{self.log_prefix}{"-" * 80}", - f"{self.log_prefix}Deduping complete. 2 SSO accounts deleted, 2 users deleted", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}2 accounts skipped:", - f"{self.log_prefix} - {user.email} | {old_social.uid} | {kept_social_1.uid}" - f" | 0.5714285714285714 | 0.9142857142857143", - f"{self.log_prefix} - {user.email} | {old_social.uid} | {kept_social_2.uid}" - f" | 0.5 | 0.9142857142857143", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}1 organizations impacted:", - f"{self.log_prefix} - old-uid -> new-uid -> new-uid-2", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}1 users impacted:", - f"{self.log_prefix} - {user.email} -> {old_social.uid}" - f" -> {new_social_1.uid} -> {new_social_2.uid}", - f"{self.log_prefix}{"- " * 40}", - f"{self.log_prefix}Summary:", - f"{self.log_prefix} 1 organizations impacted", - f"{self.log_prefix} 1 users processed", - f"{self.log_prefix} 2 users deleted", - f"{self.log_prefix} 2 SSO accounts deleted", - ], - logs.output, - ) diff --git a/src/backend/marsha/account/tests/test_social_pipeline_social_auth.py b/src/backend/marsha/account/tests/test_social_pipeline_social_auth.py index 478a55fd0c..3b00f16af2 100644 --- a/src/backend/marsha/account/tests/test_social_pipeline_social_auth.py +++ b/src/backend/marsha/account/tests/test_social_pipeline_social_auth.py @@ -5,6 +5,7 @@ from django.conf import settings from django.test import TestCase, override_settings +from social_django.models import UserSocialAuth from social_django.utils import load_backend, load_strategy from waffle.testutils import override_switch @@ -15,6 +16,7 @@ social_details, ) from marsha.core.defaults import RENATER_FER_SAML +from marsha.core.factories import UserFactory class AuthAllowedPipelineTestCase(TestCase): @@ -276,3 +278,22 @@ def test_associate_by_email_step_disabled(self): associate_by_email(backend, details, strategy, 42, some_kwargs=18) ) self.assertFalse(social_associate_by_email_mock.called) + + @override_switch(RENATER_FER_SAML, active=True) + def test_associate_by_email_associate_email(self): + """Asserts the email is used to associate the user when the waffle switch is enabled.""" + user = UserFactory(email="Baptiste.Doucey@univ-lemans.fr") + UserSocialAuth.objects.create( + user=user, uid="lmu-le-mans-universite:bdoucey@univ-lemans.fr" + ) + + strategy = load_strategy() + backend = load_backend(strategy, "saml_fer", None) + details = {"email": user.email} + + kwargs = { + "uid": "lmu-le-mans-universite-test:bdoucey@univ-lemans.fr", + } + + user_found = associate_by_email(backend, details, strategy, 42, **kwargs) + self.assertEqual(user_found.get("user").id, user.id) diff --git a/src/backend/marsha/account/utils/__init__.py b/src/backend/marsha/account/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/backend/marsha/account/utils/dedupe_accounts/__init__.py b/src/backend/marsha/account/utils/dedupe_accounts/__init__.py new file mode 100644 index 0000000000..af58b7f0b8 --- /dev/null +++ b/src/backend/marsha/account/utils/dedupe_accounts/__init__.py @@ -0,0 +1,9 @@ +"""Account deduplication utilities.""" + +from marsha.account.utils.dedupe_accounts.user_deduplicator import UserDeduplicator + + +def dedupe_accounts(duplicate_email=None, dry_run=False): + """Deduplicate accounts.""" + deduplicator = UserDeduplicator(dry_run=dry_run) + return deduplicator.deduplicate(duplicate_email) diff --git a/src/backend/marsha/account/utils/dedupe_accounts/dedupe_tracker.py b/src/backend/marsha/account/utils/dedupe_accounts/dedupe_tracker.py new file mode 100644 index 0000000000..65ff47d04c --- /dev/null +++ b/src/backend/marsha/account/utils/dedupe_accounts/dedupe_tracker.py @@ -0,0 +1,217 @@ +"""Tracks deduplication operations and their results.""" + +from dataclasses import dataclass, field +import logging +from typing import Dict, List + + +logger = logging.getLogger(__name__) + + +@dataclass +class DedupeTracker: # pylint: disable=too-many-instance-attributes + """Tracks deduplication operations and results.""" + + deleted_sso_accounts: List[str] = field(default_factory=list) + same_account_migrations: Dict[str, List[str]] = field(default_factory=dict) + organization_uid_migrations: Dict[str, List[str]] = field(default_factory=dict) + deleted_user_accounts: List[str] = field(default_factory=list) + different_account_merges: Dict[str, List[str]] = field(default_factory=dict) + transferred_organizations: Dict[str, List[str]] = field(default_factory=dict) + transferred_lti_associations: Dict[str, List[str]] = field(default_factory=dict) + transferred_lti_passports: Dict[str, int] = field(default_factory=dict) + transferred_consumersite_accesses: Dict[str, int] = field(default_factory=dict) + transferred_playlist_accesses: Dict[str, int] = field(default_factory=dict) + transferred_created_playlists: Dict[str, int] = field(default_factory=dict) + transferred_created_videos: Dict[str, int] = field(default_factory=dict) + transferred_created_documents: Dict[str, int] = field(default_factory=dict) + transferred_created_markdown_documents: Dict[str, int] = field(default_factory=dict) + transferred_portability_requests: Dict[str, int] = field(default_factory=dict) + + def track_different_account_merge(self, email, original_uid, new_uid): + """Track merge of accounts with different SSO account emails.""" + if email not in self.different_account_merges: + self.different_account_merges[email] = [original_uid, new_uid] + else: + self.different_account_merges[email].append(new_uid) + + def track_same_account_migration(self, email, original_uid, new_uid): + """Track migration of same SSO account email across organizations.""" + if email not in self.same_account_migrations: + self.same_account_migrations[email] = [original_uid, new_uid] + else: + self.same_account_migrations[email].append(new_uid) + + def track_organization_uid_migration(self, original_org_uid, new_org_uid): + """Track organization UID migrations.""" + for _, value in self.organization_uid_migrations.items(): + if original_org_uid in value: + value.append(new_org_uid) + return + if original_org_uid not in self.organization_uid_migrations: + self.organization_uid_migrations[original_org_uid] = [new_org_uid] + elif new_org_uid != self.organization_uid_migrations[original_org_uid][-1]: + self.organization_uid_migrations[original_org_uid].append(new_org_uid) + + def mark_user_for_deletion(self, username): + """Mark a user for deletion.""" + self.deleted_user_accounts.append(username) + + def mark_account_for_deletion(self, uid): + """Mark an account for deletion.""" + self.deleted_sso_accounts.append(uid) + + def get_or_init_transferred_orgs(self, user): + """Get or initialize a transferred organizations list for a user.""" + if user.email not in self.transferred_organizations: + self.transferred_organizations[user.email] = list( + user.organization_set.all().values_list("name", flat=True) + ) + return self.transferred_organizations[user.email] + + def mark_transferred_orgs(self, user, organizations): + """Mark organizations transferred to a user.""" + transferred_orgs = self.get_or_init_transferred_orgs(user) + for organization in organizations: + transferred_orgs.append(organization.name) + + def get_or_init_transferred_lti_associations(self, user): + """Get or initialize a transferred LTI associations list for a user.""" + if user.email not in self.transferred_lti_associations: + self.transferred_lti_associations[user.email] = list( + user.lti_user_associations.all().values_list( + "consumer_site__name", flat=True + ) + ) + return self.transferred_lti_associations[user.email] + + def mark_transferred_lti_association(self, user, lti_user_associations): + """Track LTI associations transferred from duplicate user to original user.""" + transferred_associations = self.get_or_init_transferred_lti_associations(user) + transferred_associations.extend( + lti_user_associations.values_list("consumer_site__name", flat=True) + ) + + def mark_transferred_lti_passports(self, user, count): + """Track LTI passports transferred to a user.""" + if user.email not in self.transferred_lti_passports: + self.transferred_lti_passports[user.email] = 0 + self.transferred_lti_passports[user.email] += count + + def mark_transferred_consumersite_accesses(self, user, count): + """Track consumer site accesses transferred to a user.""" + if user.email not in self.transferred_consumersite_accesses: + self.transferred_consumersite_accesses[user.email] = 0 + self.transferred_consumersite_accesses[user.email] += count + + def mark_transferred_playlist_accesses(self, user, count): + """Track playlist accesses transferred to a user.""" + if user.email not in self.transferred_playlist_accesses: + self.transferred_playlist_accesses[user.email] = 0 + self.transferred_playlist_accesses[user.email] += count + + def mark_transferred_created_playlists(self, user, count): + """Track created playlists transferred to a user.""" + if user.email not in self.transferred_created_playlists: + self.transferred_created_playlists[user.email] = 0 + self.transferred_created_playlists[user.email] += count + + def mark_transferred_created_videos(self, user, count): + """Track created videos transferred to a user.""" + if user.email not in self.transferred_created_videos: + self.transferred_created_videos[user.email] = 0 + self.transferred_created_videos[user.email] += count + + def mark_transferred_created_documents(self, user, count): + """Track created documents transferred to a user.""" + if user.email not in self.transferred_created_documents: + self.transferred_created_documents[user.email] = 0 + self.transferred_created_documents[user.email] += count + + def mark_transferred_created_markdown_documents(self, user, count): + """Track created markdown documents transferred to a user.""" + if user.email not in self.transferred_created_markdown_documents: + self.transferred_created_markdown_documents[user.email] = 0 + self.transferred_created_markdown_documents[user.email] += count + + def mark_transferred_portability_requests(self, user, count): + """Track portability requests transferred to a user.""" + if user.email not in self.transferred_portability_requests: + self.transferred_portability_requests[user.email] = 0 + self.transferred_portability_requests[user.email] += count + + def log_results(self, dry_run=False): + """Log deduplication results.""" + logger.info("-" * 80) + logger.info("Deduping complete") + logger.info("- " * 40) + + # Sections with dict[str, List[str]] - key_sep and value_sep + dict_list_sections = [ + ( + "organization UID migrations", + self.organization_uid_migrations, + "->", + "->", + ), + ("same account migrations", self.same_account_migrations, "->", "->"), + ("different account merges", self.different_account_merges, ":", "+"), + ("organizations transferred", self.transferred_organizations, ":", "+"), + ( + "LTI associations transferred", + self.transferred_lti_associations, + ":", + "+", + ), + ] + + for title, data, key_sep, val_sep in dict_list_sections: + logger.info("%d %s:", len(data), title) + for key, values in data.items(): + logger.info(" - %s %s %s", key, key_sep, f" {val_sep} ".join(values)) + logger.info("- " * 40) + + # Sections with dict[str, int] + dict_count_sections = [ + ("LTI passports transferred", self.transferred_lti_passports), + ( + "consumer site accesses transferred", + self.transferred_consumersite_accesses, + ), + ("playlist accesses transferred", self.transferred_playlist_accesses), + ("created playlists transferred", self.transferred_created_playlists), + ("created videos transferred", self.transferred_created_videos), + ("created documents transferred", self.transferred_created_documents), + ( + "created markdown documents transferred", + self.transferred_created_markdown_documents, + ), + ("portability requests transferred", self.transferred_portability_requests), + ] + + for title, data in dict_count_sections: + logger.info("%d %s:", len(data), title) + for key, count in data.items(): + logger.info(" - %s : %d", key, count) + logger.info("- " * 40) + + # Sections with List[str] + list_sections = [ + ("user accounts deleted", self.deleted_user_accounts), + ("SSO accounts deleted", self.deleted_sso_accounts), + ] + + for title, data in list_sections: + logger.info("%d %s:", len(data), title) + for item in data: + logger.info(" - %s", item) + logger.info("- " * 40) + + # Summary + logger.info("Summary:") + summary_sections = dict_list_sections + dict_count_sections + list_sections + for title, data, *_ in summary_sections: + logger.info(" %d %s", len(data), title) + + if dry_run: + logger.info("[DRY-RUN] No changes made.") diff --git a/src/backend/marsha/account/utils/dedupe_accounts/user_deduplicator.py b/src/backend/marsha/account/utils/dedupe_accounts/user_deduplicator.py new file mode 100644 index 0000000000..578cd403a5 --- /dev/null +++ b/src/backend/marsha/account/utils/dedupe_accounts/user_deduplicator.py @@ -0,0 +1,303 @@ +"""Handles user account deduplication logic.""" + +import logging + +from django.contrib.auth import get_user_model +from django.db.models import Count + +from marsha.account.utils.dedupe_accounts.dedupe_tracker import DedupeTracker + + +logger = logging.getLogger(__name__) + + +class UserDeduplicator: # pylint: disable=too-many-public-methods + """Handles user deduplication logic.""" + + def __init__(self, dry_run=False): + self.dry_run = dry_run + self.tracker = DedupeTracker() + + @staticmethod + def get_duplicate_emails(duplicate_email=None): + """Get list of duplicate email addresses.""" + if duplicate_email: + return [{"email": duplicate_email}] + + # pylint: disable=invalid-name + User = get_user_model() + + return ( + User.objects.values("email") + .annotate(count=Count("id")) + .filter(count__gt=1) + .order_by("email") + ) + + @staticmethod + def parse_social_uid(social_auth): + """Extract organization UID and account email from social auth UID.""" + if not social_auth: + return None, None + + parts = social_auth.uid.split(":") + organization_uid = parts[0] + account_email = parts[1] if len(parts) > 1 else None + return organization_uid, account_email + + def transfer_playlists(self, original_user, duplicate_user): + """Transfer playlists from duplicate user to original user.""" + if not self.dry_run: + for playlist_access in duplicate_user.playlist_accesses.all(): + playlist_access.playlist.created_by = original_user + playlist_access.playlist.save() + if playlist_access.playlist not in original_user.playlists.all(): + playlist_access.user = original_user + playlist_access.save() + else: + playlist_access.delete() + + def transfer_organizations(self, original_user, duplicate_user): + """Transfer organizations from duplicate user to original user.""" + new_organizations = duplicate_user.organization_set.exclude( + id__in=original_user.organization_set.values_list("id", flat=True) + ) + + if new_organizations: + self.tracker.mark_transferred_orgs(original_user, new_organizations) + + if not self.dry_run: + for organization_access in duplicate_user.organization_accesses.all(): + if ( + organization_access.organization + not in original_user.organization_set.all() + ): + organization_access.user = original_user + organization_access.save() + + def transfer_lti_associations(self, original_user, duplicate_user): + """Transfer LTI associations from duplicate user to original user.""" + lti_count = duplicate_user.lti_user_associations.count() + if lti_count > 0: + self.tracker.mark_transferred_lti_association( + original_user, duplicate_user.lti_user_associations + ) + if not self.dry_run: + duplicate_user.lti_user_associations.update(user=original_user) + + def transfer_lti_passports(self, original_user, duplicate_user): + """Transfer LTI passports from duplicate user to original user.""" + count = duplicate_user.lti_passports.count() + if count > 0: + self.tracker.mark_transferred_lti_passports(original_user, count) + if not self.dry_run: + duplicate_user.lti_passports.update(created_by=original_user) + + def transfer_consumersite_accesses(self, original_user, duplicate_user): + """Transfer consumer site accesses from duplicate user to original user.""" + new_accesses = duplicate_user.consumersite_accesses.exclude( + consumer_site_id__in=original_user.consumersite_accesses.values_list( + "consumer_site_id", flat=True + ) + ) + count = new_accesses.count() + if count > 0: + self.tracker.mark_transferred_consumersite_accesses(original_user, count) + if not self.dry_run: + new_accesses.update(user=original_user) + + # Delete duplicate accesses (only if not dry run) + if not self.dry_run: + duplicate_user.consumersite_accesses.exclude( + id__in=new_accesses.values_list("id", flat=True) + ).delete() + + def transfer_playlist_accesses(self, original_user, duplicate_user): + """Transfer playlist accesses from duplicate user to original user.""" + new_accesses = duplicate_user.playlist_accesses.exclude( + playlist_id__in=original_user.playlist_accesses.values_list( + "playlist_id", flat=True + ) + ) + count = new_accesses.count() + if count > 0: + self.tracker.mark_transferred_playlist_accesses(original_user, count) + if not self.dry_run: + new_accesses.update(user=original_user) + + # Delete duplicate accesses (only if not dry run) + if not self.dry_run: + duplicate_user.playlist_accesses.exclude( + id__in=new_accesses.values_list("id", flat=True) + ).delete() + + def transfer_created_playlists(self, original_user, duplicate_user): + """Transfer created playlists from duplicate user to original user.""" + count = duplicate_user.created_playlists.count() + if count > 0: + self.tracker.mark_transferred_created_playlists(original_user, count) + if not self.dry_run: + duplicate_user.created_playlists.update(created_by=original_user) + + def transfer_created_videos(self, original_user, duplicate_user): + """Transfer created videos from duplicate user to original user.""" + count = duplicate_user.created_video.count() + if count > 0: + self.tracker.mark_transferred_created_videos(original_user, count) + if not self.dry_run: + duplicate_user.created_video.update(created_by=original_user) + + def transfer_created_documents(self, original_user, duplicate_user): + """Transfer created documents from duplicate user to original user.""" + count = duplicate_user.created_document.count() + if count > 0: + self.tracker.mark_transferred_created_documents(original_user, count) + if not self.dry_run: + duplicate_user.created_document.update(created_by=original_user) + + def transfer_created_markdown_documents(self, original_user, duplicate_user): + """Transfer created markdown documents from duplicate user to original user.""" + count = duplicate_user.created_markdowndocument.count() + if count > 0: + self.tracker.mark_transferred_created_markdown_documents( + original_user, count + ) + if not self.dry_run: + duplicate_user.created_markdowndocument.update(created_by=original_user) + + def transfer_portability_requests(self, original_user, duplicate_user): + """Transfer portability requests from duplicate user to original user.""" + count = ( + duplicate_user.portability_requests.count() + + duplicate_user.actioned_portability_requests.count() + ) + if count > 0: + self.tracker.mark_transferred_portability_requests(original_user, count) + if not self.dry_run: + duplicate_user.portability_requests.update(from_user=original_user) + duplicate_user.actioned_portability_requests.update( + updated_by_user=original_user + ) + + def process_relations(self, original_user, duplicate_user): + """Process relations between users.""" + self.transfer_playlists(original_user, duplicate_user) + self.transfer_organizations(original_user, duplicate_user) + self.transfer_lti_associations(original_user, duplicate_user) + self.transfer_lti_passports(original_user, duplicate_user) + self.transfer_consumersite_accesses(original_user, duplicate_user) + self.transfer_playlist_accesses(original_user, duplicate_user) + self.transfer_created_playlists(original_user, duplicate_user) + self.transfer_created_videos(original_user, duplicate_user) + self.transfer_created_documents(original_user, duplicate_user) + self.transfer_created_markdown_documents(original_user, duplicate_user) + self.transfer_portability_requests(original_user, duplicate_user) + + def delete_user(self, user): + """Delete a user.""" + self.tracker.mark_user_for_deletion(user.username) + if not self.dry_run: + user.delete() + + def delete_account(self, account): + """Delete an account.""" + self.tracker.mark_account_for_deletion(account.uid) + if not self.dry_run: + account.delete() + + def add_social_auth(self, user, social_auth): + """Add social auth to a user.""" + if not self.dry_run: + user.social_auth.add(social_auth) + + def set_social_auth(self, user, social_auth): + """Set social auth for a user.""" + if not self.dry_run: + user.social_auth.set([social_auth]) + + def handle_different_accounts( + self, original_user, duplicate_user, original_social, new_social + ): + """Handle case where users have different account emails.""" + self.tracker.track_different_account_merge( + original_user.email, original_social.uid, new_social.uid + ) + + self.add_social_auth(original_user, new_social) + self.process_relations(original_user, duplicate_user) + self.delete_user(duplicate_user) + + # pylint: disable=too-many-arguments, too-many-positional-arguments + def handle_same_account( + self, + original_user, + duplicate_user, + original_social, + new_social, + original_org_uid, + new_org_uid, + ): + """Handle case where users have the same account email but different organizations.""" + if new_org_uid: + self.tracker.track_organization_uid_migration(original_org_uid, new_org_uid) + self.tracker.track_same_account_migration( + original_user.email, original_social.uid, new_social.uid + ) + + self.delete_account(original_social) + self.set_social_auth(original_user, new_social) + self.process_relations(original_user, duplicate_user) + self.delete_user(duplicate_user) + + def process_duplicate_user(self, original_user, duplicate_user): + """Process a single duplicate user against the original.""" + original_social = original_user.social_auth.first() + new_social = duplicate_user.social_auth.first() + + if not new_social: + self.process_relations(original_user, duplicate_user) + self.delete_user(duplicate_user) + return + + original_org_uid, original_account_email = self.parse_social_uid( + original_social + ) + new_org_uid, new_account_email = self.parse_social_uid(new_social) + + if original_account_email != new_account_email: + self.handle_different_accounts( + original_user, duplicate_user, original_social, new_social + ) + else: + self.handle_same_account( + original_user, + duplicate_user, + original_social, + new_social, + original_org_uid, + new_org_uid, + ) + + def deduplicate(self, duplicate_email=None): + """Execute the deduplication process.""" + User = get_user_model() # pylint: disable=invalid-name + duplicates = self.get_duplicate_emails(duplicate_email) + + processed = 0 + total = len(duplicates) + + for duplicate in duplicates: + email = duplicate["email"] + if not email: + continue + + processed += 1 + print(f"Deduping {processed}/{total}", end="\r", flush=True) + + users = list(User.objects.filter(email=email).order_by("date_joined")) + original_user, *duplicate_users = users + + for duplicate_user in duplicate_users: + self.process_duplicate_user(original_user, duplicate_user) + + self.tracker.log_results(self.dry_run)