1- use codegen_sdk_common:: Language ;
1+ use codegen_sdk_common:: { Language , naming :: normalize_type_name } ;
22use proc_macro2:: Span ;
3- use quote:: format_ident;
3+ use quote:: { format_ident, quote } ;
44use syn:: parse_quote;
55
66use super :: helpers;
@@ -19,7 +19,7 @@ fn generate_cst_struct(
1919 codebase: Arc <GILProtected <codegen_sdk_analyzer:: Codebase >>,
2020 }
2121 } ) ;
22- let file_getter = helpers:: get_file ( language) ;
22+ let file_getter = helpers:: get_file ( language, quote ! { self . id } , quote ! { self . codebase } ) ;
2323 output. push ( parse_quote ! {
2424 impl #struct_name {
2525 pub fn new( id: codegen_sdk_common:: CSTNodeTreeId , codebase: Arc <GILProtected <codegen_sdk_analyzer:: Codebase >>) -> Self {
@@ -87,6 +87,65 @@ fn generate_cst_struct(
8787 } ) ;
8888 Ok ( output)
8989}
90+ fn generate_cst_subenum (
91+ language : & Language ,
92+ state : & codegen_sdk_cst_generator:: State ,
93+ name : & str ,
94+ ) -> anyhow:: Result < Vec < syn:: Stmt > > {
95+ let mut output = Vec :: new ( ) ;
96+ let struct_name = format_ident ! ( "{}" , normalize_type_name( name, true ) ) ;
97+ let package_name = syn:: Ident :: new ( & language. package_name ( ) , Span :: call_site ( ) ) ;
98+ let module_name = format ! ( "codegen_sdk_pink::{}.cst" , language. name( ) ) ;
99+ let subenum_names = state
100+ . get_subenum_variants ( & name, false )
101+ . iter ( )
102+ . map ( |name| {
103+ let name = format_ident ! ( "{}" , name. normalize_name( ) ) ;
104+ parse_quote ! {
105+ #name( #name)
106+ }
107+ } )
108+ . collect :: < Vec < syn:: Variant > > ( ) ;
109+ output. push ( parse_quote ! {
110+ #[ derive( IntoPyObject ) ]
111+ pub enum #struct_name {
112+ #( #subenum_names, ) *
113+ }
114+ } ) ;
115+ let package_name = syn:: Ident :: new ( & language. package_name ( ) , Span :: call_site ( ) ) ;
116+ let ref_name = syn:: Ident :: new (
117+ & format ! ( "{}Ref" , struct_name. to_string( ) ) ,
118+ Span :: call_site ( ) ,
119+ ) ;
120+ let matchers: Vec < syn:: Arm > = state
121+ . get_subenum_variants ( & name, false )
122+ . iter ( )
123+ . map ( |node| {
124+ let name = format_ident ! ( "{}" , node. normalize_name( ) ) ;
125+ parse_quote ! {
126+ codegen_sdk_analyzer:: #package_name:: cst:: #ref_name:: #name( _) => Ok ( Self :: #name( #name:: new( id, codebase_arc. clone( ) ) ) ) ,
127+ }
128+ } )
129+ . collect ( ) ;
130+ let get_file = helpers:: get_file ( language, quote ! { id } , quote ! { codebase_arc } ) ;
131+ output. push ( parse_quote ! {
132+ impl #struct_name {
133+ pub fn new( py: Python <' _>, id: codegen_sdk_common:: CSTNodeTreeId , codebase_arc: Arc <GILProtected <codegen_sdk_analyzer:: Codebase >>) -> PyResult <Self > {
134+ #( #get_file) *
135+ let node = file. tree( codebase. db( ) ) . get( id. id( codebase. db( ) ) ) ;
136+ if let Some ( node) = node {
137+ match node. as_ref( ) . try_into( ) . unwrap( ) {
138+ #( #matchers) *
139+ }
140+ } else {
141+ Err ( pyo3:: exceptions:: PyValueError :: new_err( "Node not found" ) )
142+ }
143+ }
144+ }
145+ } ) ;
146+ Ok ( output)
147+ }
148+
90149fn generate_module ( state : & codegen_sdk_cst_generator:: State ) -> anyhow:: Result < Vec < syn:: Stmt > > {
91150 let mut output = Vec :: new ( ) ;
92151 let node_names = state. get_node_struct_names ( ) ;
@@ -110,6 +169,10 @@ pub fn generate_cst(
110169 let cst_struct = generate_cst_struct ( language, node) ?;
111170 output. extend ( cst_struct) ;
112171 }
172+ for subenum in & state. subenums {
173+ let cst_subenum = generate_cst_subenum ( language, state, subenum) ?;
174+ output. extend ( cst_subenum) ;
175+ }
113176 output. extend ( generate_module ( state) ?) ;
114177 Ok ( parse_quote ! {
115178 mod cst {
0 commit comments