Skip to content

Commit 36215ca

Browse files
committed
Subenum generation
1 parent 0259ff5 commit 36215ca

File tree

5 files changed

+81
-16
lines changed

5 files changed

+81
-16
lines changed

codegen-bindings-generator/src/python/cst.rs

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use codegen_sdk_common::Language;
1+
use codegen_sdk_common::{Language, naming::normalize_type_name};
22
use proc_macro2::Span;
3-
use quote::format_ident;
3+
use quote::{format_ident, quote};
44
use syn::parse_quote;
55

66
use super::helpers;
@@ -19,7 +19,7 @@ fn generate_cst_struct(
1919
codebase: Arc<GILProtected<codegen_sdk_analyzer::Codebase>>,
2020
}
2121
});
22-
let file_getter = helpers::get_file(language);
22+
let file_getter = helpers::get_file(language, quote! { self.id }, quote! { self.codebase });
2323
output.push(parse_quote! {
2424
impl #struct_name {
2525
pub fn new(id: codegen_sdk_common::CSTNodeTreeId, codebase: Arc<GILProtected<codegen_sdk_analyzer::Codebase>>) -> Self {
@@ -87,6 +87,65 @@ fn generate_cst_struct(
8787
});
8888
Ok(output)
8989
}
90+
fn generate_cst_subenum(
91+
language: &Language,
92+
state: &codegen_sdk_cst_generator::State,
93+
name: &str,
94+
) -> anyhow::Result<Vec<syn::Stmt>> {
95+
let mut output = Vec::new();
96+
let struct_name = format_ident!("{}", normalize_type_name(name, true));
97+
let package_name = syn::Ident::new(&language.package_name(), Span::call_site());
98+
let module_name = format!("codegen_sdk_pink::{}.cst", language.name());
99+
let subenum_names = state
100+
.get_subenum_variants(&name, false)
101+
.iter()
102+
.map(|name| {
103+
let name = format_ident!("{}", name.normalize_name());
104+
parse_quote! {
105+
#name(#name)
106+
}
107+
})
108+
.collect::<Vec<syn::Variant>>();
109+
output.push(parse_quote! {
110+
#[derive(IntoPyObject)]
111+
pub enum #struct_name {
112+
#(#subenum_names,)*
113+
}
114+
});
115+
let package_name = syn::Ident::new(&language.package_name(), Span::call_site());
116+
let ref_name = syn::Ident::new(
117+
&format!("{}Ref", struct_name.to_string()),
118+
Span::call_site(),
119+
);
120+
let matchers: Vec<syn::Arm> = state
121+
.get_subenum_variants(&name, false)
122+
.iter()
123+
.map(|node| {
124+
let name = format_ident!("{}", node.normalize_name());
125+
parse_quote! {
126+
codegen_sdk_analyzer::#package_name::cst::#ref_name::#name(_) => Ok(Self::#name(#name::new(id, codebase_arc.clone()))),
127+
}
128+
})
129+
.collect();
130+
let get_file = helpers::get_file(language, quote! { id }, quote! { codebase_arc });
131+
output.push(parse_quote! {
132+
impl #struct_name {
133+
pub fn new(py: Python<'_>, id: codegen_sdk_common::CSTNodeTreeId, codebase_arc: Arc<GILProtected<codegen_sdk_analyzer::Codebase>>) -> PyResult<Self> {
134+
#(#get_file)*
135+
let node = file.tree(codebase.db()).get(id.id(codebase.db()));
136+
if let Some(node) = node {
137+
match node.as_ref().try_into().unwrap() {
138+
#(#matchers)*
139+
}
140+
} else {
141+
Err(pyo3::exceptions::PyValueError::new_err("Node not found"))
142+
}
143+
}
144+
}
145+
});
146+
Ok(output)
147+
}
148+
90149
fn generate_module(state: &codegen_sdk_cst_generator::State) -> anyhow::Result<Vec<syn::Stmt>> {
91150
let mut output = Vec::new();
92151
let node_names = state.get_node_struct_names();
@@ -110,6 +169,10 @@ pub fn generate_cst(
110169
let cst_struct = generate_cst_struct(language, node)?;
111170
output.extend(cst_struct);
112171
}
172+
for subenum in &state.subenums {
173+
let cst_subenum = generate_cst_subenum(language, state, subenum)?;
174+
output.extend(cst_subenum);
175+
}
113176
output.extend(generate_module(state)?);
114177
Ok(parse_quote! {
115178
mod cst {

codegen-bindings-generator/src/python/generator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use codegen_sdk_ast_generator::{HasQuery, Symbol};
22
use codegen_sdk_common::Language;
33
use codegen_sdk_cst::CSTDatabase;
44
use proc_macro2::Span;
5-
use quote::format_ident;
5+
use quote::{format_ident, quote};
66
use syn::{parse_quote, parse_quote_spanned};
77

88
use super::{cst::generate_cst, helpers};
@@ -92,7 +92,7 @@ fn generate_symbol_struct(
9292
codebase: Arc<GILProtected<codegen_sdk_analyzer::Codebase>>,
9393
}
9494
});
95-
let file_getter = helpers::get_file(language);
95+
let file_getter = helpers::get_file(language, quote! { self.id }, quote! { self.codebase });
9696
let category = syn::Ident::new(&symbol.category, span);
9797
let subcategory = syn::Ident::new(&symbol.subcategory, span);
9898
output.push(parse_quote_spanned! {

codegen-bindings-generator/src/python/helpers.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use codegen_sdk_common::Language;
2-
use proc_macro2::Span;
2+
use proc_macro2::{Span, TokenStream};
33
use syn::parse_quote_spanned;
4-
pub fn get_file(language: &Language) -> Vec<syn::Stmt> {
4+
pub fn get_file(language: &Language, id: TokenStream, codebase: TokenStream) -> Vec<syn::Stmt> {
55
let span = Span::call_site();
66
let variant_name = syn::Ident::new(&language.struct_name, span);
77
parse_quote_spanned! {
88
span =>
9-
let codebase = self.codebase.get(py);
10-
let path = self.id.file(codebase.db());
9+
let codebase = #codebase.get(py);
10+
let path = #id.file(codebase.db());
1111
let file = codebase.get_file_for_id(path);
1212
let file = match file {
1313
Some(codegen_sdk_analyzer::ParsedFile::#variant_name(py)) => py,

codegen-sdk-ast-generator/src/query.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ impl<'a> Query<'a> {
619619
);
620620
matchers.extend_one(matcher);
621621
} else {
622-
let subenum = self.state.get_subenum_variants(&first_node.source());
622+
let subenum = self.state.get_subenum_variants(&first_node.source(), false);
623623
log::info!(
624624
"subenum {} with {} variants",
625625
first_node.source(),

codegen-sdk-cst-generator/src/generator/state.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,17 +153,19 @@ impl<'a> State<'a> {
153153
}
154154
pub fn get_variants(&self, subenum: &str, include_comment: bool) -> Vec<TypeDefinition> {
155155
let mut variants = Vec::new();
156-
if include_comment {
157-
let comment = get_comment_type();
158-
variants.push(comment);
159-
}
160156
for node in self.nodes.values() {
161157
log::debug!("Checking subenum: {} for {}", subenum, node.kind());
162158
if node.subenums.contains(&subenum.to_string()) {
163159
log::debug!("Found variant: {} for {}", node.kind(), subenum);
164160
variants.push(node.type_definition());
165161
}
166162
}
163+
if include_comment {
164+
let comment = get_comment_type();
165+
if !variants.iter().any(|v| v.type_name == comment.type_name) {
166+
variants.push(comment);
167+
}
168+
}
167169
variants
168170
}
169171
fn get_variant_map(&self, enum_name: &str) -> BTreeMap<u16, TokenStream> {
@@ -290,8 +292,8 @@ impl<'a> State<'a> {
290292
}
291293
None
292294
}
293-
pub fn get_subenum_variants(&self, name: &str) -> Vec<&Node<'a>> {
294-
let variants = self.get_variants(name, true);
295+
pub fn get_subenum_variants(&self, name: &str, include_comment: bool) -> Vec<&Node<'a>> {
296+
let variants = self.get_variants(name, include_comment);
295297
let mut nodes = Vec::new();
296298
for variant in variants {
297299
if let Some(node) = self.get_node_for_struct_name(&variant.normalize()) {

0 commit comments

Comments
 (0)