Skip to content

Commit f484aea

Browse files
authored
feat: Add serde_validate support. (#22553)
In addition to existing rust-server validation support.
1 parent 833d1d3 commit f484aea

File tree

59 files changed

+1893
-149
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1893
-149
lines changed

.github/workflows/samples-rust-server.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ jobs:
6060
cargo build --bin ${package##*/} --features cli
6161
target/debug/${package##*/} --help
6262
fi
63+
# Test the validate feature if it exists
64+
if cargo read-manifest | grep -q '"validate"'; then
65+
cargo build --features validate --all-targets
66+
fi
6367
cargo fmt
6468
cargo test
6569
cargo clippy

modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/RustServerCodegen.java

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ public class RustServerCodegen extends AbstractRustCodegen implements CodegenCon
8787
private static final String problemJsonMimeType = "application/problem+json";
8888
private static final String problemXmlMimeType = "application/problem+xml";
8989

90+
// Track if we have models with conflicting names (Ok/Err) that conflict with serde_valid
91+
private boolean hasConflictingModelNames = false;
92+
9093
public RustServerCodegen() {
9194
super();
9295

@@ -942,6 +945,11 @@ private void postProcessOperationWithModels(CodegenOperation op, List<ModelMap>
942945
if (param.contentType != null && isMimetypeJson(param.contentType)) {
943946
param.vendorExtensions.put("x-consumes-json", true);
944947
}
948+
949+
// Add a vendor extension to flag if this can have validate() run on it.
950+
if (!param.isUuid && !param.isPrimitiveType && !param.isEnum && (!param.isContainer || !languageSpecificPrimitives.contains(typeMapping.get(param.baseType)))) {
951+
param.vendorExtensions.put("x-can-validate", true);
952+
}
945953
}
946954

947955
for (CodegenParameter param : op.formParams) {
@@ -1455,8 +1463,20 @@ public String toAllOfName(List<String> names, Schema composedSchema) {
14551463
public void postProcessModelProperty(CodegenModel model, CodegenProperty property) {
14561464
super.postProcessModelProperty(model, property);
14571465

1466+
// Check for reserved field names that conflict with serde_valid macro internals
1467+
if ("ok".equalsIgnoreCase(property.name) || "err".equalsIgnoreCase(property.name)) {
1468+
model.vendorExtensions.put("x-skip-serde-valid", true);
1469+
}
1470+
1471+
// Mark properties that reference complex types (models) for nested validation
1472+
// Only add nested validation for types that reference generated models (contain "models::")
1473+
if (property.dataType != null && property.dataType.contains("models::")) {
1474+
property.vendorExtensions.put("x-needs-nested-validation", true);
1475+
}
1476+
14581477
// TODO: We should avoid reverse engineering primitive type status from the data type
1459-
if (!languageSpecificPrimitives.contains(stripNullable(property.dataType))) {
1478+
String strippedType = stripNullable(property.dataType);
1479+
if (!languageSpecificPrimitives.contains(strippedType)) {
14601480
// If we use a more qualified model name, then only camelize the actual type, not the qualifier.
14611481
if (property.dataType.contains(":")) {
14621482
int position = property.dataType.lastIndexOf(":");
@@ -1529,7 +1549,32 @@ public void postProcessModelProperty(CodegenModel model, CodegenProperty propert
15291549

15301550
@Override
15311551
public ModelsMap postProcessModels(ModelsMap objs) {
1532-
return super.postProcessModelsEnum(objs);
1552+
ModelsMap result = super.postProcessModelsEnum(objs);
1553+
1554+
// Check for model names that conflict with serde_valid macro internals
1555+
// Once we find one, set a class-level flag that persists across all model batches
1556+
if (!hasConflictingModelNames) {
1557+
for (ModelMap modelMap : result.getModels()) {
1558+
CodegenModel model = modelMap.getModel();
1559+
if ("Ok".equalsIgnoreCase(model.classname) || "Err".equalsIgnoreCase(model.classname)) {
1560+
hasConflictingModelNames = true;
1561+
additionalProperties.put("hasConflictingModelNames", true);
1562+
break;
1563+
}
1564+
}
1565+
}
1566+
1567+
// If there are conflicting names (detected in any batch), skip serde_valid for ALL models
1568+
if (hasConflictingModelNames) {
1569+
for (ModelMap modelMap : result.getModels()) {
1570+
CodegenModel model = modelMap.getModel();
1571+
model.vendorExtensions.put("x-skip-serde-valid", true);
1572+
}
1573+
// Set the flag for this batch's template context
1574+
result.put("hasConflictingModelNames", true);
1575+
}
1576+
1577+
return result;
15331578
}
15341579

15351580
private void processParam(CodegenParameter param, CodegenOperation op) {
@@ -1614,6 +1659,11 @@ private void processParam(CodegenParameter param, CodegenOperation op) {
16141659
String exampleString = (example != null) ? "Some(" + example + ")" : "None";
16151660
param.vendorExtensions.put("x-example", exampleString);
16161661
}
1662+
1663+
// Add a vendor extension to flag if this can have validate() run on it.
1664+
if (!param.isUuid && !param.isPrimitiveType && !param.isEnum && (!param.isContainer || !languageSpecificPrimitives.contains(typeMapping.get(param.baseType)))) {
1665+
param.vendorExtensions.put("x-can-validate", true);
1666+
}
16171667
}
16181668

16191669
@Override

modules/openapi-generator/src/main/resources/rust-server/Cargo.mustache

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ cli = [
7272
conversion = ["frunk", "frunk_derives", "frunk_core", "frunk-enum-core", "frunk-enum-derive"]
7373

7474
mock = ["mockall"]
75+
validate = [{{^apiUsesByteArray}}"regex",{{/apiUsesByteArray}} "serde_valid", "swagger/serdevalid"]
7576

7677
[target.'cfg(any(target_os = "macos", target_os = "windows", target_os = "ios"))'.dependencies]
7778
native-tls = { version = "0.2", optional = true }
@@ -100,6 +101,8 @@ regex = "1.12"
100101

101102
serde = { version = "1.0", features = ["derive"] }
102103
serde_json = "1.0"
104+
serde_valid = { version = "0.16", optional = true }
105+
103106
validator = { version = "0.20", features = ["derive"] }
104107

105108
# Crates included if required by the API definition

modules/openapi-generator/src/main/resources/rust-server/README.mustache

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ The generated library has a few optional features that can be activated through
130130
* This defaults to disabled and creates extra derives on models to allow "transmogrification" between objects of structurally similar types.
131131
* `cli`
132132
* This defaults to disabled and is required for building the included CLI tool.
133+
* `validate`
134+
* This defaults to disabled and allows JSON Schema validation of received data using `MakeService::set_validation` or `Service::set_validation`.
135+
* Note, enabling validation will have a performance penalty, especially if the API heavily uses regex based checks.
133136

134137
See https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section for how to use features in your `Cargo.toml`.
135138

modules/openapi-generator/src/main/resources/rust-server/models.mustache

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
#![allow(unused_qualifications)]
2+
{{^hasConflictingModelNames}}
3+
#[cfg(not(feature = "validate"))]
4+
use validator::Validate;
25

6+
use crate::models;
7+
#[cfg(any(feature = "client", feature = "server"))]
8+
use crate::header;
9+
#[cfg(feature = "validate")]
10+
use serde_valid::Validate;
11+
{{/hasConflictingModelNames}}
12+
{{#hasConflictingModelNames}}
313
use validator::Validate;
414

515
use crate::models;
616
#[cfg(any(feature = "client", feature = "server"))]
717
use crate::header;
18+
{{/hasConflictingModelNames}}
819
{{! Don't "use" structs here - they can conflict with the names of models, and mean that the code won't compile }}
920
{{#models}}
1021
{{#model}}
@@ -19,6 +30,7 @@ use crate::header;
1930
#[allow(non_camel_case_types)]
2031
#[repr(C)]
2132
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize, Hash)]
33+
{{^hasConflictingModelNames}}{{^exts.x-skip-serde-valid}}#[cfg_attr(feature = "validate", derive(Validate))]{{/exts.x-skip-serde-valid}}{{/hasConflictingModelNames}}
2234
#[cfg_attr(feature = "conversion", derive(frunk_enum_derive::LabelledGenericEnum))]{{#xmlName}}
2335
#[serde(rename = "{{{.}}}")]{{/xmlName}}
2436
pub enum {{{classname}}} {
@@ -60,11 +72,14 @@ impl std::str::FromStr for {{{classname}}} {
6072
{{^isEnum}}
6173
{{#dataType}}
6274
#[derive(Debug, Clone, PartialEq, {{#exts.x-partial-ord}}PartialOrd, {{/exts.x-partial-ord}}serde::Serialize, serde::Deserialize)]
75+
{{^hasConflictingModelNames}}{{^exts.x-skip-serde-valid}}#[cfg_attr(feature = "validate", derive(Validate))]{{/exts.x-skip-serde-valid}}{{/hasConflictingModelNames}}
6376
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
6477
{{#xmlName}}
6578
#[serde(rename = "{{{.}}}")]
6679
{{/xmlName}}
67-
pub struct {{{classname}}}({{{dataType}}});
80+
pub struct {{{classname}}}(
81+
{{>validate}} {{{dataType}}}
82+
);
6883
6984
impl std::convert::From<{{{dataType}}}> for {{{classname}}} {
7085
fn from(x: {{{dataType}}}) -> Self {
@@ -176,6 +191,7 @@ where
176191
{{/exts}}
177192
{{! vec}}
178193
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
194+
{{^hasConflictingModelNames}}{{^exts.x-skip-serde-valid}}#[cfg_attr(feature = "validate", derive(Validate))]{{/exts.x-skip-serde-valid}}{{/hasConflictingModelNames}}
179195
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
180196
pub struct {{{classname}}}(
181197
{{#exts}}
@@ -272,7 +288,7 @@ impl std::str::FromStr for {{{classname}}} {
272288
{{/arrayModelType}}
273289
{{^arrayModelType}}
274290
{{! general struct}}
275-
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)]
291+
#[derive(Debug, Clone, PartialEq, Validate, serde::Serialize, serde::Deserialize)]
276292
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
277293
{{#xmlName}}
278294
#[serde(rename = "{{{.}}}")]
@@ -288,7 +304,12 @@ pub struct {{{classname}}} {
288304
{{/x-item-xml-name}}
289305
{{/exts}}
290306
{{#hasValidation}}
307+
{{^hasConflictingModelNames}}
308+
#[cfg_attr(not(feature = "validate"), validate(
309+
{{/hasConflictingModelNames}}
310+
{{#hasConflictingModelNames}}
291311
#[validate(
312+
{{/hasConflictingModelNames}}
292313
{{#maxLength}}
293314
{{#minLength}}
294315
length(min = {{minLength}}, max = {{maxLength}}),
@@ -336,8 +357,19 @@ pub struct {{{classname}}} {
336357
length(min = {{minItems}}),
337358
{{/minItems}}
338359
{{/maxItems}}
360+
{{^hasConflictingModelNames}}
361+
))]
362+
{{/hasConflictingModelNames}}
363+
{{#hasConflictingModelNames}}
339364
)]
365+
{{/hasConflictingModelNames}}
340366
{{/hasValidation}}
367+
{{^hasConflictingModelNames}}{{>validate}}{{/hasConflictingModelNames}}
368+
{{^hasConflictingModelNames}}
369+
{{#exts.x-needs-nested-validation}}
370+
#[cfg_attr(feature = "validate", validate)]
371+
{{/exts.x-needs-nested-validation}}
372+
{{/hasConflictingModelNames}}
341373
{{#required}}
342374
pub {{{name}}}: {{{dataType}}},
343375
{{/required}}
@@ -346,6 +378,11 @@ pub struct {{{classname}}} {
346378
#[serde(deserialize_with = "swagger::nullable_format::deserialize_optional_nullable")]
347379
#[serde(default = "swagger::nullable_format::default_optional_nullable")]
348380
{{/isNullable}}
381+
{{^hasConflictingModelNames}}
382+
{{#exts.x-needs-nested-validation}}
383+
#[cfg_attr(feature = "validate", validate)]
384+
{{/exts.x-needs-nested-validation}}
385+
{{/hasConflictingModelNames}}
349386
#[serde(skip_serializing_if="Option::is_none")]
350387
pub {{{name}}}: Option<{{{dataType}}}>,
351388
{{/required}}
@@ -365,6 +402,7 @@ lazy_static::lazy_static! {
365402
lazy_static::lazy_static! {
366403
static ref RE_{{#lambda.uppercase}}{{{classname}}}_{{{name}}}{{/lambda.uppercase}}: regex::bytes::Regex = regex::bytes::Regex::new(r"{{ pattern }}").unwrap();
367404
}
405+
#[cfg(not(feature = "validate"))]
368406
fn validate_byte_{{#lambda.lowercase}}{{{classname}}}_{{{name}}}{{/lambda.lowercase}}(
369407
b: &swagger::ByteArray
370408
) -> Result<(), validator::ValidationError> {

modules/openapi-generator/src/main/resources/rust-server/server-imports.mustache

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use http_body_util::{combinators::BoxBody, Full};
44
use hyper::{body::{Body, Incoming}, HeaderMap, Request, Response, StatusCode};
55
use hyper::header::{HeaderName, HeaderValue, CONTENT_TYPE};
66
use log::warn;
7+
#[cfg(feature = "validate")]
8+
use serde_valid::Validate;
79
#[allow(unused_imports)]
810
use std::convert::{TryFrom, TryInto};
911
use std::{convert::Infallible, error::Error};

modules/openapi-generator/src/main/resources/rust-server/server-make-service.mustache

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ where
99
multipart_form_size_limit: Option<u64>,
1010
{{/apiUsesMultipartFormData}}
1111
marker: PhantomData<C>,
12+
validation: bool
1213
}
1314

1415
impl<T, C> MakeService<T, C>
@@ -22,7 +23,8 @@ where
2223
{{#apiUsesMultipartFormData}}
2324
multipart_form_size_limit: Some(8 * 1024 * 1024),
2425
{{/apiUsesMultipartFormData}}
25-
marker: PhantomData
26+
marker: PhantomData,
27+
validation: false
2628
}
2729
}
2830
{{#apiUsesMultipartFormData}}
@@ -37,6 +39,12 @@ where
3739
self
3840
}
3941
{{/apiUsesMultipartFormData}}
42+
43+
// Turn on/off validation for the service being made.
44+
#[cfg(feature = "validate")]
45+
pub fn set_validation(&mut self, validation: bool) {
46+
self.validation = validation;
47+
}
4048
}
4149

4250
impl<T, C> Clone for MakeService<T, C>
@@ -51,6 +59,7 @@ where
5159
multipart_form_size_limit: Some(8 * 1024 * 1024),
5260
{{/apiUsesMultipartFormData}}
5361
marker: PhantomData,
62+
validation: self.validation
5463
}
5564
}
5665
}
@@ -65,10 +74,8 @@ where
6574
type Future = future::Ready<Result<Self::Response, Self::Error>>;
6675
6776
fn call(&self, target: Target) -> Self::Future {
68-
let service = Service::new(self.api_impl.clone()){{^apiUsesMultipartFormData}};{{/apiUsesMultipartFormData}}
69-
{{#apiUsesMultipartFormData}}
70-
.multipart_form_size_limit(self.multipart_form_size_limit);
71-
{{/apiUsesMultipartFormData}}
77+
let service = Service::new(self.api_impl.clone(), self.validation){{^apiUsesMultipartFormData}};{{/apiUsesMultipartFormData}}{{#apiUsesMultipartFormData}}
78+
.multipart_form_size_limit(self.multipart_form_size_limit);{{/apiUsesMultipartFormData}}
7279

7380
future::ok(service)
7481
}

modules/openapi-generator/src/main/resources/rust-server/server-operation.mustache

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@
187187
.expect("Unable to create Bad Request response for missing query parameter {{{baseName}}}")),
188188
};
189189
{{/required}}
190+
{{#exts.x-can-validate}}
191+
#[cfg(not(feature = "validate"))]
192+
run_validation!(param_{{{paramName}}}, "{{{baseName}}}", validation);
193+
{{/exts.x-can-validate}}
190194
{{/isArray}}
191195
{{#-last}}
192196

modules/openapi-generator/src/main/resources/rust-server/server-request-body-basic.mustache

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,7 @@
5757
.expect("Unable to create Bad Request response for missing body parameter {{{baseName}}}")),
5858
};
5959
{{/required}}
60+
{{#exts.x-can-validate}}
61+
#[cfg(not(feature = "validate"))]
62+
run_validation!(param_{{{paramName}}}, "{{{baseName}}}", validation);
63+
{{/exts.x-can-validate}}

modules/openapi-generator/src/main/resources/rust-server/server-service-footer.mustache

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
Box::pin(run(
77
self.api_impl.clone(),
88
req,
9-
{{#apiUsesMultipartFormData}}
10-
self.multipart_form_size_limit,
11-
{{/apiUsesMultipartFormData}}
9+
self.validation{{#apiUsesMultipartFormData}},
10+
self.multipart_form_size_limit{{/apiUsesMultipartFormData}}
1211
))
1312
}
1413
}

0 commit comments

Comments
 (0)