Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 79 additions & 24 deletions datatypes/src/machine_learning.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::path::PathBuf;

use serde::{de::Visitor, Deserialize, Serialize};

use crate::{
dataset::{is_invalid_name_char, SYSTEM_NAMESPACE},
raster::RasterDataType,
};
use serde::{de::Visitor, Deserialize, Serialize};
use snafu::Snafu;
use std::path::PathBuf;
use std::str::FromStr;
use strum::IntoStaticStr;

const NAME_DELIMITER: char = ':';

Expand All @@ -15,6 +16,18 @@ pub struct MlModelName {
pub name: String,
}

#[derive(Snafu, IntoStaticStr, Debug)]
#[snafu(visibility(pub(crate)))]
#[snafu(context(suffix(false)))] // disables default `Snafu` suffix
pub enum MlModelNameError {
#[snafu(display("MlModelName is empty"))]
IsEmpty,
#[snafu(display("invalid character '{invalid_char}' in named model"))]
InvalidCharacter { invalid_char: String },
#[snafu(display("ml model name must consist of at most two parts"))]
TooManyParts,
}

impl MlModelName {
/// Canonicalize a name that reflects the system namespace and model.
fn canonicalize<S: Into<String> + PartialEq<&'static str>>(
Expand Down Expand Up @@ -62,40 +75,29 @@ impl<'de> Deserialize<'de> for MlModelName {
}
}

struct MlModelNameDeserializeVisitor;

impl Visitor<'_> for MlModelNameDeserializeVisitor {
type Value = MlModelName;

/// always keep in sync with [`is_allowed_name_char`]
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
formatter,
"a string consisting of a namespace and name name, separated by a colon, only using alphanumeric characters, underscores & dashes"
)
}
impl FromStr for MlModelName {
type Err = MlModelNameError;

fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut strings = [None, None];
let mut split = s.split(NAME_DELIMITER);

for (buffer, part) in strings.iter_mut().zip(&mut split) {
if part.is_empty() {
return Err(E::custom("empty part in named data"));
return Err(MlModelNameError::IsEmpty);
}

if let Some(c) = part.matches(is_invalid_name_char).next() {
return Err(E::custom(format!("invalid character '{c}' in named model")));
return Err(MlModelNameError::InvalidCharacter {
invalid_char: c.to_string(),
});
}

*buffer = Some(part.to_string());
}

if split.next().is_some() {
return Err(E::custom("named model must consist of at most two parts"));
return Err(MlModelNameError::TooManyParts);
}

match strings {
Expand All @@ -107,11 +109,32 @@ impl Visitor<'_> for MlModelNameDeserializeVisitor {
namespace: None,
name,
}),
_ => Err(E::custom("empty named data")),
_ => Err(MlModelNameError::IsEmpty),
}
}
}

struct MlModelNameDeserializeVisitor;

impl Visitor<'_> for MlModelNameDeserializeVisitor {
type Value = MlModelName;

/// always keep in sync with [`is_allowed_name_char`]
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
formatter,
"a string consisting of a namespace and name name, separated by a colon, only using alphanumeric characters, underscores & dashes"
)
}

fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
MlModelName::from_str(s).map_err(|e| E::custom(e.to_string()))
}
}

// For now we assume all models are pixel-wise, i.e., they take a single pixel with multiple bands as input and produce a single output value.
// To support different inputs, we would need a more sophisticated logic to produce the inputs for the model.
#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)]
Expand All @@ -122,3 +145,35 @@ pub struct MlModelMetadata {
pub output_type: RasterDataType, // TODO: support multiple outputs, e.g. one band for the probability of prediction
// TODO: output measurement, e.g. classification or regression, label names for classification. This would have to be provided by the model creator along the model file as it cannot be extracted from the model file(?)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn ml_model_name_from_str() {
const ML_MODEL_NAME: &str = "myModelName";
let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
assert_eq!(mln.name, ML_MODEL_NAME);
assert!(mln.namespace.is_none());
}

#[test]
fn ml_model_name_from_str_prefixed() {
const ML_MODEL_NAME: &str = "d5328854-6190-4af9-ad69-4e74b0961ac9:myModelName";
let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
assert_eq!(mln.name, "myModelName".to_string());
assert_eq!(
mln.namespace,
Some("d5328854-6190-4af9-ad69-4e74b0961ac9".to_string())
);
}

#[test]
fn ml_model_name_from_str_system() {
const ML_MODEL_NAME: &str = "_:myModelName";
let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
assert_eq!(mln.name, "myModelName".to_string());
assert!(mln.namespace.is_none());
}
}
12 changes: 7 additions & 5 deletions services/src/api/apidoc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::api::handlers;
use crate::api::handlers::datasets::VolumeFileLayersResponse;
use crate::api::handlers::permissions::{PermissionListOptions, PermissionRequest, Resource};
use crate::api::handlers::permissions::{
PermissionListOptions, PermissionListing, PermissionRequest, Resource,
};
use crate::api::handlers::plots::WrappedPlotOutput;
use crate::api::handlers::spatial_references::{AxisOrder, SpatialReferenceSpecification};
use crate::api::handlers::tasks::{TaskAbortOptions, TaskResponse};
Expand Down Expand Up @@ -31,6 +33,7 @@ use crate::api::model::operators::{
UnixTimeStampType, VectorColumnInfo, VectorResultDescriptor,
};
use crate::api::model::responses::datasets::DatasetNameResponse;
use crate::api::model::responses::ml_models::MlModelNameResponse;
use crate::api::model::responses::{
BadRequestQueryResponse, ErrorResponse, IdResponse, PayloadTooLargeResponse, PngResponse,
UnauthorizedAdminResponse, UnauthorizedUserResponse, UnsupportedMediaTypeForJsonResponse,
Expand All @@ -56,9 +59,7 @@ use crate::layers::listing::{
};
use crate::machine_learning::name::MlModelName;
use crate::machine_learning::{MlModel, MlModelId, MlModelMetadata};
use crate::permissions::{
Permission, PermissionListing, ResourceId, Role, RoleDescription, RoleId,
};
use crate::permissions::{Permission, ResourceId, Role, RoleDescription, RoleId};
use crate::projects::{
ColorParam, CreateProject, DerivedColor, DerivedNumber, LayerUpdate, LayerVisibility,
LineSymbology, NumberParam, Plot, PlotUpdate, PointSymbology, PolygonSymbology, Project,
Expand Down Expand Up @@ -423,7 +424,8 @@ use utoipa::{Modify, OpenApi};
MlModel,
MlModelId,
MlModelName,
MlModelMetadata
MlModelMetadata,
MlModelNameResponse
),
),
modifiers(&SecurityAddon, &ApiDocInfo, &OpenApiServerInfo, &TransformSchemasWithTag),
Expand Down
10 changes: 5 additions & 5 deletions services/src/api/handlers/machine_learning.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use actix_web::{web, FromRequest, HttpResponse, ResponseError};

use crate::{
api::model::responses::ErrorResponse,
api::model::responses::{ml_models::MlModelNameResponse, ErrorResponse},
contexts::{ApplicationContext, SessionContext},
machine_learning::{
error::MachineLearningError, name::MlModelName, MlModel, MlModelDb, MlModelListOptions,
Expand Down Expand Up @@ -48,7 +48,7 @@ impl ResponseError for MachineLearningError {
path = "/ml/models",
request_body = MlModel,
responses(
(status = 200)
(status = 200, body = MlModelNameResponse)
),
security(
("session_token" = [])
Expand All @@ -59,14 +59,14 @@ pub(crate) async fn add_ml_model<C: ApplicationContext>(
session: C::Session,
app_ctx: web::Data<C>,
model: web::Json<MlModel>,
) -> Result<HttpResponse, MachineLearningError> {
) -> Result<web::Json<MlModelNameResponse>, MachineLearningError> {
let model = model.into_inner();
app_ctx
let id_and_name = app_ctx
.session_context(session)
.db()
.add_model(model)
.await?;
Ok(HttpResponse::Ok().finish())
Ok(web::Json(id_and_name.name.into()))
}

/// List ml models.
Expand Down
Loading