Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ unwrap_used = "deny"
async-graphql = { version = "7.0.13", features = ["tracing"] }
async-graphql-axum = "7.0.13"
axum = "0.7.9"
axum-extra = { version = "0.9.3", features = ["typed-header"] }
chrono = "0.4.39"
clap = { version = "4.5.23", features = ["cargo", "derive", "env"] }
futures = "0.3.31"
Expand All @@ -19,6 +20,8 @@ opentelemetry-otlp = "0.27.0"
opentelemetry-semantic-conventions = "0.27.0"
opentelemetry-stdout = "0.27.0"
opentelemetry_sdk = { version = "0.27.1", features = ["rt-tokio"] }
reqwest = { version = "0.12.7", features = ["json", "rustls-tls-native-roots"], default-features = false }
serde = { version = "1.0.210", features = ["derive"] }
sqlx = { version = "0.8.2", features = ["runtime-tokio", "sqlite"] }
tokio = { version = "1.42.0", features = ["full"] }
tracing = "0.1.41"
Expand All @@ -27,6 +30,8 @@ tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
url = "2.5.4"

[dev-dependencies]
assert_matches = "1.5.0"
async-std = { version = "1.13.0", features = ["attributes"], default-features = false }
httpmock = { version = "0.7.0", default-features = false }
rstest = "0.23.0"
tempfile = "3.14.0"
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
FROM rust:1.82.0-slim AS build

RUN rustup target add x86_64-unknown-linux-musl && \
apt update && \
apt install -y musl-tools musl-dev && \
apt-get update && \
apt-get install -y musl-tools musl-dev && \
update-ca-certificates

COPY ./Cargo.toml ./Cargo.toml
Expand Down
89 changes: 89 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@ pub struct ServeOptions {
/// The root directory for external number tracking
#[clap(long, env = "NUMTRACKER_ROOT_DIRECTORY")]
root_directory: Option<PathBuf>,
#[clap(flatten, next_help_heading = "Authorization")]
pub policy: Option<PolicyOptions>,
}

#[derive(Debug, Default, Parser)]
#[group(requires = "policy_host")]
pub struct PolicyOptions {
/// Beamline Policy Endpoint
///
/// eg, https://authz.diamond.ac.uk
#[clap(long = "policy", required = false)]
pub policy_host: String,
/// The Rego rule used to generate visit access data
///
/// eg. v1/data/diamond/policy/session/write_to_beamline_visit
#[clap(long, required = false)]
pub access_query: String,
/// The Rego rule used to generate admin access data
///
/// eg. v1/data/diamond/policy/admin/configure_beamline
#[clap(long, required = false)]
pub admin_query: String,
}

#[derive(Debug, Args)]
Expand Down Expand Up @@ -132,6 +154,7 @@ impl TracingOptions {
mod tests {
use std::path::PathBuf;

use assert_matches::assert_matches;
use clap::error::ErrorKind;
use clap::Parser;
use tracing::Level;
Expand All @@ -154,6 +177,8 @@ mod tests {
};
assert_eq!(cmd.addr(), ("0.0.0.0".parse().unwrap(), 8000));
assert_eq!(cmd.root_directory(), None);

assert_matches!(cmd.policy, None);
}

#[test]
Expand All @@ -174,6 +199,70 @@ mod tests {
};
assert_eq!(cmd.addr(), ("127.0.0.1".parse().unwrap(), 8765));
assert_eq!(cmd.root_directory, Some("/tmp/trackers".into()));
assert_matches!(cmd.policy, None);
}

#[test]
fn policy_arguments() {
let cli = Cli::try_parse_from([
APP,
"serve",
"--policy",
"opa.example.com",
"--admin-query",
"demo/admin_check",
"--access-query",
"demo/access_check",
])
.unwrap();
let cmd = assert_matches!(cli.command, Command::Serve(cmd) => cmd);
let policy = assert_matches!(cmd.policy, Some(plc) => plc);

assert_eq!(policy.policy_host, "opa.example.com");
assert_eq!(policy.admin_query, "demo/admin_check");
assert_eq!(policy.access_query, "demo/access_check");
}

#[test]
fn missing_admin_query() {
let err = Cli::try_parse_from([
APP,
"serve",
"--policy",
"opa.example.com",
"--access-query",
"demo/access-query",
])
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument);
}

#[test]
fn missing_access_query() {
let err = Cli::try_parse_from([
APP,
"serve",
"--policy",
"opa.example.com",
"--admin-query",
"demo/admin-query",
])
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument);
}

#[test]
fn policy_queries_without_host() {
let err = Cli::try_parse_from([
APP,
"serve",
"--access-query",
"demo/access-query",
"--admin-query",
"demo/admin-query",
])
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument);
}

#[test]
Expand Down
43 changes: 39 additions & 4 deletions src/graphql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::any;
use std::borrow::Cow;
use std::error::Error;
use std::fmt::Display;
use std::future::Future;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;

Expand All @@ -27,9 +28,13 @@ use async_graphql::{
Scalar, ScalarType, Schema, SimpleObject, Value,
};
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
use auth::{AuthError, PolicyCheck};
use axum::response::{Html, IntoResponse};
use axum::routing::{get, post};
use axum::{Extension, Router};
use axum_extra::headers::authorization::Bearer;
use axum_extra::headers::Authorization;
use axum_extra::TypedHeader;
use chrono::{Datelike, Local};
use tokio::net::TcpListener;
use tracing::{info, instrument, trace, warn};
Expand All @@ -45,26 +50,30 @@ use crate::paths::{
};
use crate::template::{FieldSource, PathTemplate};

mod auth;

pub async fn serve_graphql(db: &Path, opts: ServeOptions) {
let db = SqliteScanPathService::connect(db)
.await
.expect("Unable to open DB");
let directory_numtracker = NumTracker::for_root_directory(opts.root_directory())
.expect("Could not read external directories");
info!("Serving graphql endpoints on {:?}", opts.addr());
let addr = opts.addr();
let schema = Schema::build(Query, Mutation, EmptySubscription)
.extension(Tracing)
.limit_directives(32)
.data(db)
.data(directory_numtracker)
.data(opts.policy.map(PolicyCheck::new))
.finish();
let app = Router::new()
.route("/graphql", post(graphql_handler))
.route("/graphiql", get(graphiql))
.layer(Extension(schema));
let listener = TcpListener::bind(opts.addr())
let listener = TcpListener::bind(addr)
.await
.unwrap_or_else(|_| panic!("Port {:?} in use", opts.addr()));
.unwrap_or_else(|_| panic!("Port {:?} in use", addr));
axum::serve(listener, app)
.await
.expect("Can't serve graphql endpoint");
Expand All @@ -82,10 +91,13 @@ async fn graphiql() -> impl IntoResponse {
#[instrument(skip_all)]
async fn graphql_handler(
schema: Extension<Schema<Query, Mutation, EmptySubscription>>,
auth_token: Option<TypedHeader<Authorization<Bearer>>>,
req: GraphQLRequest,
) -> GraphQLResponse {
let inner = req.into_inner();
schema.execute(inner).await.into()
schema
.execute(req.into_inner().data(auth_token.map(|header| header.0)))
.await
.into()
}

/// Read-only API for GraphQL
Expand Down Expand Up @@ -263,6 +275,7 @@ impl Query {
ctx: &Context<'_>,
beamline: String,
) -> async_graphql::Result<BeamlineConfiguration> {
check_auth(ctx, |policy, token| policy.check_admin(token, &beamline)).await?;
let db = ctx.data::<SqliteScanPathService>()?;
trace!("Getting config for {beamline:?}");
Ok(db.current_configuration(&beamline).await?)
Expand All @@ -280,6 +293,10 @@ impl Mutation {
visit: String,
sub: Option<Subdirectory>,
) -> async_graphql::Result<ScanPaths> {
check_auth(ctx, |policy, token| {
policy.check_access(token, &beamline, &visit)
})
.await?;
let db = ctx.data::<SqliteScanPathService>()?;
let nt = ctx.data::<NumTracker>()?;
// There is a race condition here if a process increments the file
Expand Down Expand Up @@ -312,6 +329,7 @@ impl Mutation {
beamline: String,
config: ConfigurationUpdates,
) -> async_graphql::Result<BeamlineConfiguration> {
check_auth(ctx, |pc, token| pc.check_admin(token, &beamline)).await?;
let db = ctx.data::<SqliteScanPathService>()?;
trace!("Configuring: {beamline}: {config:?}");
let upd = config.into_update(beamline);
Expand All @@ -322,6 +340,23 @@ impl Mutation {
}
}

async fn check_auth<'ctx, Check, R>(ctx: &Context<'ctx>, check: Check) -> async_graphql::Result<()>
where
Check: Fn(&'ctx PolicyCheck, Option<&'ctx Authorization<Bearer>>) -> R,
R: Future<Output = Result<(), AuthError>>,
{
if let Some(policy) = ctx.data::<Option<PolicyCheck>>()? {
trace!("Auth enabled: checking token");
let token = ctx.data::<Authorization<Bearer>>().ok();
check(policy, token)
.await
.inspect_err(|e| info!("Authorization failed: {e:?}"))
.map_err(async_graphql::Error::from)
} else {
Ok(())
}
}

#[derive(Debug, InputObject)]
struct ConfigurationUpdates {
visit: Option<InputTemplate<VisitTemplate>>,
Expand Down
Loading
Loading