Skip to content

Commit a8c96b6

Browse files
eacodegenbagel897
andauthored
Generate references queries (#28)
* also generate references * Fix warning * fix test * fix tests --------- Co-authored-by: bagel897 <[email protected]>
1 parent 6e9e535 commit a8c96b6

File tree

7 files changed

+43
-40
lines changed

7 files changed

+43
-40
lines changed

codegen-sdk-ast-generator/src/generator.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ pub fn generate_ast(language: &Language) -> anyhow::Result<String> {
66
pub struct {language_struct_name}File {{
77
node: {language_name}::{root_node_name},
88
path: PathBuf,
9-
pub visitor: QueryExecutor
9+
pub references: References,
10+
pub definitions: Definitions
1011
}}
1112
impl File for {language_struct_name}File {{
1213
fn path(&self) -> &PathBuf {{
@@ -15,9 +16,11 @@ pub fn generate_ast(language: &Language) -> anyhow::Result<String> {
1516
fn parse(path: &PathBuf) -> Result<Self, ParseError> {{
1617
log::debug!(\"Parsing {language_name} file: {{}}\", path.display());
1718
let ast = {language_name}::{language_struct_name}::parse_file(path)?;
18-
let mut visitor = QueryExecutor::default();
19-
ast.drive(&mut visitor);
20-
Ok({language_struct_name}File {{ node: ast, path: path.clone(), visitor }})
19+
let mut references = References::default();
20+
let mut definitions = Definitions::default();
21+
ast.drive(&mut definitions);
22+
ast.drive(&mut references);
23+
Ok({language_struct_name}File {{ node: ast, path: path.clone(), references, definitions }})
2124
}}
2225
}}
2326
impl HasNode for {language_struct_name}File {{

codegen-sdk-ast-generator/src/lib.rs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
use codegen_sdk_common::{generator::format_code, language::Language};
44
use quote::quote;
55

6-
use crate::query::HasQuery;
76
mod generator;
87
mod query;
98
mod visitor;
@@ -15,17 +14,11 @@ pub fn generate_ast(language: &Language) -> anyhow::Result<()> {
1514
use codegen_sdk_cst::CSTLanguage;
1615
};
1716
let mut ast = generator::generate_ast(language)?;
18-
let visitor = visitor::generate_visitor(
19-
&language
20-
.definitions()
21-
.values()
22-
.into_iter()
23-
.flatten()
24-
.collect(),
25-
language,
26-
);
27-
ast = imports.to_string() + &ast + &visitor.to_string();
28-
ast = format_code(&ast).unwrap();
17+
let definitions = visitor::generate_visitor(language, "definition");
18+
let references = visitor::generate_visitor(language, "reference");
19+
ast = imports.to_string() + &ast + &definitions.to_string() + &references.to_string();
20+
ast = format_code(&ast)
21+
.unwrap_or_else(|_| panic!("Failed to format ast for {}", language.name()));
2922
let out_dir = std::env::var("OUT_DIR")?;
3023
let out_file = format!("{}/{}.rs", out_dir, language.name());
3124
std::fs::write(out_file, ast)?;

codegen-sdk-ast-generator/src/query.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::{collections::BTreeMap, sync::Arc};
22

3-
use codegen_sdk_common::{CSTNode, HasChildren, Language, naming::normalize_type_name};
3+
use codegen_sdk_common::{
4+
CSTNode, HasChildren, Language,
5+
naming::{normalize_field_name, normalize_type_name},
6+
};
47
use codegen_sdk_cst::{CSTLanguage, ts_query};
58
use codegen_sdk_cst_generator::{Field, State};
69
use derive_more::Debug;
@@ -191,7 +194,7 @@ impl<'a> Query<'a> {
191194
field.children().into_iter().skip(2).next().unwrap().into();
192195
for name in &field.name {
193196
if let ts_query::FieldDefinitionName::Identifier(identifier) = name {
194-
let name = identifier.source();
197+
let name = normalize_field_name(&identifier.source());
195198
if let Some(field) = self.get_field_for_field_name(&name, struct_name) {
196199
let field_name = format_ident!("{}", name);
197200
let new_identifier = format_ident!("field");
@@ -372,12 +375,15 @@ impl<'a> Query<'a> {
372375
}
373376
}
374377
}
375-
unhandled => todo!(
376-
"Unhandled definition in language {}: {:#?}, {:#?}",
377-
self.language.name(),
378-
unhandled.kind(),
379-
unhandled.source()
380-
),
378+
unhandled => {
379+
log::warn!(
380+
"Unhandled definition in language {}: {:#?}, {:#?}",
381+
self.language.name(),
382+
unhandled.kind(),
383+
unhandled.source()
384+
);
385+
self.get_default_matcher()
386+
}
381387
}
382388
}
383389

@@ -412,12 +418,6 @@ pub trait HasQuery {
412418
}
413419
queries
414420
}
415-
fn definitions(&self) -> BTreeMap<String, Vec<Query<'_>>> {
416-
self.queries_with_prefix("definition")
417-
}
418-
// fn references(&self) -> BTreeMap<String, Vec<Query<'_>>> {
419-
// self.queries_with_prefix("reference")
420-
// }
421421
}
422422
impl HasQuery for Language {
423423
fn queries(&self) -> BTreeMap<String, Query<'_>> {

codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__generate_visitor.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ expression: "codegen_sdk_common::generator::format_code(&visitor.to_string()).un
1010
typescript::AbstractMethodSignature(enter),
1111
typescript::Module(enter)
1212
)]
13-
pub struct QueryExecutor {
13+
pub struct Definitions {
1414
pub classes: Vec<typescript::AbstractClassDeclaration>,
1515
pub functions: Vec<typescript::FunctionSignature>,
1616
pub interfaces: Vec<typescript::InterfaceDeclaration>,
1717
pub methods: Vec<typescript::AbstractMethodSignature>,
1818
pub modules: Vec<typescript::Module>,
1919
}
20-
impl QueryExecutor {
20+
impl Definitions {
2121
fn enter_abstract_class_declaration(
2222
&mut self,
2323
node: &codegen_sdk_cst::typescript::AbstractClassDeclaration,

codegen-sdk-ast-generator/src/visitor.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@ use proc_macro2::TokenStream;
88
use quote::{format_ident, quote};
99

1010
use super::query::Query;
11-
pub fn generate_visitor(queries: &Vec<&Query>, language: &Language) -> TokenStream {
12-
log::info!("Generating visitor for language: {}", language.name());
11+
use crate::query::HasQuery;
12+
pub fn generate_visitor(language: &Language, name: &str) -> TokenStream {
13+
log::info!(
14+
"Generating visitor for language: {} for {}",
15+
language.name(),
16+
name
17+
);
18+
let raw_queries = language.queries_with_prefix(&format!("{}", name));
19+
let queries: Vec<&Query> = raw_queries.values().flatten().collect();
1320
let language_name = format_ident!("{}", language.name());
1421
let mut names = Vec::new();
1522
let mut types = Vec::new();
@@ -52,7 +59,7 @@ pub fn generate_visitor(queries: &Vec<&Query>, language: &Language) -> TokenStre
5259
}
5360
});
5461
}
55-
let name = format_ident!("QueryExecutor");
62+
let name = format_ident!("{}s", name.to_case(Case::Pascal));
5663
quote! {
5764
#[derive(Visitor, Default, Debug, Clone)]
5865
#[visitor(
@@ -72,14 +79,11 @@ mod tests {
7279
use codegen_sdk_common::language::typescript::Typescript;
7380

7481
use super::*;
75-
use crate::query::HasQuery;
7682

7783
#[test_log::test]
7884
fn test_generate_visitor() {
7985
let language = &Typescript;
80-
let queries = language.definitions();
81-
log::info!("Gathered {} queries", queries.len());
82-
let visitor = generate_visitor(&queries.values().into_iter().flatten().collect(), language);
86+
let visitor = generate_visitor(language, "definition");
8387
insta::assert_snapshot!(
8488
codegen_sdk_common::generator::format_code(&visitor.to_string()).unwrap()
8589
);

codegen-sdk-ast/tests/test_typescript.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ fn test_typescript_ast_interface() {
3131
let content = "interface Test { }";
3232
let file_path = write_to_temp_file(content, &temp_dir);
3333
let file = TypescriptFile::parse(&file_path).unwrap();
34-
assert_eq!(file.visitor.interfaces.len(), 1);
34+
assert_eq!(file.definitions.interfaces.len(), 1);
3535
}

codegen-sdk-common/src/naming.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ pub fn normalize_field_name(field_name: &str) -> String {
4242
if field_name == "type" {
4343
return "r#type".to_string();
4444
}
45+
if field_name == "macro" {
46+
return "r#macro".to_string();
47+
}
4548
field_name.to_string()
4649
}
4750
fn get_char_mapping(c: char) -> String {

0 commit comments

Comments
 (0)