diff --git a/datatypes/src/machine_learning.rs b/datatypes/src/machine_learning.rs index 3d19ac1af..b4cdf687f 100644 --- a/datatypes/src/machine_learning.rs +++ b/datatypes/src/machine_learning.rs @@ -1,11 +1,12 @@ -use std::path::PathBuf; - -use serde::{de::Visitor, Deserialize, Serialize}; - use crate::{ dataset::{is_invalid_name_char, SYSTEM_NAMESPACE}, raster::RasterDataType, }; +use serde::{de::Visitor, Deserialize, Serialize}; +use snafu::Snafu; +use std::path::PathBuf; +use std::str::FromStr; +use strum::IntoStaticStr; const NAME_DELIMITER: char = ':'; @@ -15,6 +16,18 @@ pub struct MlModelName { pub name: String, } +#[derive(Snafu, IntoStaticStr, Debug)] +#[snafu(visibility(pub(crate)))] +#[snafu(context(suffix(false)))] // disables default `Snafu` suffix +pub enum MlModelNameError { + #[snafu(display("MlModelName is empty"))] + IsEmpty, + #[snafu(display("invalid character '{invalid_char}' in named model"))] + InvalidCharacter { invalid_char: String }, + #[snafu(display("ml model name must consist of at most two parts"))] + TooManyParts, +} + impl MlModelName { /// Canonicalize a name that reflects the system namespace and model. fn canonicalize + PartialEq<&'static str>>( @@ -62,40 +75,29 @@ impl<'de> Deserialize<'de> for MlModelName { } } -struct MlModelNameDeserializeVisitor; - -impl Visitor<'_> for MlModelNameDeserializeVisitor { - type Value = MlModelName; - - /// always keep in sync with [`is_allowed_name_char`] - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - formatter, - "a string consisting of a namespace and name name, separated by a colon, only using alphanumeric characters, underscores & dashes" - ) - } +impl FromStr for MlModelName { + type Err = MlModelNameError; - fn visit_str(self, s: &str) -> Result - where - E: serde::de::Error, - { + fn from_str(s: &str) -> Result { let mut strings = [None, None]; let mut split = s.split(NAME_DELIMITER); for (buffer, part) in strings.iter_mut().zip(&mut split) { if part.is_empty() { - return Err(E::custom("empty part in named data")); + return Err(MlModelNameError::IsEmpty); } if let Some(c) = part.matches(is_invalid_name_char).next() { - return Err(E::custom(format!("invalid character '{c}' in named model"))); + return Err(MlModelNameError::InvalidCharacter { + invalid_char: c.to_string(), + }); } *buffer = Some(part.to_string()); } if split.next().is_some() { - return Err(E::custom("named model must consist of at most two parts")); + return Err(MlModelNameError::TooManyParts); } match strings { @@ -107,11 +109,32 @@ impl Visitor<'_> for MlModelNameDeserializeVisitor { namespace: None, name, }), - _ => Err(E::custom("empty named data")), + _ => Err(MlModelNameError::IsEmpty), } } } +struct MlModelNameDeserializeVisitor; + +impl Visitor<'_> for MlModelNameDeserializeVisitor { + type Value = MlModelName; + + /// always keep in sync with [`is_allowed_name_char`] + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + formatter, + "a string consisting of a namespace and name name, separated by a colon, only using alphanumeric characters, underscores & dashes" + ) + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + MlModelName::from_str(s).map_err(|e| E::custom(e.to_string())) + } +} + // For now we assume all models are pixel-wise, i.e., they take a single pixel with multiple bands as input and produce a single output value. // To support different inputs, we would need a more sophisticated logic to produce the inputs for the model. #[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)] @@ -122,3 +145,35 @@ pub struct MlModelMetadata { pub output_type: RasterDataType, // TODO: support multiple outputs, e.g. one band for the probability of prediction // TODO: output measurement, e.g. classification or regression, label names for classification. This would have to be provided by the model creator along the model file as it cannot be extracted from the model file(?) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ml_model_name_from_str() { + const ML_MODEL_NAME: &str = "myModelName"; + let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap(); + assert_eq!(mln.name, ML_MODEL_NAME); + assert!(mln.namespace.is_none()); + } + + #[test] + fn ml_model_name_from_str_prefixed() { + const ML_MODEL_NAME: &str = "d5328854-6190-4af9-ad69-4e74b0961ac9:myModelName"; + let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap(); + assert_eq!(mln.name, "myModelName".to_string()); + assert_eq!( + mln.namespace, + Some("d5328854-6190-4af9-ad69-4e74b0961ac9".to_string()) + ); + } + + #[test] + fn ml_model_name_from_str_system() { + const ML_MODEL_NAME: &str = "_:myModelName"; + let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap(); + assert_eq!(mln.name, "myModelName".to_string()); + assert!(mln.namespace.is_none()); + } +} diff --git a/services/src/api/apidoc.rs b/services/src/api/apidoc.rs index 4c7ba529a..90ece1267 100644 --- a/services/src/api/apidoc.rs +++ b/services/src/api/apidoc.rs @@ -1,6 +1,8 @@ use crate::api::handlers; use crate::api::handlers::datasets::VolumeFileLayersResponse; -use crate::api::handlers::permissions::{PermissionListOptions, PermissionRequest, Resource}; +use crate::api::handlers::permissions::{ + PermissionListOptions, PermissionListing, PermissionRequest, Resource, +}; use crate::api::handlers::plots::WrappedPlotOutput; use crate::api::handlers::spatial_references::{AxisOrder, SpatialReferenceSpecification}; use crate::api::handlers::tasks::{TaskAbortOptions, TaskResponse}; @@ -31,6 +33,7 @@ use crate::api::model::operators::{ UnixTimeStampType, VectorColumnInfo, VectorResultDescriptor, }; use crate::api::model::responses::datasets::DatasetNameResponse; +use crate::api::model::responses::ml_models::MlModelNameResponse; use crate::api::model::responses::{ BadRequestQueryResponse, ErrorResponse, IdResponse, PayloadTooLargeResponse, PngResponse, UnauthorizedAdminResponse, UnauthorizedUserResponse, UnsupportedMediaTypeForJsonResponse, @@ -56,9 +59,7 @@ use crate::layers::listing::{ }; use crate::machine_learning::name::MlModelName; use crate::machine_learning::{MlModel, MlModelId, MlModelMetadata}; -use crate::permissions::{ - Permission, PermissionListing, ResourceId, Role, RoleDescription, RoleId, -}; +use crate::permissions::{Permission, ResourceId, Role, RoleDescription, RoleId}; use crate::projects::{ ColorParam, CreateProject, DerivedColor, DerivedNumber, LayerUpdate, LayerVisibility, LineSymbology, NumberParam, Plot, PlotUpdate, PointSymbology, PolygonSymbology, Project, @@ -423,7 +424,8 @@ use utoipa::{Modify, OpenApi}; MlModel, MlModelId, MlModelName, - MlModelMetadata + MlModelMetadata, + MlModelNameResponse ), ), modifiers(&SecurityAddon, &ApiDocInfo, &OpenApiServerInfo, &TransformSchemasWithTag), diff --git a/services/src/api/handlers/machine_learning.rs b/services/src/api/handlers/machine_learning.rs index 6db9bee6a..2b12c5497 100644 --- a/services/src/api/handlers/machine_learning.rs +++ b/services/src/api/handlers/machine_learning.rs @@ -1,7 +1,7 @@ use actix_web::{web, FromRequest, HttpResponse, ResponseError}; use crate::{ - api::model::responses::ErrorResponse, + api::model::responses::{ml_models::MlModelNameResponse, ErrorResponse}, contexts::{ApplicationContext, SessionContext}, machine_learning::{ error::MachineLearningError, name::MlModelName, MlModel, MlModelDb, MlModelListOptions, @@ -48,7 +48,7 @@ impl ResponseError for MachineLearningError { path = "/ml/models", request_body = MlModel, responses( - (status = 200) + (status = 200, body = MlModelNameResponse) ), security( ("session_token" = []) @@ -59,14 +59,14 @@ pub(crate) async fn add_ml_model( session: C::Session, app_ctx: web::Data, model: web::Json, -) -> Result { +) -> Result, MachineLearningError> { let model = model.into_inner(); - app_ctx + let id_and_name = app_ctx .session_context(session) .db() .add_model(model) .await?; - Ok(HttpResponse::Ok().finish()) + Ok(web::Json(id_and_name.name.into())) } /// List ml models. diff --git a/services/src/api/handlers/permissions.rs b/services/src/api/handlers/permissions.rs index 6612c2de8..e85e858d3 100644 --- a/services/src/api/handlers/permissions.rs +++ b/services/src/api/handlers/permissions.rs @@ -1,14 +1,22 @@ -use crate::api::model::datatypes::{DatasetId, LayerId}; -use crate::contexts::{ApplicationContext, SessionContext}; -use crate::error::Result; +use crate::api::model::datatypes::LayerId; +use crate::contexts::{ApplicationContext, GeoEngineDb, SessionContext}; +use crate::datasets::storage::DatasetDb; +use crate::datasets::DatasetName; +use crate::error::{self, Error, Result}; use crate::layers::listing::LayerCollectionId; -use crate::permissions::{Permission, PermissionListing}; -use crate::permissions::{PermissionDb, ResourceId, RoleId}; +use crate::machine_learning::MlModelDb; +use crate::permissions::{ + Permission, PermissionDb, PermissionListing as DbPermissionListing, ResourceId, Role, RoleId, +}; use crate::projects::ProjectId; use actix_web::{web, FromRequest, HttpResponse}; use geoengine_datatypes::error::BoxedResultExt; -use serde::Deserialize; +use geoengine_datatypes::machine_learning::MlModelName; +use serde::{Deserialize, Serialize}; +use snafu::ResultExt; +use std::str::FromStr; use utoipa::{IntoParams, ToSchema}; +use uuid::Uuid; pub(crate) fn init_permissions_routes(cfg: &mut web::ServiceConfig) where @@ -39,33 +47,105 @@ pub struct PermissionRequest { permission: Permission, } +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct PermissionListing { + resource: Resource, + role: Role, + permission: Permission, +} + +impl PermissionListing { + fn wrap_permission_listing_and_resource( + resource: Resource, + db_permission_listing: DbPermissionListing, + ) -> PermissionListing { + Self { + resource, + role: db_permission_listing.role, + permission: db_permission_listing.permission, + } + } +} + /// A resource that is affected by a permission. -#[derive(Debug, PartialEq, Eq, Deserialize, Clone, ToSchema)] +#[derive(Debug, PartialEq, Eq, Deserialize, Clone, ToSchema, Serialize)] #[serde(rename_all = "camelCase", tag = "type", content = "id")] pub enum Resource { #[schema(title = "LayerResource")] - Layer(LayerId), + Layer(LayerId), // TODO: check model #[schema(title = "LayerCollectionResource")] LayerCollection(LayerCollectionId), #[schema(title = "ProjectResource")] Project(ProjectId), #[schema(title = "DatasetResource")] - Dataset(DatasetId), + Dataset(DatasetName), // TODO: add a DatasetName to model! + #[schema(title = "MlModelResource")] + MlModel(MlModelName), } -impl From for ResourceId { - fn from(resource: Resource) -> Self { - match resource { - Resource::Layer(layer_id) => ResourceId::Layer(layer_id.into()), - Resource::LayerCollection(layer_collection_id) => { - ResourceId::LayerCollection(layer_collection_id) +impl Resource { + pub async fn resolve_resource_id( + &self, + db: &D, + ) -> Result { + match self { + Resource::Layer(layer) => Ok(ResourceId::Layer(layer.clone().into())), + Resource::LayerCollection(layer_collection) => { + Ok(ResourceId::LayerCollection(layer_collection.clone())) + } + Resource::Project(project_id) => Ok(ResourceId::Project(*project_id)), + Resource::Dataset(dataset_name) => { + let dataset_id_option = db.resolve_dataset_name_to_id(dataset_name).await?; + dataset_id_option + .ok_or(error::Error::UnknownResource { + kind: "Dataset".to_owned(), + name: dataset_name.to_string(), + }) + .map(ResourceId::DatasetId) + } + Resource::MlModel(model_name) => { + let actual_name = model_name.clone().into(); + let model_id_option = + db.resolve_model_name_to_id(&actual_name) + .await + .map_err(|e| error::Error::MachineLearning { + source: Box::new(e), + })?; // should prob. also map to UnknownResource oder something like that + model_id_option + .ok_or(error::Error::UnknownResource { + kind: "MlModel".to_owned(), + name: actual_name.to_string(), + }) + .map(ResourceId::MlModel) } - Resource::Project(project_id) => ResourceId::Project(project_id), - Resource::Dataset(dataset_id) => ResourceId::DatasetId(dataset_id.into()), } } } +impl TryFrom<(String, String)> for Resource { + type Error = Error; + + /// Transform a tuple of `String` into a `Resource`. The first element is used as type and the second element as the id / name. + fn try_from(value: (String, String)) -> Result { + Ok(match value.0.as_str() { + "layer" => Resource::Layer(LayerId(value.1)), + "layerCollection" => Resource::LayerCollection(LayerCollectionId(value.1)), + "project" => { + Resource::Project(ProjectId(Uuid::from_str(&value.1).context(error::Uuid)?)) + } + "dataset" => Resource::Dataset(DatasetName::from_str(&value.1)?), + "mlModel" => Resource::MlModel(MlModelName::from_str(&value.1)?), + _ => { + return Err(Error::InvalidResourceId { + resource_type: value.0, + resource_id: value.1, + }) + } + }) + } +} + #[derive(Debug, PartialEq, Eq, Deserialize, Clone, IntoParams, ToSchema)] pub struct PermissionListOptions { pub limit: u32, @@ -94,16 +174,25 @@ async fn get_resource_permissions_handler( app_ctx: web::Data, resource_id: web::Path<(String, String)>, options: web::Query, -) -> Result>> { - let resource_id = ResourceId::try_from(resource_id.into_inner())?; +) -> Result>> +where + <::SessionContext as SessionContext>::GeoEngineDB: GeoEngineDb, +{ + let resource = Resource::try_from(resource_id.into_inner())?; + let db = app_ctx.session_context(session).db(); + let resource_id = resource.resolve_resource_id(&db).await?; let options = options.into_inner(); - let db = app_ctx.session_context(session).db(); let permissions = db .list_permissions(resource_id, options.offset, options.limit) .await .boxed_context(crate::error::PermissionDb)?; + let permissions = permissions + .into_iter() + .map(|p| PermissionListing::wrap_permission_listing_and_resource(resource.clone(), p)) + .collect(); + Ok(web::Json(permissions)) } @@ -137,13 +226,11 @@ async fn add_permission_handler( let permission = permission.into_inner(); let db = app_ctx.session_context(session).db(); - db.add_permission::( - permission.role_id, - permission.resource.into(), - permission.permission, - ) - .await - .boxed_context(crate::error::PermissionDb)?; + let permission_id = permission.resource.resolve_resource_id(&db).await?; + + db.add_permission::(permission.role_id, permission_id, permission.permission) + .await + .boxed_context(crate::error::PermissionDb)?; Ok(HttpResponse::Ok().finish()) } @@ -178,13 +265,11 @@ async fn remove_permission_handler( let permission = permission.into_inner(); let db = app_ctx.session_context(session).db(); - db.remove_permission::( - permission.role_id, - permission.resource.into(), - permission.permission, - ) - .await - .boxed_context(crate::error::PermissionDb)?; + let permission_id = permission.resource.resolve_resource_id(&db).await?; + + db.remove_permission::(permission.role_id, permission_id, permission.permission) + .await + .boxed_context(crate::error::PermissionDb)?; Ok(HttpResponse::Ok().finish()) } @@ -194,18 +279,25 @@ mod tests { use super::*; use crate::{ + api::model::datatypes::RasterDataType as ApiRasterDataType, contexts::PostgresContext, + datasets::upload::{Upload, UploadDb, UploadId}, ge_context, + layers::{layer::AddLayer, listing::LayerCollectionProvider, storage::LayerDb}, + machine_learning::{MlModel, MlModelIdAndName, MlModelMetadata}, users::{UserAuth, UserCredentials, UserRegistration}, util::tests::{ add_ndvi_to_datasets2, add_ports_to_datasets, admin_login, read_body_string, send_test_request, }, + workflows::workflow::Workflow, }; use actix_http::header; use actix_web_httpauth::headers::authorization::Bearer; + use geoengine_datatypes::{primitives::Coordinate2D, util::Identifier}; use geoengine_operators::{ - engine::{RasterOperator, VectorOperator, WorkflowOperatorPath}, + engine::{RasterOperator, TypedOperator, VectorOperator, WorkflowOperatorPath}, + mock::{MockPointSource, MockPointSourceParams}, source::{GdalSource, GdalSourceParameters, OgrSource, OgrSourceParameters}, }; use serde_json::{json, Value}; @@ -309,14 +401,14 @@ mod tests { #[ge_context::test] #[allow(clippy::too_many_lines)] - async fn it_lists_permissions(app_ctx: PostgresContext) { + async fn it_lists_dataset_permissions(app_ctx: PostgresContext) { let admin_session = admin_login(&app_ctx).await; - let (gdal_dataset_id, _) = add_ndvi_to_datasets2(&app_ctx, true, true).await; + let (_dataset_id, dataset_name) = add_ndvi_to_datasets2(&app_ctx, true, true).await; let req = actix_web::test::TestRequest::get() .uri(&format!( - "/permissions/resources/dataset/{gdal_dataset_id}?offset=0&limit=10", + "/permissions/resources/dataset/{dataset_name}?offset=0&limit=10", )) .append_header((header::CONTENT_LENGTH, 0)) .append_header(( @@ -333,9 +425,9 @@ mod tests { res_body, json!([{ "permission":"Owner", - "resourceId": { - "id": gdal_dataset_id.to_string(), - "type": "DatasetId" + "resource": { + "id": dataset_name.to_string(), + "type": "dataset" }, "role": { "id": "d5328854-6190-4af9-ad69-4e74b0961ac9", @@ -344,9 +436,9 @@ mod tests { } }, { "permission": "Read", - "resourceId": { - "id": gdal_dataset_id.to_string(), - "type": "DatasetId" + "resource": { + "id": dataset_name.to_string(), + "type": "dataset" }, "role": { "id": "fd8e87bf-515c-4f36-8da6-1a53702ff102", @@ -354,17 +446,234 @@ mod tests { } }, { "permission": "Read", - "resourceId": { - "id": gdal_dataset_id.to_string(), - "type": "DatasetId" + "resource": { + "id": dataset_name.to_string(), + "type": "dataset", }, "role": { "id": "4e8081b6-8aa6-4275-af0c-2fa2da557d28", + "name": "user" + } + }] + ) + ); + } + + #[ge_context::test] + #[allow(clippy::too_many_lines)] + async fn it_lists_ml_model_permissions(app_ctx: PostgresContext) { + let admin_session = admin_login(&app_ctx).await; + + let db = app_ctx.session_context(admin_session.clone()).db(); + + let upload_id = UploadId::new(); + let upload = Upload { + id: upload_id, + files: vec![], + }; + db.create_upload(upload).await.unwrap(); + + let model = MlModel { + description: "No real model here".to_owned(), + display_name: "my unreal model".to_owned(), + metadata: MlModelMetadata { + file_name: "myUnrealmodel.onnx".to_owned(), + input_type: ApiRasterDataType::F32, + num_input_bands: 17, + output_type: ApiRasterDataType::F64, + }, + name: MlModelName::new(None, "myUnrealModel").into(), + upload: upload_id, + }; + + let MlModelIdAndName { + id: _model_id, + name: model_name, + } = db.add_model(model).await.unwrap(); + + let req = actix_web::test::TestRequest::get() + .uri(&format!( + "/permissions/resources/mlModel/{model_name}?offset=0&limit=10", + )) + .append_header((header::CONTENT_LENGTH, 0)) + .append_header(( + header::AUTHORIZATION, + Bearer::new(admin_session.id.to_string()), + )); + let res = send_test_request(req, app_ctx).await; + + let res_status = res.status(); + let res_body = serde_json::from_str::(&read_body_string(res).await).unwrap(); + assert_eq!(res_status, 200, "{res_body}"); + + assert_eq!( + res_body, + json!([{ + "permission":"Owner", + "resource": { + "id": model_name.to_string(), + "type": "mlModel" + }, + "role": { + "id": "d5328854-6190-4af9-ad69-4e74b0961ac9", + "name": "admin" + } + }] + ) + ); + } + + #[ge_context::test] + #[allow(clippy::too_many_lines)] + async fn it_lists_layer_collection_permissions(app_ctx: PostgresContext) { + let admin_session = admin_login(&app_ctx).await; + + let db = app_ctx.session_context(admin_session.clone()).db(); + + let root_collection = &db.get_root_layer_collection_id().await.unwrap(); + + let req = actix_web::test::TestRequest::get() + .uri(&format!( + "/permissions/resources/layerCollection/{root_collection}?offset=0&limit=10", + )) + .append_header((header::CONTENT_LENGTH, 0)) + .append_header(( + header::AUTHORIZATION, + Bearer::new(admin_session.id.to_string()), + )); + let res = send_test_request(req, app_ctx).await; + + let res_status = res.status(); + let res_body = serde_json::from_str::(&read_body_string(res).await).unwrap(); + assert_eq!(res_status, 200, "{res_body}"); + + assert_eq!( + res_body, + json!([{ + "permission":"Owner", + "resource": { + "id": root_collection.to_string(), + "type": "layerCollection" + }, + "role": { + "id": "d5328854-6190-4af9-ad69-4e74b0961ac9", "name": - "user" + "admin" + } + }, { + "permission": "Read", + "resource": { + "id": root_collection.to_string(), + "type": "layerCollection" + }, + "role": { + "id": "fd8e87bf-515c-4f36-8da6-1a53702ff102", + "name": "anonymous" + } + }, { + "permission": "Read", + "resource": { + "id": root_collection.to_string(), + "type": "layerCollection", + }, + "role": { + "id": "4e8081b6-8aa6-4275-af0c-2fa2da557d28", + "name": "user" } }] ) ); } + + #[ge_context::test] + #[allow(clippy::too_many_lines)] + async fn it_lists_layer_permissions(app_ctx: PostgresContext) { + let admin_session = admin_login(&app_ctx).await; + + let db = app_ctx.session_context(admin_session.clone()).db(); + + let root_collection = &db.get_root_layer_collection_id().await.unwrap(); + + let layer = AddLayer { + name: "layer".to_string(), + description: "description".to_string(), + workflow: Workflow { + operator: TypedOperator::Vector( + MockPointSource { + params: MockPointSourceParams { + points: vec![Coordinate2D::new(1., 2.); 3], + }, + } + .boxed(), + ), + }, + symbology: None, + metadata: Default::default(), + properties: Default::default(), + }; + + let l_id = db.add_layer(layer, root_collection).await.unwrap(); + + let req = actix_web::test::TestRequest::get() + .uri(&format!( + "/permissions/resources/layer/{l_id}?offset=0&limit=10", + )) + .append_header((header::CONTENT_LENGTH, 0)) + .append_header(( + header::AUTHORIZATION, + Bearer::new(admin_session.id.to_string()), + )); + let res = send_test_request(req, app_ctx).await; + + let res_status = res.status(); + let res_body = serde_json::from_str::(&read_body_string(res).await).unwrap(); + assert_eq!(res_status, 200, "{res_body}"); + + assert_eq!( + res_body, + json!([{ + "permission":"Owner", + "resource": { + "id": l_id.to_string(), + "type": "layer" + }, + "role": { + "id": "d5328854-6190-4af9-ad69-4e74b0961ac9", + "name": + "admin" + } + } ] + ) + ); + } + + #[test] + fn resource_from_str_tuple() { + let test_uuid = Uuid::new_v4(); + + let layer_res = Resource::try_from(("layer".to_owned(), "cats".to_owned())).unwrap(); + assert_eq!(layer_res, Resource::Layer(LayerId("cats".to_owned()))); + + let layer_col_res = + Resource::try_from(("layerCollection".to_owned(), "cats".to_owned())).unwrap(); + assert_eq!( + layer_col_res, + Resource::LayerCollection(LayerCollectionId("cats".to_owned())) + ); + + let project_res = Resource::try_from(("project".to_owned(), test_uuid.into())).unwrap(); + assert_eq!(project_res, Resource::Project(ProjectId(test_uuid))); + + let dataset_res = Resource::try_from(("dataset".to_owned(), "cats".to_owned())).unwrap(); + assert_eq!( + dataset_res, + Resource::Dataset(DatasetName::new(None, "cats".to_owned())) + ); + + let ml_model_res = Resource::try_from(("mlModel".to_owned(), "cats".to_owned())).unwrap(); + assert_eq!( + ml_model_res, + Resource::MlModel(MlModelName::new(None, "cats".to_owned())) + ); + } } diff --git a/services/src/api/model/responses/ml_models/mod.rs b/services/src/api/model/responses/ml_models/mod.rs new file mode 100644 index 000000000..05f5d8582 --- /dev/null +++ b/services/src/api/model/responses/ml_models/mod.rs @@ -0,0 +1,19 @@ +use serde::{Deserialize, Serialize}; +use utoipa::{ToResponse, ToSchema}; + +use crate::machine_learning::name::MlModelName; + +#[derive(Debug, Serialize, Deserialize, Clone, ToResponse, ToSchema)] +#[serde(rename_all = "camelCase")] +#[response(description = "Name of generated resource", example = json!({ + "name": "ns:name" +}))] +pub struct MlModelNameResponse { + pub ml_model_name: MlModelName, +} + +impl From for MlModelNameResponse { + fn from(ml_model_name: MlModelName) -> Self { + Self { ml_model_name } + } +} diff --git a/services/src/api/model/responses/mod.rs b/services/src/api/model/responses/mod.rs index 553477ec8..de558f88c 100644 --- a/services/src/api/model/responses/mod.rs +++ b/services/src/api/model/responses/mod.rs @@ -1,4 +1,5 @@ pub mod datasets; +pub mod ml_models; use actix_http::StatusCode; use actix_web::{dev::ServiceResponse, HttpResponse}; diff --git a/services/src/contexts/postgres.rs b/services/src/contexts/postgres.rs index 96aeb90c8..0bf204ab9 100644 --- a/services/src/contexts/postgres.rs +++ b/services/src/contexts/postgres.rs @@ -462,6 +462,7 @@ where #[cfg(test)] mod tests { use super::*; + use crate::api::model::datatypes::RasterDataType as ApiRasterDataType; use crate::config::QuotaTrackingMode; use crate::datasets::external::netcdfcf::NetCdfCfDataProviderDefinition; use crate::datasets::listing::{DatasetListOptions, DatasetListing, ProvenanceOutput}; @@ -483,6 +484,7 @@ mod tests { LayerDb, LayerProviderDb, LayerProviderListing, LayerProviderListingOptions, INTERNAL_PROVIDER_ID, }; + use crate::machine_learning::{MlModel, MlModelDb, MlModelIdAndName, MlModelMetadata}; use crate::permissions::{Permission, PermissionDb, Role, RoleDescription, RoleId}; use crate::projects::{ CreateProject, LayerUpdate, LoadVersion, OrderBy, Plot, PlotUpdate, PointSymbology, @@ -4882,4 +4884,44 @@ mod tests { async fn it_handles_oidc_tokens_with_encryption(app_ctx: PostgresContext) { it_handles_oidc_tokens(app_ctx).await; } + + #[ge_context::test] + #[allow(clippy::too_many_lines)] + async fn it_resolves_ml_model_names_to_ids(app_ctx: PostgresContext) { + let admin_session = UserSession::admin_session(); + let db = app_ctx.session_context(admin_session.clone()).db(); + + let upload_id = UploadId::new(); + let upload = Upload { + id: upload_id, + files: vec![], + }; + db.create_upload(upload).await.unwrap(); + + let model = MlModel { + description: "No real model here".to_owned(), + display_name: "my unreal model".to_owned(), + metadata: MlModelMetadata { + file_name: "myUnrealmodel.onnx".to_owned(), + input_type: ApiRasterDataType::F32, + num_input_bands: 17, + output_type: ApiRasterDataType::F64, + }, + name: MlModelName::new(None, "myUnrealModel"), + upload: upload_id, + }; + + let MlModelIdAndName { + id: model_id, + name: model_name, + } = db.add_model(model).await.unwrap(); + + assert_eq!( + db.resolve_model_name_to_id(&model_name) + .await + .unwrap() + .unwrap(), + model_id + ); + } } diff --git a/services/src/datasets/mod.rs b/services/src/datasets/mod.rs index b6f5b24bd..dd0a03575 100644 --- a/services/src/datasets/mod.rs +++ b/services/src/datasets/mod.rs @@ -11,5 +11,5 @@ pub(crate) use create_from_workflow::{ schedule_raster_dataset_from_workflow_task, RasterDatasetFromWorkflow, RasterDatasetFromWorkflowResult, }; -pub use name::{DatasetIdAndName, DatasetName}; +pub use name::{DatasetIdAndName, DatasetName, DatasetNameError}; pub use storage::AddDataset; diff --git a/services/src/datasets/name.rs b/services/src/datasets/name.rs index c9cd1df2b..0315a4033 100644 --- a/services/src/datasets/name.rs +++ b/services/src/datasets/name.rs @@ -1,6 +1,9 @@ use geoengine_datatypes::dataset::{DatasetId, NamedData}; use postgres_types::{FromSql, ToSql}; use serde::{de::Visitor, Deserialize, Serialize}; +use snafu::Snafu; +use std::str::FromStr; +use strum::IntoStaticStr; use utoipa::{IntoParams, ToSchema}; /// A (optionally namespaced) name for a `Dataset`. @@ -11,6 +14,18 @@ pub struct DatasetName { pub name: String, } +#[derive(Snafu, IntoStaticStr, Debug)] +#[snafu(visibility(pub(crate)))] +#[snafu(context(suffix(false)))] // disables default `Snafu` suffix +pub enum DatasetNameError { + #[snafu(display("DatasetName is empty"))] + IsEmpty, + #[snafu(display("invalid character '{invalid_char}' in named data"))] + InvalidCharacter { invalid_char: String }, + #[snafu(display("named data must consist of at most two parts"))] + TooManyParts, +} + impl DatasetName { /// Canonicalize a name that reflects the system namespace and provider. fn canonicalize + PartialEq<&'static str>>( @@ -44,6 +59,51 @@ impl std::fmt::Display for DatasetName { } } +impl FromStr for DatasetName { + type Err = DatasetNameError; + + fn from_str(s: &str) -> Result { + let mut strings = [None, None]; + let mut split = s.split(geoengine_datatypes::dataset::NAME_DELIMITER); + + for (buffer, part) in strings.iter_mut().zip(&mut split) { + if part.is_empty() { + return Err(DatasetNameError::IsEmpty); + } + + if let Some(c) = part + .matches(geoengine_datatypes::dataset::is_invalid_name_char) + .next() + { + return Err(DatasetNameError::InvalidCharacter { + invalid_char: c.to_string(), + }); + } + + *buffer = Some(part.to_string()); + } + + if split.next().is_some() { + return Err(DatasetNameError::TooManyParts); + } + + match strings { + [Some(namespace), Some(name)] => Ok(DatasetName { + namespace: DatasetName::canonicalize( + namespace, + geoengine_datatypes::dataset::SYSTEM_NAMESPACE, + ), + name, + }), + [Some(name), None] => Ok(DatasetName { + namespace: None, + name, + }), + _ => Err(DatasetNameError::IsEmpty), + } + } +} + impl Serialize for DatasetName { fn serialize(&self, serializer: S) -> Result where @@ -87,42 +147,7 @@ impl Visitor<'_> for DatasetNameDeserializeVisitor { where E: serde::de::Error, { - let mut strings = [None, None]; - let mut split = s.split(geoengine_datatypes::dataset::NAME_DELIMITER); - - for (buffer, part) in strings.iter_mut().zip(&mut split) { - if part.is_empty() { - return Err(E::custom("empty part in named data")); - } - - if let Some(c) = part - .matches(geoengine_datatypes::dataset::is_invalid_name_char) - .next() - { - return Err(E::custom(format!("invalid character '{c}' in named data"))); - } - - *buffer = Some(part.to_string()); - } - - if split.next().is_some() { - return Err(E::custom("named data must consist of at most two parts")); - } - - match strings { - [Some(namespace), Some(name)] => Ok(DatasetName { - namespace: DatasetName::canonicalize( - namespace, - geoengine_datatypes::dataset::SYSTEM_NAMESPACE, - ), - name, - }), - [Some(name), None] => Ok(DatasetName { - namespace: None, - name, - }), - _ => Err(E::custom("empty named data")), - } + DatasetName::from_str(s).map_err(|e| E::custom(e.to_string())) } } @@ -191,3 +216,35 @@ pub struct DatasetIdAndName { pub id: DatasetId, pub name: DatasetName, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dataset_name_from_str() { + const DATASET_NAME: &str = "myDatasetName"; + let mln = DatasetName::from_str(DATASET_NAME).unwrap(); + assert_eq!(mln.name, DATASET_NAME); + assert!(mln.namespace.is_none()); + } + + #[test] + fn dataset_name_from_str_prefixed() { + const DATASET_NAME: &str = "d5328854-6190-4af9-ad69-4e74b0961ac9:myDatasetName"; + let mln = DatasetName::from_str(DATASET_NAME).unwrap(); + assert_eq!(mln.name, "myDatasetName".to_string()); + assert_eq!( + mln.namespace, + Some("d5328854-6190-4af9-ad69-4e74b0961ac9".to_string()) + ); + } + + #[test] + fn dataset_name_from_str_system() { + const DATASET_NAME: &str = "_:myDatasetName"; + let mln = DatasetName::from_str(DATASET_NAME).unwrap(); + assert_eq!(mln.name, "myDatasetName".to_string()); + assert!(mln.namespace.is_none()); + } +} diff --git a/services/src/error.rs b/services/src/error.rs index ef9c8f46c..72adfb9fb 100644 --- a/services/src/error.rs +++ b/services/src/error.rs @@ -216,7 +216,6 @@ pub enum Error { UploadFieldMissingFileName, UnknownUploadId, - UnknownModelId, PathIsNotAFile, #[snafu(display("Failed loading multipart body: {reason}"))] Multipart { @@ -511,6 +510,26 @@ pub enum Error { CannotAccessVolumePath { volume_name: String, }, + + #[snafu(display("Unknown resource name {} of kind {}", name, kind))] + UnknownResource { + kind: String, + name: String, + }, + + #[snafu(display("MachineLearning error: {}", source))] + MachineLearning { + // TODO: make `source: MachineLearningError`, once pro features is removed + source: Box, + }, + + DatasetName { + source: crate::datasets::DatasetNameError, + }, + + MlModelName { + source: geoengine_datatypes::machine_learning::MlModelNameError, + }, } impl actix_web::error::ResponseError for Error { @@ -611,3 +630,15 @@ impl From for Error { Error::InvalidNotNanFloatKey { source } } } + +impl From for Error { + fn from(source: crate::datasets::DatasetNameError) -> Self { + Error::DatasetName { source } + } +} + +impl From for Error { + fn from(source: geoengine_datatypes::machine_learning::MlModelNameError) -> Self { + Error::MlModelName { source } + } +} diff --git a/services/src/machine_learning/mod.rs b/services/src/machine_learning/mod.rs index e0c0cf8ed..fcfe30a50 100644 --- a/services/src/machine_learning/mod.rs +++ b/services/src/machine_learning/mod.rs @@ -21,6 +21,13 @@ mod postgres; identifier!(MlModelId); +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema, FromSql, ToSql)] +#[serde(rename_all = "camelCase")] +pub struct MlModelIdAndName { + pub id: MlModelId, + pub name: MlModelName, +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema, FromSql, ToSql)] #[serde(rename_all = "camelCase")] pub struct MlModel { @@ -98,5 +105,10 @@ pub trait MlModelDb { name: &MlModelName, ) -> Result; - async fn add_model(&self, model: MlModel) -> Result<(), MachineLearningError>; + async fn add_model(&self, model: MlModel) -> Result; + + async fn resolve_model_name_to_id( + &self, + name: &MlModelName, + ) -> Result, MachineLearningError>; } diff --git a/services/src/machine_learning/postgres.rs b/services/src/machine_learning/postgres.rs index 52063ee20..53ea177a1 100644 --- a/services/src/machine_learning/postgres.rs +++ b/services/src/machine_learning/postgres.rs @@ -6,7 +6,7 @@ use crate::{ MachineLearningError, }, name::MlModelName, - MlModel, MlModelDb, MlModelId, MlModelListOptions, MlModelMetadata, + MlModel, MlModelDb, MlModelId, MlModelIdAndName, MlModelListOptions, MlModelMetadata, }, permissions::Permission, util::postgres::PostgresErrorExt, @@ -146,7 +146,7 @@ where Ok(row.get(1)) } - async fn add_model(&self, model: MlModel) -> Result<(), MachineLearningError> { + async fn add_model(&self, model: MlModel) -> Result { self.check_ml_model_namespace(&model.name)?; let mut conn = self @@ -201,6 +201,32 @@ where tx.commit().await.context(PostgresMachineLearningError)?; - Ok(()) + Ok(MlModelIdAndName { + id, + name: model.name, + }) + } + + async fn resolve_model_name_to_id( + &self, + model_name: &MlModelName, + ) -> Result, MachineLearningError> { + let conn = self + .conn_pool + .get() + .await + .context(Bb8MachineLearningError)?; + + let stmt = conn + .prepare( + "SELECT id + FROM ml_models + WHERE name = $1::\"MlModelName\"", + ) + .await?; + + let row_option = conn.query_opt(&stmt, &[&model_name]).await?; + + Ok(row_option.map(|row| row.get(0))) } } diff --git a/services/src/permissions/mod.rs b/services/src/permissions/mod.rs index 7e15ca4a0..a1a6b577f 100644 --- a/services/src/permissions/mod.rs +++ b/services/src/permissions/mod.rs @@ -1,4 +1,4 @@ -use crate::error::{self, Error, Result}; +use crate::error::Result; use crate::identifier; use crate::layers::listing::LayerCollectionId; use crate::machine_learning::MlModelId; @@ -8,11 +8,9 @@ use async_trait::async_trait; use geoengine_datatypes::dataset::{DatasetId, LayerId}; use postgres_types::{FromSql, ToSql}; use serde::{Deserialize, Serialize}; -use snafu::ResultExt; use snafu::Snafu; use std::str::FromStr; use utoipa::ToSchema; -use uuid::Uuid; mod postgres_permissiondb; @@ -142,29 +140,6 @@ impl From for ResourceId { } } -impl TryFrom<(String, String)> for ResourceId { - type Error = Error; - - fn try_from(value: (String, String)) -> Result { - Ok(match value.0.as_str() { - "layer" => ResourceId::Layer(LayerId(value.1)), - "layerCollection" => ResourceId::LayerCollection(LayerCollectionId(value.1)), - "project" => { - ResourceId::Project(ProjectId(Uuid::from_str(&value.1).context(error::Uuid)?)) - } - "dataset" => { - ResourceId::DatasetId(DatasetId(Uuid::from_str(&value.1).context(error::Uuid)?)) - } - _ => { - return Err(Error::InvalidResourceId { - resource_type: value.0, - resource_id: value.1, - }) - } - }) - } -} - #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, ToSchema)] #[serde(rename_all = "camelCase")] pub struct PermissionListing {