diff --git a/Cargo.lock b/Cargo.lock index fb75d8d9438..0913492abe2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5424,6 +5424,7 @@ dependencies = [ "async-trait", "axum", "axum-extra", + "base64 0.21.7", "bytes", "bytestring", "chrono", diff --git a/crates/auth/Cargo.toml b/crates/auth/Cargo.toml index a5592c08f5f..f3db2f86af7 100644 --- a/crates/auth/Cargo.toml +++ b/crates/auth/Cargo.toml @@ -11,6 +11,7 @@ spacetimedb-lib = { workspace = true, features = ["serde"] } anyhow.workspace = true serde.workspace = true +serde_json.workspace = true serde_with.workspace = true jsonwebtoken.workspace = true diff --git a/crates/auth/src/identity.rs b/crates/auth/src/identity.rs index 286d28582dd..576e850d7ee 100644 --- a/crates/auth/src/identity.rs +++ b/crates/auth/src/identity.rs @@ -6,9 +6,27 @@ use serde::{Deserialize, Serialize}; use spacetimedb_lib::Identity; use std::time::SystemTime; +#[derive(Debug, Clone)] +pub struct ConnectionAuthCtx { + pub claims: SpacetimeIdentityClaims, + pub jwt_payload: String, +} + +impl TryFrom for ConnectionAuthCtx { + type Error = anyhow::Error; + fn try_from(claims: SpacetimeIdentityClaims) -> Result { + let payload = + serde_json::to_string(&claims).map_err(|e| anyhow::anyhow!("Failed to serialize claims: {}", e))?; + Ok(ConnectionAuthCtx { + claims, + jwt_payload: payload, + }) + } +} + // These are the claims that can be attached to a request/connection. #[serde_with::serde_as] -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct SpacetimeIdentityClaims { #[serde(rename = "hex_identity")] pub identity: Identity, diff --git a/crates/client-api/Cargo.toml b/crates/client-api/Cargo.toml index f19b1495dd8..13c6bb816da 100644 --- a/crates/client-api/Cargo.toml +++ b/crates/client-api/Cargo.toml @@ -15,6 +15,7 @@ spacetimedb-lib = { workspace = true, features = ["serde"] } spacetimedb-paths.workspace = true spacetimedb-schema.workspace = true +base64.workspace = true tokio = { version = "1.2", features = ["full"] } lazy_static = "1.4.0" log = "0.4.4" diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index 61031625867..2f0ca67ef75 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -1,5 +1,4 @@ -use std::time::{Duration, SystemTime}; - +use anyhow::anyhow; use axum::extract::{Query, Request, State}; use axum::middleware::Next; use axum::response::IntoResponse; @@ -7,7 +6,7 @@ use axum_extra::typed_header::TypedHeader; use headers::{authorization, HeaderMapExt}; use http::{request, HeaderValue, StatusCode}; use serde::{Deserialize, Serialize}; -use spacetimedb::auth::identity::SpacetimeIdentityClaims; +use spacetimedb::auth::identity::{ConnectionAuthCtx, SpacetimeIdentityClaims}; use spacetimedb::auth::identity::{JwtError, JwtErrorKind}; use spacetimedb::auth::token_validation::{ new_validator, DefaultValidator, TokenSigner, TokenValidationError, TokenValidator, @@ -15,9 +14,11 @@ use spacetimedb::auth::token_validation::{ use spacetimedb::auth::JwtKeys; use spacetimedb::energy::EnergyQuanta; use spacetimedb::identity::Identity; +use std::time::{Duration, SystemTime}; use uuid::Uuid; use crate::{log_and_500, ControlStateDelegate, NodeDelegate}; +use base64::{engine::general_purpose, Engine}; /// Credentials for login for a spacetime identity, represented as a JWT. /// @@ -41,6 +42,19 @@ impl SpacetimeCreds { Self { token } } + fn extract_jwt_payload_string(&self) -> Option { + let parts: Vec<&str> = self.token.split('.').collect(); + if parts.len() != 3 { + return None; + } + + let payload_encoded = parts[1]; + let decoded_bytes = general_purpose::URL_SAFE_NO_PAD.decode(payload_encoded).ok()?; + let json_str = String::from_utf8(decoded_bytes).ok()?; + + Some(json_str) + } + pub fn to_header_value(&self) -> HeaderValue { let mut val = HeaderValue::try_from(["Bearer ", self.token()].concat()).unwrap(); val.set_sensitive(true); @@ -70,9 +84,18 @@ impl SpacetimeCreds { #[derive(Clone)] pub struct SpacetimeAuth { pub creds: SpacetimeCreds, - pub identity: Identity, - pub subject: String, - pub issuer: String, + pub claims: SpacetimeIdentityClaims, + /// The JWT payload as a json string (after base64 decoding). + pub jwt_payload: String, +} + +impl From for ConnectionAuthCtx { + fn from(auth: SpacetimeAuth) -> Self { + ConnectionAuthCtx { + claims: auth.claims, + jwt_payload: auth.jwt_payload.clone(), + } + } } use jsonwebtoken; @@ -84,10 +107,10 @@ pub struct TokenClaims { } impl From for TokenClaims { - fn from(claims: SpacetimeAuth) -> Self { + fn from(auth: SpacetimeAuth) -> Self { Self { - issuer: claims.issuer, - subject: claims.subject, + issuer: auth.claims.issuer, + subject: auth.claims.subject, // This will need to be changed when we care about audiencies. audience: Vec::new(), } @@ -108,11 +131,14 @@ impl TokenClaims { Identity::from_claims(&self.issuer, &self.subject) } + /// Encode the claims into a JWT token and sign it with the provided signer. + /// This also adds claims for expiry and issued at time. + /// Returns an object representing the claims and the signed token. pub fn encode_and_sign_with_expiry( &self, signer: &impl TokenSigner, expiry: Option, - ) -> Result { + ) -> Result<(SpacetimeIdentityClaims, String), JwtError> { let iat = SystemTime::now(); let exp = expiry.map(|dur| iat + dur); let claims = SpacetimeIdentityClaims { @@ -123,10 +149,14 @@ impl TokenClaims { iat, exp, }; - signer.sign(&claims) + let token = signer.sign(&claims)?; + Ok((claims, token)) } - pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result { + /// Encode the claims into a JWT token and sign it with the provided signer. + /// This also adds a claim for issued at time. + /// Returns an object representing the claims and the signed token. + pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result<(SpacetimeIdentityClaims, String), JwtError> { self.encode_and_sign_with_expiry(signer, None) } } @@ -143,24 +173,24 @@ impl SpacetimeAuth { audience: vec!["spacetimedb".to_string()], }; - let identity = claims.id(); - let creds = { - let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?; - SpacetimeCreds::from_signed_token(token) - }; + let (claims, token) = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?; + let creds = SpacetimeCreds::from_signed_token(token); + // Pulling out the payload should never fail, since we just made it. + let payload = creds + .extract_jwt_payload_string() + .ok_or_else(|| log_and_500("internal error"))?; Ok(Self { creds, - identity, - subject, - issuer: ctx.jwt_auth_provider().local_issuer().to_string(), + claims, + jwt_payload: payload, }) } /// Get the auth credentials as headers to be returned from an endpoint. pub fn into_headers(self) -> (TypedHeader, TypedHeader) { ( - TypedHeader(SpacetimeIdentity(self.identity)), + TypedHeader(SpacetimeIdentity(self.claims.identity)), TypedHeader(SpacetimeIdentityToken(self.creds)), ) } @@ -168,7 +198,11 @@ impl SpacetimeAuth { // Sign a new token with the same claims and a new expiry. // Note that this will not change the issuer, so the private_key might not match. // We do this to create short-lived tokens that we will be able to verify. - pub fn re_sign_with_expiry(&self, signer: &impl TokenSigner, expiry: Duration) -> Result { + pub fn re_sign_with_expiry( + &self, + signer: &impl TokenSigner, + expiry: Duration, + ) -> Result<(SpacetimeIdentityClaims, String), JwtError> { TokenClaims::from(self.clone()).encode_and_sign_with_expiry(signer, Some(expiry)) } } @@ -237,9 +271,11 @@ impl JwtAuthProvider for JwtKeyAuthProvider Result<(), anyhow::Error> { + let kp = JwtKeys::generate()?; + + let dummy_audience = "spacetimedb".to_string(); + let claims = TokenClaims { + issuer: "localhost".to_string(), + subject: "test-subject".to_string(), + audience: vec![dummy_audience.clone()], + }; + let (_, token) = claims.encode_and_sign(&kp.private)?; + let st_creds = SpacetimeCreds::from_signed_token(token); + let payload = st_creds + .extract_jwt_payload_string() + .ok_or_else(|| anyhow::anyhow!("Failed to extract JWT payload"))?; + // Make sure it is valid json. + let parsed: serde_json::Value = serde_json::from_str(&payload)?; + assert_eq!(parsed.get("iss").unwrap().as_str().unwrap(), claims.issuer); + assert_eq!(parsed.get("sub").unwrap().as_str().unwrap(), claims.subject); + assert_eq!( + parsed.get("aud").unwrap().as_array().unwrap()[0].as_str().unwrap(), + dummy_audience + ); + let as_object = parsed + .as_object() + .ok_or_else(|| anyhow::anyhow!("Failed to parse JWT payload as object"))?; + let keys: HashSet = as_object.keys().map(|s| s.to_string()).collect(); + let expected_keys = vec!["iss", "sub", "aud", "iat", "exp", "hex_identity"] + .into_iter() + .map(|s| s.to_string()) + .collect::>(); + assert_eq!(keys, expected_keys); + Ok(()) + } } pub struct SpacetimeAuthHeader { @@ -279,11 +351,13 @@ impl axum::extract::FromRequestParts for Space .await .map_err(AuthorizationRejection::Custom)?; + let payload = creds.extract_jwt_payload_string().ok_or_else(|| { + AuthorizationRejection::Custom(TokenValidationError::Other(anyhow!("Internal error parsing token"))) + })?; let auth = SpacetimeAuth { creds, - identity: claims.identity, - subject: claims.subject, - issuer: claims.issuer, + claims, + jwt_payload: payload, }; Ok(Self { auth: Some(auth) }) } diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index 91a67f9c406..ff90b1e165c 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -54,7 +54,7 @@ pub async fn call( if content_type != headers::ContentType::json() { return Err(axum::extract::rejection::MissingJsonContentType::default().into()); } - let caller_identity = auth.identity; + let caller_identity = auth.claims.identity; let args = ReducerArgs::Json(body); @@ -78,7 +78,7 @@ pub async fn call( // so generate one. let connection_id = generate_random_connection_id(); - match module.call_identity_connected(caller_identity, connection_id).await { + match module.call_identity_connected(auth.into(), connection_id).await { // If `call_identity_connected` returns `Err(Rejected)`, then the `client_connected` reducer errored, // meaning the connection was refused. Return 403 forbidden. Err(ClientConnectedError::Rejected(msg)) => return Err((StatusCode::FORBIDDEN, msg).into()), @@ -225,7 +225,7 @@ where }; Ok(( - TypedHeader(SpacetimeIdentity(auth.identity)), + TypedHeader(SpacetimeIdentity(auth.claims.identity)), TypedHeader(SpacetimeIdentityToken(auth.creds)), response_json, )) @@ -300,13 +300,13 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - if database.owner_identity != auth.identity { + if database.owner_identity != auth.claims.identity { return Err(( StatusCode::BAD_REQUEST, format!( "Identity does not own database, expected: {} got: {}", database.owner_identity.to_hex(), - auth.identity.to_hex() + auth.claims.identity.to_hex() ), ) .into()); @@ -402,7 +402,7 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - let auth = AuthCtx::new(database.owner_identity, auth.identity); + let auth = AuthCtx::new(database.owner_identity, auth.claims.identity); log::debug!("auth: {auth:?}"); let host = worker_ctx @@ -481,10 +481,13 @@ fn allow_creation(auth: &SpacetimeAuth) -> Result<(), ErrorResponse> { if !require_spacetime_auth_for_creation() { return Ok(()); } - if auth.issuer.trim_end_matches('/') == "https://auth.spacetimedb.com" { + if auth.claims.issuer.trim_end_matches('/') == "https://auth.spacetimedb.com" { Ok(()) } else { - log::trace!("Rejecting creation request because auth issuer is {}", auth.issuer); + log::trace!( + "Rejecting creation request because auth issuer is {}", + auth.claims.issuer + ); Err(( StatusCode::UNAUTHORIZED, "To create a database, you must be logged in with a SpacetimeDB account.", @@ -511,9 +514,13 @@ pub async fn publish( // exists yet. Create it now with a fresh identity. allow_creation(&auth)?; let database_auth = SpacetimeAuth::alloc(&ctx).await?; - let database_identity = database_auth.identity; + let database_identity = database_auth.claims.identity; let tld: name::Tld = name.clone().into(); - let tld = match ctx.register_tld(&auth.identity, tld).await.map_err(log_and_500)? { + let tld = match ctx + .register_tld(&auth.claims.identity, tld) + .await + .map_err(log_and_500)? + { name::RegisterTldResult::Success { domain } | name::RegisterTldResult::AlreadyRegistered { domain } => domain, name::RegisterTldResult::Unauthorized { .. } => { @@ -525,7 +532,7 @@ pub async fn publish( } }; let res = ctx - .create_dns_record(&auth.identity, &tld.into(), &database_identity) + .create_dns_record(&auth.claims.identity, &tld.into(), &database_identity) .await .map_err(log_and_500)?; match res { @@ -541,7 +548,7 @@ pub async fn publish( }, None => { let database_auth = SpacetimeAuth::alloc(&ctx).await?; - let database_identity = database_auth.identity; + let database_identity = database_auth.claims.identity; (database_identity, None) } }; @@ -558,7 +565,7 @@ pub async fn publish( } if clear && exists { - ctx.delete_database(&auth.identity, &database_identity) + ctx.delete_database(&auth.claims.identity, &database_identity) .await .map_err(log_and_500)?; } @@ -580,7 +587,7 @@ pub async fn publish( let maybe_updated = ctx .publish_database( - &auth.identity, + &auth.claims.identity, DatabaseDef { database_identity, program_bytes: body.into(), @@ -626,7 +633,7 @@ pub async fn delete_database( ) -> axum::response::Result { let database_identity = name_or_identity.resolve(&ctx).await?; - ctx.delete_database(&auth.identity, &database_identity) + ctx.delete_database(&auth.claims.identity, &database_identity) .await .map_err(log_and_500)?; @@ -648,7 +655,7 @@ pub async fn add_name( let database_identity = name_or_identity.resolve(&ctx).await?; let response = ctx - .create_dns_record(&auth.identity, &name.into(), &database_identity) + .create_dns_record(&auth.claims.identity, &name.into(), &database_identity) .await // TODO: better error code handling .map_err(log_and_500)?; @@ -691,7 +698,7 @@ pub async fn set_names( )); }; - if database.owner_identity != auth.identity { + if database.owner_identity != auth.claims.identity { return Ok(( StatusCode::UNAUTHORIZED, axum::Json(name::SetDomainsResult::NotYourDatabase { diff --git a/crates/client-api/src/routes/energy.rs b/crates/client-api/src/routes/energy.rs index 16e66963cf6..34098987c68 100644 --- a/crates/client-api/src/routes/energy.rs +++ b/crates/client-api/src/routes/energy.rs @@ -49,14 +49,14 @@ pub async fn add_energy( })?; if let Some(satoshi) = amount { - ctx.add_energy(&auth.identity, EnergyQuanta::new(satoshi)) + ctx.add_energy(&auth.claims.identity, EnergyQuanta::new(satoshi)) .await .map_err(log_and_500)?; } // TODO: is this guaranteed to pull the updated balance? let balance = ctx - .get_energy_balance(&auth.identity) + .get_energy_balance(&auth.claims.identity) .map_err(log_and_500)? .map_or(0, |quanta| quanta.get()); @@ -87,7 +87,7 @@ pub async fn set_energy_balance( // This will be a natural rate limiter until we can begin to sell energy. // No one is able to be the dummy identity so this always returns unauthorized. - if auth.identity != Identity::__dummy() { + if auth.claims.identity != Identity::__dummy() { return Err(StatusCode::UNAUTHORIZED.into()); } diff --git a/crates/client-api/src/routes/identity.rs b/crates/client-api/src/routes/identity.rs index 69b27661fbe..be9adde55f9 100644 --- a/crates/client-api/src/routes/identity.rs +++ b/crates/client-api/src/routes/identity.rs @@ -24,7 +24,7 @@ pub async fn create_identity( let auth = SpacetimeAuth::alloc(&ctx).await?; let identity_response = CreateIdentityResponse { - identity: auth.identity, + identity: auth.claims.identity, token: auth.creds.token().to_owned(), }; Ok(axum::Json(identity_response)) @@ -103,7 +103,7 @@ pub async fn create_websocket_token( SpacetimeAuthRequired(auth): SpacetimeAuthRequired, ) -> axum::response::Result { let expiry = Duration::from_secs(60); - let token = auth + let (_, token) = auth .re_sign_with_expiry(ctx.jwt_auth_provider(), expiry) .map_err(log_and_500)?; // let token = encode_token_with_expiry(ctx.private_key(), auth.identity, Some(expiry)).map_err(log_and_500)?; @@ -121,7 +121,7 @@ pub async fn validate_token( ) -> axum::response::Result { let identity = Identity::from(identity); - if auth.identity != identity { + if auth.claims.identity != identity { return Err(StatusCode::BAD_REQUEST.into()); } diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index f59f8cd68dd..07d444cd5fd 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -147,8 +147,9 @@ where let mut module_rx = leader.module_watcher().await.map_err(log_and_500)?; + let client_identity = auth.claims.identity; let client_id = ClientActorId { - identity: auth.identity, + identity: client_identity, connection_id, name: ctx.client_actor_index().next_client_name(), }; @@ -178,7 +179,13 @@ where log::debug!("websocket: New client connected from {client_log_string}"); - let connected = match ClientConnection::call_client_connected_maybe_reject(&mut module_rx, client_id).await { + let connected = match ClientConnection::call_client_connected_maybe_reject( + &mut module_rx, + client_id, + auth.clone().into(), + ) + .await + { Ok(connected) => { log::debug!("websocket: client_connected returned Ok for {client_log_string}"); connected @@ -200,8 +207,16 @@ where ); let actor = |client, sendrx| ws_client_actor(ws_opts, client, ws, sendrx); - let client = - ClientConnection::spawn(client_id, client_config, leader.replica_id, module_rx, actor, connected).await; + let client = ClientConnection::spawn( + client_id, + auth.into(), + client_config, + leader.replica_id, + module_rx, + actor, + connected, + ) + .await; // Send the client their identity token message as the first message // NOTE: We're adding this to the protocol because some client libraries are @@ -209,7 +224,7 @@ where // Clients that receive the token from the response headers should ignore this // message. let message = IdentityTokenMessage { - identity: auth.identity, + identity: client_identity, token: identity_token, connection_id, }; diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index 11103824dca..1382b7882db 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -14,7 +14,8 @@ pub use client_connection_index::ClientActorIndex; pub use message_handlers::{MessageExecutionError, MessageHandleError}; use spacetimedb_lib::ConnectionId; -#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] +// #[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] +#[derive(Clone, Debug, Copy)] pub struct ClientActorId { pub identity: Identity, pub connection_id: ConnectionId, diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index cf3e5511fff..f6bcbc43784 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -5,7 +5,7 @@ use std::sync::atomic::Ordering; use std::sync::atomic::{AtomicBool, Ordering::Relaxed}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::Instant; +use std::time::{Instant, SystemTime}; use super::messages::{OneOffQueryResponseMessage, SerializableMessage}; use super::{message_handlers, ClientActorId, MessageHandleError}; @@ -21,6 +21,7 @@ use bytestring::ByteString; use derive_more::From; use futures::prelude::*; use prometheus::{Histogram, IntCounter, IntGauge}; +use spacetimedb_auth::identity::{ConnectionAuthCtx, SpacetimeIdentityClaims}; use spacetimedb_client_api_messages::websocket::{ BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, Unsubscribe, UnsubscribeMulti, @@ -78,6 +79,7 @@ impl ClientConfig { #[derive(Debug)] pub struct ClientConnectionSender { pub id: ClientActorId, + pub auth: ConnectionAuthCtx, pub config: ClientConfig, sendtx: mpsc::Sender, abort_handle: AbortHandle, @@ -146,8 +148,17 @@ impl ClientConnectionSender { let rx = MeteredReceiver::new(rx); let cancelled = AtomicBool::new(false); + let dummy_claims = SpacetimeIdentityClaims { + identity: id.identity, + subject: "".to_string(), + issuer: "".to_string(), + audience: vec![], + iat: SystemTime::now(), + exp: None, + }; let sender = Self { id, + auth: ConnectionAuthCtx::try_from(dummy_claims).expect("dummy claims should always be valid"), config, sendtx, abort_handle, @@ -415,9 +426,10 @@ impl ClientConnection { pub async fn call_client_connected_maybe_reject( module_rx: &mut watch::Receiver, id: ClientActorId, + auth: ConnectionAuthCtx, ) -> Result { let module = module_rx.borrow_and_update().clone(); - module.call_identity_connected(id.identity, id.connection_id).await?; + module.call_identity_connected(auth, id.connection_id).await?; Ok(Connected { _private: () }) } @@ -429,6 +441,7 @@ impl ClientConnection { /// and pass the returned [`Connected`] as `_proof_of_client_connected_call`. pub async fn spawn( id: ClientActorId, + auth: ConnectionAuthCtx, config: ClientConfig, replica_id: u64, mut module_rx: watch::Receiver, @@ -450,6 +463,7 @@ impl ClientConnection { // weird dance so that we can get an abort_handle into ClientConnection let module_info = module.info.clone(); let database_identity = module_info.database_identity; + let client_identity = id.identity; let abort_handle = tokio::spawn(async move { let Ok(fut) = fut_rx.await else { return }; @@ -457,7 +471,6 @@ impl ClientConnection { module_info.metrics.ws_clients_spawned.inc(); scopeguard::defer! { let database_identity = module_info.database_identity; - let client_identity = id.identity; log::warn!("websocket connection aborted for client identity `{client_identity}` and database identity `{database_identity}`"); module_info.metrics.ws_clients_aborted.inc(); }; @@ -471,6 +484,7 @@ impl ClientConnection { let sender = Arc::new(ClientConnectionSender { id, + auth, config, sendtx, abort_handle, diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 4994ef69431..eff3b835f02 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -26,7 +26,7 @@ use spacetimedb_vm::expr::Crud; pub use spacetimedb_datastore::error::{DatastoreError, IndexError, SequenceError, TableError}; -#[derive(Error, Debug, PartialEq, Eq)] +#[derive(Error, Debug)] pub enum ClientError { #[error("Client not found: {0}")] NotFound(ClientActorId), diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 29749f651db..6c52e92c496 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -26,6 +26,8 @@ use derive_more::From; use indexmap::IndexSet; use itertools::Itertools; use prometheus::{Histogram, IntGauge}; +use scopeguard::ScopeGuard; +use spacetimedb_auth::identity::ConnectionAuthCtx; use spacetimedb_client_api_messages::websocket::{ByteListLen, Compression, OneOffTable, QueryUpdate}; use spacetimedb_data_structures::error_stream::ErrorStream; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; @@ -683,12 +685,36 @@ impl ModuleHost { /// In this case, the caller should terminate the connection. pub async fn call_identity_connected( &self, - caller_identity: Identity, + caller_auth: ConnectionAuthCtx, caller_connection_id: ConnectionId, ) -> Result<(), ClientConnectedError> { let me = self.clone(); self.call("call_identity_connected", move |inst| { let reducer_lookup = me.info.module_def.lifecycle_reducer(Lifecycle::OnConnect); + let stdb = &me.module.replica_ctx().relational_db; + let workload = Workload::Reducer(ReducerContext { + name: "call_identity_connected".to_owned(), + caller_identity: caller_auth.claims.identity, + caller_connection_id, + timestamp: Timestamp::now(), + arg_bsatn: Bytes::new(), + }); + let mut_tx = stdb.begin_mut_tx(IsolationLevel::Serializable, workload); + let mut mut_tx = scopeguard::guard(mut_tx, |mut_tx| { + // If we crash before committing, we need to ensure that the transaction is rolled back. + // This is necessary to avoid leaving the database in an inconsistent state. + log::debug!("call_identity_connected: rolling back transaction"); + let (metrics, reducer_name) = mut_tx.rollback(); + stdb.report_mut_tx_metrics(reducer_name, metrics, None); + }); + + mut_tx + .insert_st_client( + caller_auth.claims.identity, + caller_connection_id, + &caller_auth.jwt_payload, + ) + .map_err(DBError::from)?; if let Some((reducer_id, reducer_def)) = reducer_lookup { // The module defined a lifecycle reducer to handle new connections. @@ -696,7 +722,8 @@ impl ModuleHost { // If the call fails (as in, something unexpectedly goes wrong with WASM execution), // abort the connection: we can't really recover. let reducer_outcome = me.call_reducer_inner_with_inst( - caller_identity, + Some(ScopeGuard::into_inner(mut_tx)), + caller_auth.claims.identity, Some(caller_connection_id), None, None, @@ -727,35 +754,19 @@ impl ModuleHost { } } else { // The module doesn't define a client_connected reducer. - // Commit a transaction to update `st_clients` - // and to ensure we always have those events paired in the commitlog. + // We need to commit the transaction to update st_clients and st_connection_credentials. // // This is necessary to be able to disconnect clients after a server crash. - let reducer_name = reducer_lookup - .as_ref() - .map(|(_, def)| &*def.name) - .unwrap_or("__identity_connected__"); - let workload = Workload::Reducer(ReducerContext { - name: reducer_name.to_owned(), - caller_identity, - caller_connection_id, - timestamp: Timestamp::now(), - arg_bsatn: Bytes::new(), - }); - - let stdb = me.module.replica_ctx().relational_db.clone(); - stdb.with_auto_commit(workload, |mut_tx| { - mut_tx - .insert_st_client(caller_identity, caller_connection_id) - .map_err(DBError::from) - }) - .inspect_err(|e| { - log::error!( - "`call_identity_connected`: fallback transaction to insert into `st_client` failed: {e:#?}" - ) - }) - .map_err(Into::into) + // TODO: Is this being broadcast? Does it need to be, or are st_client table subscriptions + // not allowed? + // I don't think it was being broadcast previously. + stdb.finish_tx(ScopeGuard::into_inner(mut_tx), Ok(())) + .map_err(|e: DBError| { + log::error!("`call_identity_connected`: finish transaction failed: {e:#?}"); + ClientConnectedError::DBError(e) + })?; + Ok(()) } }) .await @@ -809,6 +820,7 @@ impl ModuleHost { // If it succeeds, `WasmModuleInstance::call_reducer_with_tx` has already ensured // that `st_client` is updated appropriately. let result = me.call_reducer_inner_with_inst( + None, caller_identity, Some(caller_connection_id), None, @@ -913,6 +925,7 @@ impl ModuleHost { } fn call_reducer_inner_with_inst( &self, + tx: Option, caller_identity: Identity, caller_connection_id: Option, client: Option>, @@ -928,7 +941,7 @@ impl ModuleHost { let caller_connection_id = caller_connection_id.unwrap_or(ConnectionId::ZERO); Ok(module_instance.call_reducer( - None, + tx, CallReducerParams { timestamp: Timestamp::now(), caller_identity, diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index c337140ac94..0b66ca197f1 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -457,7 +457,7 @@ impl WasmModuleInstance { // and conversely removing from `st_clients` on disconnect. Ok(Ok(())) => { let res = match reducer_def.lifecycle { - Some(Lifecycle::OnConnect) => tx.insert_st_client(caller_identity, caller_connection_id), + Some(Lifecycle::OnConnect) => Ok(()), Some(Lifecycle::OnDisconnect) => { tx.delete_st_client(caller_identity, caller_connection_id, database_identity) } diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 48c2af5fb95..65211305898 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -876,37 +876,16 @@ impl ModuleSubscriptions { return Ok(Err(WriteConflict)); }; *db_update = DatabaseUpdate::from_writes(&tx_data); - (read_tx, Some(tx_data), tx_metrics) + (read_tx, Arc::new(tx_data), tx_metrics) } EventStatus::Failed(_) | EventStatus::OutOfEnergy => { - let (tx_metrics, tx) = stdb.rollback_mut_tx_downgrade(tx, Workload::Update); - (tx, None, tx_metrics) - } - }; - - let tx_data = tx_data.map(Arc::new); - - // When we're done with this method, release the tx and report metrics. - let mut read_tx = scopeguard::guard(read_tx, |tx| { - let (tx_metrics_read, reducer) = self.relational_db.release_tx(tx); - self.relational_db - .report_tx_metrics(reducer, tx_data.clone(), Some(tx_metrics_mut), Some(tx_metrics_read)); - }); - // Create the delta transaction we'll use to eval updates against. - let delta_read_tx = tx_data - .as_ref() - .as_ref() - .map(|tx_data| DeltaTx::new(&read_tx, tx_data, subscriptions.index_ids_for_subscriptions())) - .unwrap_or_else(|| DeltaTx::from(&*read_tx)); + // If the transaction failed, we need to rollback the mutable tx. + // We don't need to do any subscription updates in this case, so we will exit early. - let event = Arc::new(event); - let mut update_metrics: ExecutionMetrics = ExecutionMetrics::default(); - - match &event.status { - EventStatus::Committed(_) => { - update_metrics = subscriptions.eval_updates_sequential(&delta_read_tx, event.clone(), caller); - } - EventStatus::Failed(_) => { + let event = Arc::new(event); + let (tx_metrics, reducer) = stdb.rollback_mut_tx(tx); + self.relational_db + .report_tx_metrics(reducer, None, Some(tx_metrics), None); if let Some(client) = caller { let message = TransactionUpdateMessage { event: Some(event.clone()), @@ -917,9 +896,25 @@ impl ModuleSubscriptions { } else { log::trace!("Reducer failed but there is no client to send the failure to!") } + return Ok(Ok((event, ExecutionMetrics::default()))); } - EventStatus::OutOfEnergy => {} // ? - } + }; + let event = Arc::new(event); + + // When we're done with this method, release the tx and report metrics. + let mut read_tx = scopeguard::guard(read_tx, |tx| { + let (tx_metrics_read, reducer) = self.relational_db.release_tx(tx); + self.relational_db.report_tx_metrics( + reducer, + Some(tx_data.clone()), + Some(tx_metrics_mut), + Some(tx_metrics_read), + ); + }); + // Create the delta transaction we'll use to eval updates against. + let delta_read_tx = DeltaTx::new(&read_tx, tx_data.as_ref(), subscriptions.index_ids_for_subscriptions()); + + let update_metrics = subscriptions.eval_updates_sequential(&delta_read_tx, event.clone(), caller); // Merge in the subscription evaluation metrics. read_tx.metrics.merge(update_metrics); diff --git a/crates/datastore/src/locking_tx_datastore/committed_state.rs b/crates/datastore/src/locking_tx_datastore/committed_state.rs index a074d43c124..c05cd558751 100644 --- a/crates/datastore/src/locking_tx_datastore/committed_state.rs +++ b/crates/datastore/src/locking_tx_datastore/committed_state.rs @@ -6,6 +6,7 @@ use super::{ tx_state::{IndexIdMap, PendingSchemaChange, TxState}, IterByColEqTx, }; +use crate::system_tables::{ST_CONNECTION_CREDENTIALS_ID, ST_CONNECTION_CREDENTIALS_IDX}; use crate::{ db_metrics::DB_METRICS, error::{IndexError, TableError}, @@ -24,10 +25,7 @@ use anyhow::anyhow; use core::{convert::Infallible, ops::RangeBounds}; use itertools::Itertools; use spacetimedb_data_structures::map::{HashSet, IntMap}; -use spacetimedb_lib::{ - db::auth::{StAccess, StTableType}, - Identity, -}; +use spacetimedb_lib::{db::auth::StTableType, Identity}; use spacetimedb_primitives::{ColList, ColSet, IndexId, TableId}; use spacetimedb_sats::memory_usage::MemoryUsage; use spacetimedb_sats::{AlgebraicValue, ProductValue}; @@ -183,7 +181,7 @@ impl CommittedState { table_id, table_name: schema.table_name.clone(), table_type: StTableType::System, - table_access: StAccess::Public, + table_access: schema.table_access, table_primary_key: schema.primary_key.map(Into::into), }; let row = ProductValue::from(row); @@ -272,6 +270,10 @@ impl CommittedState { self.create_table(ST_SCHEDULED_ID, schemas[ST_SCHEDULED_IDX].clone()); self.create_table(ST_ROW_LEVEL_SECURITY_ID, schemas[ST_ROW_LEVEL_SECURITY_IDX].clone()); + self.create_table( + ST_CONNECTION_CREDENTIALS_ID, + schemas[ST_CONNECTION_CREDENTIALS_IDX].clone(), + ); // IMPORTANT: It is crucial that the `st_sequences` table is created last diff --git a/crates/datastore/src/locking_tx_datastore/datastore.rs b/crates/datastore/src/locking_tx_datastore/datastore.rs index d9b2ac32fe2..392ec7bad22 100644 --- a/crates/datastore/src/locking_tx_datastore/datastore.rs +++ b/crates/datastore/src/locking_tx_datastore/datastore.rs @@ -1147,9 +1147,10 @@ mod tests { use crate::error::IndexError; use crate::locking_tx_datastore::tx_state::PendingSchemaChange; use crate::system_tables::{ - system_tables, StColumnRow, StConstraintData, StConstraintFields, StConstraintRow, StIndexAlgorithm, - StIndexFields, StIndexRow, StRowLevelSecurityFields, StScheduledFields, StSequenceFields, StSequenceRow, - StTableRow, StVarFields, ST_CLIENT_NAME, ST_COLUMN_ID, ST_COLUMN_NAME, ST_CONSTRAINT_ID, ST_CONSTRAINT_NAME, + system_tables, StColumnRow, StConnectionCredentialsFields, StConstraintData, StConstraintFields, + StConstraintRow, StIndexAlgorithm, StIndexFields, StIndexRow, StRowLevelSecurityFields, StScheduledFields, + StSequenceFields, StSequenceRow, StTableRow, StVarFields, ST_CLIENT_NAME, ST_COLUMN_ID, ST_COLUMN_NAME, + ST_CONNECTION_CREDENTIALS_ID, ST_CONNECTION_CREDENTIALS_NAME, ST_CONSTRAINT_ID, ST_CONSTRAINT_NAME, ST_INDEX_ID, ST_INDEX_NAME, ST_MODULE_NAME, ST_RESERVED_SEQUENCE_RANGE, ST_ROW_LEVEL_SECURITY_ID, ST_ROW_LEVEL_SECURITY_NAME, ST_SCHEDULED_ID, ST_SCHEDULED_NAME, ST_SEQUENCE_ID, ST_SEQUENCE_NAME, ST_TABLE_NAME, ST_VAR_ID, ST_VAR_NAME, @@ -1603,6 +1604,7 @@ mod tests { TableRow { id: ST_VAR_ID.into(), name: ST_VAR_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StVarFields::Name.into()) }, TableRow { id: ST_SCHEDULED_ID.into(), name: ST_SCHEDULED_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StScheduledFields::ScheduleId.into()) }, TableRow { id: ST_ROW_LEVEL_SECURITY_ID.into(), name: ST_ROW_LEVEL_SECURITY_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StRowLevelSecurityFields::Sql.into()) }, + TableRow { id: ST_CONNECTION_CREDENTIALS_ID.into(), name: ST_CONNECTION_CREDENTIALS_NAME, ty: StTableType::System, access: StAccess::Private, primary_key: Some(StConnectionCredentialsFields::ConnectionId.into()) }, ])); #[rustfmt::skip] assert_eq!(query.scan_st_columns()?, map_array([ @@ -1658,6 +1660,9 @@ mod tests { ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 0, name: "table_id", ty: TableId::get_type() }, ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 1, name: "sql", ty: AlgebraicType::String }, + + ColRow { table: ST_CONNECTION_CREDENTIALS_ID.into(), pos: 0, name: "connection_id", ty: AlgebraicType::U128 }, + ColRow { table: ST_CONNECTION_CREDENTIALS_ID.into(), pos: 1, name: "jwt_payload", ty: AlgebraicType::String }, ])); #[rustfmt::skip] assert_eq!(query.scan_st_indexes()?, map_array([ @@ -1673,6 +1678,7 @@ mod tests { IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "st_scheduled_table_id_idx_btree", }, IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "st_row_level_security_table_id_idx_btree", }, IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "st_row_level_security_sql_idx_btree", }, + IndexRow { id: 13, table: ST_CONNECTION_CREDENTIALS_ID.into(), col: col(0), name: "st_connection_credentials_connection_id_idx_btree", }, ])); let start = FIRST_NON_SYSTEM_ID as i128; #[rustfmt::skip] @@ -1702,6 +1708,7 @@ mod tests { ConstraintRow { constraint_id: 9, table_id: ST_SCHEDULED_ID.into(), unique_columns: col(0), constraint_name: "st_scheduled_schedule_id_key", }, ConstraintRow { constraint_id: 10, table_id: ST_SCHEDULED_ID.into(), unique_columns: col(1), constraint_name: "st_scheduled_table_id_key", }, ConstraintRow { constraint_id: 11, table_id: ST_ROW_LEVEL_SECURITY_ID.into(), unique_columns: col(1), constraint_name: "st_row_level_security_sql_key", }, + ConstraintRow { constraint_id: 12, table_id: ST_CONNECTION_CREDENTIALS_ID.into(), unique_columns: col(0), constraint_name: "st_connection_credentials_connection_id_key", }, ])); // Verify we get back the tables correctly with the proper ids... @@ -2099,6 +2106,7 @@ mod tests { IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "st_scheduled_table_id_idx_btree", }, IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "st_row_level_security_table_id_idx_btree", }, IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "st_row_level_security_sql_idx_btree", }, + IndexRow { id: 13, table: ST_CONNECTION_CREDENTIALS_ID.into(), col: col(0), name: "st_connection_credentials_connection_id_idx_btree", }, IndexRow { id: seq_start, table: FIRST_NON_SYSTEM_ID, col: col(0), name: "Foo_id_idx_btree", }, IndexRow { id: seq_start + 1, table: FIRST_NON_SYSTEM_ID, col: col(1), name: "Foo_name_idx_btree", }, IndexRow { id: seq_start + 2, table: FIRST_NON_SYSTEM_ID, col: col(2), name: "Foo_age_idx_btree", }, diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index e1c92e6d204..076c9a1c17d 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -10,6 +10,9 @@ use super::{ }; use crate::execution_context::ExecutionContext; use crate::execution_context::Workload; +use crate::system_tables::{ + ConnectionIdViaU128, StConnectionCredentialsFields, StConnectionCredentialsRow, ST_CONNECTION_CREDENTIALS_ID, +}; use crate::traits::{InsertFlags, RowTypeForTable, TxData, UpdateFlags}; use crate::{ error::{IndexError, SequenceError, TableError}, @@ -1164,13 +1167,14 @@ impl MutTxId { }) } - /// Commits this transaction, applying its changes to the committed state. + /// Commits this transaction in memory, applying its changes to the committed state. + /// This doesn't handle the persistence layer at all. /// /// Returns: /// - [`TxData`], the set of inserts and deletes performed by this transaction. /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran during this transaction. - pub fn commit(mut self) -> (TxData, TxMetrics, String) { + pub(super) fn commit(mut self) -> (TxData, TxMetrics, String) { let tx_data = self.committed_state_write_lock.merge(self.tx_state, &self.ctx); // Compute and keep enough info that we can @@ -1349,12 +1353,49 @@ impl<'a, I: Iterator>> Iterator for FilterDeleted<'a, I> { } impl MutTxId { - pub fn insert_st_client(&mut self, identity: Identity, connection_id: ConnectionId) -> Result<()> { + pub fn insert_st_client( + &mut self, + identity: Identity, + connection_id: ConnectionId, + jwt_payload: &str, + ) -> Result<()> { let row = &StClientRow { identity: identity.into(), connection_id: connection_id.into(), }; - self.insert_via_serialize_bsatn(ST_CLIENT_ID, row).map(|_| ()) + self.insert_via_serialize_bsatn(ST_CLIENT_ID, row) + .map(|_| ()) + .inspect_err(|e| { + log::error!( + "[{identity}]: insert_st_client: failed to insert client ({identity}, {connection_id}), error: {e}" + ); + })?; + self.insert_st_client_credentials(connection_id, jwt_payload) + } + + fn insert_st_client_credentials(&mut self, connection_id: ConnectionId, jwt_payload: &str) -> Result<()> { + let row = &StConnectionCredentialsRow { + connection_id: connection_id.into(), + jwt_payload: jwt_payload.to_owned(), + }; + self.insert_via_serialize_bsatn(ST_CONNECTION_CREDENTIALS_ID, row) + .map(|_| ()) + .inspect_err(|e| { + log::error!("[{connection_id}]: insert_st_client_credentials: failed to insert client credentials for connection id ({connection_id}), error: {e}"); + }) + } + + fn delete_st_client_credentials(&mut self, database_identity: Identity, connection_id: ConnectionId) -> Result<()> { + if let Err(e) = self.delete_col_eq( + ST_CONNECTION_CREDENTIALS_ID, + StConnectionCredentialsFields::ConnectionId.col_id(), + &ConnectionIdViaU128::from(connection_id).into(), + ) { + // This is possible on restart if the database was previously running a version + // before this system table was added. + log::error!("[{database_identity}]: delete_st_client_credentials: attempting to delete credentials for missing connection id ({connection_id}), error: {e}"); + } + Ok(()) } pub fn delete_st_client( @@ -1378,11 +1419,11 @@ impl MutTxId { .next() .map(|row| row.pointer()) { - self.delete(ST_CLIENT_ID, ptr).map(drop) + self.delete(ST_CLIENT_ID, ptr).map(drop)? } else { log::error!("[{database_identity}]: delete_st_client: attempting to delete client ({identity}, {connection_id}), but no st_client row for that client is resident"); - Ok(()) } + self.delete_st_client_credentials(database_identity, connection_id) } pub fn insert_via_serialize_bsatn<'a, T: Serialize>( diff --git a/crates/datastore/src/system_tables.rs b/crates/datastore/src/system_tables.rs index 1c56df19a65..781c8835b0a 100644 --- a/crates/datastore/src/system_tables.rs +++ b/crates/datastore/src/system_tables.rs @@ -60,6 +60,11 @@ pub const ST_SCHEDULED_ID: TableId = TableId(9); /// The static ID of the table that defines the row level security (RLS) policies pub const ST_ROW_LEVEL_SECURITY_ID: TableId = TableId(10); + +/// The static ID of the table that stores the credentials for each connection. +pub const ST_CONNECTION_CREDENTIALS_ID: TableId = TableId(11); + +pub(crate) const ST_CONNECTION_CREDENTIALS_NAME: &str = "st_connection_credentials"; pub const ST_TABLE_NAME: &str = "st_table"; pub const ST_COLUMN_NAME: &str = "st_column"; pub const ST_SEQUENCE_NAME: &str = "st_sequence"; @@ -97,7 +102,7 @@ pub enum SystemTable { st_row_level_security, } -pub fn system_tables() -> [TableSchema; 10] { +pub fn system_tables() -> [TableSchema; 11] { [ // The order should match the `id` of the system table, that start with [ST_TABLE_IDX]. st_table_schema(), @@ -109,6 +114,7 @@ pub fn system_tables() -> [TableSchema; 10] { st_var_schema(), st_scheduled_schema(), st_row_level_security_schema(), + st_connection_credential_schema(), // Is important this is always last, so the starting sequence for each // system table is correct. st_sequence_schema(), @@ -149,8 +155,9 @@ pub(crate) const ST_CLIENT_IDX: usize = 5; pub(crate) const ST_VAR_IDX: usize = 6; pub(crate) const ST_SCHEDULED_IDX: usize = 7; pub(crate) const ST_ROW_LEVEL_SECURITY_IDX: usize = 8; +pub(crate) const ST_CONNECTION_CREDENTIALS_IDX: usize = 9; // Must be the last index in the array. -pub(crate) const ST_SEQUENCE_IDX: usize = 9; +pub(crate) const ST_SEQUENCE_IDX: usize = 10; macro_rules! st_fields_enum { ($(#[$attr:meta])* enum $ty_name:ident { $($name:expr, $var:ident = $discr:expr,)* }) => { @@ -248,6 +255,13 @@ st_fields_enum!(enum StClientFields { "identity", Identity = 0, "connection_id", ConnectionId = 1, }); + +// WARNING: For a stable schema, don't change the field names and discriminants. +st_fields_enum!(enum StConnectionCredentialsFields { + "connection_id", ConnectionId = 0, + "jwt_payload", JwtPayload = 1, +}); + // WARNING: For a stable schema, don't change the field names and discriminants. st_fields_enum!(enum StVarFields { "name", Name = 0, @@ -341,6 +355,19 @@ fn system_module_def() -> ModuleDef { .with_type(TableType::System); // TODO: add empty unique constraint here, once we've implemented those. + let st_connection_credentials_type = builder.add_type::(); + // let st_connection_credentials_unique_cols = [StConnectionCredentialsFields::ConnectionId]; + builder + .build_table( + ST_CONNECTION_CREDENTIALS_NAME, + *st_connection_credentials_type.as_ref().expect("should be ref"), + ) + .with_type(TableType::System) + .with_unique_constraint(StConnectionCredentialsFields::ConnectionId) + .with_index_no_accessor_name(btree(StConnectionCredentialsFields::ConnectionId)) + .with_access(v9::TableAccess::Private) + .with_primary_key(StConnectionCredentialsFields::ConnectionId); + let st_client_type = builder.add_type::(); let st_client_unique_cols = [StClientFields::Identity, StClientFields::ConnectionId]; builder @@ -382,6 +409,7 @@ fn system_module_def() -> ModuleDef { validate_system_table::(&result, ST_CLIENT_NAME); validate_system_table::(&result, ST_VAR_NAME); validate_system_table::(&result, ST_SCHEDULED_NAME); + validate_system_table::(&result, ST_CONNECTION_CREDENTIALS_NAME); result } @@ -442,6 +470,10 @@ fn st_client_schema() -> TableSchema { st_schema(ST_CLIENT_NAME, ST_CLIENT_ID) } +fn st_connection_credential_schema() -> TableSchema { + st_schema(ST_CONNECTION_CREDENTIALS_NAME, ST_CONNECTION_CREDENTIALS_ID) +} + fn st_scheduled_schema() -> TableSchema { st_schema(ST_SCHEDULED_NAME, ST_SCHEDULED_ID) } @@ -466,6 +498,7 @@ pub(crate) fn system_table_schema(table_id: TableId) -> Option { ST_ROW_LEVEL_SECURITY_ID => Some(st_row_level_security_schema()), ST_MODULE_ID => Some(st_module_schema()), ST_CLIENT_ID => Some(st_client_schema()), + ST_CONNECTION_CREDENTIALS_ID => Some(st_connection_credential_schema()), ST_VAR_ID => Some(st_var_schema()), ST_SCHEDULED_ID => Some(st_scheduled_schema()), _ => None, @@ -836,6 +869,12 @@ impl From for ConnectionIdViaU128 { } } +impl From for AlgebraicValue { + fn from(val: ConnectionIdViaU128) -> Self { + AlgebraicValue::U128(val.0.to_u128().into()) + } +} + /// A wrapper for [`Identity`] that acts like [`AlgebraicType::U256`] for serialization purposes. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct IdentityViaU256(pub Identity); @@ -927,6 +966,18 @@ pub struct StClientRow { pub connection_id: ConnectionIdViaU128, } +/// System table [ST_CONNECTION_CREDENTIALS_NAME] +/// +/// | connection_id | jwt_payload | +/// |------------------------------------|---------------------------------------------------------| +/// | 0x6bdea3ab517f5857dc9b1b5fe99e1b14 | '{"iss":"issuer","sub":"user-id","iat":1629212345,...}' | +#[derive(Clone, Debug, Eq, PartialEq, SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct StConnectionCredentialsRow { + pub connection_id: ConnectionIdViaU128, + pub jwt_payload: String, +} + impl From for ProductValue { fn from(var: StClientRow) -> Self { to_product_value(&var) diff --git a/crates/testing/src/modules.rs b/crates/testing/src/modules.rs index 8a62076b64a..1df5f6e4005 100644 --- a/crates/testing/src/modules.rs +++ b/crates/testing/src/modules.rs @@ -192,7 +192,7 @@ impl CompiledModule { .unwrap(); // TODO: Fix this when we update identity generation. let identity = Identity::ZERO; - let db_identity = SpacetimeAuth::alloc(&env).await.unwrap().identity; + let db_identity = SpacetimeAuth::alloc(&env).await.unwrap().claims.identity; let connection_id = generate_random_connection_id(); let program_bytes = self.program_bytes().to_owned(); diff --git a/smoketests/config.toml b/smoketests/config.toml index b7c4ad31a45..5a37a8381b6 100644 --- a/smoketests/config.toml +++ b/smoketests/config.toml @@ -1,5 +1,4 @@ default_server = "127.0.0.1:3000" -spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwYzc3NDY1NTE5MDM2MTE4M2JiNjFmMWMxYzY3NDUzMzYzY2MxMTY4MmM1NTUwNWZiNjdlYzI0ZWMyMWViIiwic3ViIjoiOTJlMmNkOGQtNTk5Ny00NjZlLWIwNmYtZDNjOGQ1NzU3ODI4IiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTc1MjA0NjgwMCwiZXhwIjpudWxsfQ.dgefoxC7eCOONVUufu2JTVFo9876zQ4Mqwm0ivZ0PQK7Hacm3Ip_xqyav4bilZ0vIEf8IM8AB0_xawk8WcbvMg" [[server_configs]] nickname = "localhost"