11use codegen_sdk_ast_generator:: { HasQuery , Symbol } ;
22use codegen_sdk_common:: Language ;
33use codegen_sdk_cst:: CSTDatabase ;
4+ use convert_case:: { Case , Casing } ;
5+ use pluralizer:: pluralize;
46use proc_macro2:: Span ;
57use quote:: { format_ident, quote} ;
68use syn:: { parse_quote, parse_quote_spanned} ;
79
810use super :: { cst:: generate_cst, helpers} ;
11+ fn get_category ( group : & str ) -> Vec < syn:: Stmt > {
12+ let category = format_ident ! ( "{}" , pluralize( group, 2 , false ) ) ;
13+ parse_quote ! {
14+ let file = self . file( py) ?;
15+ let db = self . codebase. get( py) . db( ) ;
16+ let category = file. #category( db) ;
17+ }
18+ }
19+ fn filter_symbols (
20+ nodes : & Vec < & codegen_sdk_ast_generator:: Symbol > ,
21+ group : & str ,
22+ ) -> Vec < codegen_sdk_ast_generator:: Symbol > {
23+ nodes
24+ . iter ( )
25+ . filter ( |node| node. category == pluralize ( group, 2 , false ) )
26+ . cloned ( )
27+ . cloned ( )
28+ . collect ( )
29+ }
30+ fn get_symbol_method_name ( group : & str ) -> syn:: Ident {
31+ let symbol_name = codegen_sdk_ast_generator:: get_symbol_name ( group) ;
32+ syn:: Ident :: new (
33+ & format ! ( "get_{}" , symbol_name. to_string( ) . to_case( Case :: Snake ) ) ,
34+ Span :: call_site ( ) ,
35+ )
36+ }
37+ // Generates the get_symbol method
38+ fn get_symbols_method ( group : & str ) -> Vec < syn:: Stmt > {
39+ let mut output: Vec < syn:: Stmt > = Vec :: new ( ) ;
40+ let symbol_name = codegen_sdk_ast_generator:: get_symbol_name ( group) ;
41+ let category = get_category ( group) ;
42+ let symbols_method = codegen_sdk_ast_generator:: get_symbols_method ( & symbol_name) ;
43+ let get_symbols_method = get_symbol_method_name ( group) ;
44+ output. extend :: < Vec < syn:: Stmt > > ( parse_quote ! {
45+ #[ pyo3( signature = ( name, optional=false ) ) ]
46+ pub fn #get_symbols_method( & self , py: Python <' _>, name: String , optional: bool ) -> PyResult <Option <#symbol_name>> {
47+ #( #category) *
48+ let subcategory = category. #symbols_method( db) ;
49+ let res = subcategory. get( & name) ;
50+ if let Some ( nodes) = res {
51+ if nodes. len( ) == 1 {
52+ Ok ( Some ( #symbol_name:: new( py. clone( ) , nodes[ 0 ] . fully_qualified_name( db) , 0 , & nodes[ 0 ] , self . codebase. clone( ) ) ) )
53+ } else {
54+ Err ( pyo3:: exceptions:: PyValueError :: new_err( format!( "Ambiguous symbol {} found {} possible matches" , name, nodes. len( ) ) ) )
55+ }
56+ } else {
57+ if optional {
58+ Ok ( None )
59+ } else {
60+ Err ( pyo3:: exceptions:: PyValueError :: new_err( format!( "No symbol {} found" , name) ) )
61+ }
62+ }
63+ }
64+ } ) ;
65+ output
66+ }
67+ fn symbols_method ( group : & str ) -> Vec < syn:: Stmt > {
68+ let mut output: Vec < syn:: Stmt > = Vec :: new ( ) ;
69+ let symbol_name = codegen_sdk_ast_generator:: get_symbol_name ( group) ;
70+ let category = get_category ( group) ;
71+ let symbols_method = codegen_sdk_ast_generator:: get_symbols_method ( & symbol_name) ;
72+ output. extend :: < Vec < syn:: Stmt > > ( parse_quote ! {
73+ #[ getter]
74+ pub fn #symbols_method( & self , py: Python <' _>) -> PyResult <Vec <#symbol_name>> {
75+ #( #category) *
76+ let subcategory = category. #symbols_method( db) ;
77+ let nodes = subcategory. values( ) . map( |values| values. into_iter( ) . enumerate( ) . map( |( idx, node) | #symbol_name:: new( py. clone( ) , node. fully_qualified_name( db) , idx, node, self . codebase. clone( ) ) ) ) . flatten( ) . collect( ) ;
78+ Ok ( nodes)
79+ }
80+ } ) ;
81+ output
82+ }
983fn generate_file_struct (
1084 language : & Language ,
1185 symbols : Vec < & Symbol > ,
@@ -47,6 +121,13 @@ fn generate_file_struct(
47121 . filter ( |symbol| symbol. category != symbol. subcategory )
48122 . map ( |symbol| vec ! [ symbol. py_file_getter( ) , symbol. py_file_get( ) ] )
49123 . flatten ( ) ;
124+ let mut symbols_methods = Vec :: new ( ) ;
125+ for group in codegen_sdk_ast_generator:: GROUPS {
126+ if filter_symbols ( & symbols, group) . len ( ) > 0 {
127+ symbols_methods. extend ( symbols_method ( group) ) ;
128+ symbols_methods. extend ( get_symbols_method ( group) ) ;
129+ }
130+ }
50131 output. push ( parse_quote ! {
51132 #[ pymethods]
52133 impl #struct_name {
@@ -70,6 +151,7 @@ fn generate_file_struct(
70151 Ok ( self . content( py) ?. to_string( ) )
71152 }
72153 #( #methods) *
154+ #( #symbols_methods) *
73155 }
74156 } ) ;
75157 Ok ( output)
@@ -158,6 +240,54 @@ fn generate_symbol_struct(
158240 } ) ;
159241 Ok ( output)
160242}
243+ // Generate an enum for all possible symbols. (IE Symbol => Class, Function, etc)
244+ fn generate_symbol_enum (
245+ language : & Language ,
246+ symbols : Vec < & Symbol > ,
247+ group : & str ,
248+ ) -> anyhow:: Result < Vec < syn:: Stmt > > {
249+ let symbols: Vec < _ > = filter_symbols ( & symbols, group)
250+ . iter ( )
251+ . map ( |symbol| format_ident ! ( "{}" , symbol. name) )
252+ . collect ( ) ;
253+ if symbols. len ( ) == 0 {
254+ return Ok ( Vec :: new ( ) ) ;
255+ }
256+ let symbol_name = codegen_sdk_ast_generator:: get_symbol_name ( group) ;
257+ let span = Span :: call_site ( ) ;
258+ let mut output = Vec :: new ( ) ;
259+ let enum_name = format_ident ! ( "{}" , symbol_name) ;
260+ let package_name = syn:: Ident :: new ( & language. package_name ( ) , span) ;
261+ output. push ( parse_quote_spanned ! {
262+ span =>
263+ #[ derive( IntoPyObject ) ]
264+ pub enum #enum_name {
265+ #( #symbols( #symbols) , ) *
266+ }
267+ } ) ;
268+ let original_name = quote ! { codegen_sdk_analyzer:: #package_name:: ast:: #symbol_name } ;
269+ let matchers: Vec < syn:: Arm > = symbols
270+ . iter ( )
271+ . map ( |symbol| {
272+ let symbol_name = format_ident ! ( "{}" , symbol) ;
273+ parse_quote_spanned ! {
274+ span =>
275+ #original_name:: #symbol_name( _) => Self :: #symbol_name( #symbol_name:: new( id, idx, codebase_arc) ) ,
276+ }
277+ } )
278+ . collect ( ) ;
279+ output. push ( parse_quote_spanned ! {
280+ span =>
281+ impl #enum_name {
282+ pub fn new( py: Python <' _>, id: codegen_sdk_resolution:: FullyQualifiedName , idx: usize , node: & #original_name<' _>, codebase_arc: Arc <GILProtected <codegen_sdk_analyzer:: Codebase >>) -> Self {
283+ match node {
284+ #( #matchers) *
285+ }
286+ }
287+ }
288+ } ) ;
289+ Ok ( output)
290+ }
161291fn generate_module (
162292 language : & Language ,
163293 symbols : Vec < syn:: Ident > ,
@@ -167,6 +297,7 @@ fn generate_module(
167297 let register_name = format_ident ! ( "register_{}" , language_name) ;
168298 let struct_name = format_ident ! ( "{}" , language. file_struct_name( ) ) ;
169299 let module_name = format ! ( "codegen_sdk_pink.{}" , language_name) ;
300+
170301 output. push ( parse_quote ! {
171302 pub fn #register_name( py: Python <' _>, parent_module: & Bound <' _, PyModule >) -> PyResult <( ) > {
172303 let child_module = PyModule :: new( parent_module. py( ) , #language_name) ?;
@@ -194,6 +325,10 @@ pub(crate) fn generate_bindings(language: &Language) -> anyhow::Result<Vec<syn::
194325 let cst = generate_cst ( language, & state) ?;
195326 output. extend ( cst) ;
196327 let mut symbol_idents = Vec :: new ( ) ;
328+ for group in codegen_sdk_ast_generator:: GROUPS {
329+ let symbol_enum = generate_symbol_enum ( language, symbols. values ( ) . collect ( ) , group) ?;
330+ output. extend ( symbol_enum) ;
331+ }
197332 for ( _, symbol) in symbols {
198333 let symbol_struct = generate_symbol_struct ( language, & symbol) ?;
199334 output. extend ( symbol_struct) ;
0 commit comments