Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions rsky-pds/src/apis/app/bsky/feed/get_feed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,24 @@ impl<'r> FromRequest<'r> for GetFeedPipeThrough {
match limit {
Ok(limit) => match limit {
Some(limit) if limit > 100 => {
req.local_cache(|| {
Some(ApiError::InvalidRequest("`limit` is invalid".to_string()))
});
return Outcome::Error((
Status::BadRequest,
anyhow!("`limit` is invalid"),
))
));
}
_ => (),
},
_ => {
req.local_cache(|| {
Some(ApiError::InvalidRequest("`limit` is invalid".to_string()))
});
return Outcome::Error((
Status::BadRequest,
anyhow!("`limit` is invalid"),
))
));
}
}
}
Expand All @@ -77,10 +83,13 @@ impl<'r> FromRequest<'r> for GetFeedPipeThrough {
{
Ok(res) => res,
Err(error) => {
req.local_cache(|| {
Some(ApiError::InvalidRequest(error.to_string()))
});
return Outcome::Error((
Status::BadRequest,
anyhow!(error.to_string()),
))
));
}
};
let headers = req.headers().clone().into_iter().fold(
Expand All @@ -93,7 +102,7 @@ impl<'r> FromRequest<'r> for GetFeedPipeThrough {
acc
},
);
let req = ProxyRequest {
let proxy_req = ProxyRequest {
headers,
query: match req.uri().query() {
None => None,
Expand All @@ -108,7 +117,7 @@ impl<'r> FromRequest<'r> for GetFeedPipeThrough {
cfg: req.guard::<&State<ServerConfig>>().await.unwrap(),
};
match pipethrough(
&req,
&proxy_req,
requester,
OverrideOpts {
aud: Some(data.view.did.to_string()),
Expand All @@ -124,7 +133,12 @@ impl<'r> FromRequest<'r> for GetFeedPipeThrough {
buffer: res.buffer,
headers: res.headers,
}),
Err(error) => Outcome::Error((Status::BadRequest, error)),
Err(error) => {
req.local_cache(|| {
Some(ApiError::InvalidRequest(error.to_string()))
});
Outcome::Error((Status::BadRequest, error))
}
}
}
_ => Outcome::Error((
Expand All @@ -133,13 +147,21 @@ impl<'r> FromRequest<'r> for GetFeedPipeThrough {
)),
}
}
_ => Outcome::Error((Status::BadRequest, anyhow!("`feed` is invalid"))),
_ => {
req.local_cache(|| {
Some(ApiError::InvalidRequest("`feed` is invalid".to_string()))
});
Outcome::Error((Status::BadRequest, anyhow!("`feed` is invalid")))
}
}
}
Outcome::Error(err) => Outcome::Error((
Status::BadRequest,
anyhow::Error::new(InvalidRequestError::AuthError(err.1)),
)),
Outcome::Error(err) => {
req.local_cache(|| Some(ApiError::InvalidRequest(err.1.to_string())));
Outcome::Error((
Status::BadRequest,
anyhow::Error::new(InvalidRequestError::AuthError(err.1)),
))
}
_ => panic!("Unexpected outcome during Pipethrough"),
}
}
Expand Down
2 changes: 1 addition & 1 deletion rsky-pds/src/apis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub async fn bsky_api_forwarder(
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum ApiError {
RuntimeError,
InvalidLogin,
Expand Down
110 changes: 76 additions & 34 deletions rsky-pds/src/auth_verifier.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::account_manager::helpers::account::{ActorAccount, AvailabilityFlags};
use crate::account_manager::helpers::auth::CustomClaimObj;
use crate::account_manager::AccountManager;
use crate::apis::ApiError;
use crate::xrpc_server::auth::{verify_jwt as verify_service_jwt_server, ServiceJwtPayload};
use crate::SharedIdResolver;
use anyhow::{bail, Result};
Expand Down Expand Up @@ -166,15 +167,17 @@ impl<'r> FromRequest<'r> for Refresh {
match payload.jti {
Some(_) => result,
None => {
return Outcome::Error((
Status::BadRequest,
AuthError::BadJwt("Unexpected missing refresh token id".to_owned()),
));
let error =
AuthError::BadJwt("Unexpected missing refresh token id".to_owned());
req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
return Outcome::Error((Status::BadRequest, error));
}
}
}
Err(error) => {
return Outcome::Error((Status::BadRequest, AuthError::BadJwt(error.to_string())));
let error = AuthError::BadJwt(error.to_string());
req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
return Outcome::Error((Status::BadRequest, error));
}
};
Outcome::Success(Refresh {
Expand Down Expand Up @@ -231,7 +234,10 @@ impl<'r> FromRequest<'r> for AccessFull {
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match access_check(req, vec![AuthScope::Access], None).await {
Outcome::Success(access) => Outcome::Success(AccessFull { access }),
Outcome::Error(error) => Outcome::Error(error),
Outcome::Error(error) => {
req.local_cache(|| Some(ApiError::InvalidRequest(error.1.to_string())));
Outcome::Error(error)
}
Outcome::Forward(_) => panic!("Outcome::Forward returned"),
}
}
Expand All @@ -254,7 +260,10 @@ impl<'r> FromRequest<'r> for AccessPrivileged {
.await
{
Outcome::Success(access) => Outcome::Success(Self { access }),
Outcome::Error(error) => Outcome::Error(error),
Outcome::Error(error) => {
req.local_cache(|| Some(ApiError::InvalidRequest(error.1.to_string())));
Outcome::Error(error)
}
Outcome::Forward(_) => panic!("Outcome::Forward returned"),
}
}
Expand All @@ -281,7 +290,10 @@ impl<'r> FromRequest<'r> for AccessStandard {
.await
{
Outcome::Success(access) => Outcome::Success(AccessStandard { access }),
Outcome::Error(error) => Outcome::Error(error),
Outcome::Error(error) => {
req.local_cache(|| Some(ApiError::InvalidRequest(error.1.to_string())));
Outcome::Error(error)
}
Outcome::Forward(_) => panic!("Outcome::Forward returned"),
}
}
Expand Down Expand Up @@ -312,7 +324,10 @@ impl<'r> FromRequest<'r> for AccessStandardIncludeChecks {
.await
{
Outcome::Success(access) => Outcome::Success(AccessStandardIncludeChecks { access }),
Outcome::Error(error) => Outcome::Error(error),
Outcome::Error(error) => {
req.local_cache(|| Some(ApiError::InvalidRequest(error.1.to_string())));
Outcome::Error(error)
}
Outcome::Forward(_) => panic!("Outcome::Forward returned"),
}
}
Expand Down Expand Up @@ -343,7 +358,10 @@ impl<'r> FromRequest<'r> for AccessStandardCheckTakedown {
.await
{
Outcome::Success(access) => Outcome::Success(AccessStandardCheckTakedown { access }),
Outcome::Error(error) => Outcome::Error(error),
Outcome::Error(error) => {
req.local_cache(|| Some(ApiError::InvalidRequest(error.1.to_string())));
Outcome::Error(error)
}
Outcome::Forward(_) => panic!("Outcome::Forward returned"),
}
}
Expand Down Expand Up @@ -371,7 +389,10 @@ impl<'r> FromRequest<'r> for AccessStandardSignupQueued {
.await
{
Outcome::Success(access) => Outcome::Success(AccessStandardSignupQueued { access }),
Outcome::Error(error) => Outcome::Error(error),
Outcome::Error(error) => {
req.local_cache(|| Some(ApiError::InvalidRequest(error.1.to_string())));
Outcome::Error(error)
}
Outcome::Forward(_) => panic!("Outcome::Forward returned"),
}
}
Expand All @@ -391,12 +412,14 @@ impl<'r> FromRequest<'r> for RevokeRefreshToken {
match validate_bearer_token(req, vec![AuthScope::Refresh], Some(options)).await {
Ok(result) => match result.payload.jti {
Some(jti) => Outcome::Success(RevokeRefreshToken { id: jti }),
None => Outcome::Error((
Status::BadRequest,
AuthError::BadJwt("Unexpected missing refresh token id".to_owned()),
)),
None => {
let error = AuthError::BadJwt("Unexpected missing refresh token id".to_owned());
req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
Outcome::Error((Status::BadRequest, error))
}
},
Err(error) => {
req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
Outcome::Error((Status::BadRequest, AuthError::BadJwt(error.to_string())))
}
}
Expand Down Expand Up @@ -439,6 +462,11 @@ impl<'r> FromRequest<'r> for UserDidAuth {
},
}),
Err(error) => {
req.local_cache(|| {
Some(ApiError::InvalidRequest(
AuthError::BadJwt(error.to_string()).to_string(),
))
});
Outcome::Error((Status::BadRequest, AuthError::BadJwt(error.to_string())))
}
}
Expand All @@ -459,7 +487,10 @@ impl<'r> FromRequest<'r> for UserDidAuthOptional {
Outcome::Success(output) => Outcome::Success(UserDidAuthOptional {
access: Some(output.access),
}),
Outcome::Error(err) => Outcome::Error(err),
Outcome::Error(err) => {
req.local_cache(|| Some(ApiError::InvalidRequest(err.1.to_string())));
Outcome::Error(err)
}
_ => panic!("Unexpected outcome during UserDidAuthOptional"),
}
} else {
Expand Down Expand Up @@ -497,12 +528,11 @@ impl<'r> FromRequest<'r> for ModService {
&& (env_str("PDS_ENTRYWAY_DID").is_none()
|| Some(payload.aud.clone()) != env_str("PDS_ENTRYWAY_DID")) =>
{
Outcome::Error((
Status::BadRequest,
AuthError::BadJwtAudience(
"jwt audience does not match service did".to_string(),
),
))
let error = AuthError::BadJwtAudience(
"jwt audience does not match service did".to_string(),
);
req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
Outcome::Error((Status::BadRequest, error))
}
Ok(payload) => Outcome::Success(ModService {
access: AccessOutput {
Expand All @@ -520,14 +550,15 @@ impl<'r> FromRequest<'r> for ModService {
},
}),
Err(error) => {
let error = AuthError::BadJwt(error.to_string());
req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
Outcome::Error((Status::BadRequest, AuthError::BadJwt(error.to_string())))
}
}
} else {
Outcome::Error((
Status::BadRequest,
AuthError::UntrustedIss("Untrusted issuer".to_string()),
))
let error = AuthError::UntrustedIss("Untrusted issuer".to_string());
req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
Outcome::Error((Status::BadRequest, error))
}
}
}
Expand All @@ -546,15 +577,21 @@ impl<'r> FromRequest<'r> for Moderator {
Outcome::Success(output) => Outcome::Success(Moderator {
access: output.access,
}),
Outcome::Error(err) => Outcome::Error(err),
Outcome::Error(err) => {
req.local_cache(|| Some(ApiError::InvalidRequest(err.1.to_string())));
Outcome::Error(err)
}
_ => panic!("Unexpected outcome during Moderator"),
}
} else {
match AdminToken::from_request(req).await {
Outcome::Success(output) => Outcome::Success(Moderator {
access: output.access,
}),
Outcome::Error(err) => Outcome::Error(err),
Outcome::Error(err) => {
req.local_cache(|| Some(ApiError::InvalidRequest(err.1.to_string())));
Outcome::Error(err)
}
_ => panic!("Unexpected outcome during Moderator"),
}
}
Expand All @@ -580,10 +617,9 @@ impl<'r> FromRequest<'r> for AdminToken {
let BasicAuth { username, password } = parsed;

if username != "admin" || password != env::var("PDS_ADMIN_PASS").unwrap() {
Outcome::Error((
Status::BadRequest,
AuthError::AuthRequired("BadAuth".to_string()),
))
let error = AuthError::AuthRequired("BadAuth".to_string());
req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
Outcome::Error((Status::BadRequest, error))
} else {
Outcome::Success(AdminToken {
access: AccessOutput {
Expand Down Expand Up @@ -621,15 +657,21 @@ impl<'r> FromRequest<'r> for OptionalAccessOrAdminToken {
Outcome::Success(output) => Outcome::Success(OptionalAccessOrAdminToken {
access: Some(output.access),
}),
Outcome::Error(err) => Outcome::Error(err),
Outcome::Error(err) => {
req.local_cache(|| Some(ApiError::InvalidRequest(err.1.to_string())));
Outcome::Error(err)
}
_ => panic!("Unexpected outcome during OptionalAccessOrAdminToken"),
}
} else if is_basic_token(req) {
match AdminToken::from_request(req).await {
Outcome::Success(output) => Outcome::Success(OptionalAccessOrAdminToken {
access: Some(output.access),
}),
Outcome::Error(err) => Outcome::Error(err),
Outcome::Error(err) => {
req.local_cache(|| Some(ApiError::InvalidRequest(err.1.to_string())));
Outcome::Error(err)
}
_ => panic!("Unexpected outcome during OptionalAccessOrAdminToken"),
}
} else {
Expand Down
14 changes: 8 additions & 6 deletions rsky-pds/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use rocket::figment::{
};
use rocket::http::Header;
use rocket::http::Status;
use rocket::request::local_cache;
use rocket::response::status;
use rocket::serde::json::Json;
use rocket::shield::{NoSniff, Shield};
Expand Down Expand Up @@ -100,13 +101,14 @@ async fn health(
}
}

#[tracing::instrument(skip_all)]
#[catch(default)]
async fn default_catcher() -> Json<rsky_pds::models::ErrorMessageResponse> {
let internal_error = rsky_pds::models::ErrorMessageResponse {
code: Some(rsky_pds::models::ErrorCode::InternalServerError),
message: Some("Internal error.".to_string()),
};
Json(internal_error)
async fn default_catcher(_status: Status, request: &Request<'_>) -> ApiError {
let api_error: &Option<ApiError> = request.local_cache(|| None);
match api_error {
None => ApiError::RuntimeError,
Some(error) => error.clone(),
}
}

/// Catches all OPTION requests in order to get the CORS related Fairing triggered.
Expand Down
Loading