diff --git a/users/adapters.py b/users/adapters.py index ce2daea980..8f0f5db86d 100644 --- a/users/adapters.py +++ b/users/adapters.py @@ -2,6 +2,7 @@ from mitol.scim.adapters import UserAdapter +from b2b.models import ContractPage from openedx.models import OpenEdxUser from users.models import LegalAddress, UserProfile @@ -25,6 +26,7 @@ class LearnUserAdapter(UserAdapter): user_profile: UserProfile legal_address: LegalAddress openedx_user: OpenEdxUser + b2b_contracts: ContractPage def __init__(self, obj, request=None): super().__init__(obj, request=request) @@ -33,9 +35,12 @@ def __init__(self, obj, request=None): self.obj, "user_profile", UserProfile() ) - self.legal_address = self.obj.legal_address = getattr( - self.obj, "legal_address", LegalAddress() - ) + try: + self.legal_address = self.obj.legal_address # triggers DB fetch if needed + except LegalAddress.DoesNotExist: + self.legal_address = LegalAddress() + + self.b2b_contracts = self.obj.b2b_contracts self.openedx_user = self.obj.openedx_user if self.openedx_user is None: @@ -54,24 +59,28 @@ def from_dict(self, d): Consume a ``dict`` conforming to the SCIM User Schema, updating the internal user object with data from the ``dict``. - Please note, the user object is not saved within this method. To - persist the changes made by this method, please call ``.save()`` on the - adapter. Eg:: - - scim_user.from_dict(d) - scim_user.save() + Note: This method does NOT save the user object. To persist changes, + call ``.save()`` on the adapter. """ super().from_dict(d) self.obj.name = d.get("fullName", "") - first_name = d.get("name", {}).get("given_name", "") - if first_name: - self.legal_address.first_name = first_name + name_data = d.get("name", {}) - last_name = d.get("name", {}).get("last_name", "") - if last_name: - self.legal_address.last_name = last_name + self.legal_address.first_name = ( + name_data.get("given_name") or self.legal_address.first_name + ) + self.legal_address.last_name = ( + name_data.get("last_name") or self.legal_address.last_name + ) + + organization_name = d.get("organization") + if organization_name: + contract_pages = ContractPage.objects.filter( + organization__name=organization_name + ) + self.b2b_contracts.add(*contract_pages) def _save_related(self): self.user_profile.user = self.obj @@ -82,3 +91,5 @@ def _save_related(self): self.openedx_user.user = self.obj self.openedx_user.save() + + self.obj.b2b_contracts.add(*self.b2b_contracts) diff --git a/users/adapters_test.py b/users/adapters_test.py new file mode 100644 index 0000000000..21be68eabb --- /dev/null +++ b/users/adapters_test.py @@ -0,0 +1,84 @@ +from unittest import mock + +import pytest + +from b2b.factories import ContractPageFactory +from openedx.models import OpenEdxUser +from users.adapters import LearnUserAdapter +from users.factories import UserFactory +from users.models import LegalAddress, UserProfile + + +@pytest.mark.django_db +def test_init_sets_related_objects(): + user = UserFactory() + adapter = LearnUserAdapter(user) + + assert isinstance(adapter.user_profile, UserProfile) + assert isinstance(adapter.legal_address, LegalAddress) + assert isinstance(adapter.openedx_user, OpenEdxUser) + + +@pytest.mark.django_db +def test_display_name_returns_name(): + user = UserFactory(name="John Doe") + adapter = LearnUserAdapter(user) + + assert adapter.display_name == "John Doe" + + +@pytest.mark.django_db +def test_from_dict_updates_user_and_related(): + user = UserFactory.create(name="Old Name") + user.legal_address.first_name = "OldFirst" + user.legal_address.last_name = "OldLast" + user.legal_address.save() + adapter = LearnUserAdapter(user) + + contract_page = ContractPageFactory.create(organization__name="Acme Corp") + data = { + "fullName": "New Name", + "name": {"given_name": "NewFirst", "last_name": "NewLast"}, + "organization": "Acme Corp", + } + + adapter.from_dict(data) + adapter._save_related() # noqa: SLF001 + adapter.legal_address.refresh_from_db() + assert adapter.obj.name == "New Name" + assert adapter.legal_address.first_name == "NewFirst" + assert adapter.legal_address.last_name == "NewLast" + assert user.b2b_contracts.filter(id=contract_page.id).exists() + + +@pytest.mark.django_db +def test_from_dict_keeps_existing_names_if_missing(): + user = UserFactory.create(name="Old Name") + user.legal_address.first_name = "OldFirst" + user.legal_address.last_name = "OldLast" + user.legal_address.save() + adapter = LearnUserAdapter(user) + + data = {"fullName": "Another Name", "name": {}} + adapter.from_dict(data) + + adapter.legal_address.refresh_from_db() + + assert adapter.legal_address.first_name == "OldFirst" + assert adapter.legal_address.last_name == "OldLast" + + +@pytest.mark.django_db +def test_save_related_saves_all(): + user = UserFactory() + adapter = LearnUserAdapter(user) + + adapter.user_profile = mock.MagicMock() + adapter.legal_address = mock.MagicMock() + adapter.openedx_user = mock.MagicMock() + + adapter._save_related() # noqa: SLF001 + + adapter.user_profile.save.assert_called_once() + adapter.legal_address.save.assert_called_once() + adapter.openedx_user.save.assert_called_once()