Skip to content

Commit 42f83f4

Browse files
authored
feat: file.symbols and file.get_symbol (#53)
* Use supertypes * Fix path to wheels * Refactor * Revert change * Add symbols * Generate get_symbol and symbols * Fix bugs * Don't add symbol enums to the module * Remove unused import * update snapshots
1 parent 96be6a4 commit 42f83f4

20 files changed

+940
-107
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ indextree = { version = "4.7.3", features = ["std"], default-features = false }
141141
thiserror = "2.0.11"
142142
indexmap = "2"
143143
smallvec = "1.11.0"
144+
pluralizer = "0.5.0"
144145

145146

146147
[profile.dev]

codegen-bindings-generator/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ syn = { workspace = true }
1919
anyhow = { workspace = true }
2020
proc-macro2 = { workspace = true }
2121
codegen-sdk-cst = { workspace = true }
22+
pluralizer = "0.5.0"
23+
convert_case = { workspace = true }
2224
[dev-dependencies]
2325
test-log = { workspace = true }
2426
rstest = { workspace = true }

codegen-bindings-generator/src/python/generator.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,85 @@
11
use codegen_sdk_ast_generator::{HasQuery, Symbol};
22
use codegen_sdk_common::Language;
33
use codegen_sdk_cst::CSTDatabase;
4+
use convert_case::{Case, Casing};
5+
use pluralizer::pluralize;
46
use proc_macro2::Span;
57
use quote::{format_ident, quote};
68
use syn::{parse_quote, parse_quote_spanned};
79

810
use super::{cst::generate_cst, helpers};
11+
fn get_category(group: &str) -> Vec<syn::Stmt> {
12+
let category = format_ident!("{}", pluralize(group, 2, false));
13+
parse_quote! {
14+
let file = self.file(py)?;
15+
let db = self.codebase.get(py).db();
16+
let category = file.#category(db);
17+
}
18+
}
19+
fn filter_symbols(
20+
nodes: &Vec<&codegen_sdk_ast_generator::Symbol>,
21+
group: &str,
22+
) -> Vec<codegen_sdk_ast_generator::Symbol> {
23+
nodes
24+
.iter()
25+
.filter(|node| node.category == pluralize(group, 2, false))
26+
.cloned()
27+
.cloned()
28+
.collect()
29+
}
30+
fn get_symbol_method_name(group: &str) -> syn::Ident {
31+
let symbol_name = codegen_sdk_ast_generator::get_symbol_name(group);
32+
syn::Ident::new(
33+
&format!("get_{}", symbol_name.to_string().to_case(Case::Snake)),
34+
Span::call_site(),
35+
)
36+
}
37+
// Generates the get_symbol method
38+
fn get_symbols_method(group: &str) -> Vec<syn::Stmt> {
39+
let mut output: Vec<syn::Stmt> = Vec::new();
40+
let symbol_name = codegen_sdk_ast_generator::get_symbol_name(group);
41+
let category = get_category(group);
42+
let symbols_method = codegen_sdk_ast_generator::get_symbols_method(&symbol_name);
43+
let get_symbols_method = get_symbol_method_name(group);
44+
output.extend::<Vec<syn::Stmt>>(parse_quote! {
45+
#[pyo3(signature = (name,optional=false))]
46+
pub fn #get_symbols_method(&self, py: Python<'_>, name: String, optional: bool) -> PyResult<Option<#symbol_name>> {
47+
#(#category)*
48+
let subcategory = category.#symbols_method(db);
49+
let res = subcategory.get(&name);
50+
if let Some(nodes) = res {
51+
if nodes.len() == 1 {
52+
Ok(Some(#symbol_name::new(py.clone(), nodes[0].fully_qualified_name(db), 0, &nodes[0], self.codebase.clone())))
53+
} else {
54+
Err(pyo3::exceptions::PyValueError::new_err(format!("Ambiguous symbol {} found {} possible matches", name, nodes.len())))
55+
}
56+
} else {
57+
if optional {
58+
Ok(None)
59+
} else {
60+
Err(pyo3::exceptions::PyValueError::new_err(format!("No symbol {} found", name)))
61+
}
62+
}
63+
}
64+
});
65+
output
66+
}
67+
fn symbols_method(group: &str) -> Vec<syn::Stmt> {
68+
let mut output: Vec<syn::Stmt> = Vec::new();
69+
let symbol_name = codegen_sdk_ast_generator::get_symbol_name(group);
70+
let category = get_category(group);
71+
let symbols_method = codegen_sdk_ast_generator::get_symbols_method(&symbol_name);
72+
output.extend::<Vec<syn::Stmt>>(parse_quote! {
73+
#[getter]
74+
pub fn #symbols_method(&self, py: Python<'_>) -> PyResult<Vec<#symbol_name>> {
75+
#(#category)*
76+
let subcategory = category.#symbols_method(db);
77+
let nodes = subcategory.values().map(|values| values.into_iter().enumerate().map(|(idx, node)| #symbol_name::new(py.clone(), node.fully_qualified_name(db), idx, node, self.codebase.clone()))).flatten().collect();
78+
Ok(nodes)
79+
}
80+
});
81+
output
82+
}
983
fn generate_file_struct(
1084
language: &Language,
1185
symbols: Vec<&Symbol>,
@@ -47,6 +121,13 @@ fn generate_file_struct(
47121
.filter(|symbol| symbol.category != symbol.subcategory)
48122
.map(|symbol| vec![symbol.py_file_getter(), symbol.py_file_get()])
49123
.flatten();
124+
let mut symbols_methods = Vec::new();
125+
for group in codegen_sdk_ast_generator::GROUPS {
126+
if filter_symbols(&symbols, group).len() > 0 {
127+
symbols_methods.extend(symbols_method(group));
128+
symbols_methods.extend(get_symbols_method(group));
129+
}
130+
}
50131
output.push(parse_quote! {
51132
#[pymethods]
52133
impl #struct_name {
@@ -70,6 +151,7 @@ fn generate_file_struct(
70151
Ok(self.content(py)?.to_string())
71152
}
72153
#(#methods)*
154+
#(#symbols_methods)*
73155
}
74156
});
75157
Ok(output)
@@ -158,6 +240,54 @@ fn generate_symbol_struct(
158240
});
159241
Ok(output)
160242
}
243+
// Generate an enum for all possible symbols. (IE Symbol => Class, Function, etc)
244+
fn generate_symbol_enum(
245+
language: &Language,
246+
symbols: Vec<&Symbol>,
247+
group: &str,
248+
) -> anyhow::Result<Vec<syn::Stmt>> {
249+
let symbols: Vec<_> = filter_symbols(&symbols, group)
250+
.iter()
251+
.map(|symbol| format_ident!("{}", symbol.name))
252+
.collect();
253+
if symbols.len() == 0 {
254+
return Ok(Vec::new());
255+
}
256+
let symbol_name = codegen_sdk_ast_generator::get_symbol_name(group);
257+
let span = Span::call_site();
258+
let mut output = Vec::new();
259+
let enum_name = format_ident!("{}", symbol_name);
260+
let package_name = syn::Ident::new(&language.package_name(), span);
261+
output.push(parse_quote_spanned! {
262+
span =>
263+
#[derive(IntoPyObject)]
264+
pub enum #enum_name {
265+
#(#symbols(#symbols),)*
266+
}
267+
});
268+
let original_name = quote! { codegen_sdk_analyzer::#package_name::ast::#symbol_name };
269+
let matchers: Vec<syn::Arm> = symbols
270+
.iter()
271+
.map(|symbol| {
272+
let symbol_name = format_ident!("{}", symbol);
273+
parse_quote_spanned! {
274+
span =>
275+
#original_name::#symbol_name(_) => Self::#symbol_name(#symbol_name::new(id, idx, codebase_arc)),
276+
}
277+
})
278+
.collect();
279+
output.push(parse_quote_spanned! {
280+
span =>
281+
impl #enum_name {
282+
pub fn new(py: Python<'_>, id: codegen_sdk_resolution::FullyQualifiedName, idx: usize, node: &#original_name<'_>, codebase_arc: Arc<GILProtected<codegen_sdk_analyzer::Codebase>>) -> Self {
283+
match node {
284+
#(#matchers)*
285+
}
286+
}
287+
}
288+
});
289+
Ok(output)
290+
}
161291
fn generate_module(
162292
language: &Language,
163293
symbols: Vec<syn::Ident>,
@@ -167,6 +297,7 @@ fn generate_module(
167297
let register_name = format_ident!("register_{}", language_name);
168298
let struct_name = format_ident!("{}", language.file_struct_name());
169299
let module_name = format!("codegen_sdk_pink.{}", language_name);
300+
170301
output.push(parse_quote! {
171302
pub fn #register_name(py: Python<'_>, parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
172303
let child_module = PyModule::new(parent_module.py(), #language_name)?;
@@ -194,6 +325,10 @@ pub(crate) fn generate_bindings(language: &Language) -> anyhow::Result<Vec<syn::
194325
let cst = generate_cst(language, &state)?;
195326
output.extend(cst);
196327
let mut symbol_idents = Vec::new();
328+
for group in codegen_sdk_ast_generator::GROUPS {
329+
let symbol_enum = generate_symbol_enum(language, symbols.values().collect(), group)?;
330+
output.extend(symbol_enum);
331+
}
197332
for (_, symbol) in symbols {
198333
let symbol_struct = generate_symbol_struct(language, &symbol)?;
199334
output.extend(symbol_struct);

0 commit comments

Comments
 (0)