Skip to content

Commit 5cc2b4e

Browse files
committed
fix(auth): Preserve redirect state across SSO failures and detect provider mismatch
Two issues in the authentication flow are addressed: 1. Redirect state loss: The original `next` URL was cleared by `initiate_login()` when SSO failed and users retried. Added a persistent `_original_next` session key that survives multiple login attempts, with a 1-hour TTL for security. 2. Provider mismatch: When users authenticated with the wrong SSO provider (e.g., Google when Okta required), `build_identity()` would fail silently. Now we detect the mismatch before calling `build_identity()` and redirect users to the correct SSO flow with a clear warning message. Also fixes the 2FA flow setting `after_2fa` to the SSO callback URL instead of the user's original destination. URLs are validated with `is_valid_redirect()` to prevent open redirect attacks.
1 parent 7e3a022 commit 5cc2b4e

File tree

6 files changed

+412
-3
lines changed

6 files changed

+412
-3
lines changed

src/sentry/auth/helper.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,25 @@ def _login(self, user: Any) -> None:
133133
sample_rate=1.0,
134134
skip_internal=False,
135135
)
136+
137+
# Get the original intended destination for after_2fa, not the current SSO URL
138+
after_2fa_url = self.request.session.get("_next")
139+
if not after_2fa_url:
140+
after_2fa_url = auth.get_persistent_next(self.request)
141+
142+
# Validate the URL to prevent open redirect attacks
143+
if after_2fa_url and not auth.is_valid_redirect(
144+
after_2fa_url, allowed_hosts=(self.request.get_host(),)
145+
):
146+
after_2fa_url = None
147+
148+
if not after_2fa_url:
149+
after_2fa_url = self.request.build_absolute_uri()
150+
136151
user_was_logged_in = auth.login(
137152
self.request,
138153
user,
139-
after_2fa=self.request.build_absolute_uri(),
154+
after_2fa=after_2fa_url,
140155
organization_id=self.organization.id,
141156
)
142157
if not user_was_logged_in:
@@ -780,6 +795,16 @@ def finish_pipeline(self) -> HttpResponseBase:
780795
if not data:
781796
return self.error(ERR_INVALID_IDENTITY)
782797

798+
# Check for provider mismatch - user authenticated with a different provider
799+
# than what the organization requires
800+
actual_provider_key = data.get("actual_provider_key")
801+
expected_provider_key = self.provider.key
802+
if actual_provider_key and actual_provider_key != expected_provider_key:
803+
return self._handle_provider_mismatch(
804+
actual_provider_key=actual_provider_key,
805+
expected_provider_key=expected_provider_key,
806+
)
807+
783808
try:
784809
identity = self.provider.build_identity(data)
785810
except IdentityNotValid as error:
@@ -796,6 +821,49 @@ def finish_pipeline(self) -> HttpResponseBase:
796821

797822
return response
798823

824+
def _handle_provider_mismatch(
825+
self,
826+
actual_provider_key: str,
827+
expected_provider_key: str,
828+
) -> HttpResponseRedirect:
829+
"""
830+
Handle when user authenticated with a different provider than required.
831+
832+
This happens when a user starts SSO for an org (e.g., "sentry" which requires Okta)
833+
but authenticates with a different provider (e.g., Google). We redirect them back
834+
to the org's login page to use the correct SSO provider.
835+
"""
836+
logger.info(
837+
"sso.provider-mismatch",
838+
extra={
839+
"organization_id": self.organization.id,
840+
"expected_provider": expected_provider_key,
841+
"actual_provider": actual_provider_key,
842+
},
843+
)
844+
845+
metrics.incr(
846+
"sso.provider_mismatch",
847+
tags={
848+
"expected_provider": expected_provider_key,
849+
"actual_provider": actual_provider_key,
850+
},
851+
)
852+
853+
# Clear the invalid pipeline state
854+
self.clear_session()
855+
856+
expected_name = getattr(self.provider, "name", expected_provider_key)
857+
messages.add_message(
858+
self.request,
859+
messages.WARNING,
860+
f"This organization requires {expected_name} SSO. Please sign in with your organization's SSO provider.",
861+
)
862+
863+
# Redirect to org-specific login to initiate correct SSO
864+
redirect_uri = reverse("sentry-auth-organization", args=[self.organization.slug])
865+
return HttpResponseRedirect(redirect_uri)
866+
799867
def auth_handler(self, identity: Mapping[str, Any]) -> AuthIdentityHandler:
800868
assert self.provider_model is not None
801869
return AuthIdentityHandler(

src/sentry/auth/providers/oauth2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
151151
# hook here
152152
pipeline.bind_state("data", data)
153153

154+
# Store the actual provider key for mismatch detection
155+
pipeline.bind_state("actual_provider_key", pipeline.provider.key)
156+
154157
return pipeline.next_step()
155158

156159

src/sentry/auth/providers/saml2/provider.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
150150

151151
pipeline.bind_state("auth_attributes", auth.get_attributes())
152152

153+
# Store the actual provider key for mismatch detection
154+
pipeline.bind_state("actual_provider_key", pipeline.provider.key)
155+
153156
return pipeline.next_step()
154157

155158

src/sentry/utils/auth.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,57 @@ def get_login_url(reset: bool = False) -> str:
141141
return _LOGIN_URL
142142

143143

144+
# Persistent redirect URL storage - survives across multiple login attempts
145+
PERSISTENT_NEXT_KEY = "_original_next"
146+
PERSISTENT_NEXT_MAX_AGE = 60 * 60 # 1 hour
147+
148+
149+
def set_persistent_next(request: HttpRequest, next_url: str) -> None:
150+
"""
151+
Store the original next URL persistently in the session.
152+
Only sets if not already set, preserving the original destination
153+
across multiple login attempts and SSO failures.
154+
"""
155+
if not next_url:
156+
return
157+
if PERSISTENT_NEXT_KEY not in request.session:
158+
request.session[PERSISTENT_NEXT_KEY] = {
159+
"url": next_url,
160+
"timestamp": time(),
161+
}
162+
163+
164+
def get_persistent_next(request: HttpRequest, max_age: int | None = None) -> str | None:
165+
"""
166+
Retrieve the persistent next URL if it exists and is not stale.
167+
Returns None if the URL has expired or doesn't exist.
168+
"""
169+
data = request.session.get(PERSISTENT_NEXT_KEY)
170+
if not data:
171+
return None
172+
173+
if max_age is None:
174+
max_age = PERSISTENT_NEXT_MAX_AGE
175+
176+
if time() - data.get("timestamp", 0) > max_age:
177+
clear_persistent_next(request)
178+
return None
179+
180+
return data.get("url")
181+
182+
183+
def clear_persistent_next(request: HttpRequest) -> None:
184+
"""Clear the persistent next URL from session."""
185+
request.session.pop(PERSISTENT_NEXT_KEY, None)
186+
187+
144188
def initiate_login(
145189
request: HttpRequest, next_url: str | None = None, referrer: str | None = None
146190
) -> None:
147191
"""
148-
initiate_login simply clears session cache
149-
if provided a `next_url` will append to the session after clearing previous keys
192+
initiate_login clears transient session state for authentication.
193+
If provided a `next_url`, stores it both as the immediate redirect target
194+
and as a persistent fallback that survives across multiple login attempts.
150195
"""
151196
for key in ("_next", "_after_2fa", "_pending_2fa", "_referrer"):
152197
try:
@@ -156,6 +201,8 @@ def initiate_login(
156201

157202
if next_url:
158203
request.session["_next"] = next_url
204+
# Also store as persistent fallback (only sets if not already present)
205+
set_persistent_next(request, next_url)
159206
if referrer:
160207
request.session["_referrer"] = referrer
161208

@@ -186,7 +233,13 @@ def _get_login_redirect(request: HttpRequest, default: str | None = None) -> str
186233
if after_2fa is not None:
187234
return after_2fa
188235

236+
# Try the transient _next first
189237
login_url = request.session.pop("_next", None)
238+
239+
# Fall back to persistent next URL if transient is not set
240+
if not login_url:
241+
login_url = get_persistent_next(request)
242+
190243
if not login_url:
191244
return default
192245

@@ -361,6 +414,9 @@ def login(
361414
with outbox_context(flush=False):
362415
_login(request, user)
363416

417+
# Clear the persistent next URL now that login is complete
418+
clear_persistent_next(request)
419+
364420
log_auth_success(request, user.username, organization_id, source)
365421
return True
366422

tests/sentry/auth/test_helper.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,106 @@ def test_has_verified_account_fail_user_id(self) -> None:
590590
wrong_user = self.create_user()
591591
self.create_useremail(email=self.email, user=wrong_user)
592592
assert self.handler.has_verified_account(self.verification_value) is False
593+
594+
595+
@control_silo_test
596+
class ProviderMismatchTest(TestCase):
597+
"""Tests for provider mismatch detection when user auths with wrong SSO provider."""
598+
599+
def setUp(self) -> None:
600+
self.provider = "dummy"
601+
self.auth_provider_inst = AuthProvider.objects.create(
602+
organization_id=self.organization.id, provider=self.provider
603+
)
604+
605+
self.auth_key = "test_auth_key"
606+
self.request = _set_up_request()
607+
self.request.session["auth_key"] = self.auth_key
608+
609+
def _create_helper_with_state(self, actual_provider_key=None):
610+
"""Create an AuthHelper with initial state and optional provider key mismatch."""
611+
initial_state = {
612+
"org_id": self.organization.id,
613+
"flow": FLOW_LOGIN,
614+
"provider_model_id": self.auth_provider_inst.id,
615+
"provider_key": self.provider,
616+
"referrer": None,
617+
"step_index": 1,
618+
"signature": None,
619+
"config": {},
620+
"data": {"actual_provider_key": actual_provider_key} if actual_provider_key else {},
621+
}
622+
local_client = clusters.get("default").get_local_client_for_key(self.auth_key)
623+
local_client.set(self.auth_key, json.dumps(initial_state))
624+
625+
helper = AuthHelper.get_for_request(self.request)
626+
assert helper is not None
627+
return helper
628+
629+
@mock.patch("sentry.auth.helper.messages")
630+
@mock.patch("sentry.auth.helper.metrics")
631+
def test_provider_mismatch_redirects_to_correct_sso(
632+
self, mock_metrics: mock.MagicMock, mock_messages: mock.MagicMock
633+
) -> None:
634+
"""Test that authenticating with wrong provider redirects to correct SSO."""
635+
helper = self._create_helper_with_state(actual_provider_key="google")
636+
637+
# Mock the provider to have a build_identity that would fail
638+
with mock.patch.object(helper.provider, "build_identity") as mock_build:
639+
mock_build.side_effect = Exception("Should not be called")
640+
641+
response = helper.finish_pipeline()
642+
643+
# Should redirect to org SSO page
644+
assert response.status_code == 302
645+
assert f"/auth/login/{self.organization.slug}/" in response.url
646+
647+
# Should show warning message
648+
mock_messages.add_message.assert_called_once()
649+
call_args = mock_messages.add_message.call_args
650+
assert call_args[0][1] == mock_messages.WARNING
651+
652+
# Should log metric
653+
mock_metrics.incr.assert_called_with(
654+
"sso.provider_mismatch",
655+
tags={
656+
"expected_provider": self.provider,
657+
"actual_provider": "google",
658+
},
659+
)
660+
661+
@mock.patch("sentry.auth.helper.messages")
662+
def test_provider_match_continues_normally(self, mock_messages: mock.MagicMock) -> None:
663+
"""Test that matching provider continues with normal flow."""
664+
helper = self._create_helper_with_state(actual_provider_key=self.provider)
665+
666+
# Mock build_identity to return a valid identity
667+
with mock.patch.object(
668+
helper.provider,
669+
"build_identity",
670+
return_value={"id": "123", "email": "[email protected]", "name": "Test"},
671+
):
672+
# The flow will continue and eventually redirect
673+
response = helper.finish_pipeline()
674+
675+
# Should not have shown provider mismatch warning
676+
for call in mock_messages.add_message.call_args_list:
677+
assert "SSO" not in str(call)
678+
679+
@mock.patch("sentry.auth.helper.messages")
680+
def test_no_actual_provider_key_continues_normally(self, mock_messages: mock.MagicMock) -> None:
681+
"""Test that missing actual_provider_key doesn't trigger mismatch (backward compat)."""
682+
helper = self._create_helper_with_state(actual_provider_key=None)
683+
684+
# Mock build_identity to return a valid identity
685+
with mock.patch.object(
686+
helper.provider,
687+
"build_identity",
688+
return_value={"id": "123", "email": "[email protected]", "name": "Test"},
689+
):
690+
response = helper.finish_pipeline()
691+
692+
# Should not have shown provider mismatch warning
693+
for call in mock_messages.add_message.call_args_list:
694+
if call[0][1] == mock_messages.WARNING:
695+
assert "SSO" not in str(call)

0 commit comments

Comments
 (0)