From aa3af157a359e23414cf28d26afc02db748f04ec Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Feb 2025 11:58:45 +0100 Subject: [PATCH 1/7] storage: store dynamic policy data in the database --- crates/data-model/src/lib.rs | 2 + crates/data-model/src/policy_data.rs | 15 ++ ...73e81d9b75e90929af80961f8b5910873a43e.json | 14 ++ ...76eb57ccb6e4053ab8f4450dd4a9d1f6ba108.json | 32 +++ ...c05a7ed582c9e5c1f65d27b0686f843ccfe42.json | 16 ++ .../20250225091000_dynamic_policy_data.sql | 11 + crates/storage-pg/src/lib.rs | 1 + crates/storage-pg/src/policy_data.rs | 204 ++++++++++++++++++ crates/storage-pg/src/repository.rs | 6 + crates/storage/src/lib.rs | 1 + crates/storage/src/policy_data.rs | 76 +++++++ crates/storage/src/repository.rs | 17 ++ 12 files changed, 395 insertions(+) create mode 100644 crates/data-model/src/policy_data.rs create mode 100644 crates/storage-pg/.sqlx/query-5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e.json create mode 100644 crates/storage-pg/.sqlx/query-9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108.json create mode 100644 crates/storage-pg/.sqlx/query-b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42.json create mode 100644 crates/storage-pg/migrations/20250225091000_dynamic_policy_data.sql create mode 100644 crates/storage-pg/src/policy_data.rs create mode 100644 crates/storage/src/policy_data.rs diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index b26f74f1b..67d46e1e8 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -10,6 +10,7 @@ use thiserror::Error; pub(crate) mod compat; pub mod oauth2; +pub(crate) mod policy_data; mod site_config; pub(crate) mod tokens; pub(crate) mod upstream_oauth2; @@ -32,6 +33,7 @@ pub use self::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, DeviceCodeGrant, DeviceCodeGrantState, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState, }, + policy_data::PolicyData, site_config::{CaptchaConfig, CaptchaService, SessionExpirationConfig, SiteConfig}, tokens::{ AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType, diff --git a/crates/data-model/src/policy_data.rs b/crates/data-model/src/policy_data.rs new file mode 100644 index 000000000..8836c2c0c --- /dev/null +++ b/crates/data-model/src/policy_data.rs @@ -0,0 +1,15 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct PolicyData { + pub id: Ulid, + pub created_at: DateTime, + pub data: serde_json::Value, +} diff --git a/crates/storage-pg/.sqlx/query-5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e.json b/crates/storage-pg/.sqlx/query-5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e.json new file mode 100644 index 000000000..d13316a5f --- /dev/null +++ b/crates/storage-pg/.sqlx/query-5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM policy_data\n WHERE policy_data_id IN (\n SELECT policy_data_id\n FROM policy_data\n ORDER BY policy_data_id DESC\n OFFSET $1\n )\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [] + }, + "hash": "5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e" +} diff --git a/crates/storage-pg/.sqlx/query-9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108.json b/crates/storage-pg/.sqlx/query-9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108.json new file mode 100644 index 000000000..03d162afe --- /dev/null +++ b/crates/storage-pg/.sqlx/query-9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108.json @@ -0,0 +1,32 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT policy_data_id, created_at, data\n FROM policy_data\n ORDER BY policy_data_id DESC\n LIMIT 1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "policy_data_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 2, + "name": "data", + "type_info": "Jsonb" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108" +} diff --git a/crates/storage-pg/.sqlx/query-b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42.json b/crates/storage-pg/.sqlx/query-b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42.json new file mode 100644 index 000000000..304f9c96a --- /dev/null +++ b/crates/storage-pg/.sqlx/query-b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO policy_data (policy_data_id, created_at, data)\n VALUES ($1, $2, $3)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Jsonb" + ] + }, + "nullable": [] + }, + "hash": "b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42" +} diff --git a/crates/storage-pg/migrations/20250225091000_dynamic_policy_data.sql b/crates/storage-pg/migrations/20250225091000_dynamic_policy_data.sql new file mode 100644 index 000000000..4a0984925 --- /dev/null +++ b/crates/storage-pg/migrations/20250225091000_dynamic_policy_data.sql @@ -0,0 +1,11 @@ +-- Copyright 2025 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a table which stores the latest policy data +CREATE TABLE IF NOT EXISTS policy_data ( + policy_data_id UUID PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL, + data JSONB NOT NULL +); diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index 3312d73b8..8971488a5 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -173,6 +173,7 @@ mod errors; pub(crate) mod filter; pub(crate) mod iden; pub(crate) mod pagination; +pub(crate) mod policy_data; pub(crate) mod repository; pub(crate) mod tracing; diff --git a/crates/storage-pg/src/policy_data.rs b/crates/storage-pg/src/policy_data.rs new file mode 100644 index 000000000..65615b348 --- /dev/null +++ b/crates/storage-pg/src/policy_data.rs @@ -0,0 +1,204 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! A module containing the PostgreSQL implementation of the policy data +//! storage. + +use async_trait::async_trait; +use mas_data_model::PolicyData; +use mas_storage::{Clock, policy_data::PolicyDataRepository}; +use rand::RngCore; +use serde_json::Value; +use sqlx::{PgConnection, types::Json}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{DatabaseError, ExecuteExt}; + +/// An implementation of [`PolicyDataRepository`] for a PostgreSQL connection. +pub struct PgPolicyDataRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgPolicyDataRepository<'c> { + /// Create a new [`PgPolicyDataRepository`] from an active PostgreSQL + /// connection. + #[must_use] + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct PolicyDataLookup { + policy_data_id: Uuid, + created_at: chrono::DateTime, + data: Json, +} + +impl From for PolicyData { + fn from(value: PolicyDataLookup) -> Self { + PolicyData { + id: value.policy_data_id.into(), + created_at: value.created_at, + data: value.data.0, + } + } +} + +#[async_trait] +impl PolicyDataRepository for PgPolicyDataRepository<'_> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.policy_data.get", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn get(&mut self) -> Result, Self::Error> { + let row = sqlx::query_as!( + PolicyDataLookup, + r#" + SELECT policy_data_id, created_at, data + FROM policy_data + ORDER BY policy_data_id DESC + LIMIT 1 + "# + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(row) = row else { + return Ok(None); + }; + + Ok(Some(row.into())) + } + + #[tracing::instrument( + name = "db.policy_data.set", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn set( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + data: Value, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + + sqlx::query!( + r#" + INSERT INTO policy_data (policy_data_id, created_at, data) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + created_at, + data, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(PolicyData { + id, + created_at, + data, + }) + } + + #[tracing::instrument( + name = "db.policy_data.prune", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn prune(&mut self, keep: usize) -> Result { + let res = sqlx::query!( + r#" + DELETE FROM policy_data + WHERE policy_data_id IN ( + SELECT policy_data_id + FROM policy_data + ORDER BY policy_data_id DESC + OFFSET $1 + ) + "#, + i64::try_from(keep).map_err(DatabaseError::to_invalid_operation)? + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(res + .rows_affected() + .try_into() + .map_err(DatabaseError::to_invalid_operation)?) + } +} + +#[cfg(test)] +mod tests { + use mas_storage::{clock::MockClock, policy_data::PolicyDataRepository}; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use serde_json::json; + use sqlx::PgPool; + + use crate::policy_data::PgPolicyDataRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_policy_data(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut conn = pool.acquire().await.unwrap(); + let mut repo = PgPolicyDataRepository::new(&mut conn); + + // Get an empty state at first + let data = repo.get().await.unwrap(); + assert_eq!(data, None); + + // Set some data + let value1 = json!({"hello": "world"}); + let policy_data1 = repo.set(&mut rng, &clock, value1.clone()).await.unwrap(); + assert_eq!(policy_data1.data, value1); + + let data_fetched1 = repo.get().await.unwrap().unwrap(); + assert_eq!(policy_data1, data_fetched1); + + // Set some new data + clock.advance(chrono::Duration::seconds(1)); + let value2 = json!({"foo": "bar"}); + let policy_data2 = repo.set(&mut rng, &clock, value2.clone()).await.unwrap(); + assert_eq!(policy_data2.data, value2); + + // Check the new data is fetched + let data_fetched2 = repo.get().await.unwrap().unwrap(); + assert_eq!(data_fetched2, policy_data2); + + // Prune until the first entry + let affected = repo.prune(1).await.unwrap(); + let data_fetched3 = repo.get().await.unwrap().unwrap(); + assert_eq!(data_fetched3, policy_data2); + assert_eq!(affected, 1); + + // Do a raw query to check the other rows were pruned + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM policy_data") + .fetch_one(&mut *conn) + .await + .unwrap(); + assert_eq!(count, 1); + } +} diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 69ec9980b..901f1fd45 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -18,6 +18,7 @@ use mas_storage::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, + policy_data::PolicyDataRepository, queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository}, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, @@ -40,6 +41,7 @@ use crate::{ PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, }, + policy_data::PgPolicyDataRepository, queue::{ job::PgQueueJobRepository, schedule::PgQueueScheduleRepository, worker::PgQueueWorkerRepository, @@ -283,4 +285,8 @@ where ) -> Box + 'c> { Box::new(PgQueueScheduleRepository::new(self.conn.as_mut())) } + + fn policy_data<'c>(&'c mut self) -> Box + 'c> { + Box::new(PgPolicyDataRepository::new(self.conn.as_mut())) + } } diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index cd0d646c3..923113a6a 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -119,6 +119,7 @@ mod utils; pub mod app_session; pub mod compat; pub mod oauth2; +pub mod policy_data; pub mod queue; pub mod upstream_oauth2; pub mod user; diff --git a/crates/storage/src/policy_data.rs b/crates/storage/src/policy_data.rs new file mode 100644 index 000000000..6c7e5d89f --- /dev/null +++ b/crates/storage/src/policy_data.rs @@ -0,0 +1,76 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +//! Repositories to interact with the policy data saved in the storage backend. + +use async_trait::async_trait; +use mas_data_model::PolicyData; +use rand_core::RngCore; + +use crate::{Clock, repository_impl}; + +/// A [`PolicyDataRepository`] helps interacting with the policy data saved in +/// the storage backend. +#[async_trait] +pub trait PolicyDataRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Get the latest policy data + /// + /// Returns the latest policy data, or `None` if no policy data is + /// available. + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn get(&mut self) -> Result, Self::Error>; + + /// Set the latest policy data + /// + /// Returns the newly created policy data. + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate the timestamps + /// * `data`: The policy data to set + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn set( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + data: serde_json::Value, + ) -> Result; + + /// Prune old policy data + /// + /// Returns the number of entries pruned. + /// + /// # Parameters + /// + /// * `keep`: the number of old entries to keep + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn prune(&mut self, keep: usize) -> Result; +} + +repository_impl!(PolicyDataRepository: + async fn get(&mut self) -> Result, Self::Error>; + + async fn set( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + data: serde_json::Value, + ) -> Result; + + async fn prune(&mut self, keep: usize) -> Result; +); diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 8d1c501ec..2f051493c 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -17,6 +17,7 @@ use crate::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, + policy_data::PolicyDataRepository, queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository}, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, @@ -204,6 +205,9 @@ pub trait RepositoryAccess: Send { fn queue_schedule<'c>( &'c mut self, ) -> Box + 'c>; + + /// Get a [`PolicyDataRepository`] + fn policy_data<'c>(&'c mut self) -> Box + 'c>; } /// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and @@ -224,6 +228,7 @@ mod impls { OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, + policy_data::PolicyDataRepository, queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository}, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, @@ -439,6 +444,12 @@ mod impls { ) -> Box + 'c> { Box::new(MapErr::new(self.inner.queue_schedule(), &mut self.mapper)) } + + fn policy_data<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.policy_data(), &mut self.mapper)) + } } impl RepositoryAccess for Box { @@ -579,5 +590,11 @@ mod impls { ) -> Box + 'c> { (**self).queue_schedule() } + + fn policy_data<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).policy_data() + } } } From d393494e76345b382c2497213984255b1062efc5 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Feb 2025 13:06:44 +0100 Subject: [PATCH 2/7] Admin API to get and set policy data --- crates/handlers/src/admin/mod.rs | 5 + crates/handlers/src/admin/model.rs | 47 ++++ crates/handlers/src/admin/v1/mod.rs | 16 ++ .../handlers/src/admin/v1/policy_data/get.rs | 153 ++++++++++ .../src/admin/v1/policy_data/get_latest.rs | 149 ++++++++++ .../handlers/src/admin/v1/policy_data/mod.rs | 14 + .../handlers/src/admin/v1/policy_data/set.rs | 133 +++++++++ docs/api/spec.json | 263 ++++++++++++++++++ 8 files changed, 780 insertions(+) create mode 100644 crates/handlers/src/admin/v1/policy_data/get.rs create mode 100644 crates/handlers/src/admin/v1/policy_data/get_latest.rs create mode 100644 crates/handlers/src/admin/v1/policy_data/mod.rs create mode 100644 crates/handlers/src/admin/v1/policy_data/set.rs diff --git a/crates/handlers/src/admin/mod.rs b/crates/handlers/src/admin/mod.rs index 4edc8b2ca..4e65ba742 100644 --- a/crates/handlers/src/admin/mod.rs +++ b/crates/handlers/src/admin/mod.rs @@ -45,6 +45,11 @@ fn finish(t: TransformOpenApi) -> TransformOpenApi { description: Some("Manage compatibility sessions from legacy clients".to_owned()), ..Tag::default() }) + .tag(Tag { + name: "policy-data".to_owned(), + description: Some("Manage the dynamic policy data".to_owned()), + ..Tag::default() + }) .tag(Tag { name: "oauth2-session".to_owned(), description: Some("Manage OAuth2 sessions".to_owned()), diff --git a/crates/handlers/src/admin/model.rs b/crates/handlers/src/admin/model.rs index b98770379..ec4f8cbc2 100644 --- a/crates/handlers/src/admin/model.rs +++ b/crates/handlers/src/admin/model.rs @@ -534,3 +534,50 @@ impl UpstreamOAuthLink { ] } } + +/// The policy data +#[derive(Serialize, JsonSchema)] +pub struct PolicyData { + #[serde(skip)] + id: Ulid, + + /// The creation date of the policy data + created_at: DateTime, + + /// The policy data content + data: serde_json::Value, +} + +impl From for PolicyData { + fn from(policy_data: mas_data_model::PolicyData) -> Self { + Self { + id: policy_data.id, + created_at: policy_data.created_at, + data: policy_data.data, + } + } +} + +impl Resource for PolicyData { + const KIND: &'static str = "policy-data"; + const PATH: &'static str = "/api/admin/v1/policy-data"; + + fn id(&self) -> Ulid { + self.id + } +} + +impl PolicyData { + /// Samples of policy data + pub fn samples() -> [Self; 1] { + [Self { + id: Ulid::from_bytes([0x01; 16]), + created_at: DateTime::default(), + data: serde_json::json!({ + "hello": "world", + "foo": 42, + "bar": true + }), + }] + } +} diff --git a/crates/handlers/src/admin/v1/mod.rs b/crates/handlers/src/admin/v1/mod.rs index 02c63e2a5..ae258c842 100644 --- a/crates/handlers/src/admin/v1/mod.rs +++ b/crates/handlers/src/admin/v1/mod.rs @@ -17,6 +17,7 @@ use crate::passwords::PasswordManager; mod compat_sessions; mod oauth2_sessions; +mod policy_data; mod upstream_oauth_links; mod user_emails; mod user_sessions; @@ -47,6 +48,21 @@ where "/oauth2-sessions/{id}", get_with(self::oauth2_sessions::get, self::oauth2_sessions::get_doc), ) + .api_route( + "/policy-data", + post_with(self::policy_data::set, self::policy_data::set_doc), + ) + .api_route( + "/policy-data/latest", + get_with( + self::policy_data::get_latest, + self::policy_data::get_latest_doc, + ), + ) + .api_route( + "/policy-data/{id}", + get_with(self::policy_data::get, self::policy_data::get_doc), + ) .api_route( "/users", get_with(self::users::list, self::users::list_doc) diff --git a/crates/handlers/src/admin/v1/policy_data/get.rs b/crates/handlers/src/admin/v1/policy_data/get.rs new file mode 100644 index 000000000..338c999b3 --- /dev/null +++ b/crates/handlers/src/admin/v1/policy_data/get.rs @@ -0,0 +1,153 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only + +use aide::{OperationIo, transform::TransformOperation}; +use axum::{Json, response::IntoResponse}; +use hyper::StatusCode; +use ulid::Ulid; + +use crate::{ + admin::{ + call_context::CallContext, + model::PolicyData, + params::UlidPathParam, + response::{ErrorResponse, SingleResponse}, + }, + impl_from_error_for_route, +}; + +#[derive(Debug, thiserror::Error, OperationIo)] +#[aide(output_with = "Json")] +pub enum RouteError { + #[error(transparent)] + Internal(Box), + + #[error("Policy data with ID {0} not found")] + NotFound(Ulid), +} + +impl_from_error_for_route!(mas_storage::RepositoryError); + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + let error = ErrorResponse::from_error(&self); + let status = match self { + Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::NotFound(_) => StatusCode::NOT_FOUND, + }; + (status, Json(error)).into_response() + } +} + +pub fn doc(operation: TransformOperation) -> TransformOperation { + operation + .id("getPolicyData") + .summary("Get policy data by ID") + .tag("policy-data") + .response_with::<200, Json>, _>(|t| { + let [sample, ..] = PolicyData::samples(); + let response = SingleResponse::new_canonical(sample); + t.description("Policy data was found").example(response) + }) + .response_with::<404, RouteError, _>(|t| { + let response = ErrorResponse::from_error(&RouteError::NotFound(Ulid::nil())); + t.description("Policy data was not found").example(response) + }) +} + +#[tracing::instrument(name = "handler.admin.v1.policy_data.get", skip_all, err)] +pub async fn handler( + CallContext { mut repo, .. }: CallContext, + id: UlidPathParam, +) -> Result>, RouteError> { + let policy_data = repo + .policy_data() + .get() + .await? + .ok_or(RouteError::NotFound(*id))?; + + Ok(Json(SingleResponse::new_canonical(policy_data.into()))) +} + +#[cfg(test)] +mod tests { + use hyper::{Request, StatusCode}; + use insta::assert_json_snapshot; + use sqlx::PgPool; + use ulid::Ulid; + + use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup}; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_get(pool: PgPool) { + setup(); + let mut state = TestState::from_pool(pool).await.unwrap(); + let token = state.token_with_scope("urn:mas:admin").await; + + let mut rng = state.rng(); + let mut repo = state.repository().await.unwrap(); + + let policy_data = repo + .policy_data() + .set( + &mut rng, + &state.clock, + serde_json::json!({"hello": "world"}), + ) + .await + .unwrap(); + + repo.save().await.unwrap(); + + let request = Request::get(format!("/api/admin/v1/policy-data/{}", policy_data.id)) + .bearer(&token) + .empty(); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let body: serde_json::Value = response.json(); + assert_json_snapshot!(body, @r###" + { + "data": { + "type": "policy-data", + "id": "01FSHN9AG0MZAA6S4AF7CTV32E", + "attributes": { + "created_at": "2022-01-16T14:40:00Z", + "data": { + "hello": "world" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E" + } + } + "###); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_get_not_found(pool: PgPool) { + setup(); + let mut state = TestState::from_pool(pool).await.unwrap(); + let token = state.token_with_scope("urn:mas:admin").await; + + let request = Request::get(format!("/api/admin/v1/policy-data/{}", Ulid::nil())) + .bearer(&token) + .empty(); + let response = state.request(request).await; + response.assert_status(StatusCode::NOT_FOUND); + let body: serde_json::Value = response.json(); + assert_json_snapshot!(body, @r###" + { + "errors": [ + { + "title": "Policy data with ID 00000000000000000000000000 not found" + } + ] + } + "###); + } +} diff --git a/crates/handlers/src/admin/v1/policy_data/get_latest.rs b/crates/handlers/src/admin/v1/policy_data/get_latest.rs new file mode 100644 index 000000000..7b4c0654f --- /dev/null +++ b/crates/handlers/src/admin/v1/policy_data/get_latest.rs @@ -0,0 +1,149 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only + +use aide::{OperationIo, transform::TransformOperation}; +use axum::{Json, response::IntoResponse}; +use hyper::StatusCode; + +use crate::{ + admin::{ + call_context::CallContext, + model::PolicyData, + response::{ErrorResponse, SingleResponse}, + }, + impl_from_error_for_route, +}; + +#[derive(Debug, thiserror::Error, OperationIo)] +#[aide(output_with = "Json")] +pub enum RouteError { + #[error(transparent)] + Internal(Box), + + #[error("No policy data found")] + NotFound, +} + +impl_from_error_for_route!(mas_storage::RepositoryError); + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + let error = ErrorResponse::from_error(&self); + let status = match self { + Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::NotFound => StatusCode::NOT_FOUND, + }; + (status, Json(error)).into_response() + } +} + +pub fn doc(operation: TransformOperation) -> TransformOperation { + operation + .id("getLatestPolicyData") + .summary("Get the latest policy data") + .tag("policy-data") + .response_with::<200, Json>, _>(|t| { + let [sample, ..] = PolicyData::samples(); + let response = SingleResponse::new_canonical(sample); + t.description("Latest policy data was found") + .example(response) + }) + .response_with::<404, RouteError, _>(|t| { + let response = ErrorResponse::from_error(&RouteError::NotFound); + t.description("No policy data was found").example(response) + }) +} + +#[tracing::instrument(name = "handler.admin.v1.policy_data.get_latest", skip_all, err)] +pub async fn handler( + CallContext { mut repo, .. }: CallContext, +) -> Result>, RouteError> { + let policy_data = repo + .policy_data() + .get() + .await? + .ok_or(RouteError::NotFound)?; + + Ok(Json(SingleResponse::new_canonical(policy_data.into()))) +} + +#[cfg(test)] +mod tests { + use hyper::{Request, StatusCode}; + use insta::assert_json_snapshot; + use sqlx::PgPool; + + use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup}; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_get_latest(pool: PgPool) { + setup(); + let mut state = TestState::from_pool(pool).await.unwrap(); + let token = state.token_with_scope("urn:mas:admin").await; + + let mut rng = state.rng(); + let mut repo = state.repository().await.unwrap(); + + repo.policy_data() + .set( + &mut rng, + &state.clock, + serde_json::json!({"hello": "world"}), + ) + .await + .unwrap(); + + repo.save().await.unwrap(); + + let request = Request::get("/api/admin/v1/policy-data/latest") + .bearer(&token) + .empty(); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let body: serde_json::Value = response.json(); + assert_json_snapshot!(body, @r###" + { + "data": { + "type": "policy-data", + "id": "01FSHN9AG0MZAA6S4AF7CTV32E", + "attributes": { + "created_at": "2022-01-16T14:40:00Z", + "data": { + "hello": "world" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E" + } + } + "###); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_get_no_latest(pool: PgPool) { + setup(); + let mut state = TestState::from_pool(pool).await.unwrap(); + let token = state.token_with_scope("urn:mas:admin").await; + + let request = Request::get("/api/admin/v1/policy-data/latest") + .bearer(&token) + .empty(); + let response = state.request(request).await; + response.assert_status(StatusCode::NOT_FOUND); + let body: serde_json::Value = response.json(); + assert_json_snapshot!(body, @r###" + { + "errors": [ + { + "title": "No policy data found" + } + ] + } + "###); + } +} diff --git a/crates/handlers/src/admin/v1/policy_data/mod.rs b/crates/handlers/src/admin/v1/policy_data/mod.rs new file mode 100644 index 000000000..9143a2e11 --- /dev/null +++ b/crates/handlers/src/admin/v1/policy_data/mod.rs @@ -0,0 +1,14 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +mod get; +mod get_latest; +mod set; + +pub use self::{ + get::{doc as get_doc, handler as get}, + get_latest::{doc as get_latest_doc, handler as get_latest}, + set::{doc as set_doc, handler as set}, +}; diff --git a/crates/handlers/src/admin/v1/policy_data/set.rs b/crates/handlers/src/admin/v1/policy_data/set.rs new file mode 100644 index 000000000..aa3edfe5a --- /dev/null +++ b/crates/handlers/src/admin/v1/policy_data/set.rs @@ -0,0 +1,133 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only + +use aide::{NoApi, OperationIo, transform::TransformOperation}; +use axum::{Json, response::IntoResponse}; +use hyper::StatusCode; +use mas_storage::BoxRng; +use schemars::JsonSchema; +use serde::Deserialize; + +use crate::{ + admin::{ + call_context::CallContext, + model::PolicyData, + response::{ErrorResponse, SingleResponse}, + }, + impl_from_error_for_route, +}; + +#[derive(Debug, thiserror::Error, OperationIo)] +#[aide(output_with = "Json")] +pub enum RouteError { + #[error(transparent)] + Internal(Box), +} + +impl_from_error_for_route!(mas_storage::RepositoryError); + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + let error = ErrorResponse::from_error(&self); + let status = StatusCode::INTERNAL_SERVER_ERROR; + (status, Json(error)).into_response() + } +} + +fn data_example() -> serde_json::Value { + serde_json::json!({ + "hello": "world", + "foo": 42, + "bar": true + }) +} + +/// # JSON payload for the `POST /api/admin/v1/policy-data` +#[derive(Deserialize, JsonSchema)] +#[serde(rename = "SetPolicyDataRequest")] +pub struct SetPolicyDataRequest { + #[schemars(example = "data_example")] + pub data: serde_json::Value, +} + +pub fn doc(operation: TransformOperation) -> TransformOperation { + operation + .id("setPolicyData") + .summary("Set the current policy data") + .tag("policy-data") + .response_with::<201, Json>, _>(|t| { + let [sample, ..] = PolicyData::samples(); + let response = SingleResponse::new_canonical(sample); + t.description("Policy data was successfully set") + .example(response) + }) +} + +#[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all, err)] +pub async fn handler( + CallContext { + mut repo, clock, .. + }: CallContext, + NoApi(mut rng): NoApi, + Json(request): Json, +) -> Result<(StatusCode, Json>), RouteError> { + let policy_data = repo + .policy_data() + .set(&mut rng, &clock, request.data) + .await?; + + repo.save().await?; + + Ok(( + StatusCode::CREATED, + Json(SingleResponse::new_canonical(policy_data.into())), + )) +} + +#[cfg(test)] +mod tests { + use hyper::{Request, StatusCode}; + use insta::assert_json_snapshot; + use sqlx::PgPool; + + use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup}; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_create(pool: PgPool) { + setup(); + let mut state = TestState::from_pool(pool).await.unwrap(); + let token = state.token_with_scope("urn:mas:admin").await; + + let request = Request::post("/api/admin/v1/policy-data") + .bearer(&token) + .json(serde_json::json!({ + "data": { + "hello": "world" + } + })); + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + let body: serde_json::Value = response.json(); + assert_json_snapshot!(body, @r###" + { + "data": { + "type": "policy-data", + "id": "01FSHN9AG0MZAA6S4AF7CTV32E", + "attributes": { + "created_at": "2022-01-16T14:40:00Z", + "data": { + "hello": "world" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E" + } + } + "###); + } +} diff --git a/docs/api/spec.json b/docs/api/spec.json index ec593ba9e..1fcc17116 100644 --- a/docs/api/spec.json +++ b/docs/api/spec.json @@ -593,6 +593,185 @@ } } }, + "/api/admin/v1/policy-data": { + "post": { + "tags": [ + "policy-data" + ], + "summary": "Set the current policy data", + "operationId": "setPolicyData", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetPolicyDataRequest" + } + } + }, + "required": true + }, + "responses": { + "201": { + "description": "Policy data was successfully set", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SingleResponse_for_PolicyData" + }, + "example": { + "data": { + "type": "policy-data", + "id": "01040G2081040G2081040G2081", + "attributes": { + "created_at": "1970-01-01T00:00:00Z", + "data": { + "hello": "world", + "foo": 42, + "bar": true + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081" + } + } + } + } + } + } + } + }, + "/api/admin/v1/policy-data/latest": { + "get": { + "tags": [ + "policy-data" + ], + "summary": "Get the latest policy data", + "operationId": "getLatestPolicyData", + "responses": { + "200": { + "description": "Latest policy data was found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SingleResponse_for_PolicyData" + }, + "example": { + "data": { + "type": "policy-data", + "id": "01040G2081040G2081040G2081", + "attributes": { + "created_at": "1970-01-01T00:00:00Z", + "data": { + "hello": "world", + "foo": 42, + "bar": true + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081" + } + } + } + } + }, + "404": { + "description": "No policy data was found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "errors": [ + { + "title": "No policy data found" + } + ] + } + } + } + } + } + } + }, + "/api/admin/v1/policy-data/{id}": { + "get": { + "tags": [ + "policy-data" + ], + "summary": "Get policy data by ID", + "operationId": "getPolicyData", + "parameters": [ + { + "in": "path", + "name": "id", + "required": true, + "schema": { + "title": "The ID of the resource", + "$ref": "#/components/schemas/ULID" + }, + "style": "simple" + } + ], + "responses": { + "200": { + "description": "Policy data was found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SingleResponse_for_PolicyData" + }, + "example": { + "data": { + "type": "policy-data", + "id": "01040G2081040G2081040G2081", + "attributes": { + "created_at": "1970-01-01T00:00:00Z", + "data": { + "hello": "world", + "foo": 42, + "bar": true + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081" + } + }, + "links": { + "self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081" + } + } + } + } + }, + "404": { + "description": "Policy data was not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "errors": [ + { + "title": "Policy data with ID 00000000000000000000000000 not found" + } + ] + } + } + } + } + } + } + }, "/api/admin/v1/users": { "get": { "tags": [ @@ -2666,6 +2845,86 @@ } } }, + "SetPolicyDataRequest": { + "title": "JSON payload for the `POST /api/admin/v1/policy-data`", + "type": "object", + "required": [ + "data" + ], + "properties": { + "data": { + "examples": [ + { + "hello": "world", + "foo": 42, + "bar": true + } + ] + } + } + }, + "SingleResponse_for_PolicyData": { + "description": "A top-level response with a single resource", + "type": "object", + "required": [ + "data", + "links" + ], + "properties": { + "data": { + "$ref": "#/components/schemas/SingleResource_for_PolicyData" + }, + "links": { + "$ref": "#/components/schemas/SelfLinks" + } + } + }, + "SingleResource_for_PolicyData": { + "description": "A single resource, with its type, ID, attributes and related links", + "type": "object", + "required": [ + "attributes", + "id", + "links", + "type" + ], + "properties": { + "type": { + "description": "The type of the resource", + "type": "string" + }, + "id": { + "description": "The ID of the resource", + "$ref": "#/components/schemas/ULID" + }, + "attributes": { + "description": "The attributes of the resource", + "$ref": "#/components/schemas/PolicyData" + }, + "links": { + "description": "Related links", + "$ref": "#/components/schemas/SelfLinks" + } + } + }, + "PolicyData": { + "description": "The policy data", + "type": "object", + "required": [ + "created_at", + "data" + ], + "properties": { + "created_at": { + "description": "The creation date of the policy data", + "type": "string", + "format": "date-time" + }, + "data": { + "description": "The policy data content" + } + } + }, "UserFilter": { "type": "object", "properties": { @@ -3252,6 +3511,10 @@ "name": "compat-session", "description": "Manage compatibility sessions from legacy clients" }, + { + "name": "policy-data", + "description": "Manage the dynamic policy data" + }, { "name": "oauth2-session", "description": "Manage OAuth2 sessions" From 756922342a8bd637934bdca8395f8dde61cfa748 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Feb 2025 16:21:54 +0100 Subject: [PATCH 3/7] policy: allow dynamically setting policy data --- Cargo.lock | 1 + Cargo.toml | 4 + crates/policy/Cargo.toml | 1 + crates/policy/src/lib.rs | 324 +++++++++++++++++++++++++++++++++++- crates/templates/Cargo.toml | 2 +- 5 files changed, 327 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 041b54c00..131f3b74e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3588,6 +3588,7 @@ name = "mas-policy" version = "0.14.1" dependencies = [ "anyhow", + "arc-swap", "mas-data-model", "oauth2-types", "opa-wasm", diff --git a/Cargo.toml b/Cargo.toml index c1dc7569c..72f96c6bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,10 @@ syn2mas = { path = "./crates/syn2mas", version = "=0.14.1" } version = "0.14.1" features = ["axum", "axum-extra", "axum-json", "axum-query", "macros"] +# An `Arc` that can be atomically updated +[workspace.dependencies.arc-swap] +version = "1.7.1" + # GraphQL server [workspace.dependencies.async-graphql] version = "7.0.15" diff --git a/crates/policy/Cargo.toml b/crates/policy/Cargo.toml index 7212b4991..f1ceabfab 100644 --- a/crates/policy/Cargo.toml +++ b/crates/policy/Cargo.toml @@ -13,6 +13,7 @@ workspace = true [dependencies] anyhow.workspace = true +arc-swap.workspace = true opa-wasm = "0.1.4" serde.workspace = true serde_json.workspace = true diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 47d54f468..c8b771b6d 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -6,11 +6,14 @@ pub mod model; +use std::sync::Arc; + +use arc_swap::ArcSwap; +use mas_data_model::Ulid; use opa_wasm::{ Runtime, wasmtime::{Config, Engine, Module, OptLevel, Store}, }; -use serde::Serialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -33,10 +36,23 @@ pub enum LoadError { #[error("failed to compile WASM module")] Compilation(#[source] anyhow::Error), + #[error("invalid policy data")] + InvalidData(#[source] anyhow::Error), + #[error("failed to instantiate a test instance")] Instantiate(#[source] InstantiateError), } +impl LoadError { + /// Creates an example of an invalid data error, used for API response + /// documentation + #[doc(hidden)] + #[must_use] + pub fn invalid_data_example() -> Self { + Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects")) + } +} + #[derive(Debug, Error)] pub enum InstantiateError { #[error("failed to create WASM runtime")] @@ -69,11 +85,10 @@ impl Entrypoints { } } -#[derive(Serialize, Debug)] +#[derive(Debug)] pub struct Data { server_name: String, - #[serde(flatten)] rest: Option, } @@ -91,12 +106,93 @@ impl Data { self.rest = Some(rest); self } + + fn to_value(&self) -> Result { + let base = serde_json::json!({ + "server_name": self.server_name, + }); + + if let Some(rest) = &self.rest { + merge_data(base, rest.clone()) + } else { + Ok(base) + } + } +} + +fn value_kind(value: &serde_json::Value) -> &'static str { + match value { + serde_json::Value::Object(_) => "object", + serde_json::Value::Array(_) => "array", + serde_json::Value::String(_) => "string", + serde_json::Value::Number(_) => "number", + serde_json::Value::Bool(_) => "boolean", + serde_json::Value::Null => "null", + } +} + +fn merge_data( + mut left: serde_json::Value, + right: serde_json::Value, +) -> Result { + merge_data_rec(&mut left, right)?; + Ok(left) +} + +fn merge_data_rec( + left: &mut serde_json::Value, + right: serde_json::Value, +) -> Result<(), anyhow::Error> { + match (left, right) { + (serde_json::Value::Object(left), serde_json::Value::Object(right)) => { + for (key, value) in right { + if let Some(left_value) = left.get_mut(&key) { + merge_data_rec(left_value, value)?; + } else { + left.insert(key, value); + } + } + } + (serde_json::Value::Array(left), serde_json::Value::Array(right)) => { + left.extend(right); + } + // Other values override + (serde_json::Value::Number(left), serde_json::Value::Number(right)) => { + *left = right; + } + (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => { + *left = right; + } + (serde_json::Value::String(left), serde_json::Value::String(right)) => { + *left = right; + } + + // Null gets overridden by anything + (left, right) if left.is_null() => *left = right, + + // Null on the right makes the left value null + (left, right) if right.is_null() => *left = right, + + (left, right) => anyhow::bail!( + "Cannot merge a {} into a {}", + value_kind(&right), + value_kind(left), + ), + } + + Ok(()) +} + +struct DynamicData { + version: Option, + merged: serde_json::Value, } pub struct PolicyFactory { engine: Engine, module: Module, data: Data, + dynamic_data: ArcSwap, entrypoints: Entrypoints, } @@ -124,10 +220,17 @@ impl PolicyFactory { .await? .map_err(LoadError::Compilation)?; + let merged = data.to_value().map_err(LoadError::InvalidData)?; + let dynamic_data = ArcSwap::new(Arc::new(DynamicData { + version: None, + merged, + })); + let factory = Self { engine, module, data, + dynamic_data, entrypoints, }; @@ -140,8 +243,56 @@ impl PolicyFactory { Ok(factory) } + /// Set the dynamic data for the policy. + /// + /// The `dynamic_data` object is merged with the static data given when the + /// policy was loaded. + /// + /// Returns `true` if the data was updated, `false` if the version + /// of the dynamic data was the same as the one we already have. + /// + /// # Errors + /// + /// Returns an error if the data can't be merged with the static data, or if + /// the policy can't be instantiated with the new data. + pub async fn set_dynamic_data( + &self, + dynamic_data: mas_data_model::PolicyData, + ) -> Result { + // Check if the version of the dynamic data we have is the same as the one we're + // trying to set + if self.dynamic_data.load().version == Some(dynamic_data.id) { + // Don't do anything if the version is the same + return Ok(false); + } + + let static_data = self.data.to_value().map_err(LoadError::InvalidData)?; + let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?; + + // Try to instantiate with the new data + self.instantiate_with_data(&merged) + .await + .map_err(LoadError::Instantiate)?; + + // If instantiation succeeds, swap the data + self.dynamic_data.store(Arc::new(DynamicData { + version: Some(dynamic_data.id), + merged, + })); + + Ok(true) + } + #[tracing::instrument(name = "policy.instantiate", skip_all, err)] pub async fn instantiate(&self) -> Result { + let data = self.dynamic_data.load(); + self.instantiate_with_data(&data.merged).await + } + + async fn instantiate_with_data( + &self, + data: &serde_json::Value, + ) -> Result { let mut store = Store::new(&self.engine, ()); let runtime = Runtime::new(&mut store, &self.module) .await @@ -159,7 +310,7 @@ impl PolicyFactory { } let instance = runtime - .with_data(&mut store, &self.data) + .with_data(&mut store, data) .await .map_err(InstantiateError::LoadData)?; @@ -273,6 +424,8 @@ impl Policy { #[cfg(test)] mod tests { + use std::time::SystemTime; + use super::*; #[tokio::test] @@ -344,4 +497,167 @@ mod tests { .unwrap(); assert!(!res.valid()); } + + #[tokio::test] + async fn test_dynamic_data() { + let data = Data::new("example.com".to_owned()); + + #[allow(clippy::disallowed_types)] + let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("..") + .join("policies") + .join("policy.wasm"); + + let file = tokio::fs::File::open(path).await.unwrap(); + + let entrypoints = Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + email: "email/violation".to_owned(), + }; + + let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + + let mut policy = factory.instantiate().await.unwrap(); + + let res = policy + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@example.com"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) + .await + .unwrap(); + assert!(res.valid()); + + // Update the policy data + factory + .set_dynamic_data(mas_data_model::PolicyData { + id: Ulid::nil(), + created_at: SystemTime::now().into(), + data: serde_json::json!({ + "emails": { + "banned_addresses": { + "substrings": ["hello"] + } + } + }), + }) + .await + .unwrap(); + let mut policy = factory.instantiate().await.unwrap(); + let res = policy + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@example.com"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) + .await + .unwrap(); + assert!(!res.valid()); + } + + #[tokio::test] + async fn test_big_dynamic_data() { + let data = Data::new("example.com".to_owned()); + + #[allow(clippy::disallowed_types)] + let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("..") + .join("policies") + .join("policy.wasm"); + + let file = tokio::fs::File::open(path).await.unwrap(); + + let entrypoints = Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + email: "email/violation".to_owned(), + }; + + let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + + // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8 + // characters including the quotes and a comma. + let data: Vec = (0..(1024 * 1024 / 8)) + .map(|i| format!("{:05}", i % 100_000)) + .collect(); + let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } }); + factory + .set_dynamic_data(mas_data_model::PolicyData { + id: Ulid::nil(), + created_at: SystemTime::now().into(), + data: json, + }) + .await + .unwrap(); + + // Try instantiating the policy, make sure 5-digit numbers are banned from email + // addresses + let mut policy = factory.instantiate().await.unwrap(); + let res = policy + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("12345@example.com"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) + .await + .unwrap(); + assert!(!res.valid()); + } + + #[test] + fn test_merge() { + use serde_json::json as j; + + // Merging objects + let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap(); + assert_eq!(res, j!({"hello": "world", "foo": "bar"})); + + // Override a value of the same type + let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap(); + assert_eq!(res, j!({"hello": "john"})); + + let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap(); + assert_eq!(res, j!({"hello": false})); + + let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap(); + assert_eq!(res, j!({"hello": 42})); + + // Override a value of a different type + merge_data(j!({"hello": "world"}), j!({"hello": 123})) + .expect_err("Can't merge different types"); + + // Merge arrays + let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap(); + assert_eq!(res, j!({"hello": ["world", "john"]})); + + // Null overrides a value + let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap(); + assert_eq!(res, j!({"hello": null})); + + // Null gets overridden by a value + let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap(); + assert_eq!(res, j!({"hello": "world"})); + + // Objects get deeply merged + let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap(); + assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}})); + } } diff --git a/crates/templates/Cargo.toml b/crates/templates/Cargo.toml index 696a517b9..68bbadb1d 100644 --- a/crates/templates/Cargo.toml +++ b/crates/templates/Cargo.toml @@ -12,7 +12,7 @@ publish = false workspace = true [dependencies] -arc-swap = "1.7.1" +arc-swap.workspace = true tracing.workspace = true tokio.workspace = true walkdir = "2.5.0" From c3296a2e22f6aaea331da88ee815d37532f849f1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Feb 2025 16:22:42 +0100 Subject: [PATCH 4/7] Make the admin API update the local policy data --- crates/cli/src/app_state.rs | 6 +++++ crates/handlers/src/admin/mod.rs | 4 ++++ crates/handlers/src/admin/v1/mod.rs | 4 ++++ .../handlers/src/admin/v1/policy_data/set.rs | 23 +++++++++++++++++-- crates/handlers/src/bin/api-schema.rs | 3 ++- crates/handlers/src/test_utils.rs | 6 +++++ docs/api/spec.json | 23 +++++++++++++++++++ 7 files changed, 66 insertions(+), 3 deletions(-) diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index 6f8f059f0..4a1f89dee 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -204,6 +204,12 @@ impl FromRef for Limiter { } } +impl FromRef for Arc { + fn from_ref(input: &AppState) -> Self { + input.policy_factory.clone() + } +} + impl FromRef for BoxHomeserverConnection { fn from_ref(input: &AppState) -> Self { Box::new(input.homeserver_connection.clone()) diff --git a/crates/handlers/src/admin/mod.rs b/crates/handlers/src/admin/mod.rs index 4e65ba742..71e723a0c 100644 --- a/crates/handlers/src/admin/mod.rs +++ b/crates/handlers/src/admin/mod.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use aide::{ axum::ApiRouter, openapi::{OAuth2Flow, OAuth2Flows, OpenApi, SecurityScheme, Server, Tag}, @@ -20,6 +22,7 @@ use indexmap::IndexMap; use mas_axum_utils::FancyError; use mas_http::CorsLayerExt; use mas_matrix::BoxHomeserverConnection; +use mas_policy::PolicyFactory; use mas_router::{ ApiDoc, ApiDocCallback, OAuth2AuthorizationEndpoint, OAuth2TokenEndpoint, Route, SimpleRoute, UrlBuilder, @@ -118,6 +121,7 @@ where CallContext: FromRequestParts, Templates: FromRef, UrlBuilder: FromRef, + Arc: FromRef, { // We *always* want to explicitly set the possible responses, beacuse the // infered ones are not necessarily correct diff --git a/crates/handlers/src/admin/v1/mod.rs b/crates/handlers/src/admin/v1/mod.rs index ae258c842..590f65b58 100644 --- a/crates/handlers/src/admin/v1/mod.rs +++ b/crates/handlers/src/admin/v1/mod.rs @@ -4,12 +4,15 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use aide::axum::{ ApiRouter, routing::{get_with, post_with}, }; use axum::extract::{FromRef, FromRequestParts}; use mas_matrix::BoxHomeserverConnection; +use mas_policy::PolicyFactory; use mas_storage::BoxRng; use super::call_context::CallContext; @@ -28,6 +31,7 @@ where S: Clone + Send + Sync + 'static, BoxHomeserverConnection: FromRef, PasswordManager: FromRef, + Arc: FromRef, BoxRng: FromRequestParts, CallContext: FromRequestParts, { diff --git a/crates/handlers/src/admin/v1/policy_data/set.rs b/crates/handlers/src/admin/v1/policy_data/set.rs index aa3edfe5a..b857b488b 100644 --- a/crates/handlers/src/admin/v1/policy_data/set.rs +++ b/crates/handlers/src/admin/v1/policy_data/set.rs @@ -2,9 +2,12 @@ // // SPDX-License-Identifier: AGPL-3.0-only +use std::sync::Arc; + use aide::{NoApi, OperationIo, transform::TransformOperation}; -use axum::{Json, response::IntoResponse}; +use axum::{Json, extract::State, response::IntoResponse}; use hyper::StatusCode; +use mas_policy::PolicyFactory; use mas_storage::BoxRng; use schemars::JsonSchema; use serde::Deserialize; @@ -21,6 +24,9 @@ use crate::{ #[derive(Debug, thiserror::Error, OperationIo)] #[aide(output_with = "Json")] pub enum RouteError { + #[error("Failed to instanciate policy with the provided data")] + InvalidPolicyData(#[from] mas_policy::LoadError), + #[error(transparent)] Internal(Box), } @@ -30,7 +36,10 @@ impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { let error = ErrorResponse::from_error(&self); - let status = StatusCode::INTERNAL_SERVER_ERROR; + let status = match self { + RouteError::InvalidPolicyData(_) => StatusCode::BAD_REQUEST, + RouteError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR, + }; (status, Json(error)).into_response() } } @@ -62,6 +71,12 @@ pub fn doc(operation: TransformOperation) -> TransformOperation { t.description("Policy data was successfully set") .example(response) }) + .response_with::<400, Json, _>(|t| { + let error = ErrorResponse::from_error(&RouteError::InvalidPolicyData( + mas_policy::LoadError::invalid_data_example(), + )); + t.description("Invalid policy data").example(error) + }) } #[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all, err)] @@ -70,6 +85,7 @@ pub async fn handler( mut repo, clock, .. }: CallContext, NoApi(mut rng): NoApi, + State(policy_factory): State>, Json(request): Json, ) -> Result<(StatusCode, Json>), RouteError> { let policy_data = repo @@ -77,6 +93,9 @@ pub async fn handler( .set(&mut rng, &clock, request.data) .await?; + // Swap the policy data. This will fail if the policy data is invalid + policy_factory.set_dynamic_data(policy_data.clone()).await?; + repo.save().await?; Ok(( diff --git a/crates/handlers/src/bin/api-schema.rs b/crates/handlers/src/bin/api-schema.rs index 993719564..8005f5809 100644 --- a/crates/handlers/src/bin/api-schema.rs +++ b/crates/handlers/src/bin/api-schema.rs @@ -13,7 +13,7 @@ )] #![warn(clippy::pedantic)] -use std::io::Write; +use std::{io::Write, sync::Arc}; use aide::openapi::{Server, ServerVariable}; use indexmap::IndexMap; @@ -58,6 +58,7 @@ impl_from_ref!(mas_templates::Templates); impl_from_ref!(mas_matrix::BoxHomeserverConnection); impl_from_ref!(mas_keystore::Keystore); impl_from_ref!(mas_handlers::passwords::PasswordManager); +impl_from_ref!(Arc); fn main() -> Result<(), Box> { let (mut api, _) = mas_handlers::admin_api_router::(); diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 96073301a..457f9bdca 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -513,6 +513,12 @@ impl FromRef for SiteConfig { } } +impl FromRef for Arc { + fn from_ref(input: &TestState) -> Self { + input.policy_factory.clone() + } +} + impl FromRef for BoxHomeserverConnection { fn from_ref(input: &TestState) -> Self { Box::new(input.homeserver_connection.clone()) diff --git a/docs/api/spec.json b/docs/api/spec.json index 1fcc17116..d14b2c3f8 100644 --- a/docs/api/spec.json +++ b/docs/api/spec.json @@ -640,6 +640,29 @@ } } } + }, + "400": { + "description": "Invalid policy data", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "errors": [ + { + "title": "Failed to instanciate policy with the provided data" + }, + { + "title": "invalid policy data" + }, + { + "title": "Failed to merge policy data objects" + } + ] + } + } + } } } } From c8a33f00d3198cabd414ece567a8f073bdd87d4f Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Feb 2025 16:42:37 +0100 Subject: [PATCH 5/7] Regularly load the latest dynamic policy data from the database --- crates/cli/src/commands/debug.rs | 24 +++++++++--- crates/cli/src/commands/server.rs | 14 +++++-- crates/cli/src/util.rs | 65 ++++++++++++++++++++++++++++++- 3 files changed, 94 insertions(+), 9 deletions(-) diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 8768ae7b1..a82d8f059 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -1,4 +1,4 @@ -// Copyright 2024 New Vector Ltd. +// Copyright 2024, 2025 New Vector Ltd. // Copyright 2022-2024 The Matrix.org Foundation C.I.C. // // SPDX-License-Identifier: AGPL-3.0-only @@ -8,10 +8,14 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; -use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig}; +use mas_config::{ + ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig, +}; use tracing::{info, info_span}; -use crate::util::policy_factory_from_config; +use crate::util::{ + database_pool_from_config, load_policy_factory_dynamic_data, policy_factory_from_config, +}; #[derive(Parser, Debug)] pub(super) struct Options { @@ -22,7 +26,11 @@ pub(super) struct Options { #[derive(Parser, Debug)] enum Subcommand { /// Check that the policies compile - Policy, + Policy { + /// With dynamic data loaded + #[arg(long)] + with_dynamic_data: bool, + }, } impl Options { @@ -30,13 +38,19 @@ impl Options { pub async fn run(self, figment: &Figment) -> anyhow::Result { use Subcommand as SC; match self.subcommand { - SC::Policy => { + SC::Policy { with_dynamic_data } => { let _span = info_span!("cli.debug.policy").entered(); let config = PolicyConfig::extract_or_default(figment)?; let matrix_config = MatrixConfig::extract(figment)?; info!("Loading and compiling the policy module"); let policy_factory = policy_factory_from_config(&config, &matrix_config).await?; + if with_dynamic_data { + let database_config = DatabaseConfig::extract(figment)?; + let pool = database_pool_from_config(&database_config).await?; + load_policy_factory_dynamic_data(&policy_factory, &pool).await?; + } + let _instance = policy_factory.instantiate().await?; } } diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index d58fcb9da..811027594 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -26,9 +26,9 @@ use crate::{ app_state::AppState, lifecycle::LifecycleManager, util::{ - database_pool_from_config, mailer_from_config, password_manager_from_config, - policy_factory_from_config, site_config_from_config, templates_from_config, - test_mailer_in_background, + database_pool_from_config, load_policy_factory_dynamic_data_continuously, + mailer_from_config, password_manager_from_config, policy_factory_from_config, + site_config_from_config, templates_from_config, test_mailer_in_background, }, }; @@ -130,6 +130,14 @@ impl Options { let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?; let policy_factory = Arc::new(policy_factory); + load_policy_factory_dynamic_data_continuously( + &policy_factory, + &pool, + shutdown.soft_shutdown_token(), + shutdown.task_tracker(), + ) + .await?; + let url_builder = UrlBuilder::new( config.http.public_base.clone(), config.http.issuer.clone(), diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 0f155afc7..118cc4a1b 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -4,7 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use anyhow::Context; use mas_config::{ @@ -17,11 +17,14 @@ use mas_email::{MailTransport, Mailer}; use mas_handlers::passwords::PasswordManager; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; +use mas_storage::RepositoryAccess; +use mas_storage_pg::PgRepository; use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates}; use sqlx::{ ConnectOptions, PgConnection, PgPool, postgres::{PgConnectOptions, PgPoolOptions}, }; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tracing::{Instrument, log::LevelFilter}; pub async fn password_manager_from_config( @@ -346,6 +349,66 @@ pub async fn database_connection_from_config( .context("could not connect to the database") } +/// Update the policy factory dynamic data from the database and spawn a task to +/// periodically update it +// XXX: this could be put somewhere else? +pub async fn load_policy_factory_dynamic_data_continuously( + policy_factory: &Arc, + pool: &PgPool, + cancellation_token: CancellationToken, + task_tracker: &TaskTracker, +) -> Result<(), anyhow::Error> { + let policy_factory = policy_factory.clone(); + let pool = pool.clone(); + + load_policy_factory_dynamic_data(&policy_factory, &pool).await?; + + task_tracker.spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + + loop { + tokio::select! { + () = cancellation_token.cancelled() => { + return; + } + _ = interval.tick() => {} + } + + if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await { + tracing::error!( + error = ?err, + "Failed to load policy factory dynamic data" + ); + cancellation_token.cancel(); + return; + } + } + }); + + Ok(()) +} + +/// Update the policy factory dynamic data from the database +#[tracing::instrument(name = "policy.load_dynamic_data", skip_all, err(Debug))] +pub async fn load_policy_factory_dynamic_data( + policy_factory: &PolicyFactory, + pool: &PgPool, +) -> Result<(), anyhow::Error> { + let mut repo = PgRepository::from_pool(pool) + .await + .context("Failed to acquire database connection")?; + + if let Some(data) = repo.policy_data().get().await? { + let id = data.id; + let updated = policy_factory.set_dynamic_data(data).await?; + if updated { + tracing::info!(policy_data.id = %id, "Loaded dynamic policy data from the database"); + } + } + + Ok(()) +} + #[cfg(test)] mod tests { use rand::SeedableRng; From 97d2b757c776585c90a60df283ae45eb023a2e3a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 13 Mar 2025 13:27:32 +0100 Subject: [PATCH 6/7] Add a comment on the migration stating that we keep an history of the policy data --- .../migrations/20250225091000_dynamic_policy_data.sql | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/storage-pg/migrations/20250225091000_dynamic_policy_data.sql b/crates/storage-pg/migrations/20250225091000_dynamic_policy_data.sql index 4a0984925..38f87d100 100644 --- a/crates/storage-pg/migrations/20250225091000_dynamic_policy_data.sql +++ b/crates/storage-pg/migrations/20250225091000_dynamic_policy_data.sql @@ -4,6 +4,10 @@ -- Please see LICENSE in the repository root for full details. -- Add a table which stores the latest policy data +-- +-- Every time the policy data is updated, it creates a new row, so that we keep +-- an history of the policy data, trace back which version of the data was used +-- on each evaluation. CREATE TABLE IF NOT EXISTS policy_data ( policy_data_id UUID PRIMARY KEY, created_at TIMESTAMP WITH TIME ZONE NOT NULL, From 8581ca19efb56461862805f3ed9eb3a9a6bdb63b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 13 Mar 2025 13:40:59 +0100 Subject: [PATCH 7/7] Prune stale policy data once a day --- crates/storage/src/queue/tasks.rs | 8 ++++++++ crates/tasks/src/database.rs | 27 ++++++++++++++++++++++++++- crates/tasks/src/lib.rs | 7 +++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/crates/storage/src/queue/tasks.rs b/crates/storage/src/queue/tasks.rs index 2172edfdc..b0075f319 100644 --- a/crates/storage/src/queue/tasks.rs +++ b/crates/storage/src/queue/tasks.rs @@ -506,3 +506,11 @@ impl ExpireInactiveUserSessionsJob { impl InsertableJob for ExpireInactiveUserSessionsJob { const QUEUE_NAME: &'static str = "expire-inactive-user-sessions"; } + +/// Prune stale policy data +#[derive(Debug, Serialize, Deserialize)] +pub struct PruneStalePolicyDataJob; + +impl InsertableJob for PruneStalePolicyDataJob { + const QUEUE_NAME: &'static str = "prune-stale-policy-data"; +} diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index b68a603b8..24cfd43d8 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -7,7 +7,7 @@ //! Database-related tasks use async_trait::async_trait; -use mas_storage::queue::CleanupExpiredTokensJob; +use mas_storage::queue::{CleanupExpiredTokensJob, PruneStalePolicyDataJob}; use tracing::{debug, info}; use crate::{ @@ -38,3 +38,28 @@ impl RunnableJob for CleanupExpiredTokensJob { Ok(()) } } + +#[async_trait] +impl RunnableJob for PruneStalePolicyDataJob { + #[tracing::instrument(name = "job.prune_stale_policy_data", skip_all, err)] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { + let mut repo = state.repository().await.map_err(JobError::retry)?; + + // Keep the last 10 policy data + let count = repo + .policy_data() + .prune(10) + .await + .map_err(JobError::retry)?; + + repo.save().await.map_err(JobError::retry)?; + + if count == 0 { + debug!("no stale policy data to prune"); + } else { + info!(count, "pruned stale policy data"); + } + + Ok(()) + } +} diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 41ec78fc8..5ef7f5e84 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -143,6 +143,7 @@ pub async fn init( .register_handler::() .register_handler::() .register_handler::() + .register_handler::() .add_schedule( "cleanup-expired-tokens", "0 0 * * * *".parse()?, @@ -153,6 +154,12 @@ pub async fn init( // Run this job every 15 minutes "30 */15 * * * *".parse()?, mas_storage::queue::ExpireInactiveSessionsJob, + ) + .add_schedule( + "prune-stale-policy-data", + // Run once a day + "0 0 2 * * *".parse()?, + mas_storage::queue::PruneStalePolicyDataJob, ); task_tracker.spawn(worker.run());