Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 266 additions & 9 deletions crates/handlers/src/oauth2/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use oauth2_types::{
scope,
};
use thiserror::Error;
use tracing::debug;
use tracing::{debug, info};
use ulid::Ulid;

use super::{generate_id_token, generate_token_pair};
Expand Down Expand Up @@ -98,6 +98,20 @@ pub(crate) enum RouteError {
#[error("failed to load oauth session")]
NoSuchOAuthSession,

#[error(
"failed to load the next refresh token ({next:?}) from the previous one ({previous:?})"
)]
NoSuchNextRefreshToken { next: Ulid, previous: Ulid },

#[error("failed to load the access token ({access_token:?}) associated with the next refresh token ({refresh_token:?})")]
NoSuchNextAccessToken {
access_token: Ulid,
refresh_token: Ulid,
},

#[error("no access token associated with the refresh token {refresh_token:?}")]
NoAccessTokenOnRefreshToken { refresh_token: Ulid },

#[error("device code grant expired")]
DeviceCodeExpired,

Expand All @@ -122,7 +136,10 @@ impl IntoResponse for RouteError {
Self::Internal(_)
| Self::NoSuchBrowserSession
| Self::NoSuchOAuthSession
| Self::ProvisionDeviceFailed(_) => (
| Self::ProvisionDeviceFailed(_)
| Self::NoSuchNextRefreshToken { .. }
| Self::NoSuchNextAccessToken { .. }
| Self::NoAccessTokenOnRefreshToken { .. } => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
),
Expand Down Expand Up @@ -482,6 +499,7 @@ async fn authorization_code_grant(
Ok((params, repo))
}

#[allow(clippy::too_many_lines)]
async fn refresh_token_grant(
rng: &mut BoxRng,
clock: &impl Clock,
Expand Down Expand Up @@ -518,10 +536,6 @@ async fn refresh_token_grant(
.await?;
}

if !refresh_token.is_valid() {
return Err(RouteError::RefreshTokenInvalid(refresh_token.id));
}

if !session.is_valid() {
return Err(RouteError::SessionInvalid(session.id));
}
Expand All @@ -534,6 +548,77 @@ async fn refresh_token_grant(
});
}

if !refresh_token.is_valid() {
// We're seing a refresh token that already has been consumed, this might be a
// double-refresh or a replay attack

// First, get the next refresh token
let Some(next_refresh_token_id) = refresh_token.next_refresh_token_id() else {
// If we don't have a 'next' refresh token, it may just be because this was
// before we were recording those. Let's just treat it as a replay.
return Err(RouteError::RefreshTokenInvalid(refresh_token.id));
};

let Some(next_refresh_token) = repo
.oauth2_refresh_token()
.lookup(next_refresh_token_id)
.await?
else {
return Err(RouteError::NoSuchNextRefreshToken {
next: next_refresh_token_id,
previous: refresh_token.id,
});
};

// Check if the next refresh token was already consumed or not
if !next_refresh_token.is_valid() {
// XXX: This is a replay, we *may* want to invalidate the session
return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
}

// Check if the associated access token was already used
let Some(access_token_id) = next_refresh_token.access_token_id else {
// This should in theory not happen: this means an access token got cleaned up,
// but the refresh token was still valid.
return Err(RouteError::NoAccessTokenOnRefreshToken {
refresh_token: next_refresh_token.id,
});
};

// Load it
let next_access_token = repo
.oauth2_access_token()
.lookup(access_token_id)
.await?
.ok_or(RouteError::NoSuchNextAccessToken {
access_token: access_token_id,
refresh_token: next_refresh_token_id,
})?;

if next_access_token.is_used() {
// XXX: This is a replay, we *may* want to invalidate the session
return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
}

// Looks like it's a double-refresh, client lost their refresh token on
// the way back. Let's revoke the two new access and refresh tokens, and
// issue new ones
info!(
oauth2_session.id = %session.id,
oauth2_client.id = %client.id,
%refresh_token.id,
"A refresh token was used twice, but the new refresh token was lost. Revoking the old ones and issuing new ones."
);

repo.oauth2_access_token()
.revoke(clock, next_access_token)
.await?;

repo.oauth2_refresh_token()
.revoke(clock, next_refresh_token)
.await?;
}

activity_tracker
.record_oauth2_session(clock, &session)
.await;
Expand All @@ -550,9 +635,12 @@ async fn refresh_token_grant(
if let Some(access_token_id) = refresh_token.access_token_id {
let access_token = repo.oauth2_access_token().lookup(access_token_id).await?;
if let Some(access_token) = access_token {
repo.oauth2_access_token()
.revoke(clock, access_token)
.await?;
// If it is a double-refresh, it might already be revoked
if !access_token.state.is_revoked() {
repo.oauth2_access_token()
.revoke(clock, access_token)
.await?;
}
}
}

Expand Down Expand Up @@ -1123,6 +1211,175 @@ mod tests {
let _: AccessTokenResponse = response.json();
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_double_refresh(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

// Provision a client
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/callback"],
"token_endpoint_auth_method": "none",
"response_types": ["code"],
"grant_types": ["authorization_code", "refresh_token"],
}));

let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);

let ClientRegistrationResponse { client_id, .. } = response.json();

// Let's provision a user and create a session for them. This part is hard to
// test with just HTTP requests, so we'll use the repository directly.
let mut repo = state.repository().await.unwrap();

let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.await
.unwrap();

let browser_session = repo
.browser_session()
.add(&mut state.rng(), &state.clock, &user, None)
.await
.unwrap();

// Lookup the client in the database.
let client = repo
.oauth2_client()
.find_by_client_id(&client_id)
.await
.unwrap()
.unwrap();

// Get a token pair
let session = repo
.oauth2_session()
.add_from_browser_session(
&mut state.rng(),
&state.clock,
&client,
&browser_session,
Scope::from_iter([OPENID]),
)
.await
.unwrap();

let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
generate_token_pair(
&mut state.rng(),
&state.clock,
&mut repo,
&session,
Duration::microseconds(5 * 60 * 1000 * 1000),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there really no way to do Duration::seconds(5 * 60) lol? :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's the only one that can't panic, and clippy will complain that we don't document that potential panic I think? 😬

)
.await
.unwrap();

repo.save().await.unwrap();

// First check that the token is valid
assert!(state.is_access_token_valid(&access_token).await);

// Now call the token endpoint to get an access token.
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client.client_id,
}));

let first_response = state.request(request).await;
first_response.assert_status(StatusCode::OK);
let first_response: AccessTokenResponse = first_response.json();

// Call a second time, it should work, as we haven't done anything yet with the
// token
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client.client_id,
}));

let second_response = state.request(request).await;
second_response.assert_status(StatusCode::OK);
let second_response: AccessTokenResponse = second_response.json();

// Check that we got new tokens
assert_ne!(first_response.access_token, second_response.access_token);
assert_ne!(first_response.refresh_token, second_response.refresh_token);

// Check that the old-new token is invalid
assert!(
!state
.is_access_token_valid(&first_response.access_token)
.await
);

// Check that the new-new token is valid
assert!(
state
.is_access_token_valid(&second_response.access_token)
.await
);

// Do a third refresh, this one should not work, as we've used the new
// access token
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client.client_id,
}));

let third_response = state.request(request).await;
third_response.assert_status(StatusCode::BAD_REQUEST);

// The other reason we consider a new refresh token to be 'used' is if
// it was already used in a refresh
// So, if we do a refresh with the second_response.refresh_token, then
// another refresh with the result, redoing one with
// second_response.refresh_token again should fail
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": second_response.refresh_token,
"client_id": client.client_id,
}));

// This one is fine
let fourth_response = state.request(request).await;
fourth_response.assert_status(StatusCode::OK);
let fourth_response: AccessTokenResponse = fourth_response.json();

// Do another one, it should be fine as well
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": fourth_response.refresh_token,
"client_id": client.client_id,
}));

let fifth_response = state.request(request).await;
fifth_response.assert_status(StatusCode::OK);

// But now, if we re-do with the second_response.refresh_token, it should
// fail
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": second_response.refresh_token,
"client_id": client.client_id,
}));

let sixth_response = state.request(request).await;
sixth_response.assert_status(StatusCode::BAD_REQUEST);
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_client_credentials(pool: PgPool) {
setup();
Expand Down
Loading