Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions ee/api/scim/test/test_users_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ee.api.scim.auth import generate_scim_token
from ee.api.test.base import APILicensedTest
from ee.models.rbac.role import RoleMembership
from ee.models.scim_provisioned_user import SCIMProvisionedUser


class TestSCIMUsersAPI(APILicensedTest):
Expand Down Expand Up @@ -129,9 +130,9 @@ def test_users_list_filter_unrecognized_returns_empty_list(self):
def test_create_user(self):
user_data = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"userName": "newuser@example.com",
"userName": "Newuser@example.com",
"name": {"givenName": "New", "familyName": "User"},
"emails": [{"value": "newuser@example.com", "primary": True}],
"emails": [{"value": "Newuser@example.com", "primary": True}],
"active": True,
}

Expand All @@ -141,11 +142,11 @@ def test_create_user(self):

assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["userName"] == "newuser@example.com"
assert data["userName"] == "Newuser@example.com"
assert data["name"]["givenName"] == "New"
assert data["name"]["familyName"] == "User"

# Verify user was created
# Verify user was created with lowercase email
user = User.objects.get(email="[email protected]")
assert user.first_name == "New"
assert user.last_name == "User"
Expand All @@ -155,6 +156,12 @@ def test_create_user(self):
membership = OrganizationMembership.objects.get(user=user, organization=self.organization)
assert membership.level == OrganizationMembership.Level.MEMBER

# Verify SCIM provisioned user record was created
scim_user = SCIMProvisionedUser.objects.get(user=user, organization_domain=self.domain)
assert scim_user.username == "[email protected]"
assert scim_user.active is True
assert scim_user.identity_provider == SCIMProvisionedUser.IdentityProvider.OTHER

def test_existing_user_is_added_to_org(self):
# Create user in different org
other_org = Organization.objects.create(name="Other Org")
Expand Down Expand Up @@ -184,9 +191,11 @@ def test_existing_user_is_added_to_org(self):
assert OrganizationMembership.objects.filter(user=existing_user, organization=self.organization).exists()
assert OrganizationMembership.objects.filter(user=existing_user, organization=other_org).exists()

def test_repeated_post_does_not_create_duplicate_user(self):
# In case the IdP failed to match user by id, it can send POST request to create a new user.
# The user should be merged with existing one by email, not create a duplicate.
# Verify SCIM provisioned user record was created for this domain
scim_user = SCIMProvisionedUser.objects.get(user=existing_user, organization_domain=self.domain)
assert scim_user.active is True

def test_repeated_post_returns_409_for_already_provisioned_user(self):
user_data_first = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"userName": "[email protected]",
Expand All @@ -202,7 +211,7 @@ def test_repeated_post_does_not_create_duplicate_user(self):
assert response.status_code == status.HTTP_201_CREATED
first_user = User.objects.get(email="[email protected]")

# IdP sends POST request again with same email
# IdP sends POST request again with same email - should fail with 409
user_data_second = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"userName": "[email protected]",
Expand All @@ -215,14 +224,14 @@ def test_repeated_post_does_not_create_duplicate_user(self):
f"/scim/v2/{self.domain.id}/Users", data=user_data_second, content_type="application/scim+json"
)

assert response.status_code == status.HTTP_201_CREATED
assert response.status_code == status.HTTP_409_CONFLICT

# Should NOT create duplicate user
assert User.objects.filter(email="[email protected]").count() == 1

# User should be updated with new data from second POST
# User should NOT be updated (still has first POST data)
first_user.refresh_from_db()
assert first_user.first_name == "Second"
assert first_user.first_name == "First"
assert first_user.last_name == "Time"

# User should have only one membership
Expand Down Expand Up @@ -261,6 +270,14 @@ def test_deactivate_user(self):
OrganizationMembership.objects.create(
user=user, organization=self.organization, level=OrganizationMembership.Level.MEMBER
)
# Create SCIM provisioned user record
SCIMProvisionedUser.objects.create(
user=user,
organization_domain=self.domain,
username="[email protected]",
identity_provider=SCIMProvisionedUser.IdentityProvider.OTHER,
active=True,
)

patch_data = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
Expand All @@ -280,13 +297,25 @@ def test_deactivate_user(self):
user.refresh_from_db()
assert user.is_active is True # User is still active globally

# Verify SCIM provisioned user record still exists but is marked inactive
scim_user = SCIMProvisionedUser.objects.get(user=user, organization_domain=self.domain)
assert scim_user.active is False

def test_delete_user(self):
user = User.objects.create_user(
email="[email protected]", password=None, first_name="Delete", is_email_verified=True
)
OrganizationMembership.objects.create(
user=user, organization=self.organization, level=OrganizationMembership.Level.MEMBER
)
# Create SCIM provisioned user record
SCIMProvisionedUser.objects.create(
user=user,
organization_domain=self.domain,
username="[email protected]",
identity_provider=SCIMProvisionedUser.IdentityProvider.OTHER,
active=True,
)

response = self.client.delete(f"/scim/v2/{self.domain.id}/Users/{user.id}")

Expand All @@ -295,13 +324,24 @@ def test_delete_user(self):
# Verify membership was removed
assert not OrganizationMembership.objects.filter(user=user, organization=self.organization).exists()

# Verify SCIM provisioned user record was deleted
assert not SCIMProvisionedUser.objects.filter(user=user, organization_domain=self.domain).exists()

def test_put_user(self):
user = User.objects.create_user(
email="[email protected]", password=None, first_name="Old", last_name="Name", is_email_verified=True
)
OrganizationMembership.objects.create(
user=user, organization=self.organization, level=OrganizationMembership.Level.MEMBER
)
# Create SCIM provisioned user record
SCIMProvisionedUser.objects.create(
user=user,
organization_domain=self.domain,
username="[email protected]",
identity_provider=SCIMProvisionedUser.IdentityProvider.OTHER,
active=True,
)

put_data = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
Expand All @@ -321,6 +361,11 @@ def test_put_user(self):
assert user.last_name == "User"
assert user.email == "[email protected]"

# Verify SCIM provisioned user was updated
scim_user = SCIMProvisionedUser.objects.get(user=user, organization_domain=self.domain)
assert scim_user.username == "[email protected]"
assert scim_user.active is True

def test_put_user_not_found(self):
put_data = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
Expand Down
94 changes: 84 additions & 10 deletions ee/api/scim/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from posthog.models.organization_domain import OrganizationDomain

from ee.models.rbac.role import RoleMembership
from ee.models.scim_provisioned_user import SCIMProvisionedUser


class PostHogSCIMUser(SCIMUser):
Expand Down Expand Up @@ -54,7 +55,10 @@ def name(self) -> dict:

@property
def user_name(self) -> str:
return self.obj.email
scim_user = SCIMProvisionedUser.objects.filter(
user=self.obj, organization_domain=self._organization_domain
).first()
return scim_user.username if scim_user else self.obj.email

@property
def active(self) -> bool:
Expand Down Expand Up @@ -124,7 +128,12 @@ def to_dict(self) -> dict:
return base_dict

@classmethod
def from_dict(cls, data: dict, organization_domain: OrganizationDomain) -> "PostHogSCIMUser":
def from_dict(
cls,
data: dict,
organization_domain: OrganizationDomain,
identity_provider: SCIMProvisionedUser.IdentityProvider = SCIMProvisionedUser.IdentityProvider.OTHER,
) -> "PostHogSCIMUser":
"""
Create or update a User from SCIM data.
"""
Expand All @@ -135,6 +144,8 @@ def from_dict(cls, data: dict, organization_domain: OrganizationDomain) -> "Post
name_data = data.get("name", {})
first_name = name_data.get("givenName", "")
last_name = name_data.get("familyName", "")
user_name = data.get("userName", email)
active = data.get("active", True)

with transaction.atomic():
user = User.objects.filter(email__iexact=email).first()
Expand Down Expand Up @@ -164,6 +175,16 @@ def from_dict(cls, data: dict, organization_domain: OrganizationDomain) -> "Post
user.current_team = organization_domain.organization.teams.first()
user.save()

SCIMProvisionedUser.objects.update_or_create(
user=user,
organization_domain=organization_domain,
defaults={
"identity_provider": identity_provider,
"username": user_name,
"active": active,
},
)

return cls(user, organization_domain)

def put(self, data: dict) -> None:
Expand All @@ -174,6 +195,8 @@ def put(self, data: dict) -> None:
"""
name_data = data.get("name", {})
email = self._extract_email_from_value(data.get("emails", []))
user_name = data.get("userName", email)
is_active = data.get("active", True)

if not email:
raise ValueError("Email is required")
Expand All @@ -189,18 +212,45 @@ def put(self, data: dict) -> None:
self.obj.email = email
self.obj.save()

SCIMProvisionedUser.objects.filter(
user=self.obj,
organization_domain=self._organization_domain,
).update(
username=user_name,
active=is_active,
)

# Deactivate user if active is false
is_active = data.get("active", True)
if not is_active:
self.delete()
self.deactivate()

def deactivate(self) -> None:
"""
Deactivate user by removing their membership and marking SCIM record as inactive.
"""
with transaction.atomic():
OrganizationMembership.objects.filter(
user=self.obj, organization=self._organization_domain.organization
).delete()

SCIMProvisionedUser.objects.filter(
user=self.obj,
organization_domain=self._organization_domain,
).update(active=False)

def delete(self) -> None:
"""
Deactivate user by removing their membership from this organization.
Delete user by removing their membership and SCIM provisioned user record.
"""
OrganizationMembership.objects.filter(
user=self.obj, organization=self._organization_domain.organization
).delete()
with transaction.atomic():
OrganizationMembership.objects.filter(
user=self.obj, organization=self._organization_domain.organization
).delete()

SCIMProvisionedUser.objects.filter(
user=self.obj,
organization_domain=self._organization_domain,
).delete()

def handle_replace(self, path: AttrPath, value: Union[str, list, dict], operation: dict) -> None:
"""
Expand All @@ -216,7 +266,13 @@ def handle_replace(self, path: AttrPath, value: Union[str, list, dict], operatio
with transaction.atomic():
if attr_name == "active":
if not value:
self.delete()
self.deactivate()
return
else:
SCIMProvisionedUser.objects.filter(
user=self.obj,
organization_domain=self._organization_domain,
).update(active=True)

elif attr_name == "name":
if sub_attr == "givenName" and isinstance(value, str):
Expand All @@ -239,6 +295,12 @@ def handle_replace(self, path: AttrPath, value: Union[str, list, dict], operatio
if email:
self.obj.email = email

elif attr_name == "userName" and isinstance(value, str):
SCIMProvisionedUser.objects.filter(
user=self.obj,
organization_domain=self._organization_domain,
).update(username=value)

self.obj.save()

def handle_add(self, path: AttrPath, value: Union[str, list, dict], operation: dict) -> None:
Expand All @@ -257,6 +319,11 @@ def handle_add(self, path: AttrPath, value: Union[str, list, dict], operation: d
defaults={"level": OrganizationMembership.Level.MEMBER},
)

SCIMProvisionedUser.objects.filter(
user=self.obj,
organization_domain=self._organization_domain,
).update(active=True)

elif attr_name == "name":
if sub_attr == "givenName" and isinstance(value, str):
self.obj.first_name = value
Expand All @@ -280,6 +347,12 @@ def handle_add(self, path: AttrPath, value: Union[str, list, dict], operation: d
self.obj.email = email
self.obj.save()

elif attr_name == "userName" and isinstance(value, str):
SCIMProvisionedUser.objects.filter(
user=self.obj,
organization_domain=self._organization_domain,
).update(username=value)

def handle_remove(self, path: AttrPath, value: Union[str, list, dict], operation: dict) -> None:
"""
Handle SCIM PATCH remove operations (called by django-scim2 handle_operations).
Expand All @@ -290,7 +363,8 @@ def handle_remove(self, path: AttrPath, value: Union[str, list, dict], operation

with transaction.atomic():
if attr_name == "active":
self.delete()
self.deactivate()
return

elif attr_name == "name":
if sub_attr == "givenName":
Expand Down
22 changes: 22 additions & 0 deletions ee/api/scim/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from rest_framework.request import Request

from posthog.models.organization_domain import OrganizationDomain

from ee.models.scim_provisioned_user import SCIMProvisionedUser

from .auth import generate_scim_token


Expand Down Expand Up @@ -47,3 +51,21 @@ def get_scim_base_url(domain: OrganizationDomain, request=None) -> str:

base_url = settings.SITE_URL
return f"{base_url}/scim/v2/{domain.id}"


def detect_identity_provider(request: Request) -> str:
"""
Detect identity provider from request User-Agent header.
"""
user_agent = request.META.get("HTTP_USER_AGENT", "").lower()

if "okta" in user_agent:
return SCIMProvisionedUser.IdentityProvider.OKTA
elif "entra" in user_agent or "microsoft" in user_agent:
return SCIMProvisionedUser.IdentityProvider.ENTRA_ID
elif "google" in user_agent:
return SCIMProvisionedUser.IdentityProvider.GOOGLE
elif "onelogin" in user_agent:
return SCIMProvisionedUser.IdentityProvider.ONELOGIN

return SCIMProvisionedUser.IdentityProvider.OTHER
Loading
Loading