Skip to content

Commit 5ae9bb7

Browse files
authored
Merge pull request #19 from golemcloud/multipart-fixes
Multipart fixes
2 parents 832f974 + 37dcb97 commit 5ae9bb7

File tree

3 files changed

+95
-38
lines changed

3 files changed

+95
-38
lines changed

src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,14 @@ pub fn gen(
9999
let mut known_refs = RefCache::new();
100100
let mut models = Vec::new();
101101

102+
let multipart_field_file = rust::model_gen::multipart_field_module()?;
103+
std::fs::write(
104+
model.join(multipart_field_file.def.name.file_name()),
105+
multipart_field_file.code,
106+
)
107+
.unwrap();
108+
models.push(multipart_field_file.def);
109+
102110
while !ref_cache.is_empty() {
103111
let mut next_ref_cache = RefCache::new();
104112

src/rust/client_gen.rs

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -334,35 +334,39 @@ fn request_body_params(
334334
ReferenceOr::Reference { reference } => Err(Error::unimplemented(
335335
format!("Unexpected ref multipart schema: '{reference}'."),
336336
)),
337-
ReferenceOr::Item(schema) => {
338-
match &schema.schema_kind {
339-
SchemaKind::Type(Type::Object(obj)) => {
340-
fn multipart_param(
341-
name: &str,
342-
schema: &ReferenceOr<Box<Schema>>,
343-
ref_cache: &mut RefCache,
344-
) -> Result<Param> {
345-
Ok(Param {
346-
original_name: name.to_string(),
347-
name: name.to_case(Case::Snake),
348-
tpe: ref_or_box_schema_type(schema, ref_cache)?,
349-
required: true, // TODO
350-
kind: ParamKind::Multipart,
351-
})
352-
}
353-
354-
obj.properties
355-
.iter()
356-
.map(|(name, schema)| {
357-
multipart_param(name, schema, ref_cache)
358-
})
359-
.collect()
337+
ReferenceOr::Item(schema) => match &schema.schema_kind {
338+
SchemaKind::Type(Type::Object(obj)) => {
339+
fn multipart_param(
340+
name: &str,
341+
required: bool,
342+
schema: &ReferenceOr<Box<Schema>>,
343+
ref_cache: &mut RefCache,
344+
) -> Result<Param> {
345+
Ok(Param {
346+
original_name: name.to_string(),
347+
name: name.to_case(Case::Snake),
348+
tpe: ref_or_box_schema_type(schema, ref_cache)?,
349+
required,
350+
kind: ParamKind::Multipart,
351+
})
360352
}
361-
_ => Err(Error::unimplemented(
362-
"Object schema expected for multipart request body.",
363-
)),
353+
354+
obj.properties
355+
.iter()
356+
.map(|(name, schema)| {
357+
multipart_param(
358+
name,
359+
body.required && obj.required.contains(name),
360+
schema,
361+
ref_cache,
362+
)
363+
})
364+
.collect()
364365
}
365-
}
366+
_ => Err(Error::unimplemented(
367+
"Object schema expected for multipart request body.",
368+
)),
369+
},
366370
},
367371
}
368372
} else {
@@ -657,14 +661,26 @@ fn header_setter(param: &Param) -> RustResult {
657661
fn make_part(param: &Param) -> RustResult {
658662
let part_type = rust_name("reqwest::multipart", "Part");
659663

660-
if param.tpe == DataType::Binary {
661-
Ok(indent() + r#".part(""# + &param.original_name + r#"", "# + part_type + "::stream(" + &param.name + r#").mime_str("application/octet-stream")?)"#)
662-
} else if param.tpe == DataType::String {
663-
Ok(indent() + r#".part(""# + &param.original_name + r#"", "# + part_type + "::text(" + &param.name + r#".to_string()).mime_str("text/plain; charset=utf-8")?)"#)
664-
} else if let DataType::Model(_) = param.tpe {
665-
Ok(indent() + r#".part(""# + &param.original_name + r#"", "# + part_type + "::text(serde_json::to_string(" + &param.name + r#")?).mime_str("application/json")?)"#)
664+
let inner =
665+
if param.tpe == DataType::Binary {
666+
Ok(indent() + r#"form = form.part(""# + &param.original_name + r#"", "# + part_type + "::stream(" + &param.name + r#").mime_str("application/octet-stream")?);"#)
667+
} else if param.tpe == DataType::String {
668+
Ok(indent() + r#"form = form.part(""# + &param.original_name + r#"", "# + part_type + "::text(" + &param.name + r#".to_string()).mime_str("text/plain; charset=utf-8")?);"#)
669+
}
670+
else if let DataType::Model(_) = param.tpe {
671+
Ok(indent() + r#"form = form.part(""# + &param.original_name + r#"", "# + part_type + "::text(crate::model::MultipartField::to_multipart_field(" + &param.name + r#")).mime_str(crate::model::MultipartField::mime_type("# + &param.name + r#"))?);"#)
672+
} else {
673+
Err(Error::unimplemented(format!("Unsupported multipart part type {:?}", param.tpe)))
674+
};
675+
676+
if param.required {
677+
inner
666678
} else {
667-
Err(Error::unimplemented(format!("Unsupported multipart part type {:?}", param.tpe)))
679+
Ok(
680+
indent() + line(unit() + r#"if let Some("# + &param.name + r#") = "# + &param.name + " {") +
681+
indented(inner?) +
682+
line("}")
683+
)
668684
}
669685
}
670686

@@ -837,10 +853,8 @@ fn render_method_implementation(method: &Method, error_kind: &ErrorKind) -> Rust
837853
let multipart_setter = if is_multipart {
838854
#[rustfmt::skip]
839855
let code = unit() +
840-
indent() + "let form = " + rust_name("reqwest::multipart", "Form") + "::new()" +
841-
indented(
842-
multipart_parts?.into_iter().map(|p| unit() + NewLine + p).reduce(|acc, e| acc + e). unwrap_or_else(unit) + ";" + NewLine
843-
) +
856+
indent() + "let mut form = " + rust_name("reqwest::multipart", "Form") + "::new();" +
857+
(multipart_parts?.into_iter().map(|p| unit() + NewLine + p).reduce(|acc, e| acc + e). unwrap_or_else(unit) + NewLine) +
844858
NewLine +
845859
line("request = request.multipart(form);");
846860

src/rust/model_gen.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,25 @@ fn extract_enum_cases(
348348
.collect()
349349
}
350350

351+
pub fn multipart_field_module() -> Result<Module> {
352+
let code = unit()
353+
+ line(unit() + "pub trait MultipartField {")
354+
+ indented(
355+
unit()
356+
+ line("fn to_multipart_field(&self) -> String;")
357+
+ line("fn mime_type(&self) -> &'static str;"),
358+
)
359+
+ line(unit() + "}");
360+
361+
Ok(Module {
362+
def: ModuleDef {
363+
name: ModuleName::new("multipart_field"),
364+
exports: vec!["MultipartField".to_string()],
365+
},
366+
code: RustContext::new().print_to_string(code),
367+
})
368+
}
369+
351370
pub fn model_gen(reference: &str, open_api: &OpenAPI, ref_cache: &mut RefCache) -> Result<Module> {
352371
let schemas = &open_api
353372
.components
@@ -438,6 +457,22 @@ pub fn model_gen(reference: &str, open_api: &OpenAPI, ref_cache: &mut RefCache)
438457
) +
439458
line("}")
440459
) +
460+
line("}") +
461+
NewLine +
462+
line(unit() + "impl " + rust_name("crate::model", "MultipartField") + " for " + &name + "{") +
463+
indented(
464+
line(unit() + "fn to_multipart_field(&self) -> String {") +
465+
indented(
466+
line("self.to_string()")
467+
) +
468+
line("}") +
469+
NewLine +
470+
line(unit() + "fn mime_type(&self) -> &'static str {") +
471+
indented(
472+
line(r#""text/plain; charset=utf-8""#)
473+
) +
474+
line("}")
475+
) +
441476
line("}");
442477

443478
Ok(code)

0 commit comments

Comments
 (0)