diff --git a/codegen-sdk-ast-generator/src/generator.rs b/codegen-sdk-ast-generator/src/generator.rs index fc1514ab..06fd2b7b 100644 --- a/codegen-sdk-ast-generator/src/generator.rs +++ b/codegen-sdk-ast-generator/src/generator.rs @@ -6,7 +6,8 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { pub struct {language_struct_name}File {{ node: {language_name}::{root_node_name}, path: PathBuf, - pub visitor: QueryExecutor + pub references: References, + pub definitions: Definitions }} impl File for {language_struct_name}File {{ fn path(&self) -> &PathBuf {{ @@ -15,9 +16,11 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { fn parse(path: &PathBuf) -> Result {{ log::debug!(\"Parsing {language_name} file: {{}}\", path.display()); let ast = {language_name}::{language_struct_name}::parse_file(path)?; - let mut visitor = QueryExecutor::default(); - ast.drive(&mut visitor); - Ok({language_struct_name}File {{ node: ast, path: path.clone(), visitor }}) + let mut references = References::default(); + let mut definitions = Definitions::default(); + ast.drive(&mut definitions); + ast.drive(&mut references); + Ok({language_struct_name}File {{ node: ast, path: path.clone(), references, definitions }}) }} }} impl HasNode for {language_struct_name}File {{ diff --git a/codegen-sdk-ast-generator/src/lib.rs b/codegen-sdk-ast-generator/src/lib.rs index 2230b033..5c3ff0b5 100644 --- a/codegen-sdk-ast-generator/src/lib.rs +++ b/codegen-sdk-ast-generator/src/lib.rs @@ -3,7 +3,6 @@ use codegen_sdk_common::{generator::format_code, language::Language}; use quote::quote; -use crate::query::HasQuery; mod generator; mod query; mod visitor; @@ -15,17 +14,11 @@ pub fn generate_ast(language: &Language) -> anyhow::Result<()> { use codegen_sdk_cst::CSTLanguage; }; let mut ast = generator::generate_ast(language)?; - let visitor = visitor::generate_visitor( - &language - .definitions() - .values() - .into_iter() - .flatten() - .collect(), - language, - ); - ast = imports.to_string() + &ast + &visitor.to_string(); - ast = format_code(&ast).unwrap(); + let definitions = visitor::generate_visitor(language, "definition"); + let references = visitor::generate_visitor(language, "reference"); + ast = imports.to_string() + &ast + &definitions.to_string() + &references.to_string(); + ast = format_code(&ast) + .unwrap_or_else(|_| panic!("Failed to format ast for {}", language.name())); let out_dir = std::env::var("OUT_DIR")?; let out_file = format!("{}/{}.rs", out_dir, language.name()); std::fs::write(out_file, ast)?; diff --git a/codegen-sdk-ast-generator/src/query.rs b/codegen-sdk-ast-generator/src/query.rs index 6170d5e2..455319cf 100644 --- a/codegen-sdk-ast-generator/src/query.rs +++ b/codegen-sdk-ast-generator/src/query.rs @@ -1,6 +1,9 @@ use std::{collections::BTreeMap, sync::Arc}; -use codegen_sdk_common::{CSTNode, HasChildren, Language, naming::normalize_type_name}; +use codegen_sdk_common::{ + CSTNode, HasChildren, Language, + naming::{normalize_field_name, normalize_type_name}, +}; use codegen_sdk_cst::{CSTLanguage, ts_query}; use codegen_sdk_cst_generator::{Field, State}; use derive_more::Debug; @@ -191,7 +194,7 @@ impl<'a> Query<'a> { field.children().into_iter().skip(2).next().unwrap().into(); for name in &field.name { if let ts_query::FieldDefinitionName::Identifier(identifier) = name { - let name = identifier.source(); + let name = normalize_field_name(&identifier.source()); if let Some(field) = self.get_field_for_field_name(&name, struct_name) { let field_name = format_ident!("{}", name); let new_identifier = format_ident!("field"); @@ -372,12 +375,15 @@ impl<'a> Query<'a> { } } } - unhandled => todo!( - "Unhandled definition in language {}: {:#?}, {:#?}", - self.language.name(), - unhandled.kind(), - unhandled.source() - ), + unhandled => { + log::warn!( + "Unhandled definition in language {}: {:#?}, {:#?}", + self.language.name(), + unhandled.kind(), + unhandled.source() + ); + self.get_default_matcher() + } } } @@ -412,12 +418,6 @@ pub trait HasQuery { } queries } - fn definitions(&self) -> BTreeMap>> { - self.queries_with_prefix("definition") - } - // fn references(&self) -> BTreeMap>> { - // self.queries_with_prefix("reference") - // } } impl HasQuery for Language { fn queries(&self) -> BTreeMap> { diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__generate_visitor.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__generate_visitor.snap index 51a74814..dec37b8b 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__generate_visitor.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__generate_visitor.snap @@ -10,14 +10,14 @@ expression: "codegen_sdk_common::generator::format_code(&visitor.to_string()).un typescript::AbstractMethodSignature(enter), typescript::Module(enter) )] -pub struct QueryExecutor { +pub struct Definitions { pub classes: Vec, pub functions: Vec, pub interfaces: Vec, pub methods: Vec, pub modules: Vec, } -impl QueryExecutor { +impl Definitions { fn enter_abstract_class_declaration( &mut self, node: &codegen_sdk_cst::typescript::AbstractClassDeclaration, diff --git a/codegen-sdk-ast-generator/src/visitor.rs b/codegen-sdk-ast-generator/src/visitor.rs index 9b65c530..935aae11 100644 --- a/codegen-sdk-ast-generator/src/visitor.rs +++ b/codegen-sdk-ast-generator/src/visitor.rs @@ -8,8 +8,15 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote}; use super::query::Query; -pub fn generate_visitor(queries: &Vec<&Query>, language: &Language) -> TokenStream { - log::info!("Generating visitor for language: {}", language.name()); +use crate::query::HasQuery; +pub fn generate_visitor(language: &Language, name: &str) -> TokenStream { + log::info!( + "Generating visitor for language: {} for {}", + language.name(), + name + ); + let raw_queries = language.queries_with_prefix(&format!("{}", name)); + let queries: Vec<&Query> = raw_queries.values().flatten().collect(); let language_name = format_ident!("{}", language.name()); let mut names = Vec::new(); let mut types = Vec::new(); @@ -52,7 +59,7 @@ pub fn generate_visitor(queries: &Vec<&Query>, language: &Language) -> TokenStre } }); } - let name = format_ident!("QueryExecutor"); + let name = format_ident!("{}s", name.to_case(Case::Pascal)); quote! { #[derive(Visitor, Default, Debug, Clone)] #[visitor( @@ -72,14 +79,11 @@ mod tests { use codegen_sdk_common::language::typescript::Typescript; use super::*; - use crate::query::HasQuery; #[test_log::test] fn test_generate_visitor() { let language = &Typescript; - let queries = language.definitions(); - log::info!("Gathered {} queries", queries.len()); - let visitor = generate_visitor(&queries.values().into_iter().flatten().collect(), language); + let visitor = generate_visitor(language, "definition"); insta::assert_snapshot!( codegen_sdk_common::generator::format_code(&visitor.to_string()).unwrap() ); diff --git a/codegen-sdk-ast/tests/test_typescript.rs b/codegen-sdk-ast/tests/test_typescript.rs index 179d8220..1d113494 100644 --- a/codegen-sdk-ast/tests/test_typescript.rs +++ b/codegen-sdk-ast/tests/test_typescript.rs @@ -31,5 +31,5 @@ fn test_typescript_ast_interface() { let content = "interface Test { }"; let file_path = write_to_temp_file(content, &temp_dir); let file = TypescriptFile::parse(&file_path).unwrap(); - assert_eq!(file.visitor.interfaces.len(), 1); + assert_eq!(file.definitions.interfaces.len(), 1); } diff --git a/codegen-sdk-common/src/naming.rs b/codegen-sdk-common/src/naming.rs index 1e31138d..b04e194c 100644 --- a/codegen-sdk-common/src/naming.rs +++ b/codegen-sdk-common/src/naming.rs @@ -42,6 +42,9 @@ pub fn normalize_field_name(field_name: &str) -> String { if field_name == "type" { return "r#type".to_string(); } + if field_name == "macro" { + return "r#macro".to_string(); + } field_name.to_string() } fn get_char_mapping(c: char) -> String {