diff --git a/.sqlx/query-02d6faf0ffd8fb96031f7a8e824de4e4552ea72176c5e885cea3e517316e7774.json b/.sqlx/query-77054c7c20195a2a7766eb83bf39add56ec84bb502ce39d26cd9394703bdd71a.json similarity index 75% rename from .sqlx/query-02d6faf0ffd8fb96031f7a8e824de4e4552ea72176c5e885cea3e517316e7774.json rename to .sqlx/query-77054c7c20195a2a7766eb83bf39add56ec84bb502ce39d26cd9394703bdd71a.json index 08fe186..681ebd0 100644 --- a/.sqlx/query-02d6faf0ffd8fb96031f7a8e824de4e4552ea72176c5e885cea3e517316e7774.json +++ b/.sqlx/query-77054c7c20195a2a7766eb83bf39add56ec84bb502ce39d26cd9394703bdd71a.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT * FROM beamline", + "query": "INSERT INTO beamline\n (name, scan_number, visit, scan, detector, fallback_extension)\n VALUES\n (?,?,?,?,?,?)\n RETURNING *", "describe": { "columns": [ { @@ -33,19 +33,14 @@ "ordinal": 5, "type_info": "Text" }, - { - "name": "fallback_directory", - "ordinal": 6, - "type_info": "Text" - }, { "name": "fallback_extension", - "ordinal": 7, + "ordinal": 6, "type_info": "Text" } ], "parameters": { - "Right": 0 + "Right": 6 }, "nullable": [ false, @@ -54,9 +49,8 @@ false, false, false, - true, true ] }, - "hash": "02d6faf0ffd8fb96031f7a8e824de4e4552ea72176c5e885cea3e517316e7774" + "hash": "77054c7c20195a2a7766eb83bf39add56ec84bb502ce39d26cd9394703bdd71a" } diff --git a/.sqlx/query-7b769dea685f49e8ff6d24a69e69e7037c1dc197db8bd273ec53e63a85e85351.json b/.sqlx/query-7b769dea685f49e8ff6d24a69e69e7037c1dc197db8bd273ec53e63a85e85351.json index 9e5570b..7c487d0 100644 --- a/.sqlx/query-7b769dea685f49e8ff6d24a69e69e7037c1dc197db8bd273ec53e63a85e85351.json +++ b/.sqlx/query-7b769dea685f49e8ff6d24a69e69e7037c1dc197db8bd273ec53e63a85e85351.json @@ -33,14 +33,9 @@ "ordinal": 5, "type_info": "Text" }, - { - "name": "fallback_directory", - "ordinal": 6, - "type_info": "Text" - }, { "name": "fallback_extension", - "ordinal": 7, + "ordinal": 6, "type_info": "Text" } ], @@ -54,7 +49,6 @@ false, false, false, - true, true ] }, diff --git a/.sqlx/query-e45d346b58374c69e4f3bb59935719177993012eb03ce8f2d9bcca00290690be.json b/.sqlx/query-e45d346b58374c69e4f3bb59935719177993012eb03ce8f2d9bcca00290690be.json index d1d76d8..565f1f3 100644 --- a/.sqlx/query-e45d346b58374c69e4f3bb59935719177993012eb03ce8f2d9bcca00290690be.json +++ b/.sqlx/query-e45d346b58374c69e4f3bb59935719177993012eb03ce8f2d9bcca00290690be.json @@ -33,14 +33,9 @@ "ordinal": 5, "type_info": "Text" }, - { - "name": "fallback_directory", - "ordinal": 6, - "type_info": "Text" - }, { "name": "fallback_extension", - "ordinal": 7, + "ordinal": 6, "type_info": "Text" } ], @@ -54,7 +49,6 @@ false, false, false, - true, true ] }, diff --git a/.sqlx/query-e671001b8f5b99f025043d31fcc036c58692913de6a6fd1ef4bdb8177f1235fa.json b/.sqlx/query-e671001b8f5b99f025043d31fcc036c58692913de6a6fd1ef4bdb8177f1235fa.json deleted file mode 100644 index 7dac74f..0000000 --- a/.sqlx/query-e671001b8f5b99f025043d31fcc036c58692913de6a6fd1ef4bdb8177f1235fa.json +++ /dev/null @@ -1,62 +0,0 @@ -{ - "db_name": "SQLite", - "query": "INSERT INTO beamline\n (name, scan_number, visit, scan, detector, fallback_directory, fallback_extension)\n VALUES\n (?,?,?,?,?,?,?)\n RETURNING *", - "describe": { - "columns": [ - { - "name": "id", - "ordinal": 0, - "type_info": "Integer" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "scan_number", - "ordinal": 2, - "type_info": "Integer" - }, - { - "name": "visit", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "scan", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "detector", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "fallback_directory", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "fallback_extension", - "ordinal": 7, - "type_info": "Text" - } - ], - "parameters": { - "Right": 7 - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - true, - true - ] - }, - "hash": "e671001b8f5b99f025043d31fcc036c58692913de6a6fd1ef4bdb8177f1235fa" -} diff --git a/migrations/0001_init.up.sql b/migrations/0001_init.up.sql index d4121a3..c9936a2 100644 --- a/migrations/0001_init.up.sql +++ b/migrations/0001_init.up.sql @@ -9,12 +9,6 @@ CREATE TABLE beamline ( scan TEXT NOT NULL CHECK (length(scan) > 0), detector TEXT NOT NULL CHECK (length(detector) > 0), - fallback_directory TEXT, - fallback_extension TEXT, - - -- Ensure fallback number files don't collide - UNIQUE(fallback_directory, fallback_extension), - - -- Require a directory to be set if the extension is present - CHECK (fallback_extension ISNULL OR fallback_directory NOTNULL) + -- Override file tracker extension - defaults to beamline name + fallback_extension TEXT ); diff --git a/src/cli.rs b/src/cli.rs index d781b89..4f25e66 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -58,6 +58,9 @@ pub struct ServeOptions { /// The port to open for requests #[clap(short, long, default_value_t = 8000, env = "NUMTRACKER_PORT")] port: u16, + /// The root directory for external number tracking + #[clap(long, env = "NUMTRACKER_ROOT_DIRECTORY")] + root_directory: Option, } #[derive(Debug, Args)] @@ -110,6 +113,9 @@ impl ServeOptions { pub(crate) fn addr(&self) -> (Ipv4Addr, u16) { (self.host, self.port) } + pub(crate) fn root_directory(&self) -> Option { + self.root_directory.clone() + } } impl TracingOptions { diff --git a/src/db_service.rs b/src/db_service.rs index ff60c18..1350965 100644 --- a/src/db_service.rs +++ b/src/db_service.rs @@ -65,7 +65,7 @@ pub struct BeamlineConfiguration { visit: RawPathTemplate, scan: RawPathTemplate, detector: RawPathTemplate, - fallback: Option, + extension: Option, } impl BeamlineConfiguration { @@ -77,8 +77,8 @@ impl BeamlineConfiguration { self.scan_number } - pub fn fallback(&self) -> Option<&NumtrackerConfig> { - self.fallback.as_ref() + pub fn extension(&self) -> Option<&str> { + self.extension.as_deref() } pub fn visit(&self) -> SqliteTemplateResult { @@ -104,7 +104,6 @@ impl<'r> FromRow<'r, SqliteRow> for BeamlineConfiguration { scan: row.try_get::("scan")?, detector: row.try_get::("detector")?, fallback_extension: row.try_get::, _>("fallback_extension")?, - fallback_directory: row.try_get::, _>("fallback_directory")?, } .into()) } @@ -117,7 +116,6 @@ pub struct BeamlineConfigurationUpdate { pub visit: Option>, pub scan: Option>, pub detector: Option>, - pub directory: Option, pub extension: Option, } @@ -127,7 +125,6 @@ impl BeamlineConfigurationUpdate { && self.visit.is_none() && self.scan.is_none() && self.detector.is_none() - && self.directory.is_none() && self.extension.is_none() } @@ -160,10 +157,6 @@ impl BeamlineConfigurationUpdate { fields.push("detector="); fields.push_bind_unseparated(detector.to_string()); } - if let Some(dir) = &self.directory { - fields.push("fallback_directory="); - fields.push_bind_unseparated(dir); - } if let Some(ext) = &self.extension { if ext != &self.name { // extension defaults to beamline name @@ -194,7 +187,6 @@ impl BeamlineConfigurationUpdate { visit: self.visit.ok_or("visit")?.to_string(), scan: self.scan.ok_or("scan")?.to_string(), detector: self.detector.ok_or("detector")?.to_string(), - fallback_directory: self.directory, fallback_extension: self.extension, }; Ok(dbc.insert_into(db).await?) @@ -207,7 +199,6 @@ impl BeamlineConfigurationUpdate { visit: None, scan: None, detector: None, - directory: None, extension: None, } } @@ -222,7 +213,6 @@ struct DbBeamlineConfig { visit: String, scan: String, detector: String, - fallback_directory: Option, fallback_extension: Option, } @@ -234,16 +224,15 @@ impl DbBeamlineConfig { let bc = query_as!( DbBeamlineConfig, "INSERT INTO beamline - (name, scan_number, visit, scan, detector, fallback_directory, fallback_extension) + (name, scan_number, visit, scan, detector, fallback_extension) VALUES - (?,?,?,?,?,?,?) + (?,?,?,?,?,?) RETURNING *", self.name, self.scan_number, self.visit, self.scan, self.detector, - self.fallback_directory, self.fallback_extension ) .fetch_one(&db.pool) @@ -254,24 +243,13 @@ impl DbBeamlineConfig { impl From for BeamlineConfiguration { fn from(value: DbBeamlineConfig) -> Self { - let fallback = match (value.fallback_directory, value.fallback_extension) { - (None, _) => None, - (Some(dir), None) => Some(NumtrackerConfig { - directory: dir, - extension: value.name.clone(), - }), - (Some(dir), Some(ext)) => Some(NumtrackerConfig { - directory: dir, - extension: ext, - }), - }; Self { name: value.name, - scan_number: u32::try_from(value.scan_number).expect("Run out of scan numbers"), + scan_number: u32::try_from(value.scan_number).expect("Out of scan numbers"), visit: value.visit.into(), scan: value.scan.into(), detector: value.detector.into(), - fallback, + extension: value.fallback_extension, } } } @@ -461,7 +439,6 @@ mod db_tests { "{subdirectory}/{instrument}-{scan_number}-{detector}", ) .ok(), - directory: Some("/tmp/trackers".into()), extension: Some("ext".into()), } } @@ -506,16 +483,6 @@ mod db_tests { assert_eq!(bc.name(), "i22"); } - #[rstest] - #[test] - async fn extension_requires_directory(mut update: BeamlineConfigurationUpdate) { - let db = SqliteScanPathService::memory().await; - update.directory = None; - let e = err!(NewConfigurationError::Db, update.insert_new(&db)); - let e = *e.into_database_error().unwrap().downcast::(); - assert_eq!(e.kind(), ErrorKind::CheckViolation) - } - #[rstest] #[test] async fn read_only_db_propagates_errors(update: BeamlineConfigurationUpdate) { @@ -597,11 +564,10 @@ mod db_tests { conf.detector().unwrap().to_string(), "{subdirectory}/{instrument}-{scan_number}-{detector}" ); - let Some(fb) = conf.fallback() else { - panic!("Missing fallback configuration"); + let Some(ext) = conf.extension() else { + panic!("Missing extension"); }; - assert_eq!(fb.directory, "/tmp/trackers"); - assert_eq!(fb.extension, "ext"); + assert_eq!(ext, "ext"); } type Update = BeamlineConfigurationUpdate; @@ -619,12 +585,9 @@ mod db_tests { #[case::scan_number( |u: &mut Update| u.scan_number = Some(42), |u: BeamlineConfiguration| assert_eq!(u.scan_number(), 42))] - #[case::directory( - |u: &mut Update| u.directory = Some("/new_trackers".into()), - |u: BeamlineConfiguration| assert_eq!(u.fallback().unwrap().directory, "/new_trackers"))] #[case::extension( |u: &mut Update| u.extension = Some("new".into()), - |u: BeamlineConfiguration| assert_eq!(u.fallback().unwrap().extension, "new"))] + |u: BeamlineConfiguration| assert_eq!(u.extension().unwrap(), "new"))] #[tokio::test] async fn update_existing( #[future(awt)] db: SqliteScanPathService, @@ -636,4 +599,11 @@ mod db_tests { let bc = ok!(upd.update_beamline(&db)).expect("Updated beamline missing"); check(bc) } + + #[rstest] + #[tokio::test] + async fn empty_update(#[future(awt)] db: SqliteScanPathService) { + let upd = BeamlineConfigurationUpdate::empty("b21"); + assert!(ok!(upd.update_beamline(&db)).is_none()); + } } diff --git a/src/graphql.rs b/src/graphql.rs index 7867f52..cd67258 100644 --- a/src/graphql.rs +++ b/src/graphql.rs @@ -38,7 +38,7 @@ use crate::cli::ServeOptions; use crate::db_service::{ BeamlineConfiguration, BeamlineConfigurationUpdate, SqliteScanPathService, }; -use crate::numtracker::GdaNumTracker; +use crate::numtracker::NumTracker; use crate::paths::{ BeamlineField, DetectorField, DetectorTemplate, PathSpec, ScanField, ScanTemplate, VisitTemplate, @@ -49,10 +49,13 @@ pub async fn serve_graphql(db: &Path, opts: ServeOptions) { let db = SqliteScanPathService::connect(db) .await .expect("Unable to open DB"); + let directory_numtracker = NumTracker::for_root_directory(opts.root_directory()) + .expect("Could not read external directories"); info!("Serving graphql endpoints on {:?}", opts.addr()); let schema = Schema::build(Query, Mutation, EmptySubscription) .extension(Tracing) .data(db) + .data(directory_numtracker) .finish(); let app = Router::new() .route("/graphql", post(graphql_handler)) @@ -277,23 +280,19 @@ impl Mutation { sub: Option, ) -> async_graphql::Result { let db = ctx.data::()?; + let nt = ctx.data::()?; // There is a race condition here if a process increments the file // while the DB is being queried or between the two queries but there // isn't much we can do from here. let current = db.current_configuration(&beamline).await?; - let fallback = current - .fallback() - .and_then(|fb| GdaNumTracker::new(&fb.directory, &fb.extension).ok()); - let prev = match &fallback { - Some(nt) => Some(nt.latest_scan_number().await?), - None => None, - }; + let dir = nt.for_beamline(&beamline, current.extension()).await?; - let next_scan = db.next_scan_configuration(&beamline, prev).await?; - if let Some(nt) = &fallback { - if let Err(e) = nt.create_num_file(next_scan.scan_number()).await { - warn!("Failed to increment fallback tracker directory: {e}"); - } + let next_scan = db + .next_scan_configuration(&beamline, dir.prev().await?) + .await?; + + if let Err(e) = dir.set(next_scan.scan_number()).await { + warn!("Failed to increment fallback tracker directory: {e}"); } Ok(ScanPaths { @@ -328,7 +327,6 @@ struct ConfigurationUpdates { scan: Option>, detector: Option>, scan_number: Option, - directory: Option, extension: Option, } @@ -340,7 +338,6 @@ impl ConfigurationUpdates { visit: self.visit.map(|t| t.0), scan: self.scan.map(|t| t.0), detector: self.detector.map(|t| t.0), - directory: self.directory, extension: self.extension, } } diff --git a/src/numtracker.rs b/src/numtracker.rs index ea9d619..a0f1de8 100644 --- a/src/numtracker.rs +++ b/src/numtracker.rs @@ -12,32 +12,94 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; +use std::fmt::{self, Display}; use std::io::Error; use std::path::{Path, PathBuf}; use tokio::fs as async_fs; +use tokio::sync::{Mutex, MutexGuard}; use tracing::{instrument, trace}; -#[derive(Debug)] -pub struct GdaNumTracker<'e> { - ext: &'e str, - directory: &'e Path, +/// Central controller to access external directory trackers. Prevents concurrent access to the same +/// beamline's directory. +pub struct NumTracker { + bl_locks: HashMap>, +} + +impl NumTracker { + /// Build a numtracker than will provide locked access to subdirectories that exists and no-op + /// trackers for beamlines that do not have subdirectories. + pub fn for_root_directory(root: Option) -> Result { + let mut bl_locks: HashMap> = Default::default(); + if let Some(dir) = root { + for entry in dir.read_dir()? { + let dir = entry?; + if dir.file_type()?.is_dir() { + if let Ok(name) = dir.file_name().into_string() { + bl_locks.insert(name, Mutex::new(dir.path())); + } + } + } + } + + Ok(Self { bl_locks }) + } + + /// Create a wrapper around a subdirectory if one exists for the given beamline, or a no-op + /// tracker if a directory does not exist. + pub async fn for_beamline<'nt, 'bl>( + &'nt self, + bl: &'bl str, + ext: Option<&'bl str>, + ) -> Result, InvalidExtension> { + if !ext.is_none_or(Self::valid_extension) { + return Err(InvalidExtension); + } + Ok(match self.bl_locks.get(bl) { + Some(dir) => DirectoryTracker::GdaDirectory(GdaNumTracker { + ext: ext.unwrap_or(bl), + directory: dir.lock().await, + }), + None => DirectoryTracker::NoDirectory, + }) + } + + fn valid_extension(name: &str) -> bool { + name.chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + } +} + +/// Number tracker for a directory that may or may not exist +pub enum DirectoryTracker<'nt, 'bl> { + NoDirectory, + GdaDirectory(GdaNumTracker<'nt, 'bl>), } -impl<'e> GdaNumTracker<'e> { - /// Create a new num tracker for the given directory and extension - pub fn new>(directory: &'e P, ext: &'e str) -> Result { - let directory = directory.as_ref(); - if ext.chars().all(char::is_alphanumeric) { - Ok(Self { ext, directory }) - } else { - Err(Error::new( - std::io::ErrorKind::InvalidInput, - format!("{ext:?} is not a valid extension"), - )) +impl DirectoryTracker<'_, '_> { + pub async fn prev(&self) -> Result, Error> { + match self { + DirectoryTracker::NoDirectory => Ok(None), + DirectoryTracker::GdaDirectory(gnt) => Some(gnt.latest_scan_number().await).transpose(), } } + pub async fn set(&self, num: u32) -> Result<(), Error> { + match self { + DirectoryTracker::NoDirectory => Ok(()), + DirectoryTracker::GdaDirectory(gnt) => gnt.create_num_file(num).await, + } + } +} + +#[derive(Debug)] +pub struct GdaNumTracker<'nt, 'bl> { + ext: &'bl str, + directory: MutexGuard<'nt, PathBuf>, +} + +impl GdaNumTracker<'_, '_> { /// Build the path of the file that would correspond to the given number fn file_name(&self, num: u32) -> PathBuf { self.directory @@ -48,7 +110,7 @@ impl<'e> GdaNumTracker<'e> { /// Create a file named for the given number and, if present, remove the file for the previous /// number. #[instrument] - pub async fn create_num_file(&self, num: u32) -> Result<(), Error> { + async fn create_num_file(&self, num: u32) -> Result<(), Error> { trace!("Creating new scan number file: {num}.{}", self.ext); let next = self.file_name(num); async_fs::OpenOptions::new() @@ -77,9 +139,9 @@ impl<'e> GdaNumTracker<'e> { } /// Find the highest number that has a corresponding number file in this tracker's directory - pub async fn latest_scan_number(&self) -> Result { + async fn latest_scan_number(&self) -> Result { let mut high = 0; - let mut dir = async_fs::read_dir(&self.directory).await?; + let mut dir = async_fs::read_dir(&*self.directory).await?; while let Some(file) = dir.next_entry().await? { if !file.file_type().await?.is_file() { continue; @@ -91,3 +153,15 @@ impl<'e> GdaNumTracker<'e> { Ok(high) } } + +/// Error returned when an extension would result in directory traversal - eg '.foo/../../bar' +#[derive(Debug, Clone, Copy)] +pub struct InvalidExtension; + +impl Display for InvalidExtension { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Extension is not valid") + } +} + +impl std::error::Error for InvalidExtension {}