Skip to content

Commit bfa0a6d

Browse files
committed
fixing flags (WIP)
1 parent 31f9e85 commit bfa0a6d

File tree

15 files changed

+343
-83
lines changed

15 files changed

+343
-83
lines changed

api/ee/src/apis/fastapi/organizations/router.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DomainVerificationService,
1717
SSOProviderService,
1818
)
19+
from ee.src.services import db_manager_ee
1920
from ee.src.utils.permissions import check_user_org_access
2021
from ee.src.services.selectors import get_user_org_and_workspace_id
2122

@@ -35,6 +36,36 @@ async def verify_user_org_access(user_id: str, organization_id: str) -> None:
3536
)
3637

3738

39+
async def require_email_or_social_or_root_enabled(organization_id: str) -> None:
40+
"""Block domain/provider changes when SSO is the only allowed method."""
41+
organization = await db_manager_ee.get_organization(organization_id)
42+
flags = organization.flags or {}
43+
allow_email = flags.get("allow_email", False)
44+
allow_social = flags.get("allow_social", False)
45+
allow_root = flags.get("allow_root", False)
46+
if not (allow_email or allow_social or allow_root):
47+
raise HTTPException(
48+
status_code=400,
49+
detail=(
50+
"To modify domains or SSO providers, enable email or social authentication "
51+
"for this organization, or enable root access for owners."
52+
),
53+
)
54+
55+
56+
async def require_domains_and_auto_join_disabled(organization_id: str) -> None:
57+
"""Block edits to verified domains when domains-only or auto-join is enabled."""
58+
organization = await db_manager_ee.get_organization(organization_id)
59+
flags = organization.flags or {}
60+
if flags.get("domains_only") or flags.get("auto_join"):
61+
raise HTTPException(
62+
status_code=400,
63+
detail=(
64+
"Disable domains-only and auto-join before modifying verified domains."
65+
),
66+
)
67+
68+
3869
# Domain Verification Endpoints
3970

4071

@@ -81,6 +112,7 @@ async def verify_domain(
81112
user_id = request.state.user_id
82113

83114
await verify_user_org_access(user_id, organization_id)
115+
await require_domains_and_auto_join_disabled(organization_id)
84116

85117
return await domain_service.verify_domain(
86118
organization_id, payload.domain_id, user_id
@@ -115,6 +147,7 @@ async def refresh_domain_token(
115147
user_id = request.state.user_id
116148

117149
await verify_user_org_access(user_id, organization_id)
150+
await require_domains_and_auto_join_disabled(organization_id)
118151

119152
return await domain_service.refresh_token(organization_id, domain_id, user_id)
120153

@@ -134,6 +167,7 @@ async def reset_domain(
134167
user_id = request.state.user_id
135168

136169
await verify_user_org_access(user_id, organization_id)
170+
await require_domains_and_auto_join_disabled(organization_id)
137171

138172
return await domain_service.reset_domain(organization_id, domain_id, user_id)
139173

@@ -148,6 +182,7 @@ async def delete_domain(
148182
user_id = request.state.user_id
149183

150184
await verify_user_org_access(user_id, organization_id)
185+
await require_domains_and_auto_join_disabled(organization_id)
151186

152187
await domain_service.delete_domain(organization_id, domain_id, user_id)
153188
return Response(status_code=204)
@@ -172,6 +207,7 @@ async def create_provider(
172207
user_id = request.state.user_id
173208

174209
await verify_user_org_access(user_id, organization_id)
210+
await require_email_or_social_or_root_enabled(organization_id)
175211

176212
return await provider_service.create_provider(organization_id, payload, user_id)
177213

@@ -187,6 +223,7 @@ async def update_provider(
187223
user_id = request.state.user_id
188224

189225
await verify_user_org_access(user_id, organization_id)
226+
await require_email_or_social_or_root_enabled(organization_id)
190227

191228
return await provider_service.update_provider(
192229
organization_id, provider_id, payload, user_id
@@ -225,6 +262,7 @@ async def test_provider(
225262
user_id = request.state.user_id
226263

227264
await verify_user_org_access(user_id, organization_id)
265+
await require_email_or_social_or_root_enabled(organization_id)
228266

229267
return await provider_service.test_provider(organization_id, provider_id, user_id)
230268

@@ -239,6 +277,7 @@ async def delete_provider(
239277
user_id = request.state.user_id
240278

241279
await verify_user_org_access(user_id, organization_id)
280+
await require_email_or_social_or_root_enabled(organization_id)
242281

243282
await provider_service.delete_provider(organization_id, provider_id, user_id)
244283
return Response(status_code=204)

api/ee/src/services/db_manager_ee.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
UserDB,
4343
InvitationDB,
4444
)
45+
from ee.src.dbs.postgres.organizations.dao import (
46+
OrganizationProvidersDAO,
47+
OrganizationDomainsDAO,
48+
)
4549
from ee.src.services.converters import get_workspace_in_format
4650
from ee.src.services.selectors import get_org_default_workspace
4751

@@ -1133,6 +1137,55 @@ async def update_organization(
11331137
allow_sso = merged_flags.get("allow_sso", False)
11341138
allow_root = merged_flags.get("allow_root", False)
11351139

1140+
changing_auth_flags = any(
1141+
key in new_flags for key in ("allow_email", "allow_social", "allow_sso")
1142+
)
1143+
changing_auto_join = "auto_join" in new_flags
1144+
changing_domains_only = "domains_only" in new_flags
1145+
1146+
if changing_auth_flags and allow_sso:
1147+
providers_dao = OrganizationProvidersDAO(session)
1148+
providers = await providers_dao.list_by_organization(organization_id)
1149+
active_valid = [
1150+
provider
1151+
for provider in providers
1152+
if (provider.flags or {}).get("is_active")
1153+
and (provider.flags or {}).get("is_valid")
1154+
]
1155+
if not active_valid:
1156+
raise ValueError(
1157+
"SSO cannot be enabled until at least one SSO provider is "
1158+
"active and verified."
1159+
)
1160+
if not allow_email and not allow_social:
1161+
if not active_valid:
1162+
raise ValueError(
1163+
"SSO-only authentication requires at least one SSO provider to "
1164+
"be active and verified."
1165+
)
1166+
1167+
if changing_auto_join and merged_flags.get("auto_join", False):
1168+
domains_dao = OrganizationDomainsDAO(session)
1169+
domains = await domains_dao.list_by_organization(organization_id)
1170+
has_verified_domain = any(
1171+
(domain.flags or {}).get("is_verified") for domain in domains
1172+
)
1173+
if not has_verified_domain:
1174+
raise ValueError(
1175+
"Auto-join requires at least one verified domain."
1176+
)
1177+
1178+
if changing_domains_only and merged_flags.get("domains_only", False):
1179+
domains_dao = OrganizationDomainsDAO(session)
1180+
domains = await domains_dao.list_by_organization(organization_id)
1181+
has_verified_domain = any(
1182+
(domain.flags or {}).get("is_verified") for domain in domains
1183+
)
1184+
if not has_verified_domain:
1185+
raise ValueError(
1186+
"Domains-only requires at least one verified domain."
1187+
)
1188+
11361189
# Check if all auth methods are disabled
11371190
all_auth_disabled = not (allow_email or allow_social or allow_sso)
11381191

api/ee/src/services/organization_security_service.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
OrganizationProviderUpdate,
3131
OrganizationProviderResponse,
3232
)
33+
from ee.src.services import db_manager_ee
3334

3435
logger = logging.getLogger(__name__)
3536

@@ -660,6 +661,26 @@ async def delete_provider(
660661
if not provider:
661662
raise HTTPException(status_code=404, detail="Provider not found")
662663

664+
organization = await db_manager_ee.get_organization(organization_id)
665+
flags = organization.flags or {}
666+
if flags.get("allow_sso"):
667+
providers = await dao.list_by_organization(organization_id)
668+
remaining = [
669+
p
670+
for p in providers
671+
if str(p.id) != str(provider_id)
672+
and (p.flags or {}).get("is_active")
673+
and (p.flags or {}).get("is_valid")
674+
]
675+
if not remaining:
676+
raise HTTPException(
677+
status_code=400,
678+
detail=(
679+
"Cannot delete the last active and verified SSO provider while "
680+
"SSO is enabled."
681+
),
682+
)
683+
663684
await self._vault_service().delete_secret(
664685
secret_id=provider.secret_id,
665686
organization_id=organization_id,

api/ee/src/services/workspace_manager.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
check_valid_invitation,
3131
)
3232
from ee.src.services.organization_service import send_invitation_email
33+
from ee.src.dbs.postgres.organizations.dao import OrganizationDomainsDAO
3334

3435
log = get_module_logger(__name__)
3536

@@ -155,6 +156,30 @@ async def invite_user_to_workspace(
155156
organization = await db_manager_ee.get_organization(organization_id)
156157
user_performing_action = await db_manager.get_user(user_uid)
157158

159+
# Check if domains_only is enabled for this organization
160+
org_flags = organization.flags or {}
161+
domains_only = org_flags.get("domains_only", False)
162+
163+
# If domains_only is enabled, get the list of verified domains
164+
verified_domain_slugs = set()
165+
if domains_only:
166+
domains_dao = OrganizationDomainsDAO()
167+
org_domains = await domains_dao.list_by_organization(organization_id)
168+
verified_domain_slugs = {
169+
d.slug.lower()
170+
for d in org_domains
171+
if d.flags and d.flags.get("is_verified", False)
172+
}
173+
174+
# If domains_only is enabled but no verified domains exist, block all invitations
175+
if not verified_domain_slugs:
176+
return JSONResponse(
177+
status_code=400,
178+
content={
179+
"error": "Cannot send invitations: domains_only is enabled but no verified domains exist"
180+
},
181+
)
182+
158183
for payload_invite in payload:
159184
# Check that the user is not inviting themselves
160185
if payload_invite.email == user_performing_action.email:
@@ -163,6 +188,17 @@ async def invite_user_to_workspace(
163188
content={"error": "You cannot invite yourself to a workspace"},
164189
)
165190

191+
# Check if domains_only is enabled and validate the email domain
192+
if domains_only:
193+
email_domain = payload_invite.email.split("@")[-1].lower()
194+
if email_domain not in verified_domain_slugs:
195+
return JSONResponse(
196+
status_code=400,
197+
content={
198+
"error": f"Cannot invite {payload_invite.email}: domain '{email_domain}' is not a verified domain for this organization"
199+
},
200+
)
201+
166202
# Check if the user is already a member of the workspace
167203
if await db_manager_ee.check_user_in_workspace_with_email(
168204
payload_invite.email, str(workspace.id)

api/oss/src/apis/fastapi/auth/router.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def check_organization_access(request: Request, organization_id: str):
7474

7575
if policy_error and policy_error.get("error") in {
7676
"AUTH_UPGRADE_REQUIRED",
77-
"AUTH_SSO_DISABLED",
77+
"AUTH_SSO_DENIED",
7878
}:
7979
detail = {
8080
"error": policy_error.get("error"),
@@ -90,9 +90,7 @@ async def check_organization_access(request: Request, organization_id: str):
9090

9191

9292
@auth_router.post("/session/identities")
93-
async def update_session_identities(
94-
request: Request, payload: SessionIdentitiesUpdate
95-
):
93+
async def update_session_identities(request: Request, payload: SessionIdentitiesUpdate):
9694
try:
9795
session = await get_session(request) # type: ignore
9896
except Exception:
@@ -117,7 +115,9 @@ async def update_session_identities(
117115
elif hasattr(session, "merge_into_access_token_payload"):
118116
await session.merge_into_access_token_payload({"session_identities": merged})
119117
else:
120-
raise HTTPException(status_code=500, detail="Session payload update not supported")
118+
raise HTTPException(
119+
status_code=500, detail="Session payload update not supported"
120+
)
121121
return {"session_identities": merged, "previous": current}
122122

123123

api/oss/src/core/auth/service.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -532,15 +532,9 @@ async def enforce_domain_policies(self, email: str, user_id: UUID) -> None:
532532
print(f"Error during auto-join: {e}")
533533

534534
# 2. Domains-only enforcement: Check if user has access
535-
# This is enforced at the organization level, not during login
536-
# It's checked when user tries to access organization resources
537-
# For now, we just validate that the domain matches
538-
domains_only = org_flags.get("domains_only", False)
539-
if domains_only:
540-
# If domains_only is enabled, user MUST have matching domain
541-
# Since we already verified the domain matches (domain_dto exists),
542-
# the user is allowed. If domain didn't match, domain_dto would be None.
543-
pass
535+
# This is enforced at the organization level via check_organization_access()
536+
# when the user tries to access organization resources through the middleware.
537+
# No action needed here during login - enforcement happens at access time.
544538

545539
# ============================================================================
546540
# AUTHORIZATION: Validate access based on policies
@@ -619,19 +613,27 @@ async def check_organization_access(
619613
# If the session used SSO but the org doesn't allow it (or provider inactive),
620614
# block and instruct user to re-auth with allowed methods.
621615
sso_identity = next(
622-
(identity for identity in session_identities if identity.startswith("sso:")),
616+
(
617+
identity
618+
for identity in session_identities
619+
if identity.startswith("sso:")
620+
),
623621
None,
624622
)
625623
if sso_identity and self.providers_dao:
626624
org_slug = await self._get_organization_slug(organization_id)
627625
provider_slug = (
628-
sso_identity.split(":")[2] if len(sso_identity.split(":")) > 2 else None
626+
sso_identity.split(":")[2]
627+
if len(sso_identity.split(":")) > 2
628+
else None
629629
)
630630
providers = await self.providers_dao.list_by_organization(
631631
str(organization_id)
632632
)
633633
active_provider_slugs = {
634-
p.slug for p in providers if p.flags and p.flags.get("is_active", False)
634+
p.slug
635+
for p in providers
636+
if p.flags and p.flags.get("is_active", False)
635637
}
636638
sso_matches_org = bool(
637639
org_slug and sso_identity.startswith(f"sso:{org_slug}:")
@@ -647,8 +649,8 @@ async def check_organization_access(
647649
if allow_social:
648650
required_methods.append("social:*")
649651
return {
650-
"error": "AUTH_SSO_DISABLED",
651-
"message": "SSO is not enabled for this organization",
652+
"error": "AUTH_SSO_DENIED",
653+
"message": "SSO is denied for this organization",
652654
"required_methods": required_methods,
653655
"current_identities": session_identities,
654656
}
@@ -663,6 +665,33 @@ async def check_organization_access(
663665
"sso_providers": sso_providers,
664666
}
665667

668+
# Check domains_only enforcement
669+
domains_only = org_flags.get("domains_only", False)
670+
if domains_only and self.domains_dao:
671+
# Get user's email to check domain
672+
user = await db_manager.get_user(str(user_id))
673+
if user and user.email:
674+
email_domain = user.email.split("@")[-1].lower()
675+
676+
# Get verified domains for this organization
677+
org_domains = await self.domains_dao.list_by_organization(
678+
str(organization_id)
679+
)
680+
verified_domain_slugs = {
681+
d.slug.lower()
682+
for d in org_domains
683+
if d.flags and d.flags.get("is_verified", False)
684+
}
685+
686+
# If user's domain is not in the verified domains, deny access
687+
if email_domain not in verified_domain_slugs:
688+
return {
689+
"error": "AUTH_DOMAIN_DENIED",
690+
"message": f"Your email domain '{email_domain}' is not allowed for this organization",
691+
"current_domain": email_domain,
692+
"allowed_domains": list(verified_domain_slugs),
693+
}
694+
666695
return None
667696

668697
def _matches_policy(

0 commit comments

Comments
 (0)