Skip to content

Commit c8a33f0

Browse files
committed
Regularly load the latest dynamic policy data from the database
1 parent c3296a2 commit c8a33f0

File tree

3 files changed

+94
-9
lines changed

3 files changed

+94
-9
lines changed

crates/cli/src/commands/debug.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 New Vector Ltd.
1+
// Copyright 2024, 2025 New Vector Ltd.
22
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
33
//
44
// SPDX-License-Identifier: AGPL-3.0-only
@@ -8,10 +8,14 @@ use std::process::ExitCode;
88

99
use clap::Parser;
1010
use figment::Figment;
11-
use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig};
11+
use mas_config::{
12+
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
13+
};
1214
use tracing::{info, info_span};
1315

14-
use crate::util::policy_factory_from_config;
16+
use crate::util::{
17+
database_pool_from_config, load_policy_factory_dynamic_data, policy_factory_from_config,
18+
};
1519

1620
#[derive(Parser, Debug)]
1721
pub(super) struct Options {
@@ -22,21 +26,31 @@ pub(super) struct Options {
2226
#[derive(Parser, Debug)]
2327
enum Subcommand {
2428
/// Check that the policies compile
25-
Policy,
29+
Policy {
30+
/// With dynamic data loaded
31+
#[arg(long)]
32+
with_dynamic_data: bool,
33+
},
2634
}
2735

2836
impl Options {
2937
#[tracing::instrument(skip_all)]
3038
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
3139
use Subcommand as SC;
3240
match self.subcommand {
33-
SC::Policy => {
41+
SC::Policy { with_dynamic_data } => {
3442
let _span = info_span!("cli.debug.policy").entered();
3543
let config = PolicyConfig::extract_or_default(figment)?;
3644
let matrix_config = MatrixConfig::extract(figment)?;
3745
info!("Loading and compiling the policy module");
3846
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
3947

48+
if with_dynamic_data {
49+
let database_config = DatabaseConfig::extract(figment)?;
50+
let pool = database_pool_from_config(&database_config).await?;
51+
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
52+
}
53+
4054
let _instance = policy_factory.instantiate().await?;
4155
}
4256
}

crates/cli/src/commands/server.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ use crate::{
2626
app_state::AppState,
2727
lifecycle::LifecycleManager,
2828
util::{
29-
database_pool_from_config, mailer_from_config, password_manager_from_config,
30-
policy_factory_from_config, site_config_from_config, templates_from_config,
31-
test_mailer_in_background,
29+
database_pool_from_config, load_policy_factory_dynamic_data_continuously,
30+
mailer_from_config, password_manager_from_config, policy_factory_from_config,
31+
site_config_from_config, templates_from_config, test_mailer_in_background,
3232
},
3333
};
3434

@@ -130,6 +130,14 @@ impl Options {
130130
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
131131
let policy_factory = Arc::new(policy_factory);
132132

133+
load_policy_factory_dynamic_data_continuously(
134+
&policy_factory,
135+
&pool,
136+
shutdown.soft_shutdown_token(),
137+
shutdown.task_tracker(),
138+
)
139+
.await?;
140+
133141
let url_builder = UrlBuilder::new(
134142
config.http.public_base.clone(),
135143
config.http.issuer.clone(),

crates/cli/src/util.rs

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// SPDX-License-Identifier: AGPL-3.0-only
55
// Please see LICENSE in the repository root for full details.
66

7-
use std::time::Duration;
7+
use std::{sync::Arc, time::Duration};
88

99
use anyhow::Context;
1010
use mas_config::{
@@ -17,11 +17,14 @@ use mas_email::{MailTransport, Mailer};
1717
use mas_handlers::passwords::PasswordManager;
1818
use mas_policy::PolicyFactory;
1919
use mas_router::UrlBuilder;
20+
use mas_storage::RepositoryAccess;
21+
use mas_storage_pg::PgRepository;
2022
use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates};
2123
use sqlx::{
2224
ConnectOptions, PgConnection, PgPool,
2325
postgres::{PgConnectOptions, PgPoolOptions},
2426
};
27+
use tokio_util::{sync::CancellationToken, task::TaskTracker};
2528
use tracing::{Instrument, log::LevelFilter};
2629

2730
pub async fn password_manager_from_config(
@@ -346,6 +349,66 @@ pub async fn database_connection_from_config(
346349
.context("could not connect to the database")
347350
}
348351

352+
/// Update the policy factory dynamic data from the database and spawn a task to
353+
/// periodically update it
354+
// XXX: this could be put somewhere else?
355+
pub async fn load_policy_factory_dynamic_data_continuously(
356+
policy_factory: &Arc<PolicyFactory>,
357+
pool: &PgPool,
358+
cancellation_token: CancellationToken,
359+
task_tracker: &TaskTracker,
360+
) -> Result<(), anyhow::Error> {
361+
let policy_factory = policy_factory.clone();
362+
let pool = pool.clone();
363+
364+
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
365+
366+
task_tracker.spawn(async move {
367+
let mut interval = tokio::time::interval(Duration::from_secs(60));
368+
369+
loop {
370+
tokio::select! {
371+
() = cancellation_token.cancelled() => {
372+
return;
373+
}
374+
_ = interval.tick() => {}
375+
}
376+
377+
if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await {
378+
tracing::error!(
379+
error = ?err,
380+
"Failed to load policy factory dynamic data"
381+
);
382+
cancellation_token.cancel();
383+
return;
384+
}
385+
}
386+
});
387+
388+
Ok(())
389+
}
390+
391+
/// Update the policy factory dynamic data from the database
392+
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all, err(Debug))]
393+
pub async fn load_policy_factory_dynamic_data(
394+
policy_factory: &PolicyFactory,
395+
pool: &PgPool,
396+
) -> Result<(), anyhow::Error> {
397+
let mut repo = PgRepository::from_pool(pool)
398+
.await
399+
.context("Failed to acquire database connection")?;
400+
401+
if let Some(data) = repo.policy_data().get().await? {
402+
let id = data.id;
403+
let updated = policy_factory.set_dynamic_data(data).await?;
404+
if updated {
405+
tracing::info!(policy_data.id = %id, "Loaded dynamic policy data from the database");
406+
}
407+
}
408+
409+
Ok(())
410+
}
411+
349412
#[cfg(test)]
350413
mod tests {
351414
use rand::SeedableRng;

0 commit comments

Comments
 (0)