Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,29 @@ class MockLoginGovOauthClient(BaseOauthClient):
def __init__(self) -> None:
self.responses: dict[str, OauthTokenResponse] = {}

# Used to control testing of retry behavior for Login.gov token lookup calls
self.retries: dict[str, int] = {}

def add_token_response(self, code: str, response: OauthTokenResponse) -> None:
self.responses[code] = response
Comment on lines +11 to 14
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than setting the retries by using the dict directly, what if we add a param to the add token function to do it?

Suggested change
self.retries: dict[str, int] = {}
def add_token_response(self, code: str, response: OauthTokenResponse) -> None:
self.responses[code] = response
self.retries: dict[str, int] = {}
def add_token_response(self, code: str, response: OauthTokenResponse, retry_count: int = 0) -> None:
self.responses[code] = response
self.retries[code] = retry_count

Then that simplifies the logic below a bit as it'll always be set.


def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
response = self.responses.get(request.code, None)
retries = self.retries.get(request.code)
# if we don't have retries enabled on the mock, behave as usual
if retries is not None:
self.retries[request.code] = retries - 1
# retries would be one the last time through, as we've reduced it to zero but retries accounts for the data before that
if retries is None or retries == 1:
response = self.responses.get(request.code, None)

if response is None:
response = OauthTokenResponse(
error="error", error_description="default mock error description"
)

if response is None:
response = OauthTokenResponse(
error="error", error_description="default mock error description"
)
return response

return response
# if we did turn on retries on the mock, do retry stuff
return OauthTokenResponse(
error="error", error_description="mock oauth token error description"
)
25 changes: 18 additions & 7 deletions api/src/services/users/login_gov_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,26 @@ def handle_login_gov_token(
# call the token endpoint (make a client)
# https://developers.login.gov/oidc/token/
client = get_login_gov_client()
response = client.get_token(
OauthTokenRequest(
code=login_gov_data.code, client_assertion=get_login_gov_client_assertion()
limit = 3
tries = 0
while tries < limit:
tries += 1
response = client.get_token(
OauthTokenRequest(
code=login_gov_data.code, client_assertion=get_login_gov_client_assertion()
)
)
)

# If this request failed, we'll assume we're the issue and 500
if response.is_error_response():
raise_flask_error(500, response.error_description)
# If this request failed, we'll assume we're the issue and 500
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should update this as with the retries it isn't quite right?

if response.is_error_response():
if tries == limit:
raise_flask_error(500, response.error_description)
else:
logger.info(
"Retrying call to Login.gov after receiving error",
extra={"tries": tries, "limit": limit},
)
continue
Comment on lines +178 to +183
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant to do this - as written it would always call 3 times as the loop never breaks before 3 tries.

Suggested change
else:
logger.info(
"Retrying call to Login.gov after receiving error",
extra={"tries": tries, "limit": limit},
)
continue
else:
logger.info(
"Retrying call to Login.gov after receiving error",
extra={"tries": tries, "limit": limit},
)
break

Alternatively, if we wanted to use tenacity, seems they have a TryAgain exception we could use - although I don't think it's that big a deal.


# Process the token response from login.gov
# which will create/update a user in the DB
Expand Down
95 changes: 95 additions & 0 deletions api/tests/src/api/users/test_user_route_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,3 +690,98 @@ def test_agency_user_without_piv_succeeds_when_not_required(
resp_json = resp.get_json()
assert resp_json["message"] == "success"
assert resp_json["token"] is not None


def test_user_callback_retries_success(
client, db_session, enable_factory_create, mock_oauth_client, private_rsa_key
):
# Create state so the callback gets past the check
login_gov_state = LoginGovStateFactory.create()

code = str(uuid.uuid4())
id_token = create_jwt(
user_id="bob-xyz",
nonce=str(login_gov_state.nonce),
private_key=private_rsa_key,
)
mock_oauth_client.add_token_response(
code,
OauthTokenResponse(
id_token=id_token, access_token="fake_token", token_type="Bearer", expires_in=300
),
)
mock_oauth_client.retries[code] = 3

resp = client.get(
f"/v1/users/login/callback?state={login_gov_state.login_gov_state_id}&code={code}",
follow_redirects=True,
)

assert resp.status_code == 200
resp_json = resp.get_json()
assert resp_json["is_user_new"] == "0"
assert resp_json["message"] == "success"
assert resp_json["token"] is not None

user_token_session = parse_jwt_for_user(resp_json["token"], db_session)
assert user_token_session.expires_at > datetime_util.utcnow()
assert user_token_session.is_valid is True

# Make sure the external user record is created with expected IDs
external_user = (
db_session.query(LinkExternalUser)
.filter(
LinkExternalUser.user_id == user_token_session.user_id,
LinkExternalUser.external_user_id == "bob-xyz",
)
.one_or_none()
)
assert external_user is not None

# Make sure the login gov state was deleted
db_state = (
db_session.query(LoginGovState)
.filter(LoginGovState.login_gov_state_id == login_gov_state.login_gov_state_id)
.one_or_none()
)
assert db_state is None


def test_user_callback_retries_failure(
client, db_session, enable_factory_create, mock_oauth_client, private_rsa_key
):
# Create state so the callback gets past the check
login_gov_state = LoginGovStateFactory.create()

code = str(uuid.uuid4())
id_token = create_jwt(
user_id="bob-xyz",
nonce=str(login_gov_state.nonce),
private_key=private_rsa_key,
)
mock_oauth_client.add_token_response(
code,
OauthTokenResponse(
id_token=id_token, access_token="fake_token", token_type="Bearer", expires_in=300
),
)
mock_oauth_client.retries[code] = 4

resp = client.get(
f"/v1/users/login/callback?state={login_gov_state.login_gov_state_id}&code={code}",
follow_redirects=True,
)

assert resp.status_code == 200
resp_json = resp.get_json()

assert resp_json["message"] == "error"
assert resp_json["error_description"] == "internal error"

# History contains each redirect, we redirected just once
assert len(resp.history) == 1
redirect = resp.history[0]

assert redirect.status_code == 302
redirect_url = urllib.parse.urlparse(redirect.headers["Location"])
assert redirect_url.path == "/v1/users/login/result"
Loading