44//!
55//! Utilities for generating Rust source code from internal Intermediate Representations (IR).
66//!
7- //! This module facilitates the transformation of `ParsedStruct` definitions—derived from OpenAPI
8- //! schemas or other sources— into valid, compilable Rust code. It handles:
7+ //! This module facilitates the transformation of `ParsedStruct` and `ParsedEnum` definitions
8+ //! into valid, compilable Rust code. It handles:
99//! - Dependency analysis (auto-injecting imports like `Uuid`, `chrono`, `serde`).
1010//! - Attribute injection (`derive`, `serde` options).
1111//! - Formatting and comments preservation.
1212
1313use crate :: error:: { AppError , AppResult } ;
14- use crate :: parser:: ParsedStruct ;
14+ use crate :: parser:: { ParsedEnum , ParsedModel , ParsedStruct } ;
1515use ra_ap_edition:: Edition ;
1616use ra_ap_syntax:: { ast, AstNode , SourceFile } ;
1717use std:: collections:: BTreeSet ;
1818
1919/// Creates a new AST `RecordField` node from strings.
20- ///
21- /// This is used primarily when patching existing source code to insert new fields.
22- /// By parsing a small wrapper struct, we ensure the generated field syntax is strictly valid.
23- ///
24- /// # Arguments
25- ///
26- /// * `name` - The name of the field (e.g., "email").
27- /// * `ty` - The Rust type string (e.g., `String`, `Option<i32>`).
28- /// * `pub_vis` - Whether the field should be public.
29- /// * `indent_size` - Indentation level (spaces) for formatting context.
30- ///
31- /// # Returns
32- ///
33- /// * `AppResult<ast::RecordField>` - The parsed AST node.
3420pub fn make_record_field (
3521 name : & str ,
3622 ty : & str ,
@@ -76,28 +62,20 @@ pub fn make_record_field(
7662 . ok_or_else ( || AppError :: General ( "Internal generation error: Field node not found" . into ( ) ) )
7763}
7864
79- /// Generates a complete Rust source string for multiple DTOs .
65+ /// Generates a complete Rust source string for multiple Models (Structs or Enums) .
8066///
81- /// This function aggregates all necessary imports for the set of structs
82- /// and writes them sequentially into a single string, suitable for writing to a `.rs` file.
83- ///
84- /// # Arguments
85- ///
86- /// * `dtos` - A slice of parsed struct definitions.
87- ///
88- /// # Returns
89- ///
90- /// * `String` - The complete source file content.
91- pub fn generate_dtos ( dtos : & [ ParsedStruct ] ) -> String {
67+ /// This function aggregates all necessary imports for the set of models
68+ /// and writes them sequentially into a single string.
69+ pub fn generate_dtos ( models : & [ ParsedModel ] ) -> String {
9270 let mut code = String :: new ( ) ;
9371 let mut imports = BTreeSet :: new ( ) ;
9472
95- // 1. Analyze imports for all structs
73+ // 1. Analyze imports for all models
9674 imports. insert ( "use serde::{Deserialize, Serialize};" . to_string ( ) ) ;
9775 imports. insert ( "use utoipa::ToSchema;" . to_string ( ) ) ;
9876
99- for dto in dtos {
100- collect_imports ( dto , & mut imports) ;
77+ for model in models {
78+ collect_imports ( model , & mut imports) ;
10179 }
10280
10381 // 2. Write Imports
@@ -107,22 +85,23 @@ pub fn generate_dtos(dtos: &[ParsedStruct]) -> String {
10785 }
10886 code. push ( '\n' ) ;
10987
110- // 3. Write Structs
111- for ( i, dto) in dtos. iter ( ) . enumerate ( ) {
112- code. push_str ( & generate_dto_body ( dto) ) ;
113- if i < dtos. len ( ) - 1 {
88+ // 3. Write Definitions
89+ for ( i, model) in models. iter ( ) . enumerate ( ) {
90+ match model {
91+ ParsedModel :: Struct ( s) => code. push_str ( & generate_dto_body ( s) ) ,
92+ ParsedModel :: Enum ( e) => code. push_str ( & generate_enum_body ( e) ) ,
93+ }
94+ if i < models. len ( ) - 1 {
11495 code. push ( '\n' ) ;
11596 }
11697 }
11798
11899 code
119100}
120101
121- /// Generates a Rust source string for a single DTO, including imports.
122- ///
123- /// Useful for generating individual snippets or single-struct files.
102+ /// Generates a Rust source string for a single struct, including imports.
124103pub fn generate_dto ( dto : & ParsedStruct ) -> String {
125- generate_dtos ( std :: slice :: from_ref ( dto) )
104+ generate_dtos ( & [ ParsedModel :: Struct ( dto. clone ( ) ) ] )
126105}
127106
128107/// Helper to generate the body of a single struct (without file-level imports).
@@ -175,22 +154,80 @@ fn generate_dto_body(dto: &ParsedStruct) -> String {
175154 code
176155}
177156
178- /// Analyzes a struct's fields to determine required imports.
179- fn collect_imports ( dto : & ParsedStruct , imports : & mut BTreeSet < String > ) {
180- for field in & dto. fields {
181- if field. ty . contains ( "Uuid" ) {
157+ /// Helper to generate the body of a single enum.
158+ fn generate_enum_body ( en : & ParsedEnum ) -> String {
159+ let mut code = String :: new ( ) ;
160+
161+ if let Some ( desc) = & en. description {
162+ for line in desc. lines ( ) {
163+ code. push_str ( & format ! ( "/// {}\n " , line) ) ;
164+ }
165+ }
166+
167+ code. push_str ( "#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]\n " ) ;
168+
169+ // Attributes
170+ let mut serde_attrs = Vec :: new ( ) ;
171+ if let Some ( rename) = & en. rename {
172+ serde_attrs. push ( format ! ( "rename = \" {}\" " , rename) ) ;
173+ }
174+ if let Some ( tag) = & en. tag {
175+ serde_attrs. push ( format ! ( "tag = \" {}\" " , tag) ) ;
176+ }
177+ if en. untagged {
178+ serde_attrs. push ( "untagged" . to_string ( ) ) ;
179+ }
180+
181+ if !serde_attrs. is_empty ( ) {
182+ code. push_str ( & format ! ( "#[serde({})]\n " , serde_attrs. join( ", " ) ) ) ;
183+ }
184+
185+ code. push_str ( & format ! ( "pub enum {} {{\n " , en. name) ) ;
186+
187+ for variant in & en. variants {
188+ if let Some ( desc) = & variant. description {
189+ for line in desc. lines ( ) {
190+ code. push_str ( & format ! ( " /// {}\n " , line) ) ;
191+ }
192+ }
193+
194+ if let Some ( r) = & variant. rename {
195+ code. push_str ( & format ! ( " #[serde(rename = \" {}\" )]\n " , r) ) ;
196+ }
197+
198+ if let Some ( ty) = & variant. ty {
199+ code. push_str ( & format ! ( " {}({}),\n " , variant. name, ty) ) ;
200+ } else {
201+ code. push_str ( & format ! ( " {},\n " , variant. name) ) ;
202+ }
203+ }
204+
205+ code. push_str ( "}\n " ) ;
206+ code
207+ }
208+
209+ /// Analyzes a model's fields to determine required imports.
210+ /// This handles flattened composition structs by checking all fields contained within.
211+ fn collect_imports ( model : & ParsedModel , imports : & mut BTreeSet < String > ) {
212+ let types: Vec < & String > = match model {
213+ ParsedModel :: Struct ( s) => s. fields . iter ( ) . map ( |f| & f. ty ) . collect ( ) ,
214+ ParsedModel :: Enum ( e) => e. variants . iter ( ) . filter_map ( |v| v. ty . as_ref ( ) ) . collect ( ) ,
215+ } ;
216+
217+ for ty in types {
218+ if ty. contains ( "Uuid" ) {
182219 imports. insert ( "use uuid::Uuid;" . to_string ( ) ) ;
183220 }
184- if field . ty . contains ( "DateTime" ) || field . ty . contains ( "NaiveDateTime" ) {
221+ if ty. contains ( "DateTime" ) || ty. contains ( "NaiveDateTime" ) {
185222 imports. insert ( "use chrono::{DateTime, NaiveDateTime, Utc};" . to_string ( ) ) ;
186223 }
187- if field . ty . contains ( "NaiveDate" ) && !field . ty . contains ( "NaiveDateTime" ) {
224+ if ty. contains ( "NaiveDate" ) && !ty. contains ( "NaiveDateTime" ) {
188225 imports. insert ( "use chrono::NaiveDate;" . to_string ( ) ) ;
189226 }
190- if field . ty . contains ( "Value" ) {
227+ if ty. contains ( "Value" ) {
191228 imports. insert ( "use serde_json::Value;" . to_string ( ) ) ;
192229 }
193- if field . ty . contains ( "Decimal" ) {
230+ if ty. contains ( "Decimal" ) {
194231 imports. insert ( "use rust_decimal::Decimal;" . to_string ( ) ) ;
195232 }
196233 }
@@ -199,9 +236,8 @@ fn collect_imports(dto: &ParsedStruct, imports: &mut BTreeSet<String>) {
199236#[ cfg( test) ]
200237mod tests {
201238 use super :: * ;
202- use crate :: parser:: ParsedField ;
239+ use crate :: parser:: { ParsedField , ParsedVariant } ;
203240
204- // Helper to create a basic parsed field
205241 fn field ( name : & str , ty : & str ) -> ParsedField {
206242 ParsedField {
207243 name : name. into ( ) ,
@@ -212,33 +248,6 @@ mod tests {
212248 }
213249 }
214250
215- #[ test]
216- fn test_make_record_field_basic ( ) {
217- let f = make_record_field ( "foo" , "i32" , true , 4 ) . unwrap ( ) ;
218- assert_eq ! ( f. to_string( ) . trim( ) , "pub foo: i32" ) ;
219- }
220-
221- #[ test]
222- fn test_make_record_field_private ( ) {
223- let f = make_record_field ( "bar" , "String" , false , 2 ) . unwrap ( ) ;
224- assert_eq ! ( f. to_string( ) . trim( ) , "bar: String" ) ;
225- }
226-
227- #[ test]
228- fn test_make_record_field_invalid_syntax ( ) {
229- let result = make_record_field ( "bad" , "::" , true , 4 ) ;
230- assert ! ( result. is_err( ) ) ;
231- }
232-
233- #[ test]
234- fn test_make_record_field_internal_error ( ) {
235- // This simulates a scenario where parsing passes but structure is wrong.
236- // Hard to trigger with `SourceFile` unless input is crafted to parse as non-struct.
237- // `struct Wrapper` template forces struct.
238- // We trust basic syntax tests cover the AST validity.
239- assert ! ( make_record_field( "ok" , "i32" , true , 4 ) . is_ok( ) ) ;
240- }
241-
242251 #[ test]
243252 fn test_generate_dto_simple ( ) {
244253 let dto = ParsedStruct {
@@ -251,90 +260,56 @@ mod tests {
251260 let code = generate_dto ( & dto) ;
252261 assert ! ( code. contains( "struct Simple" ) ) ;
253262 assert ! ( code. contains( "/// A simple struct" ) ) ;
254- assert ! ( code. contains( "use serde" ) ) ;
255- assert ! ( code. contains( "pub id: i32" ) ) ;
256263 assert ! ( code. contains( "#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]" ) ) ;
257264 }
258265
259266 #[ test]
260- fn test_generate_dto_imports ( ) {
261- let dto = ParsedStruct {
262- name : "Complex " . into ( ) ,
263- description : None ,
267+ fn test_generate_enum_tagged ( ) {
268+ let en = ParsedEnum {
269+ name : "Pet " . into ( ) ,
270+ description : Some ( "Polymorphic pet" . into ( ) ) ,
264271 rename : None ,
265- fields : vec ! [
266- field( "id" , "Uuid" ) ,
267- field( "t" , "DateTime<Utc>" ) ,
268- field( "d" , "NaiveDate" ) ,
269- field( "v" , "Option<serde_json::Value>" ) ,
270- field( "num" , "rust_decimal::Decimal" ) ,
271- ] ,
272- } ;
273-
274- let code = generate_dto ( & dto) ;
275- assert ! ( code. contains( "use uuid::Uuid;" ) ) ;
276- assert ! ( code. contains( "use chrono::{DateTime, NaiveDateTime, Utc};" ) ) ;
277- assert ! ( code. contains( "use chrono::NaiveDate;" ) ) ;
278- assert ! ( code. contains( "use serde_json::Value;" ) ) ;
279- assert ! ( code. contains( "use rust_decimal::Decimal;" ) ) ;
280- }
281-
282- #[ test]
283- fn test_generate_dto_attributes ( ) {
284- let dto = ParsedStruct {
285- name : "Renamed" . into ( ) ,
286- description : None ,
287- rename : Some ( "api_renamed" . into ( ) ) ,
288- fields : vec ! [
289- ParsedField {
290- name: "f1" . into( ) ,
291- ty: "i32" . into( ) ,
292- description: Some ( "Field doc" . into( ) ) ,
293- rename: Some ( "f_one" . into( ) ) ,
294- is_skipped: false ,
272+ tag : Some ( "type" . into ( ) ) ,
273+ untagged : false ,
274+ variants : vec ! [
275+ ParsedVariant {
276+ name: "Cat" . into( ) ,
277+ ty: Some ( "CatInfo" . into( ) ) ,
278+ description: None ,
279+ rename: Some ( "cat" . into( ) ) ,
295280 } ,
296- ParsedField {
297- name: "f2 " . into( ) ,
298- ty: "i32 ". into( ) ,
281+ ParsedVariant {
282+ name: "Dog " . into( ) ,
283+ ty: Some ( "DogInfo ". into( ) ) ,
299284 description: None ,
300- rename: None ,
301- is_skipped: true ,
285+ rename: Some ( "dog" . into( ) ) ,
302286 } ,
303287 ] ,
304288 } ;
305289
306- let code = generate_dto ( & dto) ;
307- assert ! ( code. contains( "#[serde(rename = \" api_renamed\" )]" ) ) ;
308- assert ! ( code. contains( "/// Field doc" ) ) ;
309- // Field 1: rename
310- assert ! ( code. contains( "#[serde(rename = \" f_one\" )]" ) ) ;
311- // Field 2: skip
312- assert ! ( code. contains( "#[serde(skip)]" ) ) ;
290+ let code = generate_dtos ( & [ ParsedModel :: Enum ( en) ] ) ;
291+ assert ! ( code. contains( "pub enum Pet" ) ) ;
292+ assert ! ( code. contains( "#[serde(tag = \" type\" )]" ) ) ;
293+ assert ! ( code. contains( " #[serde(rename = \" cat\" )]" ) ) ;
294+ assert ! ( code. contains( " Cat(CatInfo)," ) ) ;
313295 }
314296
315297 #[ test]
316- fn test_generate_dtos_multiple ( ) {
317- let dto1 = ParsedStruct {
318- name : "A" . into ( ) ,
319- description : None ,
320- rename : None ,
321- fields : vec ! [ field( "u" , "Uuid" ) ] ,
322- } ;
323- let dto2 = ParsedStruct {
324- name : "B" . into ( ) ,
298+ fn test_flattened_imports ( ) {
299+ // Simulating a struct that resulted from allOf flattening
300+ // It has a Uuid field (from Base) and a Value field (from Extension)
301+ let dto = ParsedStruct {
302+ name : "Merged" . into ( ) ,
325303 description : None ,
326304 rename : None ,
327- fields : vec ! [ field( "v " , "Value" ) ] ,
305+ fields : vec ! [ field( "id " , "Uuid" ) , field ( "meta" , "serde_json:: Value") ] ,
328306 } ;
329307
330- let code = generate_dtos ( & [ dto1, dto2] ) ;
331-
332- // Imports should be unified at the top
308+ let code = generate_dto ( & dto) ;
333309 assert ! ( code. contains( "use uuid::Uuid;" ) ) ;
334310 assert ! ( code. contains( "use serde_json::Value;" ) ) ;
335-
336- // Both structs should exist
337- assert ! ( code. contains( "pub struct A" ) ) ;
338- assert ! ( code. contains( "pub struct B" ) ) ;
311+ // Ensure struct body is valid
312+ assert ! ( code. contains( "pub id: Uuid," ) ) ;
313+ assert ! ( code. contains( "pub meta: serde_json::Value," ) ) ;
339314 }
340315}
0 commit comments