Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ jobs:
- name: Generate artifact attestation
uses: actions/attest-build-provenance@v2
with:
subject-path: 'bindings/python/wheels-*/*'
subject-path: 'wheels-*/*'
- name: Publish to PyPI
if: ${{ startsWith(github.ref, 'refs/tags/') }}
uses: PyO3/maturin-action@v1
Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ indextree = { version = "4.7.3", features = ["std"], default-features = false }
thiserror = "2.0.11"
indexmap = "2"
smallvec = "1.11.0"
pluralizer = "0.5.0"


[profile.dev]
Expand Down
2 changes: 2 additions & 0 deletions codegen-bindings-generator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ syn = { workspace = true }
anyhow = { workspace = true }
proc-macro2 = { workspace = true }
codegen-sdk-cst = { workspace = true }
pluralizer = "0.5.0"
convert_case = { workspace = true }
[dev-dependencies]
test-log = { workspace = true }
rstest = { workspace = true }
Expand Down
120 changes: 120 additions & 0 deletions codegen-bindings-generator/src/python/generator.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,74 @@
use codegen_sdk_ast_generator::{HasQuery, Symbol};
use codegen_sdk_common::Language;
use codegen_sdk_cst::CSTDatabase;
use convert_case::{Case, Casing};
use pluralizer::pluralize;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse_quote, parse_quote_spanned};

use super::{cst::generate_cst, helpers};
fn get_category(group: &str) -> Vec<syn::Stmt> {
let category = format_ident!("{}", pluralize(group, 2, false));
parse_quote! {
let file = self.file(py)?;
let db = self.codebase.get(py).db();
let category = file.#category(db);
}
}
fn get_symbol_method_name(group: &str) -> syn::Ident {
let symbol_name = codegen_sdk_ast_generator::get_symbol_name(group);
syn::Ident::new(
&format!("get_{}", symbol_name.to_string().to_case(Case::Snake)),
Span::call_site(),
)
}
// Generates the get_symbol method
fn get_symbols_method(group: &str) -> Vec<syn::Stmt> {
let mut output: Vec<syn::Stmt> = Vec::new();
let symbol_name = codegen_sdk_ast_generator::get_symbol_name(group);
let category = get_category(group);
let symbols_method = codegen_sdk_ast_generator::get_symbols_method(&symbol_name);
let get_symbols_method = get_symbol_method_name(group);
output.extend::<Vec<syn::Stmt>>(parse_quote! {
#[pyo3(signature = (name,optional=false))]
pub fn #get_symbols_method(&self, py: Python<'_>, name: String, optional: bool) -> PyResult<Option<#symbol_name>> {
#(#category)*
let subcategory = category.#symbols_method(db);
let res = subcategory.get(&name);
if let Some(nodes) = res {
if nodes.len() == 1 {
Ok(Some(#symbol_name::new(py.clone(), nodes[0].fully_qualified_name(db), 0, &nodes[0], self.codebase.clone())))
} else {
Err(pyo3::exceptions::PyValueError::new_err(format!("Ambiguous symbol {} found {} possible matches", name, nodes.len())))
}
} else {
if optional {
Ok(None)
} else {
Err(pyo3::exceptions::PyValueError::new_err(format!("No symbol {} found", name)))
}
}
}
});
output
}
fn symbols_method(group: &str) -> Vec<syn::Stmt> {
let mut output: Vec<syn::Stmt> = Vec::new();
let symbol_name = codegen_sdk_ast_generator::get_symbol_name(group);
let category = get_category(group);
let symbols_method = codegen_sdk_ast_generator::get_symbols_method(&symbol_name);
output.extend::<Vec<syn::Stmt>>(parse_quote! {
#[getter]
pub fn #symbols_method(&self, py: Python<'_>) -> PyResult<Vec<#symbol_name>> {
#(#category)*
let subcategory = category.#symbols_method(db);
let nodes = subcategory.values().map(|values| values.into_iter().enumerate().map(|(idx, node)| #symbol_name::new(py.clone(), node.fully_qualified_name(db), idx, node, self.codebase.clone()))).flatten().collect();
Ok(nodes)
}
});
output
}
fn generate_file_struct(
language: &Language,
symbols: Vec<&Symbol>,
Expand Down Expand Up @@ -47,6 +110,11 @@ fn generate_file_struct(
.filter(|symbol| symbol.category != symbol.subcategory)
.map(|symbol| vec![symbol.py_file_getter(), symbol.py_file_get()])
.flatten();
let mut symbols_methods = Vec::new();
for group in codegen_sdk_ast_generator::GROUPS {
symbols_methods.extend(symbols_method(group));
symbols_methods.extend(get_symbols_method(group));
}
output.push(parse_quote! {
#[pymethods]
impl #struct_name {
Expand All @@ -70,6 +138,7 @@ fn generate_file_struct(
Ok(self.content(py)?.to_string())
}
#(#methods)*
#(#symbols_methods)*
}
});
Ok(output)
Expand Down Expand Up @@ -158,6 +227,52 @@ fn generate_symbol_struct(
});
Ok(output)
}
// Generate an enum for all possible symbols. (IE Symbol => Class, Function, etc)
fn generate_symbol_enum(
language: &Language,
symbols: Vec<&Symbol>,
group: &str,
) -> anyhow::Result<Vec<syn::Stmt>> {
let symbols: Vec<_> = symbols
.iter()
.filter(|symbol| symbol.category == pluralize(group, 2, false))
.map(|symbol| format_ident!("{}", symbol.name))
.collect();
let symbol_name = codegen_sdk_ast_generator::get_symbol_name(group);
let span = Span::call_site();
let mut output = Vec::new();
let enum_name = format_ident!("{}", symbol_name);
let package_name = syn::Ident::new(&language.package_name(), span);
output.push(parse_quote_spanned! {
span =>
#[derive(IntoPyObject)]
pub enum #enum_name {
#(#symbols(#symbols),)*
}
});
let original_name = quote! { codegen_sdk_analyzer::#package_name::ast::#symbol_name };
let matchers: Vec<syn::Arm> = symbols
.iter()
.map(|symbol| {
let symbol_name = format_ident!("{}", symbol);
parse_quote_spanned! {
span =>
#original_name::#symbol_name(_) => Self::#symbol_name(#symbol_name::new(id, idx, codebase_arc)),
}
})
.collect();
output.push(parse_quote_spanned! {
span =>
impl #enum_name {
pub fn new(py: Python<'_>, id: codegen_sdk_resolution::FullyQualifiedName, idx: usize, node: &#original_name<'_>, codebase_arc: Arc<GILProtected<codegen_sdk_analyzer::Codebase>>) -> Self {
match node {
#(#matchers)*
}
}
}
});
Ok(output)
}
fn generate_module(
language: &Language,
symbols: Vec<syn::Ident>,
Expand Down Expand Up @@ -194,6 +309,11 @@ pub(crate) fn generate_bindings(language: &Language) -> anyhow::Result<Vec<syn::
let cst = generate_cst(language, &state)?;
output.extend(cst);
let mut symbol_idents = Vec::new();
for group in codegen_sdk_ast_generator::GROUPS {
let symbol_enum = generate_symbol_enum(language, symbols.values().collect(), group)?;
output.extend(symbol_enum);
}

for (_, symbol) in symbols {
let symbol_struct = generate_symbol_struct(language, &symbol)?;
output.extend(symbol_struct);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50576,6 +50576,94 @@ impl Call {
)
}
}
#[pyclass(module = "codegen_sdk_pink.python")]
pub enum Symbol {
Class(codegen_sdk_analyzer::codegen_sdk_python::ast::Class),
Constant(codegen_sdk_analyzer::codegen_sdk_python::ast::Constant),
Function(codegen_sdk_analyzer::codegen_sdk_python::ast::Function),
Import(codegen_sdk_analyzer::codegen_sdk_python::ast::Import),
Call(codegen_sdk_analyzer::codegen_sdk_python::ast::Call),
}
impl Symbol {
pub fn new(
py: Python<'_>,
id: codegen_sdk_common::FullyQualifiedName,
codebase_arc: Arc<GILProtected<codegen_sdk_analyzer::Codebase>>,
) -> PyResult<Self> {
let codebase = codebase_arc.get(py);
let path = id.file(codebase.db());
let file = codebase.get_file_for_id(path);
let file = match file {
Some(codegen_sdk_analyzer::ParsedFile::Python(py)) => py,
_ => {
return Err(
pyo3::exceptions::PyValueError::new_err(
format!(
"File not found for path: {}", path.path(codebase.db())
.display()
),
),
);
}
};
let node = file.tree(codebase.db()).get(id.id(codebase.db()));
if let Some(node) = node {
match node.as_ref().try_into().unwrap() {
Class => Ok(Self::Class(Class::new(py, id, codebase_arc))),
Constant => Ok(Self::Constant(Constant::new(py, id, codebase_arc))),
Function => Ok(Self::Function(Function::new(py, id, codebase_arc))),
Import => Ok(Self::Import(Import::new(py, id, codebase_arc))),
Call => Ok(Self::Call(Call::new(py, id, codebase_arc))),
}
} else {
Err(pyo3::exceptions::PyValueError::new_err("Node not found"))
}
}
}
#[pyclass(module = "codegen_sdk_pink.python")]
pub enum Reference {
Class(codegen_sdk_analyzer::codegen_sdk_python::ast::Class),
Constant(codegen_sdk_analyzer::codegen_sdk_python::ast::Constant),
Function(codegen_sdk_analyzer::codegen_sdk_python::ast::Function),
Import(codegen_sdk_analyzer::codegen_sdk_python::ast::Import),
Call(codegen_sdk_analyzer::codegen_sdk_python::ast::Call),
}
impl Reference {
pub fn new(
py: Python<'_>,
id: codegen_sdk_common::FullyQualifiedName,
codebase_arc: Arc<GILProtected<codegen_sdk_analyzer::Codebase>>,
) -> PyResult<Self> {
let codebase = codebase_arc.get(py);
let path = id.file(codebase.db());
let file = codebase.get_file_for_id(path);
let file = match file {
Some(codegen_sdk_analyzer::ParsedFile::Python(py)) => py,
_ => {
return Err(
pyo3::exceptions::PyValueError::new_err(
format!(
"File not found for path: {}", path.path(codebase.db())
.display()
),
),
);
}
};
let node = file.tree(codebase.db()).get(id.id(codebase.db()));
if let Some(node) = node {
match node.as_ref().try_into().unwrap() {
Class => Ok(Self::Class(Class::new(py, id, codebase_arc))),
Constant => Ok(Self::Constant(Constant::new(py, id, codebase_arc))),
Function => Ok(Self::Function(Function::new(py, id, codebase_arc))),
Import => Ok(Self::Import(Import::new(py, id, codebase_arc))),
Call => Ok(Self::Call(Call::new(py, id, codebase_arc))),
}
} else {
Err(pyo3::exceptions::PyValueError::new_err("Node not found"))
}
}
}
pub fn register_python(
py: Python<'_>,
parent_module: &Bound<'_, PyModule>,
Expand Down
Loading
Loading