Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ spacetimedb-lib = { workspace = true, features = ["serde"] }

anyhow.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_with.workspace = true
jsonwebtoken.workspace = true

Expand Down
20 changes: 19 additions & 1 deletion crates/auth/src/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,27 @@ use serde::{Deserialize, Serialize};
use spacetimedb_lib::Identity;
use std::time::SystemTime;

#[derive(Debug, Clone)]
pub struct ConnectionAuthCtx {
pub claims: SpacetimeIdentityClaims,
pub jwt_payload: String,
}

impl TryFrom<SpacetimeIdentityClaims> for ConnectionAuthCtx {
type Error = anyhow::Error;
fn try_from(claims: SpacetimeIdentityClaims) -> Result<Self, Self::Error> {
let payload =
serde_json::to_string(&claims).map_err(|e| anyhow::anyhow!("Failed to serialize claims: {}", e))?;
Ok(ConnectionAuthCtx {
claims,
jwt_payload: payload,
})
}
}

// These are the claims that can be attached to a request/connection.
#[serde_with::serde_as]
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SpacetimeIdentityClaims {
#[serde(rename = "hex_identity")]
pub identity: Identity,
Expand Down
1 change: 1 addition & 0 deletions crates/client-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ spacetimedb-lib = { workspace = true, features = ["serde"] }
spacetimedb-paths.workspace = true
spacetimedb-schema.workspace = true

base64.workspace = true
tokio = { version = "1.2", features = ["full"] }
lazy_static = "1.4.0"
log = "0.4.4"
Expand Down
122 changes: 95 additions & 27 deletions crates/client-api/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
use std::time::{Duration, SystemTime};

use anyhow::anyhow;
use axum::extract::{Query, Request, State};
use axum::middleware::Next;
use axum::response::IntoResponse;
use axum_extra::typed_header::TypedHeader;
use headers::{authorization, HeaderMapExt};
use http::{request, HeaderValue, StatusCode};
use serde::{Deserialize, Serialize};
use spacetimedb::auth::identity::SpacetimeIdentityClaims;
use spacetimedb::auth::identity::{ConnectionAuthCtx, SpacetimeIdentityClaims};
use spacetimedb::auth::identity::{JwtError, JwtErrorKind};
use spacetimedb::auth::token_validation::{
new_validator, DefaultValidator, TokenSigner, TokenValidationError, TokenValidator,
};
use spacetimedb::auth::JwtKeys;
use spacetimedb::energy::EnergyQuanta;
use spacetimedb::identity::Identity;
use std::time::{Duration, SystemTime};
use uuid::Uuid;

use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
use base64::{engine::general_purpose, Engine};

/// Credentials for login for a spacetime identity, represented as a JWT.
///
Expand All @@ -41,6 +42,19 @@ impl SpacetimeCreds {
Self { token }
}

fn extract_jwt_payload_string(&self) -> Option<String> {
let parts: Vec<&str> = self.token.split('.').collect();
if parts.len() != 3 {
return None;
}

let payload_encoded = parts[1];
let decoded_bytes = general_purpose::URL_SAFE_NO_PAD.decode(payload_encoded).ok()?;
let json_str = String::from_utf8(decoded_bytes).ok()?;

Some(json_str)
}

pub fn to_header_value(&self) -> HeaderValue {
let mut val = HeaderValue::try_from(["Bearer ", self.token()].concat()).unwrap();
val.set_sensitive(true);
Expand Down Expand Up @@ -70,9 +84,18 @@ impl SpacetimeCreds {
#[derive(Clone)]
pub struct SpacetimeAuth {
pub creds: SpacetimeCreds,
pub identity: Identity,
pub subject: String,
pub issuer: String,
pub claims: SpacetimeIdentityClaims,
// The decoded JWT payload.
pub raw_payload: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// The decoded JWT payload.
/// The decoded JWT payload.

Use /// for doc comments.

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dissonance between comment saying "decoded" and field name saying "raw" is weird. I assume this is decoded from base64 into a string which will contain a JSON object, but is raw in the sense that the JSON hasn't been parsed.


impl From<SpacetimeAuth> for ConnectionAuthCtx {
fn from(auth: SpacetimeAuth) -> Self {
ConnectionAuthCtx {
claims: auth.claims,
jwt_payload: auth.raw_payload.clone(),
}
}
}

use jsonwebtoken;
Expand All @@ -84,10 +107,10 @@ pub struct TokenClaims {
}

impl From<SpacetimeAuth> for TokenClaims {
fn from(claims: SpacetimeAuth) -> Self {
fn from(auth: SpacetimeAuth) -> Self {
Self {
issuer: claims.issuer,
subject: claims.subject,
issuer: auth.claims.issuer,
subject: auth.claims.subject,
// This will need to be changed when we care about audiencies.
audience: Vec::new(),
}
Expand All @@ -112,7 +135,7 @@ impl TokenClaims {
&self,
signer: &impl TokenSigner,
expiry: Option<Duration>,
) -> Result<String, JwtError> {
) -> Result<(SpacetimeIdentityClaims, String), JwtError> {
let iat = SystemTime::now();
let exp = expiry.map(|dur| iat + dur);
let claims = SpacetimeIdentityClaims {
Expand All @@ -123,10 +146,11 @@ impl TokenClaims {
iat,
exp,
};
signer.sign(&claims)
let token = signer.sign(&claims)?;
Ok((claims, token))
}

pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result<String, JwtError> {
pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result<(SpacetimeIdentityClaims, String), JwtError> {
self.encode_and_sign_with_expiry(signer, None)
}
}
Expand All @@ -143,32 +167,36 @@ impl SpacetimeAuth {
audience: vec!["spacetimedb".to_string()],
};

let identity = claims.id();
let creds = {
let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?;
SpacetimeCreds::from_signed_token(token)
};
let (claims, token) = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?;
let creds = SpacetimeCreds::from_signed_token(token);
// Pulling out the payload should never fail, since we just made it.
let payload = creds
.extract_jwt_payload_string()
.ok_or_else(|| log_and_500("internal error"))?;

Ok(Self {
creds,
identity,
subject,
issuer: ctx.jwt_auth_provider().local_issuer().to_string(),
claims,
raw_payload: payload,
})
}

/// Get the auth credentials as headers to be returned from an endpoint.
pub fn into_headers(self) -> (TypedHeader<SpacetimeIdentity>, TypedHeader<SpacetimeIdentityToken>) {
(
TypedHeader(SpacetimeIdentity(self.identity)),
TypedHeader(SpacetimeIdentity(self.claims.identity)),
TypedHeader(SpacetimeIdentityToken(self.creds)),
)
}

// Sign a new token with the same claims and a new expiry.
// Note that this will not change the issuer, so the private_key might not match.
// We do this to create short-lived tokens that we will be able to verify.
pub fn re_sign_with_expiry(&self, signer: &impl TokenSigner, expiry: Duration) -> Result<String, JwtError> {
pub fn re_sign_with_expiry(
&self,
signer: &impl TokenSigner,
expiry: Duration,
) -> Result<(SpacetimeIdentityClaims, String), JwtError> {
TokenClaims::from(self.clone()).encode_and_sign_with_expiry(signer, Some(expiry))
}
}
Expand Down Expand Up @@ -237,9 +265,11 @@ impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV

#[cfg(test)]
mod tests {
use crate::auth::TokenClaims;
use crate::auth::{SpacetimeCreds, TokenClaims};
use anyhow::Ok;

use spacetimedb::auth::{token_validation::TokenValidator, JwtKeys};
use std::collections::HashSet;

// Make sure that when we encode TokenClaims, we can decode to get the expected identity.
#[tokio::test]
Expand All @@ -252,12 +282,48 @@ mod tests {
audience: vec!["spacetimedb".to_string()],
};
let id = claims.id();
let token = claims.encode_and_sign(&kp.private)?;
let (_, token) = claims.encode_and_sign(&kp.private)?;
let decoded = kp.public.validate_token(&token).await?;

assert_eq!(decoded.identity, id);
Ok(())
}

// Test that extracting a JWT payload from a valid token gets the json representation.
#[tokio::test]
async fn extract_payload() -> Result<(), anyhow::Error> {
let kp = JwtKeys::generate()?;

let dummy_audience = "spacetimedb".to_string();
let claims = TokenClaims {
issuer: "localhost".to_string(),
subject: "test-subject".to_string(),
audience: vec![dummy_audience.clone()],
};
let (_, token) = claims.encode_and_sign(&kp.private)?;
let st_creds = SpacetimeCreds::from_signed_token(token);
let payload = st_creds
.extract_jwt_payload_string()
.ok_or_else(|| anyhow::anyhow!("Failed to extract JWT payload"))?;
// Make sure it is valid json.
let parsed: serde_json::Value = serde_json::from_str(&payload)?;
assert_eq!(parsed.get("iss").unwrap().as_str().unwrap(), claims.issuer);
assert_eq!(parsed.get("sub").unwrap().as_str().unwrap(), claims.subject);
assert_eq!(
parsed.get("aud").unwrap().as_array().unwrap()[0].as_str().unwrap(),
dummy_audience
);
let as_object = parsed
.as_object()
.ok_or_else(|| anyhow::anyhow!("Failed to parse JWT payload as object"))?;
let keys: HashSet<String> = as_object.keys().map(|s| s.to_string()).collect();
let expected_keys = vec!["iss", "sub", "aud", "iat", "exp", "hex_identity"]
.into_iter()
.map(|s| s.to_string())
.collect::<HashSet<String>>();
assert_eq!(keys, expected_keys);
Ok(())
}
}

pub struct SpacetimeAuthHeader {
Expand All @@ -279,11 +345,13 @@ impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for Space
.await
.map_err(AuthorizationRejection::Custom)?;

let payload = creds.extract_jwt_payload_string().ok_or_else(|| {
AuthorizationRejection::Custom(TokenValidationError::Other(anyhow!("Internal error parsing token")))
})?;
let auth = SpacetimeAuth {
creds,
identity: claims.identity,
subject: claims.subject,
issuer: claims.issuer,
claims,
raw_payload: payload,
};
Ok(Self { auth: Some(auth) })
}
Expand Down
Loading
Loading