diff --git a/Cargo.lock b/Cargo.lock index 09021f2..485f765 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,6 +85,16 @@ dependencies = [ "backtrace", ] +[[package]] +name = "assert-tokenstreams-eq" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5aa2c884cf12de50cf3edd0b1e6f41056e5b32707028e1c2c9e9d4868ab5e9e" +dependencies = [ + "pretty_assertions", + "thiserror 1.0.69", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -168,6 +178,12 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.10.0" @@ -237,6 +253,7 @@ version = "0.1.0" dependencies = [ "codegen-sdk-ast", "salsa", + "test-log", ] [[package]] @@ -266,14 +283,16 @@ dependencies = [ "base64", "buildid", "bytes", - "convert_case", + "convert_case 0.7.1", + "enum_delegate", "lazy_static", "phf", "rkyv", "serde", "serde_json", "sha2", - "thiserror", + "test-log", + "thiserror 2.0.11", "tree-sitter", "tree-sitter-java", "tree-sitter-javascript", @@ -301,6 +320,7 @@ dependencies = [ "rayon", "rkyv", "sysinfo", + "test-log", ] [[package]] @@ -311,13 +331,17 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst-generator", "codegen-sdk-macros", - "convert_case", + "convert_case 0.7.1", + "derive-visitor", "derive_more", + "enum_delegate", "env_logger", "log", "rayon", "rkyv", + "subenum", "tempfile", + "test-log", "tree-sitter", ] @@ -326,14 +350,16 @@ name = "codegen-sdk-cst-generator" version = "0.1.0" dependencies = [ "anyhow", + "assert-tokenstreams-eq", "codegen-sdk-common", - "convert_case", + "convert_case 0.7.1", "log", "prettyplease", "proc-macro2", "quote", "syn 2.0.98", "tempfile", + "test-log", "tree-sitter", ] @@ -350,6 +376,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "convert_case" version = "0.7.1" @@ -440,6 +472,28 @@ dependencies = [ "typenum", ] +[[package]] +name = "derive-visitor" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d47165df83b9707cbada3216607a5d66125b6a66906de0bc1216c0669767ca9e" +dependencies = [ + "derive-visitor-macros", +] + +[[package]] +name = "derive-visitor-macros" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "427b39a85fecafea16b1a5f3f50437151022e35eb4fe038107f08adbf7f8def6" +dependencies = [ + "convert_case 0.4.0", + "itertools", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_more" version = "2.0.1" @@ -461,6 +515,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "digest" version = "0.10.7" @@ -477,6 +537,30 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "enum_delegate" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8ea75f31022cba043afe037940d73684327e915f88f62478e778c3de914cd0a" +dependencies = [ + "enum_delegate_lib", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "enum_delegate_lib" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e1f6c3800b304a6be0012039e2a45a322a093539c45ab818d9e6895a39c90fe" +dependencies = [ + "proc-macro2", + "quote", + "rand", + "syn 1.0.109", +] + [[package]] name = "env_filter" version = "0.1.3" @@ -532,6 +616,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", +] + [[package]] name = "getrandom" version = "0.3.1" @@ -540,7 +635,7 @@ checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.13.3+wasi-0.2.2", "windows-targets", ] @@ -577,6 +672,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -624,6 +725,15 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -673,6 +783,15 @@ version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "memchr" version = "2.7.4" @@ -717,6 +836,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "object" version = "0.36.7" @@ -738,6 +867,12 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.11.2" @@ -805,12 +940,37 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + [[package]] name = "pkg-config" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "pretty_assertions" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" +dependencies = [ + "diff", + "yansi", +] + [[package]] name = "prettyplease" version = "0.2.29" @@ -874,6 +1034,18 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", "rand_core", ] @@ -882,6 +1054,9 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.15", +] [[package]] name = "rayon" @@ -920,8 +1095,17 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -932,9 +1116,15 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -1090,6 +1280,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1126,6 +1325,18 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subenum" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f5d5dfb8556dd04017db5e318bbeac8ab2b0c67b76bf197bfb79e9b29f18ecf" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "syn" version = "1.0.109" @@ -1170,19 +1381,61 @@ checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys", ] +[[package]] +name = "test-log" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f46083d221181166e5b6f6b1e5f1d499f3a76888826e6cb1d057554157cd0f" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "888d0c3c6db53c0fdab160d2ed5e12ba745383d3e85813f2ea0f2b1475ab553f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + [[package]] name = "thiserror" version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.11", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", ] [[package]] @@ -1196,6 +1449,16 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "tinyvec" version = "1.8.1" @@ -1211,6 +1474,54 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + [[package]] name = "tree-sitter" version = "0.25.1" @@ -1219,7 +1530,7 @@ checksum = "1a802c93485fb6781d27e27cb5927f6b00ff8d26b56c70af87267be7e99def97" dependencies = [ "cc", "regex", - "regex-syntax", + "regex-syntax 0.8.5", "serde_json", "streaming-iterator", "tree-sitter-language", @@ -1326,12 +1637,24 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "version_check" version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "wasi" version = "0.13.3+wasi-0.2.2" @@ -1504,6 +1827,33 @@ version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "213b7324336b53d2414b2db8537e56544d981803139155afa84f76eeebb7a546" +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", +] + [[package]] name = "zstd" version = "0.13.2" diff --git a/Cargo.toml b/Cargo.toml index 797438c..daf1bf6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,8 @@ log = { workspace = true } rayon = { workspace = true} sysinfo = "0.33.1" rkyv.workspace = true - +[dev-dependencies] +test-log = { workspace = true } [workspace] members = [ "codegen-sdk-analyzer", @@ -40,3 +41,5 @@ serde = { version = "1.0.217", features = ["derive"] } serde_json = "1.0.138" anyhow = { version = "1.0.95", features = ["backtrace"] } rkyv = { version = "0.8.10", features = ["bytes-1","pointer_width_64"] } +test-log = "0.2.17" +enum_delegate = "0.2.0" diff --git a/codegen-sdk-analyzer/Cargo.toml b/codegen-sdk-analyzer/Cargo.toml index 58a3475..3ca8dfe 100644 --- a/codegen-sdk-analyzer/Cargo.toml +++ b/codegen-sdk-analyzer/Cargo.toml @@ -6,3 +6,5 @@ edition = "2024" [dependencies] salsa = "0.16.1" codegen-sdk-ast = { path = "../codegen-sdk-ast" } +[dev-dependencies] +test-log = { workspace = true } diff --git a/codegen-sdk-analyzer/src/lib.rs b/codegen-sdk-analyzer/src/lib.rs index b93cf3f..8f1f58f 100644 --- a/codegen-sdk-analyzer/src/lib.rs +++ b/codegen-sdk-analyzer/src/lib.rs @@ -6,7 +6,7 @@ pub fn add(left: u64, right: u64) -> u64 { mod tests { use super::*; - #[test] + #[test_log::test] fn it_works() { let result = add(2, 2); assert_eq!(result, 4); diff --git a/codegen-sdk-ast-generator/src/generator.rs b/codegen-sdk-ast-generator/src/generator.rs index e37247e..910b726 100644 --- a/codegen-sdk-ast-generator/src/generator.rs +++ b/codegen-sdk-ast-generator/src/generator.rs @@ -2,7 +2,7 @@ use codegen_sdk_common::language::Language; pub fn generate_ast(language: &Language) -> anyhow::Result { let content = format!( " - #[derive(Debug)] + #[derive(Debug, Clone)] pub struct {language_struct_name}File {{ node: {language_name}::{root_node_name}, path: PathBuf diff --git a/codegen-sdk-ast/src/lib.rs b/codegen-sdk-ast/src/lib.rs index b5ae5af..a1a6ead 100644 --- a/codegen-sdk-ast/src/lib.rs +++ b/codegen-sdk-ast/src/lib.rs @@ -1,3 +1,4 @@ +#![recursion_limit = "512"] use codegen_sdk_common::{File, HasNode}; pub use codegen_sdk_cst::*; pub trait Named { diff --git a/codegen-sdk-common/Cargo.toml b/codegen-sdk-common/Cargo.toml index 65a8e32..2d1c939 100644 --- a/codegen-sdk-common/Cargo.toml +++ b/codegen-sdk-common/Cargo.toml @@ -26,6 +26,9 @@ base64 = "0.22.1" buildid = "1.0.3" sha2 = "0.10.8" zstd = { version = "0.13.2", features = ["zstdmt"] } +enum_delegate = { workspace = true } +[dev-dependencies] +test-log = { workspace = true } [features] python = ["dep:tree-sitter-python"] json = ["dep:tree-sitter-json"] diff --git a/codegen-sdk-common/src/language.rs b/codegen-sdk-common/src/language.rs index 2020633..b1b69d7 100644 --- a/codegen-sdk-common/src/language.rs +++ b/codegen-sdk-common/src/language.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroU16; + use convert_case::{Case, Casing}; use tree_sitter::Parser; @@ -5,6 +7,7 @@ use crate::{ errors::ParseError, parser::{Node, parse_node_types}, }; +#[derive(Debug)] pub struct Language { pub name: &'static str, pub struct_name: &'static str, @@ -50,6 +53,18 @@ impl Language { .type_name .to_case(Case::Pascal) } + pub fn kind_id(&self, name: &str, named: bool) -> u16 { + self.tree_sitter_language.id_for_node_kind(name, named) + } + pub fn kind_name(&self, id: u16) -> Option<&str> { + self.tree_sitter_language.node_kind_for_id(id) + } + pub fn field_id(&self, name: &str) -> Option { + self.tree_sitter_language.field_id_for_name(name) + } + pub fn field_name(&self, id: u16) -> Option<&str> { + self.tree_sitter_language.field_name_for_id(id) + } } #[cfg(feature = "java")] pub mod java; diff --git a/codegen-sdk-common/src/language/go.rs b/codegen-sdk-common/src/language/go.rs new file mode 100644 index 0000000..65feae0 --- /dev/null +++ b/codegen-sdk-common/src/language/go.rs @@ -0,0 +1,12 @@ +use super::Language; +lazy_static! { + pub static ref Go: Language = Language::new( + "go", + "Go", + tree_sitter_go::NODE_TYPES, + &["go"], + tree_sitter_go::LANGUAGE.into(), + tree_sitter_go::TAGS_QUERY, + ) + .unwrap(); +} diff --git a/codegen-sdk-common/src/language/ruby.rs b/codegen-sdk-common/src/language/ruby.rs new file mode 100644 index 0000000..a35153a --- /dev/null +++ b/codegen-sdk-common/src/language/ruby.rs @@ -0,0 +1,12 @@ +use super::Language; +lazy_static! { + pub static ref Ruby: Language = Language::new( + "ruby", + "Ruby", + tree_sitter_ruby::NODE_TYPES, + &["rb"], + tree_sitter_ruby::LANGUAGE.into(), + tree_sitter_ruby::TAGS_QUERY, + ) + .unwrap(); +} diff --git a/codegen-sdk-common/src/language/rust.rs b/codegen-sdk-common/src/language/rust.rs new file mode 100644 index 0000000..9a4c289 --- /dev/null +++ b/codegen-sdk-common/src/language/rust.rs @@ -0,0 +1,12 @@ +use super::Language; +lazy_static! { + pub static ref Rust: Language = Language::new( + "rust", + "Rust", + tree_sitter_rust::NODE_TYPES, + &["rs"], + tree_sitter_rust::LANGUAGE.into(), + tree_sitter_rust::TAGS_QUERY, + ) + .unwrap(); +} diff --git a/codegen-sdk-common/src/lib.rs b/codegen-sdk-common/src/lib.rs index 3fa66c2..f16f9aa 100644 --- a/codegen-sdk-common/src/lib.rs +++ b/codegen-sdk-common/src/lib.rs @@ -13,6 +13,6 @@ pub mod parser; #[macro_use] extern crate lazy_static; pub mod naming; -mod point; -pub use point::Point; pub mod serialize; +pub mod tree; +pub use tree::{Point, Range}; diff --git a/codegen-sdk-common/src/naming.rs b/codegen-sdk-common/src/naming.rs index 4bd8c35..487dde2 100644 --- a/codegen-sdk-common/src/naming.rs +++ b/codegen-sdk-common/src/naming.rs @@ -65,7 +65,7 @@ pub fn normalize_string(string: &str) -> String { let escaped = String::from_iter(string.chars().map(escape_char)); escaped } -pub fn normalize_type_name(type_name: &str) -> String { +pub fn normalize_type_name(type_name: &str, named: bool) -> String { let mut cased = type_name.to_string(); if type_name.chars().any(|c| c.is_ascii_alphabetic()) { cased = cased.to_case(Case::Pascal); @@ -78,5 +78,9 @@ pub fn normalize_type_name(type_name: &str) -> String { "Type name '{}' contains invalid characters", type_name ); - escaped + if named { + escaped + } else { + format!("Anonymous{}", escaped) + } } diff --git a/codegen-sdk-common/src/parser.rs b/codegen-sdk-common/src/parser.rs index a6e2633..e68bd06 100644 --- a/codegen-sdk-common/src/parser.rs +++ b/codegen-sdk-common/src/parser.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct Node { #[serde(rename = "type")] pub type_name: String, @@ -15,13 +15,13 @@ pub struct Node { pub children: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct Fields { #[serde(flatten)] pub fields: std::collections::HashMap, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct FieldDefinition { pub multiple: bool, pub required: bool, @@ -29,14 +29,14 @@ pub struct FieldDefinition { pub types: Vec, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)] pub struct TypeDefinition { #[serde(rename = "type")] pub type_name: String, pub named: bool, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] pub struct Children { pub multiple: bool, pub required: bool, @@ -52,7 +52,7 @@ pub fn parse_node_types(node_types: &str) -> anyhow::Result> { mod tests { use super::*; use crate::language::python::Python; - #[test] + #[test_log::test] fn test_parse_node_types() { let cst = parse_node_types(Python.node_types).unwrap(); assert!(!cst.is_empty()); diff --git a/codegen-sdk-common/src/traits.rs b/codegen-sdk-common/src/traits.rs index 310e581..1a907b6 100644 --- a/codegen-sdk-common/src/traits.rs +++ b/codegen-sdk-common/src/traits.rs @@ -3,29 +3,142 @@ use std::{fmt::Debug, sync::Arc}; use bytes::Bytes; use tree_sitter::{self}; -use crate::{errors::ParseError, point::Point}; +use crate::{Point, errors::ParseError, tree::Range}; pub trait FromNode: Sized { fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result; } -pub trait CSTNode: Send + Debug { +#[enum_delegate::register] +pub trait CSTNode { + /// Returns the byte offset where the node starts fn start_byte(&self) -> usize; + + /// Returns the byte offset where the node ends fn end_byte(&self) -> usize; + + /// Returns the position where the node starts fn start_position(&self) -> Point; + + /// Returns the position where the node ends fn end_position(&self) -> Point; + + /// Returns the range of positions that this node spans + fn range(&self) -> Range { + Range::new(self.start_position(), self.end_position()) + } + + /// Returns the source text buffer for this node fn buffer(&self) -> &Bytes; + + /// Returns the raw text content of this node as bytes fn text(&self) -> Bytes { Bytes::copy_from_slice(&self.buffer()[self.start_byte()..self.end_byte()]) } - fn source(&self) -> String { + + /// Returns the text content of this node as a String + fn source(&self) -> std::string::String { String::from_utf8(self.text().to_vec()).unwrap() } + /// Returns the node's type as a numerical id fn kind_id(&self) -> u16; + + /// Returns the node's type as a string + fn kind(&self) -> &str; + + /// Returns true if this node is named, false if it is anonymous + fn is_named(&self) -> bool; + + /// Returns true if this node represents a syntax error + fn is_error(&self) -> bool { + unimplemented!("is_error not implemented") + } + + /// Returns true if this node is *missing* from the source code + fn is_missing(&self) -> bool { + unimplemented!("is_missing not implemented") + } + + /// Returns true if this node has been edited + fn is_edited(&self) -> bool { + unimplemented!("is_edited not implemented") + } + + /// Returns true if this node represents extra tokens from the source code + fn is_extra(&self) -> bool { + unimplemented!("is_extra not implemented") + } + + /// Returns the field id for the given field name + fn field_id_for_name(&self, name: &str) -> Option { + unimplemented!("field_id_for_name not implemented") + } + + /// Returns the field name for the given field id + fn field_name_for_id(&self, id: u16) -> Option<&str> { + unimplemented!("field_name_for_id not implemented") + } + fn id(&self) -> usize; } -pub trait HasNode: Send + Debug { +trait CSTNodeExt: CSTNode { + /// Get the next sibling of this node in its parent + fn next_sibling>( + &self, + parent: &Parent, + ) -> Option { + let mut iter = parent.children().into_iter(); + while let Some(child) = iter.next() { + if child.id() == self.id() { + return iter.next(); + } + } + None + } + fn next_named_sibling>( + &self, + parent: &Parent, + ) -> Option { + let mut iter = parent.named_children().into_iter(); + while let Some(child) = iter.next() { + if child.id() == self.id() { + return iter.next(); + } + } + None + } + fn prev_sibling>( + &self, + parent: &Parent, + ) -> Option { + let mut prev = None; + for child in parent.children() { + if child.id() == self.id() { + return prev; + } + prev = Some(child); + } + None + } + fn prev_named_sibling>( + &self, + parent: &Parent, + ) -> Option { + let mut prev = None; + for child in parent.named_children() { + if child.id() == self.id() { + return prev; + } + prev = Some(child); + } + None + } +} +pub trait HasNode: Send + Debug + Clone { type Node: CSTNode; fn node(&self) -> &Self::Node; } impl CSTNode for T { + fn kind(&self) -> &str { + self.node().kind() + } fn start_byte(&self) -> usize { self.node().start_byte() } @@ -44,8 +157,103 @@ impl CSTNode for T { fn kind_id(&self) -> u16 { self.node().kind_id() } + fn is_named(&self) -> bool { + self.node().is_named() + } + fn is_error(&self) -> bool { + self.node().is_error() + } + fn is_missing(&self) -> bool { + self.node().is_missing() + } + fn is_edited(&self) -> bool { + self.node().is_edited() + } + fn is_extra(&self) -> bool { + self.node().is_extra() + } + + fn field_id_for_name(&self, name: &str) -> Option { + self.node().field_id_for_name(name) + } + fn field_name_for_id(&self, id: u16) -> Option<&str> { + self.node().field_name_for_id(id) + } + fn id(&self) -> usize { + self.node().id() + } } +// impl HasChildren for T { +// type Child = ::Child; +// fn child_by_field_name(&self, field_name: &str) -> Option { +// self.node().child_by_field_name(field_name) +// } +// fn children_by_field_name(&self, field_name: &str) -> Vec { +// self.node().children_by_field_name(field_name) +// } +// fn children(&self) -> Vec { +// self.node().children() +// } +// fn child_by_field_id(&self, field_id: u16) -> Option { +// self.node().child_by_field_id(field_id) +// } +// fn child_count(&self) -> usize { +// self.node().child_count() +// } +// } pub trait HasChildren { - type Child: Send; - fn children(&self) -> &Vec; + type Child: Send + Debug + Clone + CSTNode; + /// Returns the first child with the given field name + fn child_by_field_id(&self, field_id: u16) -> Option { + self.children_by_field_id(field_id) + .first() + .map(|child| child.clone()) + } + + /// Returns all children with the given field name + fn children_by_field_id(&self, field_id: u16) -> Vec { + unimplemented!("children_by_field_id not implemented") + } + + /// Returns the first child with the given field name + fn child_by_field_name(&self, field_name: &str) -> Option { + self.children_by_field_name(field_name) + .first() + .map(|child| child.clone()) + } + + /// Returns all children with the given field name + fn children_by_field_name(&self, field_name: &str) -> Vec; + + /// Returns all children of the node + fn children(&self) -> Vec; + /// Returns all named children of the node + fn named_children(&self) -> Vec { + self.children() + .into_iter() + .filter(|child| child.is_named()) + .collect() + } + + // /// Returns a cursor for walking the tree starting from this node + // fn walk(&self) -> TreeCursor + // where + // Self: Sized, + // { + // TreeCursor::new(self) + // } + + /// Returns the first child of the node + fn first_child(&self) -> Option { + self.children().into_iter().next() + } + + /// Returns the last child of the node + fn last_child(&self) -> Option { + self.children().into_iter().last() + } + /// Returns the number of children of this node + fn child_count(&self) -> usize { + self.children().len() + } } diff --git a/codegen-sdk-common/src/tree.rs b/codegen-sdk-common/src/tree.rs new file mode 100644 index 0000000..6a815a7 --- /dev/null +++ b/codegen-sdk-common/src/tree.rs @@ -0,0 +1,6 @@ +// mod cursor; +mod point; +mod range; +// pub use cursor::TreeCursor; +pub use point::Point; +pub use range::Range; diff --git a/codegen-sdk-common/src/tree/cursor.rs b/codegen-sdk-common/src/tree/cursor.rs new file mode 100644 index 0000000..6b805fe --- /dev/null +++ b/codegen-sdk-common/src/tree/cursor.rs @@ -0,0 +1,156 @@ +use std::num::NonZeroU16; + +use crate::{CSTNode, HasChildren, tree::point::Point}; +#[derive(Debug, Clone)] +pub struct TreeCursor<'cursor> { + // Private implementation details + current: &'cursor dyn CSTNode, + parents: Vec<&'cursor dyn HasChildren>, + field_id: Option, + exhausted: bool, +} + +impl<'cursor> TreeCursor<'cursor> { + pub fn new>(node: &'cursor T) -> Self { + Self { + current: node, + parents: vec![], + field_id: None, + exhausted: false, + } + } + /// Get the tree cursor's current Node. + pub fn node(&self) -> &'cursor dyn CSTNode { + self.current + } + + /// Get the numerical field id of this tree cursor's current node. + pub fn field_id(&self) -> Option { + self.field_id + } + + /// Get the field name of this tree cursor's current node. + pub fn field_name(&self) -> Option<&'static str> { + unimplemented!() + } + + /// Get the depth of the cursor's current node relative to the original node + /// that the cursor was constructed with. + pub fn depth(&self) -> usize { + self.parents.len() + } + + /// Get the index of the cursor's current node out of all of the descendants + /// of the original node that the cursor was constructed with + pub fn descendant_index(&self) -> usize { + unimplemented!() + } + + /// Move this cursor to the first child of its current node. + /// + /// Returns `true` if the cursor successfully moved, and returns `false` + /// if there were no children. + pub fn goto_first_child(&mut self) -> bool { + let current: &dyn HasChildren = self.current.try_into().unwrap(); + if let Some(first_child) = ¤t.first_child() { + self.parents.push(current); + self.current = first_child; + return true; + } + false + } + + /// Move this cursor to the last child of its current node. + /// + /// Returns `true` if the cursor successfully moved, and returns `false` + /// if there were no children. + pub fn goto_last_child(&mut self) -> bool { + unimplemented!() + } + + /// Move this cursor to the parent of its current node. + /// + /// Returns `true` if the cursor successfully moved, and returns `false` + /// if there was no parent node. + pub fn goto_parent(&mut self) -> bool { + if let Some(parent) = self.parents.pop() { + self.current = parent; + true + } else { + false + } + } + + /// Move this cursor to the next sibling of its current node. + /// + /// Returns `true` if the cursor successfully moved, and returns `false` + /// if there was no next sibling node. + pub fn goto_next_sibling(&mut self) -> bool { + if let Some(parent) = self.parents.last_mut() { + if let Some(next_sibling) = self.current.next_sibling(parent) { + self.current = next_sibling; + return true; + } + } + false + } + + /// Move this cursor to the previous sibling of its current node. + /// + /// Returns `true` if the cursor successfully moved, and returns `false` + /// if there was no previous sibling node. + pub fn goto_previous_sibling(&mut self) -> bool { + unimplemented!() + } + + /// Move the cursor to the node that is the nth descendant of the original node + /// that the cursor was constructed with, where zero represents the original node itself. + pub fn goto_descendant(&mut self, descendant_index: usize) { + unimplemented!() + } + + /// Move this cursor to the first child of its current node that contains or + /// starts after the given byte offset. + pub fn goto_first_child_for_byte(&mut self, index: usize) -> Option { + unimplemented!() + } + + /// Move this cursor to the first child of its current node that contains or + /// starts after the given point. + pub fn goto_first_child_for_point(&mut self, point: Point) -> Option { + unimplemented!() + } + + // /// Re-initialize this tree cursor to start at the original node that the + // /// cursor was constructed with. + // pub fn reset>(&mut self, node: &NewT) { + // unimplemented!() + // } + + // /// Re-initialize a tree cursor to the same position as another cursor. + // pub fn reset_to(&mut self, cursor: &Self) { + // unimplemented!() + // } +} + +// Depth-first iterator +impl<'cursor> Iterator for TreeCursor<'cursor> { + type Item = &'cursor dyn CSTNode; + fn next(&mut self) -> Option { + if self.exhausted { + return None; + } + let ret = Some(self.current); + if !self.goto_first_child() { + if !self.goto_next_sibling() { + while self.goto_parent() { + if self.goto_next_sibling() { + break; // Found a sibling + } + } + self.exhausted = true; + } + } + ret + } +} diff --git a/codegen-sdk-common/src/point.rs b/codegen-sdk-common/src/tree/point.rs similarity index 75% rename from codegen-sdk-common/src/point.rs rename to codegen-sdk-common/src/tree/point.rs index 6244b77..1f73a28 100644 --- a/codegen-sdk-common/src/point.rs +++ b/codegen-sdk-common/src/tree/point.rs @@ -1,6 +1,6 @@ use rkyv::{Archive, Deserialize, Serialize}; -#[derive(Debug, Clone, Copy, Eq, PartialEq, Archive, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Archive, Deserialize, Serialize)] pub struct Point { pub row: usize, pub column: usize, diff --git a/codegen-sdk-common/src/tree/range.rs b/codegen-sdk-common/src/tree/range.rs new file mode 100644 index 0000000..1fe85c8 --- /dev/null +++ b/codegen-sdk-common/src/tree/range.rs @@ -0,0 +1,21 @@ +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::Point; +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Archive, Deserialize, Serialize)] +pub struct Range { + start: Point, + end: Point, +} +impl From for Range { + fn from(value: tree_sitter::Range) -> Self { + Self { + start: value.start_point.into(), + end: value.end_point.into(), + } + } +} +impl Range { + pub fn new(start: Point, end: Point) -> Self { + Self { start, end } + } +} diff --git a/codegen-sdk-cst-generator/Cargo.toml b/codegen-sdk-cst-generator/Cargo.toml index 107fef1..4d86e05 100644 --- a/codegen-sdk-cst-generator/Cargo.toml +++ b/codegen-sdk-cst-generator/Cargo.toml @@ -14,5 +14,8 @@ anyhow = { workspace = true } quote = "1.0.38" proc-macro2 = "1.0.93" tempfile = "3.8.1" + [dev-dependencies] +assert-tokenstreams-eq = "0.1.0" codegen-sdk-common = { path = "../codegen-sdk-common" , features = ["python"] } +test-log = { workspace = true } diff --git a/codegen-sdk-cst-generator/src/generator.rs b/codegen-sdk-cst-generator/src/generator.rs index 8253126..c327318 100644 --- a/codegen-sdk-cst-generator/src/generator.rs +++ b/codegen-sdk-cst-generator/src/generator.rs @@ -1,13 +1,11 @@ -use std::collections::HashSet; - -use codegen_sdk_common::{naming::normalize_type_name, parser::Node}; -use enum_generator::generate_enum; +use codegen_sdk_common::Language; use state::State; -use struct_generator::generate_struct; -mod enum_generator; +mod constants; +mod field; mod format; +mod node; mod state; -mod struct_generator; +mod utils; use std::io::Write; use proc_macro2::TokenStream; @@ -19,44 +17,21 @@ fn get_imports() -> TokenStream { use tree_sitter; use derive_more::Debug; use codegen_sdk_common::*; + use subenum::subenum; use std::backtrace::Backtrace; use bytes::Bytes; - use rkyv::{Archive, Deserialize, Serialize, Portable}; + use rkyv::{Archive, Deserialize, Serialize}; + use derive_visitor::Drive; } } -pub(crate) fn generate_cst(node_types: &Vec) -> anyhow::Result { - let mut state = State::default(); - let mut nodes = HashSet::new(); - for node in node_types { - if !node.subtypes.is_empty() { - state - .variants - .insert(normalize_type_name(&node.type_name), node.subtypes.clone()); - } else if node.children.is_none() && node.fields.is_none() { - state - .anonymous_nodes - .insert(node.type_name.clone(), normalize_type_name(&node.type_name)); - } - } - for node in node_types { - let name = normalize_type_name(&node.type_name); - if nodes.contains(&name) { - continue; - } - nodes.insert(name.clone()); - if name.is_empty() { - continue; - } - if !node.subtypes.is_empty() { - generate_enum(&node.subtypes, &mut state, &name, true); - } else { - generate_struct(node, &mut state, &name); - } - } +pub fn generate_cst(language: &Language) -> anyhow::Result { + let state = State::new(language); let mut result = get_imports(); - result.extend_one(state.enums); - result.extend_one(state.structs); + let enums = state.get_enum(); + let structs = state.get_structs(); + result.extend_one(enums); + result.extend_one(structs); let formatted = format::format_cst(&result.to_string()); match formatted { Ok(formatted) => return Ok(formatted), @@ -78,7 +53,7 @@ mod tests { use codegen_sdk_common::{language::python::Python, parser::parse_node_types}; use super::*; - #[test] + #[test_log::test] fn test_generate_cst() { let node_types = parse_node_types(&Python.node_types).unwrap(); let cst = generate_cst(&node_types).unwrap(); diff --git a/codegen-sdk-cst-generator/src/generator/constants.rs b/codegen-sdk-cst-generator/src/generator/constants.rs new file mode 100644 index 0000000..1f77b4a --- /dev/null +++ b/codegen-sdk-cst-generator/src/generator/constants.rs @@ -0,0 +1 @@ +pub const TYPE_NAME: &str = "NodeTypes"; diff --git a/codegen-sdk-cst-generator/src/generator/enum_generator.rs b/codegen-sdk-cst-generator/src/generator/enum_generator.rs deleted file mode 100644 index 34b3984..0000000 --- a/codegen-sdk-cst-generator/src/generator/enum_generator.rs +++ /dev/null @@ -1,106 +0,0 @@ -use codegen_sdk_common::{ - naming::{normalize_string, normalize_type_name}, - parser::TypeDefinition, -}; -use proc_macro2::TokenStream; -use quote::{format_ident, quote}; - -use crate::generator::state::State; -fn get_cases( - variants: &Vec, - state: &State, - override_variant_name: Option<&str>, - existing_cases: &mut Vec, -) -> Vec<(String, TokenStream)> { - let mut cases = Vec::new(); - for t in variants { - let normalized_variant_name = normalize_type_name(&t.type_name); - if normalized_variant_name.is_empty() { - continue; - } - let variant_name = override_variant_name.unwrap_or_else(|| &normalized_variant_name); - if let Some(variants) = state.variants.get(&normalized_variant_name) { - cases.extend(get_cases( - variants, - state, - Some(variant_name), - existing_cases, - )); - } else if !existing_cases.contains(&t.type_name) { - existing_cases.push(t.type_name.clone()); - let variant_name = format_ident!("{}", variant_name); - cases.push(( - t.type_name.clone(), - quote! { Self::#variant_name (#variant_name::from_node(node, buffer)?)}, - )); - // cases.insert(t.type_name.clone(), quote!{ - // #t.type_name => Ok(#(#prefix)::from_node(node, buffer)?), - // }.to_string()); - } - } - return cases; -} -pub fn generate_enum( - variants: &Vec, - state: &mut State, - enum_name: &str, - anonymous_nodes: bool, -) { - let mut variant_tokens = Vec::new(); - for t in variants { - let variant_name = normalize_type_name(&t.type_name); - if variant_name.is_empty() { - continue; - } - let variant_name = format_ident!("{}", variant_name); - variant_tokens.push(quote! { - #variant_name(#variant_name) - }); - } - if anonymous_nodes { - variant_tokens.push(quote! { - Anonymous, - }); - } - let enum_name = format_ident!("{}", enum_name); - state.enums.extend_one(quote! { - #[derive(Debug, Clone, Archive, Portable, Deserialize, Serialize)] - #[repr(C, u8)] - pub enum #enum_name { - #(#variant_tokens),* - } - }); - let mut existing_cases = Vec::new(); - let mut cases = get_cases(variants, state, None, &mut existing_cases); - if anonymous_nodes { - for (name, _variant_name) in state.anonymous_nodes.iter() { - if name.is_empty() { - continue; - } - if existing_cases.contains(name) { - continue; - } - let normalized_name = normalize_string(name); - cases.push((normalized_name, quote! {Self::Anonymous})); - } - } - let mut keys = Vec::new(); - let mut values = Vec::new(); - for (key, value) in cases { - keys.push(key); - values.push(value); - } - state.enums.extend_one(quote! { - impl FromNode for #enum_name { - fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { - match node.kind() { - #(#keys => Ok(#values)),*, - _ => Err(ParseError::UnexpectedNode { - node_type: node.kind().to_string(), - backtrace: Backtrace::capture(), - }), - } - } - } - }); -} diff --git a/codegen-sdk-cst-generator/src/generator/field.rs b/codegen-sdk-cst-generator/src/generator/field.rs new file mode 100644 index 0000000..5f5f3d7 --- /dev/null +++ b/codegen-sdk-cst-generator/src/generator/field.rs @@ -0,0 +1,374 @@ +use codegen_sdk_common::{ + Language, + naming::{normalize_field_name, normalize_type_name}, + parser::{FieldDefinition, TypeDefinition}, +}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +use super::constants::TYPE_NAME; +#[derive(Debug)] +pub struct Field<'a> { + raw: &'a FieldDefinition, + name: String, + node_name: String, + language: &'a Language, +} + +impl<'a> Field<'a> { + pub fn new( + node_name: &str, + name: &str, + raw: &'a FieldDefinition, + language: &'a Language, + ) -> Self { + Self { + node_name: node_name.to_string(), + name: name.to_string(), + raw, + language, + } + } + fn field_id(&self) -> u16 { + self.language.field_id(&self.name).unwrap().into() + } + pub fn name(&self) -> String { + normalize_field_name(&self.name) + } + pub fn normalized_name(&self) -> String { + normalize_type_name(&self.name, true) + } + pub fn types(&self) -> Vec<&TypeDefinition> { + self.raw.types.iter().collect() + } + pub fn type_name(&self) -> String { + let types = self.types(); + if types.len() == 1 { + normalize_type_name(&types[0].type_name, types[0].named) + } else { + format!("{}{}", self.node_name, self.normalized_name()) + } + } + pub fn get_constructor_field(&self) -> TokenStream { + let field_name_ident = format_ident!("{}", self.name()); + let original_name = &self.name; + if self.raw.multiple { + quote! { + #field_name_ident: get_multiple_children_by_field_name(&node, #original_name, buffer)? + } + } else if !self.raw.required { + quote! { + #field_name_ident: Box::new(get_optional_child_by_field_name(&node, #original_name, buffer)?) + } + } else { + quote! { + #field_name_ident: Box::new(get_child_by_field_name(&node, #original_name, buffer)?) + } + } + } + pub fn get_convert_child(&self, convert_children: bool) -> TokenStream { + let field_name_ident = format_ident!("{}", self.name()); + let types = format_ident!("{}", TYPE_NAME); + if convert_children { + if self.raw.multiple { + quote! { + Self::Child::try_from(#types::from(child.clone())).unwrap() + } + } else if !self.raw.required { + quote! { + Self::Child::try_from(#types::from(child.clone())).unwrap() + } + } else { + quote! { + Self::Child::try_from(#types::from(self.#field_name_ident.as_ref().clone())).unwrap() + } + } + } else if self.raw.multiple || !self.raw.required { + quote! { + child.clone() + } + } else { + quote! { + self.#field_name_ident.as_ref().clone() + } + } + } + pub fn get_children_field(&self, convert_children: bool) -> TokenStream { + let field_name_ident = format_ident!("{}", self.name()); + let convert_child = self.get_convert_child(convert_children); + + if self.raw.multiple { + quote! { + children.extend(self.#field_name_ident.iter().map(|child| #convert_child)); + } + } else if self.raw.required { + quote! { + children.push(#convert_child); + } + } else { + quote! { + if let Some(child) = self.#field_name_ident.as_ref() { + children.push(#convert_child); + } + } + } + } + pub fn get_children_by_field_name_field(&self, convert_children: bool) -> TokenStream { + let field_name = &self.name; + let field_name_ident = format_ident!("{}", self.name()); + let convert_child = self.get_convert_child(convert_children); + + if self.raw.multiple { + quote! { + #field_name => self.#field_name_ident.iter().map(|child| #convert_child).collect() + } + } else if self.raw.required { + quote! { + #field_name => vec![#convert_child] + } + } else { + quote! { + #field_name => self.#field_name_ident.as_ref().iter().map(|child| #convert_child).collect() + } + } + } + pub fn get_children_by_field_id_field(&self, convert_children: bool) -> TokenStream { + let field_id = self.field_id(); + let field_name_ident = format_ident!("{}", self.name()); + let convert_child = self.get_convert_child(convert_children); + + if self.raw.multiple { + quote! { + #field_id => self.#field_name_ident.iter().map(|child| #convert_child).collect() + } + } else if self.raw.required { + quote! { + #field_id => vec![#convert_child] + } + } else { + quote! { + #field_id => self.#field_name_ident.as_ref().iter().map(|child| #convert_child).collect() + } + } + } + pub fn get_struct_field(&self) -> TokenStream { + let field_name_ident = format_ident!("{}", self.name()); + let converted_type_name = format_ident!("{}", self.type_name()); + if self.raw.multiple { + quote! { + #[rkyv(omit_bounds)] + pub #field_name_ident: Vec<#converted_type_name> + } + } else if !self.raw.required { + quote! { + #[rkyv(omit_bounds)] + pub #field_name_ident: Box> + } + } else { + quote! { + #[rkyv(omit_bounds)] + pub #field_name_ident: Box<#converted_type_name> + } + } + } +} + +#[cfg(test)] +mod tests { + use codegen_sdk_common::parser::TypeDefinition; + + use super::*; + + fn create_test_field_definition(name: &str, multiple: bool, required: bool) -> FieldDefinition { + FieldDefinition { + types: vec![TypeDefinition { + type_name: name.to_string(), + named: true, + }], + multiple, + required, + } + } + fn create_test_field_definition_variants( + name: &Vec, + multiple: bool, + required: bool, + ) -> FieldDefinition { + FieldDefinition { + types: name + .iter() + .map(|n| TypeDefinition { + type_name: n.to_string(), + named: true, + }) + .collect(), + multiple, + required, + } + } + + #[test] + fn test_field_normalized_name() { + let field_definition = create_test_field_definition("test_type", false, true); + let field = Field::new("node", "field", &field_definition); + assert_eq!(field.normalized_name(), "Field"); + } + + #[test] + fn test_field_types() { + let field_definition = create_test_field_definition_variants( + &vec!["type_a".to_string(), "type_b".to_string()], + false, + true, + ); + let field = Field::new("test_node", "test_field", &field_definition); + assert_eq!( + field.types(), + field_definition.types.iter().collect::>() + ); + } + + #[test] + fn test_field_type_name() { + let field_definition = create_test_field_definition_variants( + &vec!["test_type".to_string(), "test_type".to_string()], + false, + true, + ); + let field = Field::new("Node", "field", &field_definition); + assert_eq!(field.type_name(), "NodeField"); + } + + #[test] + fn test_get_struct_field() { + let field_definition = create_test_field_definition("test_type", false, true); + let field = Field::new("test_node", "test_field", &field_definition); + + assert_eq!( + field.get_struct_field().to_string(), + quote! { + #[rkyv(omit_bounds)] + pub test_field: Box + } + .to_string() + ); + + // Test optional field + let optional_definition = create_test_field_definition("test_type", false, false); + let optional_field = Field::new("test_node", "test_field", &optional_definition); + + assert_eq!( + optional_field.get_struct_field().to_string(), + quote! { + #[rkyv(omit_bounds)] + pub test_field: Box> + } + .to_string() + ); + + // Test multiple field + let multiple_definition = create_test_field_definition("test_type", true, true); + let multiple_field = Field::new("test_node", "test_field", &multiple_definition); + + assert_eq!( + multiple_field.get_struct_field().to_string(), + quote! { + #[rkyv(omit_bounds)] + pub test_field: Vec + } + .to_string() + ); + } + + #[test] + fn test_get_constructor_field() { + let field_definition = create_test_field_definition("test_type", false, true); + let field = Field::new("test_node", "test_field", &field_definition); + + assert_eq!( + field.get_constructor_field().to_string(), + quote!(test_field: Box::new(get_child_by_field_name(&node, "test_field", buffer)?)) + .to_string() + ); + + // Test optional field + let optional_definition = create_test_field_definition("test_type", false, false); + let optional_field = Field::new("test_node", "test_field", &optional_definition); + + assert_eq!( + optional_field.get_constructor_field().to_string(), + quote!(test_field: Box::new(get_optional_child_by_field_name(&node, "test_field", buffer)?)).to_string() + ); + + // Test multiple field + let multiple_definition = create_test_field_definition("test_type", true, true); + let multiple_field = Field::new("test_node", "test_field", &multiple_definition); + + assert_eq!( + multiple_field.get_constructor_field().to_string(), + quote!(test_field: get_multiple_children_by_field_name(&node, "test_field", buffer)?) + .to_string() + ); + } + + #[test] + fn test_get_children_field() { + let field_definition = create_test_field_definition("test_type", false, true); + let field = Field::new("test_node", "test_field", &field_definition); + + assert_eq!( + field.get_children_field(true).to_string(), + quote!(children.push(Self::Child::try_from(NodeTypes::from(self.test_field.as_ref().clone())).unwrap());).to_string() + ); + + // Test optional field + let optional_definition = create_test_field_definition("test_type", false, false); + let optional_field = Field::new("test_node", "test_field", &optional_definition); + + assert_eq!( + optional_field.get_children_field(true).to_string(), + quote!(if let Some(child) = self.test_field.as_ref() { + children.push(Self::Child::try_from(NodeTypes::from(child.clone())).unwrap()); + }) + .to_string() + ); + + // Test multiple field + let multiple_definition = create_test_field_definition("test_type", true, true); + let multiple_field = Field::new("test_node", "test_field", &multiple_definition); + + assert_eq!( + multiple_field.get_children_field(true).to_string(), + quote!(children.extend(self.test_field.iter().map(|child| Self::Child::try_from(NodeTypes::from(child.clone())).unwrap()));).to_string() + ); + } + + #[test] + fn test_get_children_by_field_name_field() { + let field_definition = create_test_field_definition("test_type", false, true); + let field = Field::new("test_node", "test_field", &field_definition); + + assert_eq!( + field.get_children_by_field_name_field(true).to_string(), + quote!("test_field" => vec![Self::Child::try_from(NodeTypes::from(self.test_field.as_ref().clone())).unwrap()]).to_string() + ); + + // Test optional field + let optional_definition = create_test_field_definition("test_type", false, false); + let optional_field = Field::new("test_node", "test_field", &optional_definition); + + assert_eq!( + optional_field.get_children_by_field_name_field(true).to_string(), + quote!("test_field" => self.test_field.as_ref().iter().map(|child| Self::Child::try_from(NodeTypes::from(child.clone())).unwrap()).collect()).to_string() + ); + + // Test multiple field + let multiple_definition = create_test_field_definition("test_type", true, true); + let multiple_field = Field::new("test_node", "test_field", &multiple_definition); + + assert_eq!( + multiple_field.get_children_by_field_name_field(true).to_string(), + quote!("test_field" => self.test_field.iter().map(|child| Self::Child::try_from(NodeTypes::from(child.clone())).unwrap()).collect()).to_string() + ); + } +} diff --git a/codegen-sdk-cst-generator/src/generator/node.rs b/codegen-sdk-cst-generator/src/generator/node.rs new file mode 100644 index 0000000..f51a97c --- /dev/null +++ b/codegen-sdk-cst-generator/src/generator/node.rs @@ -0,0 +1,959 @@ +use codegen_sdk_common::{Language, naming::normalize_type_name, parser::TypeDefinition}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +use super::field::Field; +use crate::generator::utils::{get_comment_type, get_serialize_bounds}; +#[derive(Debug)] +pub struct Node<'a> { + raw: &'a codegen_sdk_common::parser::Node, + pub subenums: Vec, + pub fields: Vec>, + language: &'a Language, +} +impl<'a> Node<'a> { + pub fn new(raw: &'a codegen_sdk_common::parser::Node, language: &'a Language) -> Self { + let mut fields = Vec::new(); + let normalized_name = normalize_type_name(&raw.type_name, raw.named); + if let Some(raw_fields) = &raw.fields { + for (name, field) in raw_fields.fields.iter() { + fields.push(Field::new(&normalized_name, name, field, language)); + } + } + fields.sort_by_key(|f| f.normalized_name().clone()); + Node { + raw, + subenums: Vec::new(), + fields, + language, + } + } + pub fn kind(&self) -> &str { + &self.raw.type_name + } + pub fn kind_id(&self) -> u16 { + self.language.kind_id(&self.raw.type_name, self.raw.named) + } + pub fn normalize_name(&self) -> String { + normalize_type_name(&self.raw.type_name, self.raw.named) + } + pub fn type_definition(&self) -> TypeDefinition { + TypeDefinition { + type_name: self.raw.type_name.clone(), + named: self.raw.named, + } + } + pub fn add_subenum(&mut self, subenum: String) { + if !self.subenums.contains(&subenum) { + self.subenums.push(subenum); + } + } + pub fn get_enum_tokens(&self) -> TokenStream { + let name = format_ident!("{}", self.normalize_name()); + let subenum_names = &self + .subenums + .iter() + .map(|s| format_ident!("{}", normalize_type_name(s, true))) + .collect::>(); + if subenum_names.is_empty() { + quote! { + #name(#name) + } + } else { + quote! { + #[subenum(#(#subenum_names), *)] + #name(#name) + } + } + } + pub fn get_children_names(&self) -> Vec { + let mut children_names = vec![]; + let comment = get_comment_type(); + if let Some(children) = &self.raw.children { + children_names.extend(children.types.iter().cloned()); + } + for field in &self.fields { + children_names.extend(field.types().into_iter().cloned()); + } + if children_names.len() > 0 && !children_names.contains(&comment) { + children_names.push(comment); + } + children_names.sort(); + children_names.dedup(); + children_names + } + pub fn children_struct_name(&self) -> String { + let children_names = self.get_children_names(); + match children_names.len() { + 0 => "Self".to_string(), + 1 => normalize_type_name(&children_names[0].type_name, children_names[0].named), + _ => format!("{}Children", self.normalize_name()), + } + } + fn has_children(&self) -> bool { + self.raw.children.is_some() + } + fn get_children_field(&self) -> TokenStream { + if self.has_children() { + let children_type_name = format_ident!("{}", self.children_struct_name()); + quote! { + #[rkyv(omit_bounds)] + pub children: Vec<#children_type_name>, + } + } else { + quote! {} + } + } + + pub fn get_struct_tokens(&self) -> TokenStream { + let constructor = self.get_constructor(); + let struct_fields = self + .fields + .iter() + .map(|f| f.get_struct_field()) + .collect::>(); + let children_field = self.get_children_field(); + let name = format_ident!("{}", self.normalize_name()); + let serialize_bounds = get_serialize_bounds(); + let trait_impls = self.get_trait_implementations(); + + quote! { + #[derive(Debug, Clone, Deserialize, Archive, Serialize, Drive)] + #serialize_bounds + pub struct #name { + #[drive(skip)] + start_byte: usize, + #[drive(skip)] + end_byte: usize, + #[drive(skip)] + _kind: std::string::String, + #[debug("[{},{}]", start_position.row, start_position.column)] + #[drive(skip)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + #[drive(skip)] + end_position: Point, + #[debug(ignore)] + #[drive(skip)] + buffer: Arc, + #[debug(ignore)] + #[drive(skip)] + kind_id: u16, + #[debug(ignore)] + #[drive(skip)] + is_error: bool, + #[debug(ignore)] + #[drive(skip)] + named: bool, + #[debug(ignore)] + #[drive(skip)] + id: usize, + #children_field + #(#struct_fields),* + } + #constructor + #trait_impls + } + } + fn get_children_constructor(&self) -> TokenStream { + if self.has_children() { + quote! { + children: named_children_without_field_names(node, buffer)? + } + } else { + quote! {} + } + } + pub fn get_constructor(&self) -> TokenStream { + let name = format_ident!("{}", self.normalize_name()); + let mut constructor_fields = Vec::new(); + for field in &self.fields { + constructor_fields.push(field.get_constructor_field()); + } + constructor_fields.push(self.get_children_constructor()); + + quote! { + impl FromNode for #name { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + Ok(Self { + start_byte: node.start_byte(), + end_byte: node.end_byte(), + _kind: node.kind().to_string(), + start_position: node.start_position().into(), + end_position: node.end_position().into(), + buffer: buffer.clone(), + kind_id: node.kind_id(), + is_error: node.is_error(), + named: node.is_named(), + id: node.id(), + #(#constructor_fields),* + }) + } + } + } + } + fn get_children_impl(&self) -> TokenStream { + let name = format_ident!("{}", self.normalize_name()); + let children_type_name = format_ident!("{}", self.children_struct_name()); + let children_field = self.get_children_field_impl(); + let children_by_field_name = self.get_children_by_field_name_impl(); + quote! { + impl HasChildren for #name { + type Child = #children_type_name; + #children_field + #children_by_field_name + } + } + } + pub fn get_trait_implementations(&self) -> TokenStream { + let name = format_ident!("{}", self.normalize_name()); + let children_impl = self.get_children_impl(); + + quote! { + impl CSTNode for #name { + fn kind(&self) -> &str { + &self._kind + } + fn start_byte(&self) -> usize { + self.start_byte + } + fn end_byte(&self) -> usize { + self.end_byte + } + fn start_position(&self) -> Point { + self.start_position + } + fn end_position(&self) -> Point { + self.end_position + } + fn buffer(&self) -> &Bytes { + &self.buffer + } + fn kind_id(&self) -> u16 { + self.kind_id + } + fn is_error(&self) -> bool { + self.is_error + } + fn is_named(&self) -> bool { + self.named + } + fn id(&self) -> usize { + self.id + } + } + #children_impl + } + } + fn get_children_field_impl(&self) -> TokenStream { + let mut children_fields = Vec::new(); + let num_children = self.get_children_names().len(); + if num_children == 0 { + return quote! { + fn children(&self) -> Vec { + vec![] + } + }; + } + let convert_children = num_children > 1; + for field in &self.fields { + children_fields.push(field.get_children_field(convert_children)); + } + + let m = if children_fields.is_empty() { + quote! {} + } else { + quote! {mut} + }; + let children_init = if self.has_children() { + quote! { + self.children.iter().cloned().collect() + } + } else { + quote! { + vec![] + } + }; + quote! { + fn children(&self) -> Vec { + let #m children: Vec<_> = #children_init; + #(#children_fields;)* + children + } + } + } + fn get_children_by_field_name_impl(&self) -> TokenStream { + let convert_children = self.get_children_names().len() > 1; + let field_matches = self + .fields + .iter() + .map(|f| f.get_children_by_field_name_field(convert_children)) + .collect::>(); + + quote! { + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + #(#field_matches,)* + _ => vec![], + } + } + } + } +} +#[cfg(test)] +mod tests { + use assert_tokenstreams_eq::assert_tokenstreams_eq; + use codegen_sdk_common::parser::{FieldDefinition, Fields, TypeDefinition}; + + use super::*; + + fn create_test_node(name: &str) -> codegen_sdk_common::parser::Node { + codegen_sdk_common::parser::Node { + type_name: name.to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + } + } + + fn create_test_node_with_fields( + name: &str, + fields: Vec<(&str, FieldDefinition)>, + ) -> codegen_sdk_common::parser::Node { + codegen_sdk_common::parser::Node { + type_name: name.to_string(), + subtypes: vec![], + named: true, + root: false, + fields: Some(Fields { + fields: fields + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(), + }), + children: None, + } + } + + fn create_test_node_with_children( + name: &str, + child_types: Vec<&str>, + ) -> codegen_sdk_common::parser::Node { + codegen_sdk_common::parser::Node { + type_name: name.to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: Some(codegen_sdk_common::parser::Children { + types: child_types + .into_iter() + .map(|t| TypeDefinition { + type_name: t.to_string(), + named: true, + }) + .collect(), + multiple: false, + required: true, + }), + } + } + + #[test_log::test] + fn test_get_enum_tokens() { + let base_node = create_test_node("test"); + let mut node = Node::from(&base_node); + + let tokens = node.get_enum_tokens(); + assert_eq!(quote! { Test(Test) }.to_string(), tokens.to_string()); + node.add_subenum("subenum".to_string()); + let tokens = node.get_enum_tokens(); + assert_eq!( + quote! { + #[subenum(Subenum)] + Test(Test) + } + .to_string(), + tokens.to_string() + ); + } + + #[test] + fn test_get_struct_tokens_simple() { + let raw_node = create_test_node("test_node"); + let node = Node::from(&raw_node); + let serialize_bounds = get_serialize_bounds(); + assert_tokenstreams_eq!( + "e! { + #[derive(Debug, Clone, Deserialize, Archive, Serialize)] + #serialize_bounds + pub struct TestNode { + start_byte: usize, + end_byte: usize, + _kind: std::string::String, + #[debug("[{},{}]", start_position.row, start_position.column)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + end_position: Point, + #[debug(ignore)] + buffer: Arc, + #[debug(ignore)] + kind_id: u16, + } + + impl FromNode for TestNode { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + Ok(Self { + start_byte: node.start_byte(), + end_byte: node.end_byte(), + _kind: node.kind().to_string(), + start_position: node.start_position().into(), + end_position: node.end_position().into(), + buffer: buffer.clone(), + kind_id: node.kind_id(), + }) + } + } + impl CSTNode for TestNode { + fn kind(&self) -> &str { + &self._kind + } + fn start_byte(&self) -> usize { + self.start_byte + } + fn end_byte(&self) -> usize { + self.end_byte + } + fn start_position(&self) -> Point { + self.start_position + } + fn end_position(&self) -> Point { + self.end_position + } + fn buffer(&self) -> &Bytes { + &self.buffer + } + fn kind_id(&self) -> u16 { + self.kind_id + } + } + impl HasChildren for TestNode { + type Child = Self; + fn children(&self) -> Vec { + vec![] + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + _ => vec![], + } + } + } + }, + &node.get_struct_tokens() + ); + } + + #[test] + fn test_get_struct_tokens_with_fields() { + let raw_node = create_test_node_with_fields( + "test_node", + vec![( + "test_field", + FieldDefinition { + types: vec![TypeDefinition { + type_name: "test_type".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + )], + ); + let node = Node::from(&raw_node); + let serialize_bounds = get_serialize_bounds(); + assert_tokenstreams_eq!( + "e! { + #[derive(Debug, Clone, Deserialize, Archive, Serialize)] + #serialize_bounds + pub struct TestNode { + start_byte: usize, + end_byte: usize, + _kind: std::string::String, + #[debug("[{},{}]", start_position.row, start_position.column)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + end_position: Point, + #[debug(ignore)] + buffer: Arc, + #[debug(ignore)] + kind_id: u16, + #[rkyv(omit_bounds)] + pub test_field: Box, + } + + impl FromNode for TestNode { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + Ok(Self { + start_byte: node.start_byte(), + end_byte: node.end_byte(), + _kind: node.kind().to_string(), + start_position: node.start_position().into(), + end_position: node.end_position().into(), + buffer: buffer.clone(), + kind_id: node.kind_id(), + test_field: Box::new(get_child_by_field_name(&node, "test_field", buffer)?), + }) + } + } + impl CSTNode for TestNode { + fn kind(&self) -> &str { + &self._kind + } + fn start_byte(&self) -> usize { + self.start_byte + } + fn end_byte(&self) -> usize { + self.end_byte + } + fn start_position(&self) -> Point { + self.start_position + } + fn end_position(&self) -> Point { + self.end_position + } + fn buffer(&self) -> &Bytes { + &self.buffer + } + fn kind_id(&self) -> u16 { + self.kind_id + } + } + impl HasChildren for TestNode { + type Child = TestType; + fn children(&self) -> Vec { + let mut children: Vec<_> = vec![]; + children.push(self.test_field.as_ref().clone()); + children + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + "test_field" => vec![self.test_field.as_ref().clone()], + _ => vec![], + } + } + } + }, + &node.get_struct_tokens() + ); + } + + #[test] + fn test_get_struct_tokens_complex() { + let raw_node = create_test_node_with_fields( + "test_node", + vec![ + ( + "required_field", + FieldDefinition { + types: vec![TypeDefinition { + type_name: "test_type".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + ), + ( + "optional_field", + FieldDefinition { + types: vec![TypeDefinition { + type_name: "test_type".to_string(), + named: true, + }], + multiple: false, + required: false, + }, + ), + ( + "multiple_field", + FieldDefinition { + types: vec![TypeDefinition { + type_name: "test_type".to_string(), + named: true, + }], + multiple: true, + required: true, + }, + ), + ], + ); + let node = Node::from(&raw_node); + let serialize_bounds = get_serialize_bounds(); + assert_tokenstreams_eq!( + "e! { + #[derive(Debug, Clone, Deserialize, Archive, Serialize)] + #serialize_bounds + pub struct TestNode { + start_byte: usize, + end_byte: usize, + _kind: std::string::String, + #[debug("[{},{}]", start_position.row, start_position.column)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + end_position: Point, + #[debug(ignore)] + buffer: Arc, + #[debug(ignore)] + kind_id: u16, + #[rkyv(omit_bounds)] + pub multiple_field: Vec, + #[rkyv(omit_bounds)] + pub optional_field: Box>, + #[rkyv(omit_bounds)] + pub required_field: Box, + } + + impl FromNode for TestNode { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + Ok(Self { + start_byte: node.start_byte(), + end_byte: node.end_byte(), + _kind: node.kind().to_string(), + start_position: node.start_position().into(), + end_position: node.end_position().into(), + buffer: buffer.clone(), + kind_id: node.kind_id(), + multiple_field: get_multiple_children_by_field_name(&node, "multiple_field", buffer)?, + optional_field: Box::new(get_optional_child_by_field_name(&node, "optional_field", buffer)?), + required_field: Box::new(get_child_by_field_name(&node, "required_field", buffer)?), + }) + } + } + impl CSTNode for TestNode { + fn kind(&self) -> &str { + &self._kind + } + fn start_byte(&self) -> usize { + self.start_byte + } + fn end_byte(&self) -> usize { + self.end_byte + } + fn start_position(&self) -> Point { + self.start_position + } + fn end_position(&self) -> Point { + self.end_position + } + fn buffer(&self) -> &Bytes { + &self.buffer + } + fn kind_id(&self) -> u16 { + self.kind_id + } + } + impl HasChildren for TestNode { + type Child = TestType; + fn children(&self) -> Vec { + let mut children: Vec<_> = vec![]; + children.extend(self.multiple_field.iter().map(|child| child.clone())); + if let Some(child) = self.optional_field.as_ref() { + children.push(child.clone()); + }; + children.push(self.required_field.as_ref().clone()); + children + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + "multiple_field" => self + .multiple_field + .iter() + .map(|child| child.clone()) + .collect(), + "optional_field" => self + .optional_field + .as_ref() + .iter() + .map(|child| child.clone()) + .collect(), + "required_field" => vec![self.required_field.as_ref().clone()], + _ => vec![], + } + } + } + }, + &node.get_struct_tokens() + ); + } + + #[test] + fn test_get_struct_tokens_with_children() { + let raw_node = + create_test_node_with_children("test_node", vec!["child_type_a", "child_type_b"]); + let node = Node::from(&raw_node); + let serialize_bounds = get_serialize_bounds(); + + assert_tokenstreams_eq!( + "e! { + #[derive(Debug, Clone, Deserialize, Archive, Serialize)] + #serialize_bounds + pub struct TestNode { + start_byte: usize, + end_byte: usize, + _kind: std::string::String, + #[debug("[{},{}]", start_position.row, start_position.column)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + end_position: Point, + #[debug(ignore)] + buffer: Arc, + #[debug(ignore)] + kind_id: u16, + #[rkyv(omit_bounds)] + pub children: Vec, + } + + impl FromNode for TestNode { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + Ok(Self { + start_byte: node.start_byte(), + end_byte: node.end_byte(), + _kind: node.kind().to_string(), + start_position: node.start_position().into(), + end_position: node.end_position().into(), + buffer: buffer.clone(), + kind_id: node.kind_id(), + children: named_children_without_field_names(node, buffer)?, + }) + } + } + impl CSTNode for TestNode { + fn kind(&self) -> &str { + &self._kind + } + fn start_byte(&self) -> usize { + self.start_byte + } + fn end_byte(&self) -> usize { + self.end_byte + } + fn start_position(&self) -> Point { + self.start_position + } + fn end_position(&self) -> Point { + self.end_position + } + fn buffer(&self) -> &Bytes { + &self.buffer + } + fn kind_id(&self) -> u16 { + self.kind_id + } + } + impl HasChildren for TestNode { + type Child = TestNodeChildren; + fn children(&self) -> Vec { + let children: Vec<_> = self.children.iter().cloned().collect(); + children + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + _ => vec![], + } + } + } + }, + &node.get_struct_tokens() + ); + } + + #[test] + fn test_get_struct_tokens_with_single_child_type() { + let raw_node = create_test_node_with_children("test_node", vec!["child_type"]); + let node = Node::new(&raw_node, &Language::Typescript); + let serialize_bounds = get_serialize_bounds(); + + assert_tokenstreams_eq!( + "e! { + #[derive(Debug, Clone, Deserialize, Archive, Serialize)] + #serialize_bounds + pub struct TestNode { + start_byte: usize, + end_byte: usize, + _kind: std::string::String, + #[debug("[{},{}]", start_position.row, start_position.column)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + end_position: Point, + #[debug(ignore)] + buffer: Arc, + #[debug(ignore)] + kind_id: u16, + #[rkyv(omit_bounds)] + pub children: Vec, + } + + impl FromNode for TestNode { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + Ok(Self { + start_byte: node.start_byte(), + end_byte: node.end_byte(), + _kind: node.kind().to_string(), + start_position: node.start_position().into(), + end_position: node.end_position().into(), + buffer: buffer.clone(), + kind_id: node.kind_id(), + children: named_children_without_field_names(node, buffer)?, + }) + } + } + impl CSTNode for TestNode { + fn kind(&self) -> &str { + &self._kind + } + fn start_byte(&self) -> usize { + self.start_byte + } + fn end_byte(&self) -> usize { + self.end_byte + } + fn start_position(&self) -> Point { + self.start_position + } + fn end_position(&self) -> Point { + self.end_position + } + fn buffer(&self) -> &Bytes { + &self.buffer + } + fn kind_id(&self) -> u16 { + self.kind_id + } + } + impl HasChildren for TestNode { + type Child = ChildType; + fn children(&self) -> Vec { + let children: Vec<_> = self.children.iter().cloned().collect(); + children + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + _ => vec![], + } + } + } + }, + &node.get_struct_tokens() + ); + } + + #[test] + fn test_get_trait_implementations() { + let raw_node = create_test_node("test_node"); + let node = Node::from(&raw_node); + let tokens = node.get_trait_implementations(); + + assert_tokenstreams_eq!( + "e! { + impl CSTNode for TestNode { + fn kind(&self) -> &str { + &self._kind + } + fn start_byte(&self) -> usize { + self.start_byte + } + fn end_byte(&self) -> usize { + self.end_byte + } + fn start_position(&self) -> Point { + self.start_position + } + fn end_position(&self) -> Point { + self.end_position + } + fn buffer(&self) -> &Bytes { + &self.buffer + } + fn kind_id(&self) -> u16 { + self.kind_id + } + } + impl HasChildren for TestNode { + type Child = Self; + fn children(&self) -> Vec { + vec![] + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + _ => vec![], + } + } + } + }, + &tokens + ); + } + + #[test] + fn test_get_children_field_impl() { + let raw_node = create_test_node_with_fields( + "test_node", + vec![( + "test_field", + FieldDefinition { + types: vec![TypeDefinition { + type_name: "test_type".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + )], + ); + let node = Node::from(&raw_node); + + assert_tokenstreams_eq!( + "e! { + fn children(&self) -> Vec { + let mut children: Vec<_> = vec![]; + children.push(self.test_field.as_ref().clone()); + children + } + }, + &node.get_children_field_impl() + ); + } + + #[test] + fn test_get_children_by_field_name_impl() { + let raw_node = create_test_node_with_fields( + "test_node", + vec![( + "test_field", + FieldDefinition { + types: vec![TypeDefinition { + type_name: "test_type".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + )], + ); + let node = Node::from(&raw_node); + + assert_tokenstreams_eq!( + "e! { + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + "test_field" => vec![self.test_field.as_ref().clone()], + _ => vec![], + } + } + }, + &node.get_children_by_field_name_impl() + ); + } +} diff --git a/codegen-sdk-cst-generator/src/generator/state.rs b/codegen-sdk-cst-generator/src/generator/state.rs index 95383fb..e5e0cfb 100644 --- a/codegen-sdk-cst-generator/src/generator/state.rs +++ b/codegen-sdk-cst-generator/src/generator/state.rs @@ -1,11 +1,690 @@ -use std::collections::HashMap; +use std::collections::{BTreeMap, BTreeSet, VecDeque}; -use codegen_sdk_common::parser::TypeDefinition; +use codegen_sdk_common::{Language, naming::normalize_type_name, parser::TypeDefinition}; use proc_macro2::TokenStream; -#[derive(Default, Debug)] -pub struct State { - pub enums: TokenStream, - pub structs: TokenStream, - pub variants: HashMap>, - pub anonymous_nodes: HashMap, +use quote::{format_ident, quote}; + +use super::{node::Node, utils::get_from_node}; +use crate::generator::{ + constants::TYPE_NAME, + utils::{get_comment_type, get_from_for_enum}, +}; +#[derive(Debug)] +pub struct State<'a> { + pub subenums: BTreeSet, + nodes: BTreeMap>, + language: &'a Language, +} +impl<'a> State<'a> { + pub fn new(language: &'a Language) -> Self { + let mut nodes = BTreeMap::new(); + let mut subenums = BTreeSet::new(); + let raw_nodes = language.nodes(); + for raw_node in raw_nodes { + if raw_node.subtypes.is_empty() { + let node = Node::new(raw_node, language); + nodes.insert(node.normalize_name(), node); + } else { + subenums.insert(raw_node.type_name.clone()); + } + } + let mut ret = Self { + nodes, + subenums, + language, + }; + let mut subenums = VecDeque::new(); + for raw_node in raw_nodes.iter().filter(|n| !n.subtypes.is_empty()) { + subenums.push_back(raw_node.clone()); + } + while let Some(raw_node) = subenums.pop_front() { + if raw_node + .subtypes + .iter() + .any(|s| subenums.iter().any(|n| n.type_name == s.type_name)) + { + subenums.push_back(raw_node); + } else { + // Add subtypes to the state + ret.add_subenum(&raw_node.type_name, &raw_node.subtypes.iter().collect()); + } + } + log::info!("Adding child subenums"); + ret.add_child_subenums(); + log::info!("Adding field subenums"); + ret.add_field_subenums(); + ret + } + fn add_child_subenums(&mut self) { + let keys = self.nodes.keys().cloned().collect::>(); + for name in keys.into_iter() { + let normalized_name = normalize_type_name(&name, true); + let node = self.nodes.get(&normalized_name).unwrap(); + let mut children_types = node.get_children_names(); + if children_types.len() > 1 { + children_types.sort(); + children_types.dedup(); + self.add_subenum( + &node.children_struct_name(), + &children_types.iter().collect(), + ); + } + } + } + fn add_field_subenums(&mut self) { + let mut to_add: Vec<(String, Vec)> = Vec::new(); + for node in self.nodes.values() { + for field in &node.fields { + log::debug!("Adding field subenum: {}", field.normalized_name()); + if field.types().len() > 1 { + to_add.push(( + field.type_name(), + field.types().into_iter().cloned().collect(), + )); + } + } + } + for (name, types) in to_add.into_iter() { + self.add_subenum(&name, &types.iter().collect()); + } + } + fn add_subenum(&mut self, name: &str, nodes: &Vec<&TypeDefinition>) { + self.subenums.insert(name.to_string()); + let mut nodes = nodes.clone(); + let comment = get_comment_type(); + if self.nodes.contains_key(&comment.type_name) { + nodes.push(&comment); + } + for node in nodes { + let normalized_name = normalize_type_name(&node.type_name, node.named); + if !self.subenums.contains(&node.type_name) { + log::debug!("Adding subenum: {} to {}", name, normalized_name); + if let Some(node) = self.nodes.get_mut(&normalized_name) { + node.add_subenum(name.to_string()); + } + } else { + let variants = self.get_variants(&node.type_name); + self.add_subenum(name, &variants.iter().collect()); + } + } + } + fn get_variants(&self, subenum: &str) -> Vec { + let comment = get_comment_type(); + let mut variants = vec![comment]; + for node in self.nodes.values() { + log::debug!("Checking subenum: {} for {}", subenum, node.kind()); + if node.subenums.contains(&subenum.to_string()) { + log::debug!("Found variant: {} for {}", node.kind(), subenum); + variants.push(node.type_definition()); + } + } + variants + } + fn get_variant_map(&self, enum_name: &str) -> BTreeMap { + let mut variant_map = BTreeMap::new(); + for node in self.nodes.values() { + let variant_name = format_ident!("{}", node.normalize_name()); + if node.subenums.contains(&enum_name.to_string()) { + log::debug!("Adding variant: {} for {}", node.kind(), enum_name); + variant_map.insert( + node.kind_id(), + quote! { + Ok(Self::#variant_name(#variant_name::from_node(node, buffer)?)) + }, + ); + } + } + variant_map + } + // Implement the TSNode => CSTNode conversion trait on a given subenum + fn get_from_node(&self, enum_name: &str) -> TokenStream { + let variant_map = self.get_variant_map(enum_name); + get_from_node(enum_name, true, &variant_map) + } + // Get the overarching enum for the nodes + pub fn get_enum(&self) -> TokenStream { + let mut enum_tokens = Vec::new(); + let mut from_tokens = TokenStream::new(); + let mut subenums = Vec::new(); + for node in self.nodes.values() { + enum_tokens.push(node.get_enum_tokens()); + let variant_name = node.normalize_name(); + } + for subenum in self.subenums.iter() { + assert!( + self.get_variants(subenum).len() > 0, + "Subenum {} has no variants", + subenum + ); + from_tokens.extend_one(self.get_from_node(subenum)); + subenums.push(format_ident!("{}", normalize_type_name(&subenum, true))); + } + let subenum_tokens = if !subenums.is_empty() { + subenums.sort(); + subenums.dedup(); + quote! { + #[subenum(#(#subenums(derive(Archive, Deserialize, Serialize))),*)] + } + } else { + quote! {} + }; + let enum_name = format_ident!("{}", TYPE_NAME); + quote! { + #subenum_tokens + #[derive(Debug, Clone, Drive)] + #[enum_delegate::implement(CSTNode)] + pub enum #enum_name { + #(#enum_tokens),* + } + #from_tokens + } + } + pub fn get_structs(&self) -> TokenStream { + let mut struct_tokens = TokenStream::new(); + for node in self.nodes.values() { + struct_tokens.extend_one(node.get_struct_tokens()); + } + struct_tokens + } +} +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use assert_tokenstreams_eq::assert_tokenstreams_eq; + + use super::*; + use crate::generator::utils::get_serialize_bounds; + + #[test_log::test] + fn test_get_enum() { + let node = codegen_sdk_common::parser::Node { + type_name: "test".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }; + let nodes = vec![node]; + let state = State::from(&nodes); + let enum_tokens = state.get_enum(); + assert_tokenstreams_eq!( + "e! { + #[derive(Debug, Clone)] + pub enum NodeTypes { + Test(Test) + } + impl std::convert::From for NodeTypes { + fn from(variant: Test) -> Self { + Self::Test(variant) + } + } + }, + &enum_tokens + ); + } + #[test_log::test] + fn test_parse_children() { + let child = codegen_sdk_common::parser::Node { + type_name: "child".to_string(), + subtypes: vec![], + named: false, + root: false, + fields: None, + children: None, + }; + let child_two = codegen_sdk_common::parser::Node { + type_name: "child_two".to_string(), + subtypes: vec![], + named: false, + root: false, + fields: None, + children: None, + }; + let node = codegen_sdk_common::parser::Node { + type_name: "test".to_string(), + subtypes: vec![], + named: false, + root: false, + fields: None, + children: Some(codegen_sdk_common::parser::Children { + multiple: true, + required: false, + types: vec![ + codegen_sdk_common::parser::TypeDefinition { + type_name: "child".to_string(), + named: true, + }, + codegen_sdk_common::parser::TypeDefinition { + type_name: "child_two".to_string(), + named: true, + }, + ], + }), + }; + let nodes = vec![child, child_two, node]; + let state = State::from(&nodes); + let enum_tokens = state.get_enum(); + assert_tokenstreams_eq!( + "e! { + #[subenum(TestChildren(derive(Archive, Deserialize, Serialize)))] + #[derive(Debug, Clone)] + pub enum NodeTypes { + #[subenum(TestChildren)] + Child(Child), + #[subenum(TestChildren)] + ChildTwo(ChildTwo), + Test(Test) + } + impl std::convert::From for NodeTypes { + fn from(variant: Child) -> Self { + Self::Child(variant) + } + } + impl std::convert::From for TestChildren { + fn from(variant: Child) -> Self { + Self::Child(variant) + } + } + impl std::convert::From for NodeTypes { + fn from(variant: ChildTwo) -> Self { + Self::ChildTwo(variant) + } + } + impl std::convert::From for TestChildren { + fn from(variant: ChildTwo) -> Self { + Self::ChildTwo(variant) + } + } + impl std::convert::From for NodeTypes { + fn from(variant: Test) -> Self { + Self::Test(variant) + } + } + impl FromNode for TestChildren { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "child" => Ok(Self::Child(node.from_node(node, buffer)?)), + "child_two" => Ok(Self::ChildTwo(node.from_node(node, buffer)?)), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), } + } + } + }, + &enum_tokens + ); + } + #[test_log::test] + fn test_parse_children_subtypes() { + let definition = codegen_sdk_common::parser::Node { + type_name: "definition".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "function".to_string(), + named: true, + }, + TypeDefinition { + type_name: "class".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }; + let class = codegen_sdk_common::parser::Node { + type_name: "class".to_string(), + subtypes: vec![], + named: false, + root: false, + fields: None, + children: Some(codegen_sdk_common::parser::Children { + multiple: true, + required: false, + types: vec![codegen_sdk_common::parser::TypeDefinition { + type_name: "definition".to_string(), + named: true, + }], + }), + }; + let function = codegen_sdk_common::parser::Node { + type_name: "function".to_string(), + subtypes: vec![], + named: false, + root: false, + fields: None, + children: None, + }; + let nodes = vec![definition, class, function]; + let state = State::from(&nodes); + let enum_tokens = state.get_enum(); + assert_tokenstreams_eq!( + "e! { + #[subenum(Definition(derive(Archive, Deserialize, Serialize)))] + #[derive(Debug, Clone)] + pub enum NodeTypes { + #[subenum(Definition)] + Class(Class), + #[subenum(Definition)] + Function(Function) + } + impl std::convert::From for NodeTypes { + fn from(variant: Class) -> Self { + Self::Class(variant) + } + } + impl std::convert::From for Definition { + fn from(variant: Class) -> Self { + Self::Class(variant) + } + } + impl std::convert::From for NodeTypes { + fn from(variant: Function) -> Self { + Self::Function(variant) + } + } + impl std::convert::From for Definition { + fn from(variant: Function) -> Self { + Self::Function(variant) + } + } + impl FromNode for Definition { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "class" => Ok(Self::Class(node.from_node(node)?)), + "function" => Ok(Self::Function(node.from_node(node)?)), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), } + } + } + }, + &enum_tokens + ); + } + #[test_log::test] + fn test_add_field_subenums() { + let node_a = codegen_sdk_common::parser::Node { + type_name: "node_a".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }; + let node_b = codegen_sdk_common::parser::Node { + type_name: "node_b".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }; + let field = codegen_sdk_common::parser::FieldDefinition { + types: vec![ + TypeDefinition { + type_name: "node_a".to_string(), + named: true, + }, + TypeDefinition { + type_name: "node_b".to_string(), + named: true, + }, + ], + multiple: true, + required: false, + }; + let node_c = codegen_sdk_common::parser::Node { + type_name: "node_c".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: Some(codegen_sdk_common::parser::Fields { + fields: HashMap::from([("field".to_string(), field)]), + }), + children: None, + }; + let nodes = vec![node_a, node_b, node_c]; + let state = State::from(&nodes); + let enum_tokens = state.get_enum(); + assert_tokenstreams_eq!( + "e! { + #[subenum(NodeCChildren(derive(Archive, Deserialize, Serialize)), NodeCField(derive(Archive, Deserialize, Serialize)))] + #[derive(Debug, Clone)] + pub enum NodeTypes { + #[subenum(NodeCChildren, NodeCField)] + NodeA(NodeA), + #[subenum(NodeCChildren, NodeCField)] + NodeB(NodeB), + NodeC(NodeC) + } + impl std::convert::From for NodeTypes { + fn from(variant: NodeA) -> Self { + Self::NodeA(variant) + } + } + impl std::convert::From for NodeCChildren { + fn from(variant: NodeA) -> Self { + Self::NodeA(variant) + } + } + impl std::convert::From for NodeCField { + fn from(variant: NodeA) -> Self { + Self::NodeA(variant) + } + } + + impl std::convert::From for NodeTypes { + fn from(variant: NodeB) -> Self { + Self::NodeB(variant) + } + } + impl std::convert::From for NodeCChildren { + fn from(variant: NodeB) -> Self { + Self::NodeB(variant) + } + } + impl std::convert::From for NodeCField { + fn from(variant: NodeB) -> Self { + Self::NodeB(variant) + } + } + impl std::convert::From for NodeTypes { + fn from(variant: NodeC) -> Self { + Self::NodeC(variant) + } + } + impl FromNode for NodeCChildren { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "node_a" => Ok(Self::NodeA(NodeA::from_node(node, buffer)?)), + "node_b" => Ok(Self::NodeB(NodeB::from_node(node, buffer)?)), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), } + } + } + impl FromNode for NodeCField { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "node_a" => Ok(Self::NodeA(NodeA::from_node(node, buffer)?)), + "node_b" => Ok(Self::NodeB(NodeB::from_node(node, buffer)?)), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), } + } + } + + }, + &enum_tokens + ); + } + #[test_log::test] + fn test_get_structs() { + let node = codegen_sdk_common::parser::Node { + type_name: "test".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }; + let nodes = vec![node]; + let state = State::from(&nodes); + let struct_tokens = state.get_structs(); + let serialize_bounds = get_serialize_bounds(); + assert_tokenstreams_eq!( + "e! { + #[derive(Debug, Clone, Deserialize, Archive, Serialize)] + #serialize_bounds + pub struct Test { + start_byte: usize, + end_byte: usize, + _kind: std::string::String, + #[debug("[{},{}]", start_position.row, start_position.column)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + end_position: Point, + #[debug(ignore)] + buffer: Arc, + #[debug(ignore)] + kind_id: u16, + } + impl FromNode for Test { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + Ok(Self { + start_byte: node.start_byte(), + end_byte: node.end_byte(), + _kind: node.kind().to_string(), + start_position: node.start_position().into(), + end_position: node.end_position().into(), + buffer: buffer.clone(), + kind_id: node.kind_id(), + }) + } + } + impl CSTNode for Test { + fn kind(&self) -> &str { + &self._kind + } + fn start_byte(&self) -> usize { + self.start_byte + } + fn end_byte(&self) -> usize { + self.end_byte + } + fn start_position(&self) -> Point { + self.start_position + } + fn end_position(&self) -> Point { + self.end_position + } + fn buffer(&self) -> &Bytes { + &self.buffer + } + fn kind_id(&self) -> u16 { + self.kind_id + } + } + impl HasChildren for Test { + type Child = Self; + fn children(&self) -> Vec { + vec![] + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + _ => vec![], + } + } + } + + }, + &struct_tokens + ); + } + #[test_log::test] + fn test_get_variants() { + let node_a = codegen_sdk_common::parser::Node { + type_name: "node_a".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }; + let node_b = codegen_sdk_common::parser::Node { + type_name: "node_b".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }; + let parent = codegen_sdk_common::parser::Node { + type_name: "parent".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "node_a".to_string(), + named: true, + }, + TypeDefinition { + type_name: "node_b".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }; + let nodes = vec![node_a, node_b, parent]; + let state = State::from(&nodes); + + let variants = state.get_variants("parent"); + assert_eq!( + vec![ + TypeDefinition { + type_name: "node_a".to_string(), + named: true, + }, + TypeDefinition { + type_name: "node_b".to_string(), + named: true, + }, + ], + variants + ); + } + #[test_log::test] + fn test_add_subenum() { + let node_a = codegen_sdk_common::parser::Node { + type_name: "node_a".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }; + let nodes = vec![node_a]; + let mut state = State::from(&nodes); + + state.add_subenum( + "TestEnum", + &vec![&TypeDefinition { + type_name: "node_a".to_string(), + named: true, + }], + ); + assert!(state.subenums.contains("TestEnum")); + + let node = state.nodes.get("NodeA").unwrap(); + assert!(node.subenums.contains(&"TestEnum".to_string())); + } } diff --git a/codegen-sdk-cst-generator/src/generator/struct_generator.rs b/codegen-sdk-cst-generator/src/generator/struct_generator.rs deleted file mode 100644 index 4efa9fb..0000000 --- a/codegen-sdk-cst-generator/src/generator/struct_generator.rs +++ /dev/null @@ -1,213 +0,0 @@ -use codegen_sdk_common::{ - naming::{normalize_field_name, normalize_type_name}, - parser::{Children, FieldDefinition, Fields, Node, TypeDefinition}, -}; -use proc_macro2::TokenStream; -use quote::{format_ident, quote}; - -use super::enum_generator::generate_enum; -use crate::generator::state::State; -fn convert_type_definition( - type_name: &Vec, - state: &mut State, - field_name: &str, - node_name: &str, -) -> String { - let include_anonymous_nodes = true; - if type_name.len() == 1 && !include_anonymous_nodes { - normalize_type_name(&type_name[0].type_name) - } else { - let enum_name = normalize_type_name( - format!( - "{}{}", - normalize_type_name(node_name), - normalize_type_name(field_name) - ) - .as_str(), - ); - generate_enum(type_name, state, &enum_name, include_anonymous_nodes); - enum_name - } -} - -fn generate_multiple_field( - field_name: &str, - converted_type_name: &str, - original_name: &str, -) -> (TokenStream, TokenStream) { - let field_name = format_ident!("{}", field_name); - let converted_type_name = format_ident!("{}", converted_type_name); - let struct_field = quote! { - pub #field_name: Vec<#converted_type_name> - }; - let constructor_field = quote! { - #field_name: get_multiple_children_by_field_name(&node, #original_name, buffer)? - }; - (struct_field, constructor_field) -} -fn generate_required_field( - field_name: &str, - converted_type_name: &str, - original_name: &str, -) -> (TokenStream, TokenStream) { - let field_name = format_ident!("{}", field_name); - let converted_type_name = format_ident!("{}", converted_type_name); - let struct_field = quote! { - #[rkyv(omit_bounds)] - pub #field_name: Box<#converted_type_name> - }; - let constructor_field = quote! { - #field_name: Box::new(get_child_by_field_name(&node, #original_name, buffer)?) - }; - (struct_field, constructor_field) -} -fn generate_optional_field( - field_name: &str, - converted_type_name: &str, - original_name: &str, -) -> (TokenStream, TokenStream) { - let field_name = format_ident!("{}", field_name); - let converted_type_name = format_ident!("{}", converted_type_name); - let struct_field = quote! { - #[rkyv(omit_bounds)] - pub #field_name: Box> - }; - let constructor_field = quote! { - #field_name: Box::new(get_optional_child_by_field_name(&node, #original_name, buffer)?) - }; - (struct_field, constructor_field) -} -fn generate_field( - field: &FieldDefinition, - state: &mut State, - node: &Node, - name: &str, -) -> (TokenStream, TokenStream) { - let field_name = normalize_field_name(name); - let converted_type_name = convert_type_definition(&field.types, state, &node.type_name, name); - if field.multiple { - return generate_multiple_field(&field_name, &converted_type_name, name); - } else if field.required { - return generate_required_field(&field_name, &converted_type_name, name); - } else { - return generate_optional_field(&field_name, &converted_type_name, name); - } -} -fn generate_fields( - fields: &Fields, - state: &mut State, - node: &Node, -) -> (Vec, Vec) { - let mut struct_fields = Vec::new(); - let mut constructor_fields = Vec::new(); - for (name, field) in &fields.fields { - let (struct_field, constructor_field) = generate_field(field, state, node, name); - struct_fields.push(struct_field); - constructor_fields.push(constructor_field); - } - (struct_fields, constructor_fields) -} -fn generate_children( - children: &Children, - state: &mut State, - node_name: &str, -) -> (String, TokenStream) { - let converted_type_name = - convert_type_definition(&children.types, state, node_name, "children"); - let constructor_field = quote! { - children: named_children_without_field_names(node, buffer)? - }; - (converted_type_name, constructor_field) -} -pub fn generate_struct(node: &Node, state: &mut State, name: &str) { - let mut constructor_fields = Vec::new(); - let mut struct_fields = Vec::new(); - if let Some(fields) = &node.fields { - (struct_fields, constructor_fields) = generate_fields(fields, state, node); - } - let mut children_type_name = "Self".to_string(); - if let Some(children) = &node.children { - let constructor_field; - (children_type_name, constructor_field) = - generate_children(children, state, &node.type_name); - constructor_fields.push(constructor_field); - } else { - constructor_fields.push(quote! { - children: vec![] - }); - } - let name = format_ident!("{}", name); - let children_type_name = format_ident!("{}", children_type_name); - let definition = quote! { - #[derive(Debug, Clone, Deserialize, Archive, Serialize)] - #[rkyv(serialize_bounds( - __S: rkyv::ser::Writer + rkyv::ser::Allocator, - __S::Error: rkyv::rancor::Source, - ))] - #[rkyv(deserialize_bounds(__D::Error: rkyv::rancor::Source))] - #[rkyv(bytecheck( - bounds( - __C: rkyv::validation::ArchiveContext, - __C::Error: rkyv::rancor::Source, - ) - ))] - pub struct #name { - start_byte: usize, - end_byte: usize, - #[debug("[{},{}]", start_position.row, start_position.column)] - start_position: Point, - #[debug("[{},{}]", end_position.row, end_position.column)] - end_position: Point, - #[debug(ignore)] - buffer: Arc, - #[debug(ignore)] - kind_id: u16, - #[rkyv(omit_bounds)] - pub children: Vec<#children_type_name>, - #(#struct_fields),* - } - }; - state.structs.extend_one(definition); - let implementation = quote! { - impl CSTNode for #name { - fn start_byte(&self) -> usize { - self.start_byte - } - fn end_byte(&self) -> usize { - self.end_byte - } - fn start_position(&self) -> Point { - self.start_position - } - fn end_position(&self) -> Point { - self.end_position - } - fn buffer(&self) -> &Bytes { - &self.buffer - } - fn kind_id(&self) -> u16 { - self.kind_id - } - } - impl HasChildren for #name { - type Child = #children_type_name; - fn children(&self) -> &Vec { - self.children.as_ref() - } - } - impl FromNode for #name { - fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { - Ok(Self { - start_byte: node.start_byte(), - end_byte: node.end_byte(), - start_position: node.start_position().into(), - end_position: node.end_position().into(), - buffer: buffer.clone(), - kind_id: node.kind_id(), - #(#constructor_fields),* - }) - } - } - }; - state.structs.extend_one(implementation); -} diff --git a/codegen-sdk-cst-generator/src/generator/utils.rs b/codegen-sdk-cst-generator/src/generator/utils.rs new file mode 100644 index 0000000..8fbf47c --- /dev/null +++ b/codegen-sdk-cst-generator/src/generator/utils.rs @@ -0,0 +1,62 @@ +use std::collections::BTreeMap; + +use codegen_sdk_common::{naming::normalize_type_name, parser::TypeDefinition}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +pub fn get_from_for_enum(variant: &str, enum_name: &str) -> TokenStream { + let enum_name = format_ident!("{}", enum_name); + let variant = format_ident!("{}", variant); + quote! { + impl std::convert::From<#variant> for #enum_name { + fn from(variant: #variant) -> Self { + Self::#variant(variant) + } + } + } +} +pub fn get_serialize_bounds() -> TokenStream { + quote! { + #[rkyv(serialize_bounds( + __S: rkyv::ser::Writer + rkyv::ser::Allocator, + __S::Error: rkyv::rancor::Source, + ))] + #[rkyv(deserialize_bounds(__D::Error: rkyv::rancor::Source))] + #[rkyv(bytecheck( + bounds( + __C: rkyv::validation::ArchiveContext, + __C::Error: rkyv::rancor::Source, + ) + ))] + } +} +pub fn get_from_node( + node: &str, + named: bool, + variant_map: &BTreeMap, +) -> TokenStream { + let node = format_ident!("{}", normalize_type_name(node, named)); + let mut keys = Vec::new(); + let mut values = Vec::new(); + for (key, value) in variant_map.iter() { + keys.push(key); + values.push(value); + } + quote! { + impl FromNode for #node { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind_id() { + #(#keys => #values,)* + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), } + } + } + } +} +pub fn get_comment_type() -> TypeDefinition { + TypeDefinition { + type_name: "comment".to_string(), + named: true, + } +} diff --git a/codegen-sdk-cst-generator/src/lib.rs b/codegen-sdk-cst-generator/src/lib.rs index 914366d..45b342f 100644 --- a/codegen-sdk-cst-generator/src/lib.rs +++ b/codegen-sdk-cst-generator/src/lib.rs @@ -1,9 +1,9 @@ #![feature(extend_one)] mod generator; use codegen_sdk_common::language::Language; -pub fn generate_cst(language: &Language) -> anyhow::Result<()> { - let node_types = language.nodes(); - let cst = generator::generate_cst(&node_types)?; +pub use generator::generate_cst; +pub fn generate_cst_to_file(language: &Language) -> anyhow::Result<()> { + let cst = generator::generate_cst(language)?; let out_dir = std::env::var("OUT_DIR")?; let out_file = format!("{}/{}.rs", out_dir, language.name); std::fs::write(out_file, cst)?; diff --git a/codegen-sdk-cst-generator/tests/test_subtypes.rs b/codegen-sdk-cst-generator/tests/test_subtypes.rs new file mode 100644 index 0000000..cdc27b5 --- /dev/null +++ b/codegen-sdk-cst-generator/tests/test_subtypes.rs @@ -0,0 +1,558 @@ +use std::collections::HashMap; + +use assert_tokenstreams_eq::assert_tokenstreams_eq; +use codegen_sdk_common::parser::{Fields, Node, TypeDefinition}; +use codegen_sdk_cst_generator::generate_cst; +use quote::quote; + +#[test] +fn test_basic_subtypes() { + // Define nodes with basic subtype relationships + let nodes = vec![ + Node { + type_name: "expression".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "binary_expression".to_string(), + named: true, + }, + TypeDefinition { + type_name: "unary_expression".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "binary_expression".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "unary_expression".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + ]; + + let output = generate_cst(&nodes).unwrap(); + let expected = quote! { + use std::{backtrace::Backtrace, sync::Arc}; + use bytes::Bytes; + use codegen_sdk_common::*; + use derive_more::Debug; + use rkyv::{Archive, Deserialize, Serialize}; + use subenum::subenum; + use tree_sitter; + + #[derive(Debug, Clone)] + #[subenum(Expression(derive(Archive, Deserialize, Serialize)))] + pub enum NodeTypes { + #[subenum(Expression)] + BinaryExpression(BinaryExpression), + #[subenum(Expression)] + UnaryExpression(UnaryExpression), + } + + impl std::convert::From for NodeTypes { + fn from(variant: BinaryExpression) -> Self { + Self::BinaryExpression(variant) + } + } + + impl std::convert::From for NodeTypes { + fn from(variant: UnaryExpression) -> Self { + Self::UnaryExpression(variant) + } + } + + impl std::convert::From for Expression { + fn from(variant: BinaryExpression) -> Self { + Self::BinaryExpression(variant) + } + } + + impl std::convert::From for Expression { + fn from(variant: UnaryExpression) -> Self { + Self::UnaryExpression(variant) + } + } + + impl FromNode for Expression { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "binary_expression" => Ok(Self::BinaryExpression(BinaryExpression::from_node(node, buffer)?)), + "unary_expression" => Ok(Self::UnaryExpression(UnaryExpression::from_node(node, buffer)?)), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), + } + } + } + }; + assert_tokenstreams_eq!(&output, &expected); +} + +#[test_log::test] +fn test_nested_subtypes() { + // Define nodes with nested subtype relationships + let nodes = vec![ + // Top level statement type + Node { + type_name: "statement".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "declaration".to_string(), + named: true, + }, + TypeDefinition { + type_name: "expression_statement".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }, + // Declaration is both a statement subtype and has its own subtypes + Node { + type_name: "declaration".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "function_declaration".to_string(), + named: true, + }, + TypeDefinition { + type_name: "class_declaration".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }, + // Concrete node types + Node { + type_name: "function_declaration".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "class_declaration".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "expression_statement".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + ]; + + let output = generate_cst(&nodes).unwrap(); + let expected = quote! { + use std::{backtrace::Backtrace, sync::Arc}; + use bytes::Bytes; + use codegen_sdk_common::*; + use derive_more::Debug; + use rkyv::{Archive, Deserialize, Serialize}; + use subenum::subenum; + use tree_sitter; + + #[subenum( + Declaration(derive(Archive, Deserialize, Serialize)), + Statement(derive(Archive, Deserialize, Serialize)) + )] + #[derive(Debug, Clone)] + pub enum NodeTypes { + #[subenum(Declaration, Statement)] + ClassDeclaration(ClassDeclaration), + #[subenum(Statement)] + ExpressionStatement(ExpressionStatement), + #[subenum(Declaration, Statement)] + FunctionDeclaration(FunctionDeclaration), + } + + impl std::convert::From for NodeTypes { + fn from(variant: ClassDeclaration) -> Self { + Self::ClassDeclaration(variant) + } + } + + impl std::convert::From for Declaration { + fn from(variant: ClassDeclaration) -> Self { + Self::ClassDeclaration(variant) + } + } + + impl std::convert::From for NodeTypes { + fn from(variant: ExpressionStatement) -> Self { + Self::ExpressionStatement(variant) + } + } + + impl std::convert::From for Statement { + fn from(variant: ExpressionStatement) -> Self { + Self::ExpressionStatement(variant) + } + } + + impl std::convert::From for NodeTypes { + fn from(variant: FunctionDeclaration) -> Self { + Self::FunctionDeclaration(variant) + } + } + + impl std::convert::From for Declaration { + fn from(variant: FunctionDeclaration) -> Self { + Self::FunctionDeclaration(variant) + } + } + + impl FromNode for Declaration { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "class_declaration" => Ok(Self::ClassDeclaration(ClassDeclaration::from_node( + node, buffer, + )?)), + "function_declaration" => Ok(Self::FunctionDeclaration( + FunctionDeclaration::from_node(node, buffer)?, + )), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), + } + } + } + + impl FromNode for Statement { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "class_declaration" => Ok(Self::ClassDeclaration( + ClassDeclaration::from_node(node, buffer)?, + )), + "function_declaration" => Ok(Self::FunctionDeclaration( + FunctionDeclaration::from_node(node, buffer)?, + )), + "expression_statement" => Ok(Self::ExpressionStatement( + ExpressionStatement::from_node(node, buffer)?, + )), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), + } + } + } + + + // ... struct definitions and implementations for ClassDeclaration, ExpressionStatement, and FunctionDeclaration ... + }; + assert_tokenstreams_eq!(&output, &expected); +} + +#[test] +fn test_subtypes_with_fields() { + let nodes = vec![ + Node { + type_name: "expression".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "binary_expression".to_string(), + named: true, + }, + TypeDefinition { + type_name: "literal".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "binary_expression".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: Some(Fields { + fields: HashMap::from([ + ( + "left".to_string(), + codegen_sdk_common::parser::FieldDefinition { + types: vec![TypeDefinition { + type_name: "expression".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + ), + ( + "right".to_string(), + codegen_sdk_common::parser::FieldDefinition { + types: vec![TypeDefinition { + type_name: "expression".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + ), + ]), + }), + children: None, + }, + Node { + type_name: "literal".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + ]; + + let output = generate_cst(&nodes).unwrap(); + let expected = quote! { + use std::{backtrace::Backtrace, sync::Arc}; + use bytes::Bytes; + use codegen_sdk_common::*; + use derive_more::Debug; + use rkyv::{Archive, Deserialize, Serialize}; + use subenum::subenum; + use tree_sitter; + #[derive(Debug, Clone)] + pub struct BinaryExpression { + left: Box, + right: Box, + // ... other required fields ... + } + // ... expected impl blocks ... + }; + assert_tokenstreams_eq!(&output, &expected); +} + +#[test_log::test] +fn test_deeply_nested_subtypes() { + let nodes = vec![ + // Top level statement type + Node { + type_name: "statement".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "declaration".to_string(), + named: true, + }, + TypeDefinition { + type_name: "expression_statement".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }, + // Declaration with its subtypes + Node { + type_name: "declaration".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "function_declaration".to_string(), + named: true, + }, + TypeDefinition { + type_name: "class_declaration".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }, + // Function declaration with its subtype + Node { + type_name: "function_declaration".to_string(), + subtypes: vec![TypeDefinition { + type_name: "method_declaration".to_string(), + named: true, + }], + named: true, + root: false, + fields: None, + children: None, + }, + // Concrete node types + Node { + type_name: "method_declaration".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "class_declaration".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "expression_statement".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + ]; + + let output = generate_cst(&nodes).unwrap(); + let expected = quote! { + use std::{backtrace::Backtrace, sync::Arc}; + use bytes::Bytes; + use codegen_sdk_common::*; + use derive_more::Debug; + use rkyv::{Archive, Deserialize, Serialize}; + use subenum::subenum; + use tree_sitter; + + #[subenum( + Declaration(derive(Archive, Deserialize, Serialize)), + Statement(derive(Archive, Deserialize, Serialize)), + FunctionDeclaration(derive(Archive, Deserialize, Serialize)) + )] + #[derive(Debug, Clone)] + pub enum NodeTypes { + #[subenum(Declaration, Statement)] + ClassDeclaration(ClassDeclaration), + #[subenum(Statement)] + ExpressionStatement(ExpressionStatement), + #[subenum(Declaration, Statement, FunctionDeclaration)] + MethodDeclaration(MethodDeclaration), + } + + impl std::convert::From for NodeTypes { + fn from(variant: ClassDeclaration) -> Self { + Self::ClassDeclaration(variant) + } + } + + impl std::convert::From for Declaration { + fn from(variant: ClassDeclaration) -> Self { + Self::ClassDeclaration(variant) + } + } + + impl std::convert::From for NodeTypes { + fn from(variant: ExpressionStatement) -> Self { + Self::ExpressionStatement(variant) + } + } + + impl std::convert::From for Statement { + fn from(variant: ExpressionStatement) -> Self { + Self::ExpressionStatement(variant) + } + } + + impl std::convert::From for NodeTypes { + fn from(variant: MethodDeclaration) -> Self { + Self::MethodDeclaration(variant) + } + } + + impl std::convert::From for Declaration { + fn from(variant: MethodDeclaration) -> Self { + Self::MethodDeclaration(variant) + } + } + + impl std::convert::From for FunctionDeclaration { + fn from(variant: MethodDeclaration) -> Self { + Self::MethodDeclaration(variant) + } + } + + impl FromNode for Declaration { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "class_declaration" => Ok(Self::ClassDeclaration(ClassDeclaration::from_node( + node, buffer, + )?)), + "method_declaration" => Ok(Self::MethodDeclaration( + MethodDeclaration::from_node(node, buffer)?, + )), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), + } + } + } + + impl FromNode for Statement { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "class_declaration" => Ok(Self::ClassDeclaration( + ClassDeclaration::from_node(node, buffer)?, + )), + "method_declaration" => Ok(Self::MethodDeclaration( + MethodDeclaration::from_node(node, buffer)?, + )), + "expression_statement" => Ok(Self::ExpressionStatement( + ExpressionStatement::from_node(node, buffer)?, + )), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), + } + } + } + + impl FromNode for FunctionDeclaration { + fn from_node(node: tree_sitter::Node, buffer: &Arc) -> Result { + match node.kind() { + "method_declaration" => Ok(Self::MethodDeclaration( + MethodDeclaration::from_node(node, buffer)?, + )), + _ => Err(ParseError::UnexpectedNode { + node_type: node.kind().to_string(), + backtrace: Backtrace::capture(), + }), + } + } + } + + // ... struct definitions and implementations for ClassDeclaration, ExpressionStatement, and MethodDeclaration ... + }; + assert_tokenstreams_eq!(&output, &expected); +} diff --git a/codegen-sdk-cst-generator/tests/test_subtypes_children.rs b/codegen-sdk-cst-generator/tests/test_subtypes_children.rs new file mode 100644 index 0000000..1860f16 --- /dev/null +++ b/codegen-sdk-cst-generator/tests/test_subtypes_children.rs @@ -0,0 +1,107 @@ +use assert_tokenstreams_eq::assert_tokenstreams_eq; +use codegen_sdk_common::parser::{Children, Node, TypeDefinition}; +use codegen_sdk_cst_generator::generate_cst; +use quote::quote; + +#[test] +fn test_subtypes_with_children() { + let nodes = vec![ + // A block can contain multiple statements + Node { + type_name: "block".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: Some(Children { + multiple: true, + required: false, + types: vec![TypeDefinition { + type_name: "statement".to_string(), + named: true, + }], + }), + }, + // Statement is a subtype with its own subtypes + Node { + type_name: "statement".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "if_statement".to_string(), + named: true, + }, + TypeDefinition { + type_name: "return_statement".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }, + // Concrete statement types + Node { + type_name: "if_statement".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: Some(Children { + multiple: false, + required: true, + types: vec![TypeDefinition { + type_name: "block".to_string(), + named: true, + }], + }), + }, + Node { + type_name: "return_statement".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + ]; + + let output = generate_cst(&nodes).unwrap(); + let expected = quote! { + use bytes::Bytes; + use codegen_sdk_common::*; + use derive_more::Debug; + use rkyv::{Archive, Deserialize, Serialize}; + use subenum::subenum; + use tree_sitter; + + #[derive(Debug, Clone)] + pub struct Block { + start_byte: usize, + end_byte: usize, + _kind: String, + #[debug("[{},{}]", start_position.row, start_position.column)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + end_position: Point, + #[debug(ignore)] + buffer: Arc, + #[debug(ignore)] + kind_id: u16, + children: Vec, + } + + impl HasChildren for Block { + type Child = Statement; + fn children(&self) -> Vec { + self.children.iter().cloned().collect() + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + _ => vec![], + } + } + } + }; + assert_tokenstreams_eq!(&expected, &output); +} diff --git a/codegen-sdk-cst-generator/tests/test_subtypes_multiple_inheritance.rs b/codegen-sdk-cst-generator/tests/test_subtypes_multiple_inheritance.rs new file mode 100644 index 0000000..9e8e713 --- /dev/null +++ b/codegen-sdk-cst-generator/tests/test_subtypes_multiple_inheritance.rs @@ -0,0 +1,81 @@ +use assert_tokenstreams_eq::assert_tokenstreams_eq; +use codegen_sdk_common::parser::{Node, TypeDefinition}; +use codegen_sdk_cst_generator::generate_cst; +use quote::quote; + +#[test] +fn test_multiple_inheritance() { + let nodes = vec![ + // Base types + Node { + type_name: "declaration".to_string(), + subtypes: vec![TypeDefinition { + type_name: "class_method".to_string(), + named: true, + }], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "class_member".to_string(), + subtypes: vec![TypeDefinition { + type_name: "class_method".to_string(), + named: true, + }], + named: true, + root: false, + fields: None, + children: None, + }, + // ClassMethod inherits from both Declaration and ClassMember + Node { + type_name: "class_method".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: None, + children: None, + }, + ]; + + let output = generate_cst(&nodes).unwrap(); + let expected = quote! { + use bytes::Bytes; + use codegen_sdk_common::*; + use derive_more::Debug; + use rkyv::{Archive, Deserialize, Serialize}; + use subenum::subenum; + use tree_sitter; + + #[derive(Debug, Clone)] + #[subenum( + ClassMember(derive(Archive, Deserialize, Serialize)), + Declaration(derive(Archive, Deserialize, Serialize)) + )] + pub enum NodeTypes { + #[subenum(ClassMember, Declaration)] + ClassMethod(ClassMethod), + } + + impl std::convert::From for NodeTypes { + fn from(variant: ClassMethod) -> Self { + Self::ClassMethod(variant) + } + } + + impl std::convert::From for ClassMember { + fn from(variant: ClassMethod) -> Self { + Self::ClassMethod(variant) + } + } + + impl std::convert::From for Declaration { + fn from(variant: ClassMethod) -> Self { + Self::ClassMethod(variant) + } + } + }; + assert_tokenstreams_eq!(&output, &expected); +} diff --git a/codegen-sdk-cst-generator/tests/test_subtypes_recursive.rs b/codegen-sdk-cst-generator/tests/test_subtypes_recursive.rs new file mode 100644 index 0000000..1c907ad --- /dev/null +++ b/codegen-sdk-cst-generator/tests/test_subtypes_recursive.rs @@ -0,0 +1,152 @@ +use std::collections::HashMap; + +use assert_tokenstreams_eq::assert_tokenstreams_eq; +use codegen_sdk_common::parser::{Children, Fields, Node, TypeDefinition}; +use codegen_sdk_cst_generator::generate_cst; +use quote::quote; + +#[test] +fn test_recursive_subtypes() { + let nodes = vec![ + // Expression can contain other expressions recursively + Node { + type_name: "expression".to_string(), + subtypes: vec![ + TypeDefinition { + type_name: "binary_expression".to_string(), + named: true, + }, + TypeDefinition { + type_name: "call_expression".to_string(), + named: true, + }, + ], + named: true, + root: false, + fields: None, + children: None, + }, + Node { + type_name: "binary_expression".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: Some(Fields { + fields: HashMap::from([ + ( + "left".to_string(), + codegen_sdk_common::parser::FieldDefinition { + types: vec![TypeDefinition { + type_name: "expression".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + ), + ( + "right".to_string(), + codegen_sdk_common::parser::FieldDefinition { + types: vec![TypeDefinition { + type_name: "expression".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + ), + ]), + }), + children: None, + }, + Node { + type_name: "call_expression".to_string(), + subtypes: vec![], + named: true, + root: false, + fields: Some(Fields { + fields: HashMap::from([( + "callee".to_string(), + codegen_sdk_common::parser::FieldDefinition { + types: vec![TypeDefinition { + type_name: "expression".to_string(), + named: true, + }], + multiple: false, + required: true, + }, + )]), + }), + children: Some(Children { + multiple: true, + required: false, + types: vec![TypeDefinition { + type_name: "expression".to_string(), + named: true, + }], + }), + }, + ]; + + let output = generate_cst(&nodes).unwrap(); + let expected = quote! { + use bytes::Bytes; + use codegen_sdk_common::*; + use derive_more::Debug; + use rkyv::{Archive, Deserialize, Serialize}; + use subenum::subenum; + use tree_sitter; + + #[derive(Debug, Clone)] + #[subenum(Expression(derive(Archive, Deserialize, Serialize)))] + pub enum NodeTypes { + #[subenum(Expression)] + CallExpression(CallExpression), + } + + impl std::convert::From for NodeTypes { + fn from(variant: CallExpression) -> Self { + Self::CallExpression(variant) + } + } + + impl std::convert::From for Expression { + fn from(variant: CallExpression) -> Self { + Self::CallExpression(variant) + } + } + + #[derive(Debug, Clone)] + pub struct CallExpression { + start_byte: usize, + end_byte: usize, + _kind: String, + #[debug("[{},{}]", start_position.row, start_position.column)] + start_position: Point, + #[debug("[{},{}]", end_position.row, end_position.column)] + end_position: Point, + #[debug(ignore)] + buffer: Arc, + #[debug(ignore)] + kind_id: u16, + callee: Box, + children: Vec, + } + + impl HasChildren for CallExpression { + type Child = Expression; + fn children(&self) -> Vec { + let mut children: Vec<_> = self.children.iter().cloned().collect(); + children.push(self.callee.as_ref().clone()); + children + } + fn children_by_field_name(&self, field_name: &str) -> Vec { + match field_name { + "callee" => vec![self.callee.as_ref().clone()], + _ => vec![], + } + } + } + }; + assert_tokenstreams_eq!(&output, &expected); +} diff --git a/codegen-sdk-cst/Cargo.toml b/codegen-sdk-cst/Cargo.toml index 794d69e..54067ac 100644 --- a/codegen-sdk-cst/Cargo.toml +++ b/codegen-sdk-cst/Cargo.toml @@ -11,14 +11,19 @@ codegen-sdk-macros = { path = "../codegen-sdk-macros" } derive_more = { version = "2.0.1", features = ["debug", "display"] } convert_case = { workspace = true } rkyv = { workspace = true } +subenum = "1.1.2" log = { workspace = true } +enum_delegate = { workspace = true } +derive-visitor = "0.4.0" [build-dependencies] codegen-sdk-cst-generator = { path = "../codegen-sdk-cst-generator"} codegen-sdk-common = { path = "../codegen-sdk-common", features = ["all"] } rayon = { workspace = true } env_logger = { workspace = true } +log = { workspace = true } [dev-dependencies] tempfile = "3.16.0" +test-log = { workspace = true } [features] python = [ "codegen-sdk-common/python"] typescript = [ "codegen-sdk-common/typescript"] @@ -28,4 +33,4 @@ javascript = [ "codegen-sdk-common/typescript"] json = [ "codegen-sdk-common/json"] java = [ "codegen-sdk-common/java"] ts_query = [] -default = ["typescript", "tsx", "jsx", "javascript", "json", "ts_query"] +default = ["json", "ts_query"] diff --git a/codegen-sdk-cst/build.rs b/codegen-sdk-cst/build.rs index 6e21cae..bc09f0a 100644 --- a/codegen-sdk-cst/build.rs +++ b/codegen-sdk-cst/build.rs @@ -1,10 +1,12 @@ use codegen_sdk_common::language::LANGUAGES; -use codegen_sdk_cst_generator::generate_cst; +use codegen_sdk_cst_generator::generate_cst_to_file; use rayon::prelude::*; fn main() { env_logger::init(); - println!("cargo:rerun-if-changed=build.rs"); + // println!("cargo:rerun-if-changed=build.rs"); LANGUAGES.par_iter().for_each(|language| { - generate_cst(language).unwrap(); + generate_cst_to_file(language).unwrap_or_else(|e| { + log::error!("Error generating CST for {}: {}", language.name, e); + }); }); } diff --git a/codegen-sdk-cst/src/language.rs b/codegen-sdk-cst/src/language.rs new file mode 100644 index 0000000..6dcaeda --- /dev/null +++ b/codegen-sdk-cst/src/language.rs @@ -0,0 +1,40 @@ +use std::{path::PathBuf, sync::Arc}; + +use bytes::Bytes; +use codegen_sdk_common::{ + ParseError, + language::Language, + serialize::Cache, + traits::{CSTNode, FromNode}, +}; +use codegen_sdk_macros::{include_languages, parse_languages}; +use rkyv::{api::high::to_bytes_in, from_bytes}; +pub trait CSTLanguage { + type Program: CSTNode + FromNode + Send; + fn language() -> &'static Language; + fn parse(content: &str) -> Result { + let buffer = Bytes::from(content.as_bytes().to_vec()); + let tree = Self::language().parse_tree_sitter(content)?; + if tree.root_node().has_error() { + Err(ParseError::SyntaxError) + } else { + let buffer = Arc::new(buffer); + Self::Program::from_node(tree.root_node(), &buffer) + } + } + fn parse_file(file_path: &PathBuf) -> Result { + let content = std::fs::read_to_string(file_path)?; + let parsed = Self::parse(&content)?; + Ok(parsed) + } + + fn should_parse(file_path: &PathBuf) -> Result { + Ok(Self::language().file_extensions.contains( + &file_path + .extension() + .ok_or(ParseError::Miscelaneous)? + .to_str() + .ok_or(ParseError::Miscelaneous)?, + )) + } +} diff --git a/codegen-sdk-cst/src/lib.rs b/codegen-sdk-cst/src/lib.rs index d7bb14c..b3ec157 100644 --- a/codegen-sdk-cst/src/lib.rs +++ b/codegen-sdk-cst/src/lib.rs @@ -1,4 +1,4 @@ -#![recursion_limit = "256"] +#![recursion_limit = "512"] #![feature(trivial_bounds)] use std::{path::PathBuf, sync::Arc}; @@ -11,35 +11,8 @@ use codegen_sdk_common::{ }; use codegen_sdk_macros::{include_languages, parse_languages}; use rkyv::{api::high::to_bytes_in, from_bytes}; -pub trait CSTLanguage { - type Program: CSTNode + FromNode + Send; - fn language() -> &'static Language; - fn parse(content: &str) -> Result { - let buffer = Bytes::from(content.as_bytes().to_vec()); - let tree = Self::language().parse_tree_sitter(content)?; - if tree.root_node().has_error() { - Err(ParseError::SyntaxError) - } else { - let buffer = Arc::new(buffer); - Self::Program::from_node(tree.root_node(), &buffer) - } - } - fn parse_file(file_path: &PathBuf) -> Result { - let content = std::fs::read_to_string(file_path)?; - let parsed = Self::parse(&content)?; - Ok(parsed) - } - - fn should_parse(file_path: &PathBuf) -> Result { - Ok(Self::language().file_extensions.contains( - &file_path - .extension() - .ok_or(ParseError::Miscelaneous)? - .to_str() - .ok_or(ParseError::Miscelaneous)?, - )) - } -} +mod language; +use language::CSTLanguage; include_languages!(); pub fn parse_file( cache: &Cache, @@ -56,7 +29,7 @@ mod tests { use codegen_sdk_common::traits::HasChildren; use super::*; - #[test] + #[test_log::test] fn test_snazzy_items() { let content = " class SnazzyItems { diff --git a/codegen-sdk-cst/src/query.rs b/codegen-sdk-cst/src/query.rs index 696b4c7..3a23993 100644 --- a/codegen-sdk-cst/src/query.rs +++ b/codegen-sdk-cst/src/query.rs @@ -6,32 +6,38 @@ use derive_more::Debug; use crate::{CSTLanguage, ts_query}; fn captures_for_field_definition( node: &ts_query::FieldDefinition, -) -> impl Iterator { +) -> impl Iterator { let mut captures = Vec::new(); for child in node.children() { match child { - ts_query::ChildrenFieldDefinition::Definition(definition) => { - captures.extend(captures_for_node(definition)); + ts_query::FieldDefinitionChildren::NamedNode(named) => { + captures.extend(captures_for_named_node(&named)); + } + ts_query::FieldDefinitionChildren::FieldDefinition(field) => { + captures.extend(captures_for_field_definition(&field)); } _ => {} } } captures.into_iter() } -fn captures_for_named_node(node: &ts_query::NamedNode) -> impl Iterator { +fn captures_for_named_node(node: &ts_query::NamedNode) -> impl Iterator { let mut captures = Vec::new(); for child in node.children() { match child { - ts_query::ChildrenNamedNode::Capture(capture) => captures.push(capture), - ts_query::ChildrenNamedNode::Definition(definition) => { - captures.extend(captures_for_node(definition)); + ts_query::NamedNodeChildren::Capture(capture) => captures.push(capture), + ts_query::NamedNodeChildren::NamedNode(named) => { + captures.extend(captures_for_named_node(&named)); + } + ts_query::NamedNodeChildren::FieldDefinition(field) => { + captures.extend(captures_for_field_definition(&field)); } _ => {} } } captures.into_iter() } -fn captures_for_node(node: &ts_query::Definition) -> impl Iterator { +fn captures_for_node(node: &ts_query::Definition) -> impl Iterator { let mut captures = Vec::new(); match node { ts_query::Definition::NamedNode(named) => captures.extend(captures_for_named_node(named)), @@ -52,15 +58,13 @@ impl Query { let parsed = ts_query::Query::parse(source).unwrap(); let mut queries = HashMap::new(); for node in parsed.children() { - if let ts_query::ChildrenProgram::Definition(definition) = node { - match definition { - ts_query::Definition::NamedNode(named) => { - let query = Self::from_named_node(named); - queries.insert(query.name(), query); - } - node => { - println!("Unhandled query: {:#?}", node); - } + match node { + ts_query::ProgramChildren::NamedNode(named) => { + let query = Self::from_named_node(&named); + queries.insert(query.name(), query); + } + node => { + println!("Unhandled query: {:#?}", node); } } } @@ -86,16 +90,16 @@ impl Query { } /// Get the kind of the query (the node to be matched) pub fn kind(&self) -> String { - if let ts_query::NameNamedNode::Identifier(identifier) = &(*self.node.name) { + if let ts_query::NamedNodeName::Identifier(identifier) = &(*self.node.name) { return identifier.source(); } panic!("No kind found for query. {:#?}", self.node); } pub fn struct_name(&self) -> String { - normalize_type_name(&self.kind()) + normalize_type_name(&self.kind(), true) } - fn captures(&self) -> Vec<&ts_query::Capture> { + fn captures(&self) -> Vec { captures_for_named_node(&self.node).collect() } /// Get the name of the query (IE @reference.class) @@ -123,6 +127,20 @@ impl Query { pub fn source(&self) -> String { self.node.source() } + // fn execute(&self, node: &T) -> Vec> { + // let mut result = Vec::new(); + + // for child in node.children() { + // if self + // .captures() + // .iter() + // .any(|capture| capture.source() == child.kind()) + // { + // result.push(child); + // } + // } + // result + // } } pub trait HasQuery { diff --git a/src/main.rs b/src/main.rs index e048756..0c42c74 100644 --- a/src/main.rs +++ b/src/main.rs @@ -70,11 +70,6 @@ fn parse_files(dir: String) -> (Vec>, Vec) { cached += 1; } } - log::info!( - "{} files cached. {}% of total", - cached, - cached * 100 / files_to_parse.len() - ); let files: Vec> = files_to_parse .par_iter() .filter_map(|file| parse_file(&cache, file, &tx)) @@ -83,6 +78,11 @@ fn parse_files(dir: String) -> (Vec>, Vec) { for e in rx.iter() { errors.push(e); } + log::info!( + "{} files cached. {}% of total", + cached, + cached * 100 / files_to_parse.len() + ); (files, errors) } fn main() {