diff --git a/api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py b/api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py index e437b8a5a2..a47e1793df 100644 --- a/api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py +++ b/api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py @@ -7,15 +7,29 @@ class MockLoginGovOauthClient(BaseOauthClient): def __init__(self) -> None: self.responses: dict[str, OauthTokenResponse] = {} + # Used to control testing of retry behavior for Login.gov token lookup calls + self.retries: dict[str, int] = {} + def add_token_response(self, code: str, response: OauthTokenResponse) -> None: self.responses[code] = response def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse: - response = self.responses.get(request.code, None) + retries = self.retries.get(request.code) + # if we don't have retries enabled on the mock, behave as usual + if retries is not None: + self.retries[request.code] = retries - 1 + # retries would be one the last time through, as we've reduced it to zero but retries accounts for the data before that + if retries is None or retries == 1: + response = self.responses.get(request.code, None) + + if response is None: + response = OauthTokenResponse( + error="error", error_description="default mock error description" + ) - if response is None: - response = OauthTokenResponse( - error="error", error_description="default mock error description" - ) + return response - return response + # if we did turn on retries on the mock, do retry stuff + return OauthTokenResponse( + error="error", error_description="mock oauth token error description" + ) diff --git a/api/src/services/users/login_gov_callback_handler.py b/api/src/services/users/login_gov_callback_handler.py index 9c50549841..a2bf0ce99d 100644 --- a/api/src/services/users/login_gov_callback_handler.py +++ b/api/src/services/users/login_gov_callback_handler.py @@ -161,15 +161,26 @@ def handle_login_gov_token( # call the token endpoint (make a client) # https://developers.login.gov/oidc/token/ client = get_login_gov_client() - response = client.get_token( - OauthTokenRequest( - code=login_gov_data.code, client_assertion=get_login_gov_client_assertion() + limit = 3 + tries = 0 + while tries < limit: + tries += 1 + response = client.get_token( + OauthTokenRequest( + code=login_gov_data.code, client_assertion=get_login_gov_client_assertion() + ) ) - ) - # If this request failed, we'll assume we're the issue and 500 - if response.is_error_response(): - raise_flask_error(500, response.error_description) + # If this request failed, we'll assume we're the issue and 500 + if response.is_error_response(): + if tries == limit: + raise_flask_error(500, response.error_description) + else: + logger.info( + "Retrying call to Login.gov after receiving error", + extra={"tries": tries, "limit": limit}, + ) + continue # Process the token response from login.gov # which will create/update a user in the DB diff --git a/api/tests/src/api/users/test_user_route_login.py b/api/tests/src/api/users/test_user_route_login.py index af476e716c..e4bde5b2f0 100644 --- a/api/tests/src/api/users/test_user_route_login.py +++ b/api/tests/src/api/users/test_user_route_login.py @@ -690,3 +690,98 @@ def test_agency_user_without_piv_succeeds_when_not_required( resp_json = resp.get_json() assert resp_json["message"] == "success" assert resp_json["token"] is not None + + +def test_user_callback_retries_success( + client, db_session, enable_factory_create, mock_oauth_client, private_rsa_key +): + # Create state so the callback gets past the check + login_gov_state = LoginGovStateFactory.create() + + code = str(uuid.uuid4()) + id_token = create_jwt( + user_id="bob-xyz", + nonce=str(login_gov_state.nonce), + private_key=private_rsa_key, + ) + mock_oauth_client.add_token_response( + code, + OauthTokenResponse( + id_token=id_token, access_token="fake_token", token_type="Bearer", expires_in=300 + ), + ) + mock_oauth_client.retries[code] = 3 + + resp = client.get( + f"/v1/users/login/callback?state={login_gov_state.login_gov_state_id}&code={code}", + follow_redirects=True, + ) + + assert resp.status_code == 200 + resp_json = resp.get_json() + assert resp_json["is_user_new"] == "0" + assert resp_json["message"] == "success" + assert resp_json["token"] is not None + + user_token_session = parse_jwt_for_user(resp_json["token"], db_session) + assert user_token_session.expires_at > datetime_util.utcnow() + assert user_token_session.is_valid is True + + # Make sure the external user record is created with expected IDs + external_user = ( + db_session.query(LinkExternalUser) + .filter( + LinkExternalUser.user_id == user_token_session.user_id, + LinkExternalUser.external_user_id == "bob-xyz", + ) + .one_or_none() + ) + assert external_user is not None + + # Make sure the login gov state was deleted + db_state = ( + db_session.query(LoginGovState) + .filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id) + .one_or_none() + ) + assert db_state is None + + +def test_user_callback_retries_failure( + client, db_session, enable_factory_create, mock_oauth_client, private_rsa_key +): + # Create state so the callback gets past the check + login_gov_state = LoginGovStateFactory.create() + + code = str(uuid.uuid4()) + id_token = create_jwt( + user_id="bob-xyz", + nonce=str(login_gov_state.nonce), + private_key=private_rsa_key, + ) + mock_oauth_client.add_token_response( + code, + OauthTokenResponse( + id_token=id_token, access_token="fake_token", token_type="Bearer", expires_in=300 + ), + ) + mock_oauth_client.retries[code] = 4 + + resp = client.get( + f"/v1/users/login/callback?state={login_gov_state.login_gov_state_id}&code={code}", + follow_redirects=True, + ) + + assert resp.status_code == 200 + resp_json = resp.get_json() + + assert resp_json["message"] == "error" + assert resp_json["error_description"] == "internal error" + + # History contains each redirect, we redirected just once + assert len(resp.history) == 1 + redirect = resp.history[0] + + assert redirect.status_code == 302 + redirect_url = urllib.parse.urlparse(redirect.headers["Location"]) + assert redirect_url.path == "/v1/users/login/result"