diff --git a/Cargo.toml b/Cargo.toml index 5eac625..2d52a70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,4 +34,5 @@ assert_matches = "1.5.0" async-std = { version = "1.13.0", features = ["attributes"], default-features = false } httpmock = { version = "0.7.0", default-features = false } rstest = "0.23.0" +serde_json = "1.0.133" tempfile = "3.14.0" diff --git a/src/db_service.rs b/src/db_service.rs index a1bc855..c0e7a36 100644 --- a/src/db_service.rs +++ b/src/db_service.rs @@ -16,7 +16,8 @@ use std::fmt; use std::marker::PhantomData; use std::path::Path; -use error::{ConfigurationError, NewConfigurationError}; +pub use error::ConfigurationError; +use error::NewConfigurationError; use sqlx::sqlite::{SqliteConnectOptions, SqliteRow}; use sqlx::{query_as, FromRow, QueryBuilder, Row, Sqlite, SqlitePool}; use tracing::{info, instrument, trace}; @@ -300,7 +301,7 @@ impl SqliteScanPathService { } #[cfg(test)] - async fn ro_memory() -> Self { + pub(crate) async fn ro_memory() -> Self { let db = Self::memory().await; db.pool .set_connect_options(SqliteConnectOptions::new().read_only(true)); @@ -308,7 +309,7 @@ impl SqliteScanPathService { } #[cfg(test)] - async fn memory() -> Self { + pub(crate) async fn memory() -> Self { let pool = SqlitePool::connect(":memory:").await.unwrap(); sqlx::migrate!().run(&pool).await.unwrap(); Self { pool } diff --git a/src/graphql.rs b/src/graphql.rs index d3bd2d3..bac1ddd 100644 --- a/src/graphql.rs +++ b/src/graphql.rs @@ -17,6 +17,7 @@ use std::borrow::Cow; use std::error::Error; use std::fmt::Display; use std::future::Future; +use std::io::Write; use std::path::{Component, Path, PathBuf}; use std::sync::Arc; @@ -79,9 +80,9 @@ pub async fn serve_graphql(db: &Path, opts: ServeOptions) { .expect("Can't serve graphql endpoint"); } -pub fn graphql_schema() { +pub fn graphql_schema(mut out: W) -> Result<(), std::io::Error> { let schema = Schema::new(Query, Mutation, EmptySubscription); - println!("{}", schema.sdl()); + write!(out, "{}", schema.sdl()) } async fn graphiql() -> impl IntoResponse { @@ -526,6 +527,425 @@ impl Detector { } } +#[cfg(test)] +mod tests { + use std::error::Error; + use std::fs; + + use async_graphql::{EmptySubscription, InputType as _, Request, Schema, SchemaBuilder, Value}; + use axum::http::HeaderValue; + use axum_extra::headers::authorization::{Bearer, Credentials}; + use axum_extra::headers::Authorization; + use httpmock::MockServer; + use rstest::{fixture, rstest}; + use serde_json::json; + use tempfile::TempDir; + + use super::auth::PolicyCheck; + use super::{ConfigurationUpdates, InputTemplate, Mutation, Query}; + use crate::cli::PolicyOptions; + use crate::db_service::SqliteScanPathService; + use crate::graphql::graphql_schema; + use crate::numtracker::TempTracker; + + type NtSchema = Schema; + type NtBuilder = SchemaBuilder; + + struct TestEnv { + schema: NtSchema, + dir: TempDir, + db: SqliteScanPathService, + } + + struct TestAuthEnv { + schema: NtSchema, + dir: TempDir, + db: SqliteScanPathService, + server: MockServer, + } + + fn updates( + visit: Option<&str>, + scan: Option<&str>, + det: Option<&str>, + num: Option, + ext: Option<&str>, + ) -> ConfigurationUpdates { + ConfigurationUpdates { + visit: visit.map(|v| InputTemplate::parse(Some(Value::String(v.into()))).unwrap()), + scan: scan.map(|s| InputTemplate::parse(Some(Value::String(s.into()))).unwrap()), + detector: det.map(|d| InputTemplate::parse(Some(Value::String(d.into()))).unwrap()), + scan_number: num, + extension: ext.map(|e| e.into()), + } + } + + /// Helper for creating graphql values from literals + macro_rules! value { + ($tree:tt) => { + Value::from_json(json!($tree)).unwrap() + }; + } + + #[fixture] + async fn db() -> SqliteScanPathService { + let db = SqliteScanPathService::memory().await; + let cfg = updates( + Some("/tmp/{instrument}/data/{visit}/"), + Some("{subdirectory}/{instrument}-{scan_number}"), + Some("{subdirectory}/{instrument}-{scan_number}-{detector}"), + Some(122), + None, + ); + cfg.into_update("i22".into()).insert_new(&db).await.unwrap(); + let cfg = updates( + Some("/tmp/{instrument}/data/{visit}/"), + Some("{subdirectory}/{instrument}-{scan_number}"), + Some("{subdirectory}/{scan_number}/{instrument}-{scan_number}-{detector}"), + Some(621), + Some("b21_ext"), + ); + cfg.into_update("b21".into()).insert_new(&db).await.unwrap(); + db + } + + #[fixture] + async fn components( + #[future(awt)] db: SqliteScanPathService, + ) -> (NtBuilder, TempDir, SqliteScanPathService) { + let TempTracker(nt, dir) = TempTracker::new(|p| { + fs::create_dir(p.join("i22"))?; + fs::File::create_new(p.join("i22").join("122.i22"))?; + fs::create_dir(p.join("b21"))?; + fs::File::create_new(p.join("b21").join("211.b21_ext"))?; + Ok(()) + }); + ( + Schema::build(Query, Mutation, EmptySubscription) + .data(db.clone()) + .data(nt), + dir, + db, + ) + } + + #[fixture] + async fn env( + #[future(awt)] components: (NtBuilder, TempDir, SqliteScanPathService), + ) -> TestEnv { + TestEnv { + schema: components.0.data(Option::::None).finish(), + dir: components.1, + db: components.2, + } + } + + #[fixture] + async fn auth_env( + #[future(awt)] components: (NtBuilder, TempDir, SqliteScanPathService), + ) -> TestAuthEnv { + let server = MockServer::start(); + let check = PolicyCheck::new(PolicyOptions { + policy_host: server.url(""), + access_query: "demo/access".into(), + admin_query: "demo/admin".into(), + }); + TestAuthEnv { + schema: components.0.data(Some(check)).finish(), + dir: components.1, + db: components.2, + server, + } + } + + #[rstest] + #[tokio::test] + async fn missing_config(#[future(awt)] env: TestEnv) { + let result = env + .schema + .execute(r#"{paths(beamline: "i11", visit: "cm1234-5") {directory}}"#) + .await; + + assert_eq!(result.data, Value::Null); + println!("{result:?}"); + assert_eq!( + result.errors[0].message, + r#"No configuration available for beamline "i11""# + ); + } + + #[rstest] + #[tokio::test] + async fn paths(#[future(awt)] env: TestEnv) { + let result = env + .schema + .execute(r#"{paths(beamline: "i22", visit: "cm12345-3") {directory visit}}"#) + .await; + println!("{result:#?}"); + let exp = value!({"paths": {"visit": "cm12345-3", "directory": "/tmp/i22/data/cm12345-3"}}); + assert!(result.errors.is_empty()); + assert_eq!(result.data, exp); + } + + #[rstest] + #[tokio::test] + async fn scan(#[future(awt)] env: TestEnv) { + let query = r#"mutation { + scan(beamline: "i22", visit: "cm12345-3", sub: "foo/bar") { + visit { beamline directory visit} scanFile scanNumber + detectors(names: ["det_one", "det_two"]) { name path } + } + }"#; + let result = env.schema.execute(query).await; + + println!("{result:#?}"); + assert!(result.errors.is_empty()); + let exp = value!({ + "scan": { + "visit": {"visit": "cm12345-3", "beamline": "i22", "directory": "/tmp/i22/data/cm12345-3"}, + "scanFile": "foo/bar/i22-123", + "scanNumber": 123, + "detectors": [ + {"path": "foo/bar/i22-123-det_one", "name": "det_one"}, + {"path": "foo/bar/i22-123-det_two", "name": "det_two"} + ] + }}); + assert_eq!(result.data, exp); + } + + #[rstest] + #[tokio::test] + async fn configuration(#[future(awt)] env: TestEnv) { + let query = r#"{ + configuration(beamline: "i22") { + visitTemplate scanTemplate detectorTemplate latestScanNumber + }}"#; + let result = env.schema.execute(query).await; + let exp = value!({ + "configuration": { + "visitTemplate": "/tmp/{instrument}/data/{visit}", + "scanTemplate": "{subdirectory}/{instrument}-{scan_number}", + "detectorTemplate": "{subdirectory}/{instrument}-{scan_number}-{detector}", + "latestScanNumber": 122 + }}); + assert!(result.errors.is_empty()); + assert_eq!(result.data, exp); + } + + #[rstest] + #[tokio::test] + async fn empty_configure_for_existing(#[future(awt)] env: TestEnv) { + let query = r#"mutation { + configure(beamline: "i22", config: {}) { + visitTemplate scanTemplate detectorTemplate latestScanNumber + } + }"#; + let result = env.schema.execute(query).await; + let exp = value!({ + "configure": { + "visitTemplate": "/tmp/{instrument}/data/{visit}", + "scanTemplate": "{subdirectory}/{instrument}-{scan_number}", + "detectorTemplate": "{subdirectory}/{instrument}-{scan_number}-{detector}", + "latestScanNumber": 122 + } + }); + println!("{result:#?}"); + assert!(result.errors.is_empty()); + assert_eq!(result.data, exp); + } + + #[rstest] + #[tokio::test] + async fn configure_template_for_existing( + #[future(awt)] env: TestEnv, + ) -> Result<(), Box> { + let query = r#"mutation { + configure(beamline: "i22", config: { scan: "{instrument}-{scan_number}"}) { + scanTemplate + }}"#; + let result = env.schema.execute(query).await; + let exp = value!({"configure": { "scanTemplate": "{instrument}-{scan_number}"}}); + println!("{result:#?}"); + assert!(result.errors.is_empty()); + assert_eq!(result.data, exp); + let new = env + .db + .current_configuration("i22") + .await? + .scan()? + .to_string(); + assert_eq!(new, "{instrument}-{scan_number}"); + Ok(()) + } + + #[rstest] + #[tokio::test] + async fn configure_new_beamline(#[future(awt)] env: TestEnv) { + assert_matches::assert_matches!( + env.db.current_configuration("i16").await, + Err(crate::db_service::ConfigurationError::MissingBeamline(bl)) if bl == "i16" + ); + + let result = env + .schema + .execute( + r#"mutation { + configure(beamline: "i16", config: { + visit: "/tmp/{instrument}/{year}/{visit}" + scan: "{instrument}-{scan_number}" + detector: "{scan_number}-{detector}" + }) { + scanTemplate visitTemplate detectorTemplate latestScanNumber + } + }"#, + ) + .await; + let exp = value!({ "configure": { + "visitTemplate": "/tmp/{instrument}/{year}/{visit}", + "scanTemplate": "{instrument}-{scan_number}", + "detectorTemplate": "{scan_number}-{detector}", + "latestScanNumber": 0 + } }); + assert!(result.errors.is_empty()); + assert_eq!(result.data, exp); + _ = env.db.current_configuration("i16").await.unwrap(); + } + + #[rstest] + #[tokio::test] + async fn unauthorised_scan_request(#[future(awt)] auth_env: TestAuthEnv) { + let query = r#"mutation {scan(beamline: "i22", visit: "cm12345-3") { scanNumber }}"#; + let result = auth_env + .schema + .execute(Request::new(query).data(Option::>::None)) + .await; + + println!("{result:#?}"); + assert_eq!( + result.errors[0].message, + "No authentication token was provided" + ); + assert_eq!(result.data, Value::Null); + } + + #[rstest] + #[tokio::test] + async fn denied_scan_request(#[future(awt)] auth_env: TestAuthEnv) { + let query = r#"mutation{ scan(beamline: "i22", visit: "cm12345-3") { scanNumber }}"#; + let token = Some(Authorization( + Bearer::decode(&HeaderValue::from_str("Bearer token_value").unwrap()).unwrap(), + )); + let auth = auth_env + .server + .mock_async(|when, then| { + when.method("POST").path("/demo/access"); + then.status(200).body(r#"{"result": false}"#); + }) + .await; + let result = auth_env + .schema + .execute(Request::new(query).data(token)) + .await; + auth.assert(); + + println!("{result:#?}"); + assert_eq!(result.errors[0].message, "Authentication failed"); + assert_eq!(result.data, Value::Null); + + // Ensure that the number wasn't incremented + assert_eq!( + auth_env + .db + .current_configuration("i22") + .await + .unwrap() + .scan_number(), + 122 + ); + } + + #[rstest] + #[tokio::test] + async fn authorized_scan_request(#[future(awt)] auth_env: TestAuthEnv) { + let query = r#"mutation{ scan(beamline: "i22", visit: "cm12345-3") { scanNumber }}"#; + let token = Some(Authorization( + Bearer::decode(&HeaderValue::from_str("Bearer token_value").unwrap()).unwrap(), + )); + let auth = auth_env + .server + .mock_async(|when, then| { + when.method("POST").path("/demo/access"); + then.status(200).body(r#"{"result": true}"#); + }) + .await; + let result = auth_env + .schema + .execute(Request::new(query).data(token)) + .await; + auth.assert(); + + println!("{result:#?}"); + assert!(result.errors.is_empty()); + assert_eq!(result.data, value!({"scan": {"scanNumber": 123}})); + // Ensure that the number was incremented + assert_eq!( + auth_env + .db + .current_configuration("i22") + .await + .unwrap() + .scan_number(), + 123 + ); + assert!( + tokio::fs::try_exists(auth_env.dir.as_ref().join("i22").join("123.i22")) + .await + .unwrap() + ); + } + + #[rstest] + #[tokio::test] + async fn scan_numbers_synced_with_external(#[future(awt)] env: TestEnv) { + tokio::fs::File::create_new(env.dir.as_ref().join("i22").join("5678.i22")) + .await + .unwrap(); + let query = r#"mutation { scan(beamline: "i22", visit:"cm12345-3") { scanNumber }}"#; + let result = env.schema.execute(query).await; + let exp = value!({"scan": {"scanNumber": 5679}}); + + assert_eq!(result.data, exp); + + // DB number has been updated + assert_eq!( + env.db + .current_configuration("i22") + .await + .unwrap() + .scan_number(), + 5679 + ); + + // File has been updated + assert!( + tokio::fs::try_exists(env.dir.as_ref().join("i22").join("5679.i22")) + .await + .unwrap() + ); + } + + /// Ensure that the schema has not changed unintentionally. Might end up being a pain to + /// maintain but should hopefully be fairly stable once the API has stabilised. + #[test] + fn schema_sdl() { + let mut buf = Vec::new(); + graphql_schema(&mut buf).unwrap(); + assert_eq!( + String::from_utf8(buf).unwrap(), + include_str!("../static/service_schema") + ); + } +} #[cfg(test)] mod subdirectory_tests { use async_graphql::{InputType as _, InputValueResult, Number, Value}; @@ -584,17 +1004,25 @@ mod detector_tests { #[cfg(test)] mod input_template_tests { - use async_graphql::{InputType as _, Value}; + use async_graphql::{InputType, Value}; use super::InputTemplate; use crate::paths::{DetectorTemplate, ScanTemplate, VisitTemplate}; #[test] fn valid_visit_template() { - InputTemplate::::parse(Some(Value::String( + let template = InputTemplate::::parse(Some(Value::String( "/tmp/{instrument}/data/{visit}".into(), ))) .unwrap(); + assert_eq!( + template.as_raw_value().unwrap().to_string(), + "/tmp/{instrument}/data/{visit}" + ); + assert_eq!( + template.to_value(), + Value::String("/tmp/{instrument}/data/{visit}".into()) + ) } #[rstest::rstest] @@ -622,4 +1050,12 @@ mod input_template_tests { fn invalid_detector_template(#[case] path: String) { InputTemplate::::parse(Some(Value::String(path))).unwrap_err(); } + + #[rstest::rstest] + #[case::integer(Some(Value::Number(42.into())))] + #[case::list(Some(Value::List(vec![Value::Number(211.into())])))] + #[case::none(None)] + fn invalid_value_type(#[case] value: Option) { + InputTemplate::::parse(value).unwrap_err(); + } } diff --git a/src/main.rs b/src/main.rs index 26faae1..e7dedb8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,7 +32,9 @@ async fn main() -> Result<(), Box> { debug!(?args, "Starting numtracker service"); match args.command { Command::Serve(opts) => graphql::serve_graphql(&args.db, opts).await, - Command::Schema => graphql::graphql_schema(), + Command::Schema => { + graphql::graphql_schema(std::io::stdout()).expect("Failed to write schema") + } } Ok(()) } diff --git a/src/numtracker.rs b/src/numtracker.rs index d58dec1..1bc666d 100644 --- a/src/numtracker.rs +++ b/src/numtracker.rs @@ -17,6 +17,8 @@ use std::fmt::{self, Display}; use std::io::Error; use std::path::{Path, PathBuf}; +#[cfg(test)] +pub use tests::TempTracker; use tokio::fs as async_fs; use tokio::sync::{Mutex, MutexGuard}; use tracing::{instrument, trace}; @@ -168,9 +170,10 @@ impl std::error::Error for InvalidExtension {} #[cfg(test)] mod tests { - use std::fs; use std::ops::Deref; + use std::path::Path; use std::time::Duration; + use std::{fs, io}; use rstest::{fixture, rstest}; use tempfile::{tempdir, TempDir}; @@ -179,7 +182,17 @@ mod tests { use super::{InvalidExtension, NumTracker}; /// Wrapper around a NumTracker to ensure the tempdir is not dropped while it is still required - struct TempTracker(NumTracker, TempDir); + pub struct TempTracker(pub NumTracker, pub TempDir); + impl TempTracker { + pub fn new(init: F) -> Self + where + F: for<'f> FnOnce(&'f Path) -> io::Result<()>, + { + let root = tempdir().unwrap(); + init(root.as_ref()).unwrap(); + Self(NumTracker::for_root_directory(Some(&root)).unwrap(), root) + } + } impl Deref for TempTracker { type Target = NumTracker; diff --git a/static/service_schema b/static/service_schema new file mode 100644 index 0000000..2bd269f --- /dev/null +++ b/static/service_schema @@ -0,0 +1,99 @@ +type BeamlineConfiguration { + visitTemplate: String! + scanTemplate: String! + detectorTemplate: String! + latestScanNumber: Int! +} + + +input ConfigurationUpdates { + visit: VisitTemplate + scan: ScanTemplate + detector: DetectorTemplate + scanNumber: Int + extension: String +} + +scalar Detector + +""" +GraphQL type to mimic a key-value pair from the map type that GraphQL doesn't have +""" +type DetectorPath { + name: String! + path: String! +} + +""" +A template describing the location within a visit directory where the data for a given detector should be written + +It should contain placeholders for {detector} and {scan_number} to ensure paths are unique between scans and for multiple detectors. +""" +scalar DetectorTemplate + + + + +type Mutation { + """ + Access scan file locations for the next scan + """ + scan(beamline: String!, visit: String!, sub: Subdirectory): ScanPaths! + configure(beamline: String!, config: ConfigurationUpdates!): BeamlineConfiguration! +} + +type Query { + paths(beamline: String!, visit: String!): VisitPath! + configuration(beamline: String!): BeamlineConfiguration! +} + +type ScanPaths { + """ + The visit used to generate this scan information. Should be the same as the visit passed in + """ + visit: VisitPath! + """ + The root scan file for this scan. The path has no extension so that the format can be + chosen by the client. + """ + scanFile: String! + """ + The scan number for this scan. This should be unique for the requested beamline. + """ + scanNumber: Int! + """ + The paths where the given detectors should write their files. + + Detector names are normalised before being used in file names by replacing any + non-alphanumeric characters with '_'. If there are duplicate names in the list + of detectors after this normalisation, there will be duplicate paths in the + results. + """ + detectors(names: [Detector!]!): [DetectorPath!]! +} + +""" +A template describing the location within a visit directory where the root scan file should be written +""" +scalar ScanTemplate + + +scalar Subdirectory + +type VisitPath { + visit: String! + beamline: String! + directory: String! +} + +""" +A template describing the path to the visit directory for a beamline +""" +scalar VisitTemplate + +directive @include(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT +directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT +schema { + query: Query + mutation: Mutation +}