Skip to content

Commit 5fcead8

Browse files
authored
fix(auth): Detect SSO provider mismatch and fix 2FA redirect (#106041)
Fixes two issues in the SSO authentication flow: **Provider mismatch detection**: When users authenticated with the wrong SSO provider (e.g., Google when the org requires Okta), `build_identity()` would fail with a confusing error. Now we detect the mismatch before calling `build_identity()` by comparing the callback's provider against the org's configured provider, and redirect users to the correct SSO flow with a clear warning message. **2FA redirect fix**: The 2FA flow was setting `after_2fa` to the SSO callback URL instead of the user's original destination. Now it uses `_next` from the session (validated with `is_valid_redirect()` to prevent open redirects). ## Changes - `src/sentry/auth/helper.py`: Add provider mismatch detection in `finish_pipeline()`, fix `after_2fa` URL in `_login()` - `src/sentry/auth/providers/oauth2.py`, `saml2/provider.py`: Store `provider_key` in pipeline state for mismatch detection
1 parent a520079 commit 5fcead8

File tree

7 files changed

+261
-5
lines changed

7 files changed

+261
-5
lines changed

src/sentry/auth/helper.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,19 @@ def _login(self, user: Any) -> None:
133133
sample_rate=1.0,
134134
skip_internal=False,
135135
)
136+
137+
# Use the user's original destination (from _next) for 2FA redirect,
138+
# falling back to current URL if not set or invalid
139+
after_2fa_url = self.request.session.get("_next")
140+
if not after_2fa_url or not auth.is_valid_redirect(
141+
after_2fa_url, allowed_hosts=(self.request.get_host(),)
142+
):
143+
after_2fa_url = self.request.build_absolute_uri()
144+
136145
user_was_logged_in = auth.login(
137146
self.request,
138147
user,
139-
after_2fa=self.request.build_absolute_uri(),
148+
after_2fa=after_2fa_url,
140149
organization_id=self.organization.id,
141150
)
142151
if not user_was_logged_in:
@@ -780,6 +789,21 @@ def finish_pipeline(self) -> HttpResponseBase:
780789
if not data:
781790
return self.error(ERR_INVALID_IDENTITY)
782791

792+
# Check for provider mismatch - user authenticated with a different provider
793+
# than what the organization requires. This can happen when a user has multiple
794+
# SSO sessions in different tabs and completes the wrong one.
795+
provider_key = data.get("provider_key")
796+
if (
797+
self.state.flow == FLOW_LOGIN
798+
and self.provider_model
799+
and provider_key
800+
and provider_key != self.provider_model.provider
801+
):
802+
return self._handle_provider_mismatch(
803+
provider_key=provider_key,
804+
expected_provider_key=self.provider_model.provider,
805+
)
806+
783807
try:
784808
identity = self.provider.build_identity(data)
785809
except IdentityNotValid as error:
@@ -796,6 +820,49 @@ def finish_pipeline(self) -> HttpResponseBase:
796820

797821
return response
798822

823+
def _handle_provider_mismatch(
824+
self,
825+
provider_key: str,
826+
expected_provider_key: str,
827+
) -> HttpResponseRedirect:
828+
"""
829+
Handle when user authenticated with a different provider than required.
830+
831+
This happens when a user starts SSO for an org (e.g., "sentry" which requires Okta)
832+
but authenticates with a different provider (e.g., Google). We redirect them back
833+
to the org's login page to use the correct SSO provider.
834+
"""
835+
logger.info(
836+
"sso.provider-mismatch",
837+
extra={
838+
"organization_id": self.organization.id,
839+
"expected_provider": expected_provider_key,
840+
"actual_provider": provider_key,
841+
},
842+
)
843+
844+
metrics.incr(
845+
"sso.provider_mismatch",
846+
tags={
847+
"expected_provider": expected_provider_key,
848+
"actual_provider": provider_key,
849+
},
850+
)
851+
852+
# Clear the invalid pipeline state
853+
self.clear_session()
854+
855+
expected_name = getattr(self.provider, "name", expected_provider_key)
856+
messages.add_message(
857+
self.request,
858+
messages.WARNING,
859+
f"This organization requires {expected_name} SSO. Please sign in with your organization's SSO provider.",
860+
)
861+
862+
# Redirect to org-specific login to initiate correct SSO
863+
redirect_uri = reverse("sentry-auth-organization", args=[self.organization.slug])
864+
return HttpResponseRedirect(redirect_uri)
865+
799866
def auth_handler(self, identity: Mapping[str, Any]) -> AuthIdentityHandler:
800867
assert self.provider_model is not None
801868
return AuthIdentityHandler(

src/sentry/auth/providers/oauth2.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
7070
if "code" in request.GET:
7171
return pipeline.next_step()
7272

73-
state = secrets.token_hex()
73+
# Encode provider key in the state parameter so it survives the OAuth redirect.
74+
# This allows detecting when a user completes an OAuth flow that was started
75+
# for a different provider (e.g., multiple SSO tabs open).
76+
nonce = secrets.token_hex()
77+
state = f"{nonce}:{pipeline.provider.key}"
7478

7579
params = self.get_authorize_params(state=state, redirect_uri=_get_redirect_url())
7680
authorization_url = f"{self.get_authorize_url()}?{urlencode(params)}"
@@ -151,6 +155,15 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
151155
# hook here
152156
pipeline.bind_state("data", data)
153157

158+
# Extract the provider key from the OAuth state parameter.
159+
# This was encoded when the OAuth flow started (in OAuth2Login) and survives
160+
# the redirect through the IdP, allowing us to detect if the user completed
161+
# an OAuth flow that was started for a different provider.
162+
provider_key = None
163+
if state and ":" in state:
164+
provider_key = state.split(":", 1)[1]
165+
pipeline.bind_state("provider_key", provider_key)
166+
154167
return pipeline.next_step()
155168

156169

src/sentry/auth/providers/saml2/provider.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,13 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
7878
saml_config = build_saml_config(provider.config, pipeline.organization.slug)
7979
auth = build_auth(request, saml_config)
8080

81-
return HttpResponseRedirect(auth.login())
81+
# Encode provider key in RelayState so it survives the SAML redirect.
82+
# This allows detecting when a user completes a SAML flow that was started
83+
# for a different provider (e.g., multiple SSO tabs open).
84+
# Format: "provider_key:{key}" or just return_to URL for backward compat
85+
relay_state = f"provider_key:{pipeline.provider.key}"
86+
87+
return HttpResponseRedirect(auth.login(return_to=relay_state))
8288

8389

8490
# With SAML, the SSO request can be initiated by both the Service Provider
@@ -150,6 +156,16 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
150156

151157
pipeline.bind_state("auth_attributes", auth.get_attributes())
152158

159+
# Extract the provider key from the RelayState parameter.
160+
# This was encoded when the SAML flow started (in SAML2LoginView) and survives
161+
# the redirect through the IdP, allowing us to detect if the user completed
162+
# a SAML flow that was started for a different provider.
163+
provider_key = None
164+
relay_state = request.POST.get("RelayState") or request.GET.get("RelayState")
165+
if relay_state and relay_state.startswith("provider_key:"):
166+
provider_key = relay_state.split(":", 1)[1]
167+
pipeline.bind_state("provider_key", provider_key)
168+
153169
return pipeline.next_step()
154170

155171

src/sentry/utils/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def initiate_login(
145145
request: HttpRequest, next_url: str | None = None, referrer: str | None = None
146146
) -> None:
147147
"""
148-
initiate_login simply clears session cache
149-
if provided a `next_url` will append to the session after clearing previous keys
148+
Clears existing login state and initializes a new login flow.
149+
Optionally sets the post-login redirect destination and referrer.
150150
"""
151151
for key in ("_next", "_after_2fa", "_pending_2fa", "_referrer"):
152152
try:

tests/sentry/auth/test_helper.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,3 +632,106 @@ def test_has_verified_account_fail_user_id(self) -> None:
632632
wrong_user = self.create_user()
633633
self.create_useremail(email=self.email, user=wrong_user)
634634
assert self.handler.has_verified_account(self.verification_value) is False
635+
636+
637+
@control_silo_test
638+
class ProviderMismatchTest(TestCase):
639+
"""Tests for provider mismatch detection when user auths with wrong SSO provider."""
640+
641+
def setUp(self) -> None:
642+
self.provider = "dummy"
643+
self.auth_provider_inst = AuthProvider.objects.create(
644+
organization_id=self.organization.id, provider=self.provider
645+
)
646+
647+
self.auth_key = "test_auth_key"
648+
self.request = _set_up_request()
649+
self.request.session["auth_key"] = self.auth_key
650+
651+
def _create_helper_with_state(self, provider_key=None):
652+
"""Create an AuthHelper with initial state and optional provider key mismatch."""
653+
initial_state = {
654+
"org_id": self.organization.id,
655+
"flow": FLOW_LOGIN,
656+
"provider_model_id": self.auth_provider_inst.id,
657+
"provider_key": self.provider,
658+
"referrer": None,
659+
"step_index": 1,
660+
"signature": None,
661+
"config": {},
662+
"data": {"provider_key": provider_key} if provider_key else {},
663+
}
664+
local_client = clusters.get("default").get_local_client_for_key(self.auth_key)
665+
local_client.set(self.auth_key, json.dumps(initial_state))
666+
667+
helper = AuthHelper.get_for_request(self.request)
668+
assert helper is not None
669+
return helper
670+
671+
@mock.patch("sentry.auth.helper.messages")
672+
@mock.patch("sentry.auth.helper.metrics")
673+
def test_provider_mismatch_redirects_to_correct_sso(
674+
self, mock_metrics: mock.MagicMock, mock_messages: mock.MagicMock
675+
) -> None:
676+
"""Test that authenticating with wrong provider redirects to correct SSO."""
677+
helper = self._create_helper_with_state(provider_key="google")
678+
679+
# Mock the provider to have a build_identity that would fail
680+
with mock.patch.object(helper.provider, "build_identity") as mock_build:
681+
mock_build.side_effect = Exception("Should not be called")
682+
683+
response = helper.finish_pipeline()
684+
685+
# Should redirect to org SSO page
686+
assert response.status_code == 302
687+
assert f"/auth/login/{self.organization.slug}/" in response.url
688+
689+
# Should show warning message
690+
mock_messages.add_message.assert_called_once()
691+
call_args = mock_messages.add_message.call_args
692+
assert call_args[0][1] == mock_messages.WARNING
693+
694+
# Should log metric
695+
mock_metrics.incr.assert_called_with(
696+
"sso.provider_mismatch",
697+
tags={
698+
"expected_provider": self.provider,
699+
"actual_provider": "google",
700+
},
701+
)
702+
703+
@mock.patch("sentry.auth.helper.messages")
704+
def test_provider_match_continues_normally(self, mock_messages: mock.MagicMock) -> None:
705+
"""Test that matching provider continues with normal flow."""
706+
helper = self._create_helper_with_state(provider_key=self.provider)
707+
708+
# Mock build_identity to return a valid identity
709+
with mock.patch.object(
710+
helper.provider,
711+
"build_identity",
712+
return_value={"id": "123", "email": "[email protected]", "name": "Test"},
713+
):
714+
# The flow will continue and eventually redirect
715+
helper.finish_pipeline()
716+
717+
# Should not have shown provider mismatch warning
718+
for call in mock_messages.add_message.call_args_list:
719+
assert "SSO" not in str(call)
720+
721+
@mock.patch("sentry.auth.helper.messages")
722+
def test_no_provider_key_continues_normally(self, mock_messages: mock.MagicMock) -> None:
723+
"""Test that missing provider_key doesn't trigger mismatch (backward compat)."""
724+
helper = self._create_helper_with_state(provider_key=None)
725+
726+
# Mock build_identity to return a valid identity
727+
with mock.patch.object(
728+
helper.provider,
729+
"build_identity",
730+
return_value={"id": "123", "email": "[email protected]", "name": "Test"},
731+
):
732+
helper.finish_pipeline()
733+
734+
# Should not have shown provider mismatch warning
735+
for call in mock_messages.add_message.call_args_list:
736+
if call[0][1] == mock_messages.WARNING:
737+
assert "SSO" not in str(call)

tests/sentry/web/frontend/test_auth_oauth2.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,37 @@ def test_response_errors(self) -> None:
229229
assert len(messages) == 1
230230
assert str(messages[0]).startswith("Authentication error")
231231
assert auth_resp.context["user"] != self.user
232+
233+
def test_state_contains_provider_key(self) -> None:
234+
"""Test that OAuth state parameter contains the provider key for mismatch detection."""
235+
state = self.initiate_oauth_flow()
236+
237+
# State format should be "{nonce}:{provider_key}"
238+
assert ":" in state
239+
nonce, provider_key = state.split(":", 1)
240+
assert len(nonce) > 0
241+
assert provider_key == self.provider_name
242+
243+
@mock.patch("sentry.auth.providers.oauth2.safe_urlopen")
244+
def test_provider_mismatch_detected(self, urlopen: mock.MagicMock) -> None:
245+
"""Test that authenticating with wrong provider in state triggers mismatch detection."""
246+
# Start a normal OAuth flow to get the session set up
247+
state = self.initiate_oauth_flow()
248+
249+
# Modify the state to have a different provider key (simulating multi-tab scenario
250+
# where user started auth with a different provider)
251+
nonce = state.split(":")[0]
252+
wrong_state = f"{nonce}:wrong_provider"
253+
254+
# The state validation should fail because the full state doesn't match
255+
headers = {"Content-Type": "application/json"}
256+
auth_data = {"id": "oauth_external_id_1234", "email": self.user.email}
257+
urlopen.return_value = MockResponse(headers, json.dumps(auth_data))
258+
259+
query = urlencode({"code": "1234", "state": wrong_state})
260+
auth_resp = self.client.get(f"{self.sso_path}?{query}", follow=True)
261+
262+
# Should fail with state mismatch error
263+
messages = list(auth_resp.context["messages"])
264+
assert len(messages) == 1
265+
assert str(messages[0]).startswith("Authentication error")

tests/sentry/web/frontend/test_auth_saml2.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,26 @@ def test_verify_email(self, follow=False, **kwargs) -> None:
308308

309309
# expect no linking before verification
310310
assert AuthIdentity.objects.filter(user_id=self.user.id).count() == 0
311+
312+
def test_relay_state_contains_provider_key(self) -> None:
313+
"""Test that SAML RelayState contains the provider key for mismatch detection."""
314+
resp = self.client.post(self.login_path, {"init": True})
315+
316+
assert resp.status_code == 302
317+
redirect = urlparse(resp.get("Location", ""))
318+
query = parse_qs(redirect.query)
319+
320+
# RelayState should contain the provider key
321+
assert "RelayState" in query
322+
relay_state = query["RelayState"][0]
323+
assert relay_state == f"provider_key:{self.provider_name}"
324+
325+
def test_idp_initiated_without_relay_state_continues(self) -> None:
326+
"""Test that IdP-initiated SAML without RelayState continues normally (backward compat)."""
327+
# IdP-initiated auth doesn't have RelayState from our side
328+
# This should still work - the provider_key will be None and the check will be skipped
329+
auth = self.accept_auth()
330+
331+
# Should continue to identity confirmation
332+
assert auth.status_code == 200
333+
assert auth.context["existing_user"] == self.user

0 commit comments

Comments
 (0)