4
4
// SPDX-License-Identifier: AGPL-3.0-only
5
5
// Please see LICENSE in the repository root for full details.
6
6
7
- use std:: time:: Duration ;
7
+ use std:: { sync :: Arc , time:: Duration } ;
8
8
9
9
use anyhow:: Context ;
10
10
use mas_config:: {
@@ -17,11 +17,14 @@ use mas_email::{MailTransport, Mailer};
17
17
use mas_handlers:: passwords:: PasswordManager ;
18
18
use mas_policy:: PolicyFactory ;
19
19
use mas_router:: UrlBuilder ;
20
+ use mas_storage:: RepositoryAccess ;
21
+ use mas_storage_pg:: PgRepository ;
20
22
use mas_templates:: { SiteConfigExt , TemplateLoadingError , Templates } ;
21
23
use sqlx:: {
22
24
ConnectOptions , PgConnection , PgPool ,
23
25
postgres:: { PgConnectOptions , PgPoolOptions } ,
24
26
} ;
27
+ use tokio_util:: { sync:: CancellationToken , task:: TaskTracker } ;
25
28
use tracing:: { Instrument , log:: LevelFilter } ;
26
29
27
30
pub async fn password_manager_from_config (
@@ -346,6 +349,66 @@ pub async fn database_connection_from_config(
346
349
. context ( "could not connect to the database" )
347
350
}
348
351
352
+ /// Update the policy factory dynamic data from the database and spawn a task to
353
+ /// periodically update it
354
+ // XXX: this could be put somewhere else?
355
+ pub async fn load_policy_factory_dynamic_data_continuously (
356
+ policy_factory : & Arc < PolicyFactory > ,
357
+ pool : & PgPool ,
358
+ cancellation_token : CancellationToken ,
359
+ task_tracker : & TaskTracker ,
360
+ ) -> Result < ( ) , anyhow:: Error > {
361
+ let policy_factory = policy_factory. clone ( ) ;
362
+ let pool = pool. clone ( ) ;
363
+
364
+ load_policy_factory_dynamic_data ( & policy_factory, & pool) . await ?;
365
+
366
+ task_tracker. spawn ( async move {
367
+ let mut interval = tokio:: time:: interval ( Duration :: from_secs ( 60 ) ) ;
368
+
369
+ loop {
370
+ tokio:: select! {
371
+ ( ) = cancellation_token. cancelled( ) => {
372
+ return ;
373
+ }
374
+ _ = interval. tick( ) => { }
375
+ }
376
+
377
+ if let Err ( err) = load_policy_factory_dynamic_data ( & policy_factory, & pool) . await {
378
+ tracing:: error!(
379
+ error = ?err,
380
+ "Failed to load policy factory dynamic data"
381
+ ) ;
382
+ cancellation_token. cancel ( ) ;
383
+ return ;
384
+ }
385
+ }
386
+ } ) ;
387
+
388
+ Ok ( ( ) )
389
+ }
390
+
391
+ /// Update the policy factory dynamic data from the database
392
+ #[ tracing:: instrument( name = "policy.load_dynamic_data" , skip_all, err( Debug ) ) ]
393
+ pub async fn load_policy_factory_dynamic_data (
394
+ policy_factory : & PolicyFactory ,
395
+ pool : & PgPool ,
396
+ ) -> Result < ( ) , anyhow:: Error > {
397
+ let mut repo = PgRepository :: from_pool ( pool)
398
+ . await
399
+ . context ( "Failed to acquire database connection" ) ?;
400
+
401
+ if let Some ( data) = repo. policy_data ( ) . get ( ) . await ? {
402
+ let id = data. id ;
403
+ let updated = policy_factory. set_dynamic_data ( data) . await ?;
404
+ if updated {
405
+ tracing:: info!( policy_data. id = %id, "Loaded dynamic policy data from the database" ) ;
406
+ }
407
+ }
408
+
409
+ Ok ( ( ) )
410
+ }
411
+
349
412
#[ cfg( test) ]
350
413
mod tests {
351
414
use rand:: SeedableRng ;
0 commit comments