Skip to content

Commit ce37ef5

Browse files
committed
Improve OpenAPI 3.2.0 compliance
1 parent 9f49b09 commit ce37ef5

23 files changed

+2603
-1193
lines changed

core/src/codegen.rs

Lines changed: 117 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,19 @@
44
//!
55
//! Utilities for generating Rust source code from internal Intermediate Representations (IR).
66
//!
7-
//! This module facilitates the transformation of `ParsedStruct` definitions—derived from OpenAPI
8-
//! schemas or other sources—into valid, compilable Rust code. It handles:
7+
//! This module facilitates the transformation of `ParsedStruct` and `ParsedEnum` definitions
8+
//! into valid, compilable Rust code. It handles:
99
//! - Dependency analysis (auto-injecting imports like `Uuid`, `chrono`, `serde`).
1010
//! - Attribute injection (`derive`, `serde` options).
1111
//! - Formatting and comments preservation.
1212
1313
use crate::error::{AppError, AppResult};
14-
use crate::parser::ParsedStruct;
14+
use crate::parser::{ParsedEnum, ParsedModel, ParsedStruct};
1515
use ra_ap_edition::Edition;
1616
use ra_ap_syntax::{ast, AstNode, SourceFile};
1717
use std::collections::BTreeSet;
1818

1919
/// Creates a new AST `RecordField` node from strings.
20-
///
21-
/// This is used primarily when patching existing source code to insert new fields.
22-
/// By parsing a small wrapper struct, we ensure the generated field syntax is strictly valid.
23-
///
24-
/// # Arguments
25-
///
26-
/// * `name` - The name of the field (e.g., "email").
27-
/// * `ty` - The Rust type string (e.g., `String`, `Option<i32>`).
28-
/// * `pub_vis` - Whether the field should be public.
29-
/// * `indent_size` - Indentation level (spaces) for formatting context.
30-
///
31-
/// # Returns
32-
///
33-
/// * `AppResult<ast::RecordField>` - The parsed AST node.
3420
pub fn make_record_field(
3521
name: &str,
3622
ty: &str,
@@ -76,28 +62,20 @@ pub fn make_record_field(
7662
.ok_or_else(|| AppError::General("Internal generation error: Field node not found".into()))
7763
}
7864

79-
/// Generates a complete Rust source string for multiple DTOs.
65+
/// Generates a complete Rust source string for multiple Models (Structs or Enums).
8066
///
81-
/// This function aggregates all necessary imports for the set of structs
82-
/// and writes them sequentially into a single string, suitable for writing to a `.rs` file.
83-
///
84-
/// # Arguments
85-
///
86-
/// * `dtos` - A slice of parsed struct definitions.
87-
///
88-
/// # Returns
89-
///
90-
/// * `String` - The complete source file content.
91-
pub fn generate_dtos(dtos: &[ParsedStruct]) -> String {
67+
/// This function aggregates all necessary imports for the set of models
68+
/// and writes them sequentially into a single string.
69+
pub fn generate_dtos(models: &[ParsedModel]) -> String {
9270
let mut code = String::new();
9371
let mut imports = BTreeSet::new();
9472

95-
// 1. Analyze imports for all structs
73+
// 1. Analyze imports for all models
9674
imports.insert("use serde::{Deserialize, Serialize};".to_string());
9775
imports.insert("use utoipa::ToSchema;".to_string());
9876

99-
for dto in dtos {
100-
collect_imports(dto, &mut imports);
77+
for model in models {
78+
collect_imports(model, &mut imports);
10179
}
10280

10381
// 2. Write Imports
@@ -107,22 +85,23 @@ pub fn generate_dtos(dtos: &[ParsedStruct]) -> String {
10785
}
10886
code.push('\n');
10987

110-
// 3. Write Structs
111-
for (i, dto) in dtos.iter().enumerate() {
112-
code.push_str(&generate_dto_body(dto));
113-
if i < dtos.len() - 1 {
88+
// 3. Write Definitions
89+
for (i, model) in models.iter().enumerate() {
90+
match model {
91+
ParsedModel::Struct(s) => code.push_str(&generate_dto_body(s)),
92+
ParsedModel::Enum(e) => code.push_str(&generate_enum_body(e)),
93+
}
94+
if i < models.len() - 1 {
11495
code.push('\n');
11596
}
11697
}
11798

11899
code
119100
}
120101

121-
/// Generates a Rust source string for a single DTO, including imports.
122-
///
123-
/// Useful for generating individual snippets or single-struct files.
102+
/// Generates a Rust source string for a single struct, including imports.
124103
pub fn generate_dto(dto: &ParsedStruct) -> String {
125-
generate_dtos(std::slice::from_ref(dto))
104+
generate_dtos(&[ParsedModel::Struct(dto.clone())])
126105
}
127106

128107
/// Helper to generate the body of a single struct (without file-level imports).
@@ -175,22 +154,80 @@ fn generate_dto_body(dto: &ParsedStruct) -> String {
175154
code
176155
}
177156

178-
/// Analyzes a struct's fields to determine required imports.
179-
fn collect_imports(dto: &ParsedStruct, imports: &mut BTreeSet<String>) {
180-
for field in &dto.fields {
181-
if field.ty.contains("Uuid") {
157+
/// Helper to generate the body of a single enum.
158+
fn generate_enum_body(en: &ParsedEnum) -> String {
159+
let mut code = String::new();
160+
161+
if let Some(desc) = &en.description {
162+
for line in desc.lines() {
163+
code.push_str(&format!("/// {}\n", line));
164+
}
165+
}
166+
167+
code.push_str("#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]\n");
168+
169+
// Attributes
170+
let mut serde_attrs = Vec::new();
171+
if let Some(rename) = &en.rename {
172+
serde_attrs.push(format!("rename = \"{}\"", rename));
173+
}
174+
if let Some(tag) = &en.tag {
175+
serde_attrs.push(format!("tag = \"{}\"", tag));
176+
}
177+
if en.untagged {
178+
serde_attrs.push("untagged".to_string());
179+
}
180+
181+
if !serde_attrs.is_empty() {
182+
code.push_str(&format!("#[serde({})]\n", serde_attrs.join(", ")));
183+
}
184+
185+
code.push_str(&format!("pub enum {} {{\n", en.name));
186+
187+
for variant in &en.variants {
188+
if let Some(desc) = &variant.description {
189+
for line in desc.lines() {
190+
code.push_str(&format!(" /// {}\n", line));
191+
}
192+
}
193+
194+
if let Some(r) = &variant.rename {
195+
code.push_str(&format!(" #[serde(rename = \"{}\")]\n", r));
196+
}
197+
198+
if let Some(ty) = &variant.ty {
199+
code.push_str(&format!(" {}({}),\n", variant.name, ty));
200+
} else {
201+
code.push_str(&format!(" {},\n", variant.name));
202+
}
203+
}
204+
205+
code.push_str("}\n");
206+
code
207+
}
208+
209+
/// Analyzes a model's fields to determine required imports.
210+
/// This handles flattened composition structs by checking all fields contained within.
211+
fn collect_imports(model: &ParsedModel, imports: &mut BTreeSet<String>) {
212+
let types: Vec<&String> = match model {
213+
ParsedModel::Struct(s) => s.fields.iter().map(|f| &f.ty).collect(),
214+
ParsedModel::Enum(e) => e.variants.iter().filter_map(|v| v.ty.as_ref()).collect(),
215+
};
216+
217+
for ty in types {
218+
if ty.contains("Uuid") {
182219
imports.insert("use uuid::Uuid;".to_string());
183220
}
184-
if field.ty.contains("DateTime") || field.ty.contains("NaiveDateTime") {
221+
if ty.contains("DateTime") || ty.contains("NaiveDateTime") {
185222
imports.insert("use chrono::{DateTime, NaiveDateTime, Utc};".to_string());
186223
}
187-
if field.ty.contains("NaiveDate") && !field.ty.contains("NaiveDateTime") {
224+
if ty.contains("NaiveDate") && !ty.contains("NaiveDateTime") {
188225
imports.insert("use chrono::NaiveDate;".to_string());
189226
}
190-
if field.ty.contains("Value") {
227+
if ty.contains("Value") {
191228
imports.insert("use serde_json::Value;".to_string());
192229
}
193-
if field.ty.contains("Decimal") {
230+
if ty.contains("Decimal") {
194231
imports.insert("use rust_decimal::Decimal;".to_string());
195232
}
196233
}
@@ -199,9 +236,8 @@ fn collect_imports(dto: &ParsedStruct, imports: &mut BTreeSet<String>) {
199236
#[cfg(test)]
200237
mod tests {
201238
use super::*;
202-
use crate::parser::ParsedField;
239+
use crate::parser::{ParsedField, ParsedVariant};
203240

204-
// Helper to create a basic parsed field
205241
fn field(name: &str, ty: &str) -> ParsedField {
206242
ParsedField {
207243
name: name.into(),
@@ -212,33 +248,6 @@ mod tests {
212248
}
213249
}
214250

215-
#[test]
216-
fn test_make_record_field_basic() {
217-
let f = make_record_field("foo", "i32", true, 4).unwrap();
218-
assert_eq!(f.to_string().trim(), "pub foo: i32");
219-
}
220-
221-
#[test]
222-
fn test_make_record_field_private() {
223-
let f = make_record_field("bar", "String", false, 2).unwrap();
224-
assert_eq!(f.to_string().trim(), "bar: String");
225-
}
226-
227-
#[test]
228-
fn test_make_record_field_invalid_syntax() {
229-
let result = make_record_field("bad", "::", true, 4);
230-
assert!(result.is_err());
231-
}
232-
233-
#[test]
234-
fn test_make_record_field_internal_error() {
235-
// This simulates a scenario where parsing passes but structure is wrong.
236-
// Hard to trigger with `SourceFile` unless input is crafted to parse as non-struct.
237-
// `struct Wrapper` template forces struct.
238-
// We trust basic syntax tests cover the AST validity.
239-
assert!(make_record_field("ok", "i32", true, 4).is_ok());
240-
}
241-
242251
#[test]
243252
fn test_generate_dto_simple() {
244253
let dto = ParsedStruct {
@@ -251,90 +260,56 @@ mod tests {
251260
let code = generate_dto(&dto);
252261
assert!(code.contains("struct Simple"));
253262
assert!(code.contains("/// A simple struct"));
254-
assert!(code.contains("use serde"));
255-
assert!(code.contains("pub id: i32"));
256263
assert!(code.contains("#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]"));
257264
}
258265

259266
#[test]
260-
fn test_generate_dto_imports() {
261-
let dto = ParsedStruct {
262-
name: "Complex".into(),
263-
description: None,
267+
fn test_generate_enum_tagged() {
268+
let en = ParsedEnum {
269+
name: "Pet".into(),
270+
description: Some("Polymorphic pet".into()),
264271
rename: None,
265-
fields: vec![
266-
field("id", "Uuid"),
267-
field("t", "DateTime<Utc>"),
268-
field("d", "NaiveDate"),
269-
field("v", "Option<serde_json::Value>"),
270-
field("num", "rust_decimal::Decimal"),
271-
],
272-
};
273-
274-
let code = generate_dto(&dto);
275-
assert!(code.contains("use uuid::Uuid;"));
276-
assert!(code.contains("use chrono::{DateTime, NaiveDateTime, Utc};"));
277-
assert!(code.contains("use chrono::NaiveDate;"));
278-
assert!(code.contains("use serde_json::Value;"));
279-
assert!(code.contains("use rust_decimal::Decimal;"));
280-
}
281-
282-
#[test]
283-
fn test_generate_dto_attributes() {
284-
let dto = ParsedStruct {
285-
name: "Renamed".into(),
286-
description: None,
287-
rename: Some("api_renamed".into()),
288-
fields: vec![
289-
ParsedField {
290-
name: "f1".into(),
291-
ty: "i32".into(),
292-
description: Some("Field doc".into()),
293-
rename: Some("f_one".into()),
294-
is_skipped: false,
272+
tag: Some("type".into()),
273+
untagged: false,
274+
variants: vec![
275+
ParsedVariant {
276+
name: "Cat".into(),
277+
ty: Some("CatInfo".into()),
278+
description: None,
279+
rename: Some("cat".into()),
295280
},
296-
ParsedField {
297-
name: "f2".into(),
298-
ty: "i32".into(),
281+
ParsedVariant {
282+
name: "Dog".into(),
283+
ty: Some("DogInfo".into()),
299284
description: None,
300-
rename: None,
301-
is_skipped: true,
285+
rename: Some("dog".into()),
302286
},
303287
],
304288
};
305289

306-
let code = generate_dto(&dto);
307-
assert!(code.contains("#[serde(rename = \"api_renamed\")]"));
308-
assert!(code.contains("/// Field doc"));
309-
// Field 1: rename
310-
assert!(code.contains("#[serde(rename = \"f_one\")]"));
311-
// Field 2: skip
312-
assert!(code.contains("#[serde(skip)]"));
290+
let code = generate_dtos(&[ParsedModel::Enum(en)]);
291+
assert!(code.contains("pub enum Pet"));
292+
assert!(code.contains("#[serde(tag = \"type\")]"));
293+
assert!(code.contains(" #[serde(rename = \"cat\")]"));
294+
assert!(code.contains(" Cat(CatInfo),"));
313295
}
314296

315297
#[test]
316-
fn test_generate_dtos_multiple() {
317-
let dto1 = ParsedStruct {
318-
name: "A".into(),
319-
description: None,
320-
rename: None,
321-
fields: vec![field("u", "Uuid")],
322-
};
323-
let dto2 = ParsedStruct {
324-
name: "B".into(),
298+
fn test_flattened_imports() {
299+
// Simulating a struct that resulted from allOf flattening
300+
// It has a Uuid field (from Base) and a Value field (from Extension)
301+
let dto = ParsedStruct {
302+
name: "Merged".into(),
325303
description: None,
326304
rename: None,
327-
fields: vec![field("v", "Value")],
305+
fields: vec![field("id", "Uuid"), field("meta", "serde_json::Value")],
328306
};
329307

330-
let code = generate_dtos(&[dto1, dto2]);
331-
332-
// Imports should be unified at the top
308+
let code = generate_dto(&dto);
333309
assert!(code.contains("use uuid::Uuid;"));
334310
assert!(code.contains("use serde_json::Value;"));
335-
336-
// Both structs should exist
337-
assert!(code.contains("pub struct A"));
338-
assert!(code.contains("pub struct B"));
311+
// Ensure struct body is valid
312+
assert!(code.contains("pub id: Uuid,"));
313+
assert!(code.contains("pub meta: serde_json::Value,"));
339314
}
340315
}

0 commit comments

Comments
 (0)