diff --git a/.cargo/config.toml b/.cargo/config.toml index d958320..4b2316c 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,2 @@ [build] -rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zthreads=16"] +rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zthreads=16", "-Ctarget-cpu=native"] diff --git a/.gitignore b/.gitignore index 11cbd47..dff150c 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,4 @@ tests/integration/verified_codemods/codemod_data/repo_commits.json target/* .benchmarks/* **.snap.new +flamegraph.svg diff --git a/Cargo.lock b/Cargo.lock index 860f9b1..e3e0cb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -364,6 +364,7 @@ dependencies = [ "codegen-sdk-ts_query", "convert_case", "derive_more", + "indextree", "insta", "log", "proc-macro2", @@ -384,6 +385,8 @@ dependencies = [ "buildid", "bytes", "convert_case", + "hashbrown 0.15.2", + "indexmap", "indextree", "lazy_static", "mockall", @@ -392,6 +395,7 @@ dependencies = [ "proc-macro2", "quote", "rkyv", + "rustc-hash", "salsa", "serde", "serde_json", @@ -425,6 +429,7 @@ dependencies = [ "codegen-sdk-analyzer", "codegen-sdk-ast", "codegen-sdk-common", + "codegen-sdk-python", "codegen-sdk-resolution", "codegen-sdk-typescript", "criterion", @@ -446,6 +451,7 @@ dependencies = [ "codegen-sdk-common", "convert_case", "dashmap", + "indextree", "log", "rkyv", "salsa", @@ -484,6 +490,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -504,6 +511,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -524,6 +532,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -544,6 +553,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -565,6 +575,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -594,6 +605,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -619,6 +631,7 @@ dependencies = [ "env_logger", "indextree", "log", + "memchr", "salsa", "subenum", "tempfile", @@ -630,7 +643,15 @@ dependencies = [ name = "codegen-sdk-resolution" version = "0.1.0" dependencies = [ + "ambassador", + "anyhow", + "codegen-sdk-ast", + "codegen-sdk-common", + "codegen-sdk-cst", + "indicatif", + "log", "salsa", + "smallvec", ] [[package]] @@ -644,6 +665,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -664,6 +686,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -684,6 +707,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -722,6 +746,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -742,6 +767,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -764,6 +790,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -1202,6 +1229,7 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash", + "rayon", ] [[package]] @@ -2316,7 +2344,7 @@ checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "salsa" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#ceb9b083b3c0f6a1634e5a0b75b7bb5c7ca7b33f" +source = "git+https://github.com/salsa-rs/salsa?rev=dbb0e5f6ab2cd61e42b372f333ab694f24141cf1#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" dependencies = [ "boxcar", "crossbeam-queue", @@ -2325,6 +2353,7 @@ dependencies = [ "hashlink", "indexmap", "parking_lot", + "portable-atomic", "rayon", "rustc-hash", "salsa-macro-rules", @@ -2336,12 +2365,12 @@ dependencies = [ [[package]] name = "salsa-macro-rules" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#ceb9b083b3c0f6a1634e5a0b75b7bb5c7ca7b33f" +source = "git+https://github.com/salsa-rs/salsa?rev=dbb0e5f6ab2cd61e42b372f333ab694f24141cf1#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" [[package]] name = "salsa-macros" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#ceb9b083b3c0f6a1634e5a0b75b7bb5c7ca7b33f" +source = "git+https://github.com/salsa-rs/salsa?rev=dbb0e5f6ab2cd61e42b372f333ab694f24141cf1#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" dependencies = [ "heck 0.5.0", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index f42a448..c96fe6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,20 +7,21 @@ edition = "2024" [dependencies] clap = { version = "4.5.28", features = ["derive"] } -codegen-sdk-analyzer = { path = "codegen-sdk-analyzer" } +codegen-sdk-analyzer = { path = "codegen-sdk-analyzer", default-features = false } codegen-sdk-ast = { workspace = true} codegen-sdk-common = { workspace = true} anyhow = { workspace = true} salsa = { workspace = true} -codegen-sdk-typescript = { workspace = true} +codegen-sdk-typescript = { workspace = true, optional = true } +codegen-sdk-python = { workspace = true, optional = true } env_logger = { workspace = true } log = { workspace = true } codegen-sdk-resolution = { workspace = true} sysinfo = "0.33.1" rkyv.workspace = true [features] -python = [ "codegen-sdk-analyzer/python"] # TODO: Add python support -typescript = [ "codegen-sdk-analyzer/typescript"] +python = [ "codegen-sdk-analyzer/python", "codegen-sdk-python"] +typescript = [ "codegen-sdk-analyzer/typescript", "codegen-sdk-typescript"] tsx = [ "codegen-sdk-analyzer/tsx"] jsx = [ "codegen-sdk-analyzer/jsx"] javascript = [ "codegen-sdk-analyzer/javascript"] @@ -116,7 +117,7 @@ insta = "1.42.1" prettyplease = "0.2.29" syn = { version = "2.0.98", features = ["proc-macro", "full"] } derive_more = { version = "2.0.1", features = ["debug", "display"] } -salsa = {git = "https://github.com/salsa-rs/salsa", branch = "master"} +salsa = {git = "https://github.com/salsa-rs/salsa", rev ="dbb0e5f6ab2cd61e42b372f333ab694f24141cf1"} subenum = {git = "https://github.com/mrenow/subenum", branch = "main"} indicatif-log-bridge = "0.2.3" indicatif = { version = "0.17.11", features = ["rayon"] } @@ -124,7 +125,8 @@ crossbeam-channel = "0.5.11" rstest = "0.25.0" indextree = "4.7.3" thiserror = "2.0.11" - +indexmap = "2" +smallvec = "1.11.0" [profile.dev] # codegen-backend = "cranelift" # split-debuginfo = "unpacked" @@ -157,3 +159,7 @@ lto = false name = "parse" harness = false required-features = ["stable"] + +[profile.profiling] +inherits = "release" +debug = true diff --git a/codegen-sdk-analyzer/src/codebase.rs b/codegen-sdk-analyzer/src/codebase.rs index 01b56e6..5f8e7b3 100644 --- a/codegen-sdk-analyzer/src/codebase.rs +++ b/codegen-sdk-analyzer/src/codebase.rs @@ -1,21 +1,18 @@ use std::path::PathBuf; use anyhow::Context; -use codegen_sdk_ast::Input; #[cfg(feature = "serialization")] use codegen_sdk_common::serialization::Cache; -use codegen_sdk_resolution::CodebaseContext; +use codegen_sdk_resolution::{CodebaseContext, Db}; use discovery::FilesToParse; use notify_debouncer_mini::DebounceEventResult; use salsa::Setter; -use crate::{ - ParsedFile, - database::{CodegenDatabase, Db}, - parser::parse_file, -}; +use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; mod discovery; mod parser; +use parser::execute_op_with_progress; + pub struct Codebase { db: CodegenDatabase, root: PathBuf, @@ -26,8 +23,9 @@ pub struct Codebase { impl Codebase { pub fn new(root: PathBuf) -> Self { + let root = root.canonicalize().unwrap(); let (tx, rx) = crossbeam_channel::unbounded(); - let mut db = CodegenDatabase::new(tx); + let mut db = CodegenDatabase::new(tx, root.clone()); db.watch_dir(PathBuf::from(&root)).unwrap(); let codebase = Self { db, root, rx }; codebase.sync(); @@ -47,8 +45,7 @@ impl Codebase { // to kick in, just like any other update to a salsa input. let contents = std::fs::read_to_string(path) .with_context(|| format!("Failed to read file {}", event.path.display()))?; - let input = Input::new(&self.db, contents); - file.set_contents(&mut self.db).to(input); + file.set_content(&mut self.db).to(contents); } Err(e) => { log::error!( @@ -68,7 +65,7 @@ impl Codebase { pub fn errors(&self) -> Vec<()> { let mut errors = Vec::new(); - for file in self.discover().files(&self.db) { + for file in self.db.files() { if self.get_file(file.path(&self.db)).is_none() { errors.push(()); } @@ -84,9 +81,29 @@ impl Codebase { files, ) } + fn _db(&self) -> &dyn Db { + &self.db + } + pub fn execute_op_with_progress( + &self, + name: &str, + parallel: bool, + op: fn(&dyn Db, codegen_sdk_common::FileNodeId<'_>) -> T, + ) -> Vec { + execute_op_with_progress( + self._db(), + codegen_sdk_resolution::files(self._db()), + name, + parallel, + op, + ) + } } impl CodebaseContext for Codebase { type File<'a> = ParsedFile<'a>; + fn root_path(&self) -> PathBuf { + self.root.clone() + } fn files<'a>(&'a self) -> Vec<&'a Self::File<'a>> { let mut files = Vec::new(); for file in self.discover().files(&self.db) { @@ -96,13 +113,16 @@ impl CodebaseContext for Codebase { } files } - fn db(&self) -> &dyn salsa::Database { + fn db(&self) -> &dyn Db { &self.db } fn get_file<'a>(&'a self, path: PathBuf) -> Option<&'a Self::File<'a>> { - let file = self.db.files.get(&path); - if let Some(file) = file { - return parse_file(&self.db, file.clone()).file(&self.db).as_ref(); + if let Ok(path) = path.canonicalize() { + let file = self.db.files.get(&path); + if let Some(_) = file { + let file_id = codegen_sdk_common::FileNodeId::new(&self.db, path); + return parse_file(&self.db, file_id).file(&self.db).as_ref(); + } } None } diff --git a/codegen-sdk-analyzer/src/codebase/discovery.rs b/codegen-sdk-analyzer/src/codebase/discovery.rs index 825a508..97fa0e2 100644 --- a/codegen-sdk-analyzer/src/codebase/discovery.rs +++ b/codegen-sdk-analyzer/src/codebase/discovery.rs @@ -3,12 +3,14 @@ use std::path::PathBuf; use codegen_sdk_ast::*; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; +use codegen_sdk_resolution::Db; use glob::glob; -use crate::database::{CodegenDatabase, Db}; +use crate::database::CodegenDatabase; #[salsa::input] pub struct FilesToParse { - pub files: Vec, + pub files: codegen_sdk_common::hash::FxHashSet, + pub root: PathBuf, } pub fn log_languages() { for language in LANGUAGES.iter() { @@ -40,7 +42,8 @@ pub fn collect_files(db: &CodegenDatabase, dir: &PathBuf) -> FilesToParse { .into_iter() .filter_map(|file| file.ok()) .filter(|file| !file.is_dir() && !file.is_symlink()) + .filter_map(|file| file.canonicalize().ok()) .map(|file| db.input(file).unwrap()) .collect(); - FilesToParse::new(db, files) + FilesToParse::new(db, files, dir) } diff --git a/codegen-sdk-analyzer/src/codebase/parser.rs b/codegen-sdk-analyzer/src/codebase/parser.rs index a1f86e0..3f2c490 100644 --- a/codegen-sdk-analyzer/src/codebase/parser.rs +++ b/codegen-sdk-analyzer/src/codebase/parser.rs @@ -1,19 +1,23 @@ -use codegen_sdk_ast::{Definitions, References, input::File}; +use std::path::PathBuf; + +use codegen_sdk_ast::{Definitions, References}; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; +use codegen_sdk_resolution::Db; use indicatif::{ProgressBar, ProgressStyle}; use super::discovery::{FilesToParse, log_languages}; -use crate::{ - ParsedFile, - database::{CodegenDatabase, Db}, - parser::parse_file, -}; -fn execute_op_with_progress( +use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; +pub fn execute_op_with_progress< + Database: Db + ?Sized + 'static, + Input: Send + Sync, + T: Send + Sync, +>( db: &Database, - files: FilesToParse, + files: codegen_sdk_common::hash::FxHashSet, name: &str, - op: fn(&Database, File) -> T, + parallel: bool, + op: fn(&Database, Input) -> T, ) -> Vec { let multi = db.multi_progress(); let style = ProgressStyle::with_template( @@ -21,28 +25,43 @@ fn execute_op_with_progress( ) .unwrap(); let pg = multi.add( - ProgressBar::new(files.files(db).len() as u64) + ProgressBar::new(files.len() as u64) .with_style(style) .with_message(name.to_string()), ); let inputs = files - .files(db) .into_iter() .map(|file| (&pg, file, op)) .collect::>(); - let results: Vec = salsa::par_map(db, inputs, move |db, input| { - let (pg, file, op) = input; - let res = op( - db, - #[cfg(feature = "serialization")] - &cache, - file, - ); - pg.inc(1); - res - }); + let results: Vec<_> = if parallel { + salsa::par_map(db, inputs, move |db, input| { + let (pg, file, op) = input; + let res = op( + db, + #[cfg(feature = "serialization")] + &cache, + file, + ); + pg.inc(1); + res + }) + } else { + inputs + .into_iter() + .map(|input| { + let (pg, file, op) = input; + let res = op( + db, + #[cfg(feature = "serialization")] + &cache, + file, + ); + pg.inc(1); + res + }) + .collect() + }; pg.finish(); - multi.remove(&pg); results } // #[salsa::tracked] @@ -53,8 +72,13 @@ fn execute_op_with_progress( // } #[salsa::tracked] fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { - let _: Vec<_> = execute_op_with_progress(db, files, "Parsing Files", |db, file| { - let file = parse_file(db, file); + let ids = files + .files(db) + .iter() + .map(|input| codegen_sdk_common::FileNodeId::new(db, input.path(db))) + .collect::>(); + let _: Vec<_> = execute_op_with_progress(db, ids, "Parsing Files", true, |db, input| { + let file = parse_file(db, input.clone()); if let Some(parsed) = file.file(db) { #[cfg(feature = "typescript")] if let ParsedFile::Typescript(parsed) = parsed { @@ -65,11 +89,47 @@ fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { if let ParsedFile::Python(parsed) = parsed { parsed.definitions(db); parsed.references(db); + // let deps = codegen_sdk_python::ast::dependencies(db, input); + // for dep in deps.dependencies(db).keys() { + // codegen_sdk_resolution::ast::references_impl(db, dep); + // } } } () }); } +#[salsa::tracked] +fn compute_dependencies_par(db: &dyn Db, files: FilesToParse) { + let ids = files + .files(db) + .iter() + .map(|input| codegen_sdk_common::FileNodeId::new(db, input.path(db))) + .collect::>(); + let _targets: codegen_sdk_common::hash::FxHashSet<(PathBuf, String)> = + execute_op_with_progress(db, ids, "Computing Dependencies", true, |db, input| { + let file = parse_file(db, input.clone()); + if let Some(parsed) = file.file(db) { + #[cfg(feature = "python")] + if let ParsedFile::Python(_parsed) = parsed { + let deps = codegen_sdk_python::ast::dependency_keys(db, input); + return deps + .iter() + .map(|dep| (dep.path(db).path(db).clone(), dep.name(db).clone())) + .collect::>(); + } + } + Vec::new() + }) + .into_iter() + .flatten() + .collect(); + // let _: Vec<_> = execute_op_with_progress(db, targets, "Finding Usages", true, |db, input: (PathBuf, String)| { + // let file_node_id = codegen_sdk_common::FileNodeId::new(db, input.0); + // let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new(db, file_node_id, input.1); + // codegen_sdk_python::ast::references_impl(db, fully_qualified_name); + // }); +} + pub fn parse_files<'db>( db: &'db CodegenDatabase, #[cfg(feature = "serialization")] cache: &'db Cache, @@ -91,6 +151,12 @@ pub fn parse_files<'db>( &cache, files_to_parse, ); + compute_dependencies_par( + db, + #[cfg(feature = "serialization")] + &cache, + files_to_parse, + ); #[cfg(feature = "serialization")] report_cached_count(cached, &files_to_parse.files(db)); } diff --git a/codegen-sdk-analyzer/src/database.rs b/codegen-sdk-analyzer/src/database.rs index 588c802..7d004e1 100644 --- a/codegen-sdk-analyzer/src/database.rs +++ b/codegen-sdk-analyzer/src/database.rs @@ -5,8 +5,8 @@ use std::{ }; use anyhow::Context; -use codegen_sdk_ast::input::File; -use codegen_sdk_cst::Input; +use codegen_sdk_cst::File; +use codegen_sdk_resolution::Db; use dashmap::{DashMap, mapref::entry::Entry}; use indicatif::MultiProgress; use notify_debouncer_mini::{ @@ -16,20 +16,15 @@ use notify_debouncer_mini::{ use crate::progress::get_multi_progress; #[salsa::db] -pub trait Db: salsa::Database + Send { - fn input(&self, path: PathBuf) -> anyhow::Result; - fn multi_progress(&self) -> &MultiProgress; - fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()>; -} -#[salsa::db] #[derive(Clone)] // Basic Database implementation for Query generation. This is not used for anything else. pub struct CodegenDatabase { storage: salsa::Storage, - pub files: DashMap, + pub files: Arc>, dirs: Vec, multi_progress: MultiProgress, file_watcher: Arc>>, + root: PathBuf, } fn get_watcher( tx: crossbeam_channel::Sender, @@ -40,14 +35,15 @@ fn get_watcher( Arc::new(Mutex::new(new_debouncer_opt(config, tx).unwrap())) } impl CodegenDatabase { - pub fn new(tx: crossbeam_channel::Sender) -> Self { + pub fn new(tx: crossbeam_channel::Sender, root: PathBuf) -> Self { let multi_progress = get_multi_progress(); Self { file_watcher: get_watcher(tx), storage: salsa::Storage::default(), multi_progress, - files: DashMap::new(), + files: Arc::new(DashMap::new()), dirs: Vec::new(), + root, } } fn _watch_file(&self, path: &PathBuf) -> anyhow::Result<()> { @@ -66,16 +62,22 @@ impl CodegenDatabase { } #[salsa::db] impl salsa::Database for CodegenDatabase { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) { // don't log boring events - let event = event(); - if let salsa::EventKind::WillExecute { .. } = event.kind { - log::debug!("{:?}", event); - } + // let event = event(); + // if let salsa::EventKind::WillExecute { .. } = event.kind { + // log::debug!("{:?}", event); + // } } } #[salsa::db] impl Db for CodegenDatabase { + fn files(&self) -> codegen_sdk_common::hash::FxHashSet> { + self.files + .iter() + .map(|entry| codegen_sdk_common::FileNodeId::new(self, entry.key().clone())) + .collect() + } fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()> { let path = path.canonicalize()?; let watcher = &mut *self.file_watcher.lock().unwrap(); @@ -86,10 +88,11 @@ impl Db for CodegenDatabase { self.dirs.push(path); Ok(()) } + fn get_file(&self, path: PathBuf) -> Option { + self.files.get(&path).map(|entry| entry.value().clone()) + } fn input(&self, path: PathBuf) -> anyhow::Result { - let path = path - .canonicalize() - .with_context(|| format!("Failed to read {}", path.display()))?; + let path = path.canonicalize()?; Ok(match self.files.entry(path.clone()) { // If the file already exists in our cache then just return it. Entry::Occupied(entry) => *entry.get(), @@ -101,8 +104,11 @@ impl Db for CodegenDatabase { self._watch_file(&path)?; let contents = std::fs::read_to_string(&path) .with_context(|| format!("Failed to read {}", path.display()))?; - let input = Input::new(self, contents); - *entry.insert(File::new(self, path, input)) + *entry.insert( + File::builder(path, contents, self.root.clone()) + .root_durability(salsa::Durability::HIGH) + .new(self), + ) } }) } diff --git a/codegen-sdk-analyzer/src/lib.rs b/codegen-sdk-analyzer/src/lib.rs index f6e05d9..22004e5 100644 --- a/codegen-sdk-analyzer/src/lib.rs +++ b/codegen-sdk-analyzer/src/lib.rs @@ -2,6 +2,6 @@ mod database; mod parser; mod progress; -pub use parser::{Parsed, ParsedFile}; +pub use parser::{Parsed, ParsedFile, parse_file}; mod codebase; pub use codebase::Codebase; diff --git a/codegen-sdk-analyzer/src/parser.rs b/codegen-sdk-analyzer/src/parser.rs index b3cf807..16f3eda 100644 --- a/codegen-sdk-analyzer/src/parser.rs +++ b/codegen-sdk-analyzer/src/parser.rs @@ -12,7 +12,10 @@ pub struct Parsed<'db> { pub file: Option>, } #[salsa::tracked(return_ref)] -pub fn parse_file(db: &dyn salsa::Database, file: codegen_sdk_ast::input::File) -> Parsed<'_> { +pub fn parse_file<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + file: codegen_sdk_common::FileNodeId<'db>, +) -> Parsed<'db> { parse_language!(); - Parsed::new(db, FileNodeId::new(db, file.path(db)), None) + Parsed::new(db, file, None) } diff --git a/codegen-sdk-ast-generator/Cargo.toml b/codegen-sdk-ast-generator/Cargo.toml index 574d4b3..00fce7a 100644 --- a/codegen-sdk-ast-generator/Cargo.toml +++ b/codegen-sdk-ast-generator/Cargo.toml @@ -1,6 +1,7 @@ [package] name = "codegen-sdk-ast-generator" version = "0.1.0" +description = "Generates the AST for the given language from the tree-sitter queries" edition = "2024" [dependencies] @@ -16,6 +17,7 @@ codegen-sdk-ts_query = { workspace = true } convert_case = { workspace = true } salsa = { workspace = true } syn = { workspace = true } +indextree = { workspace = true } [dev-dependencies] test-log = { workspace = true } insta = { workspace = true } diff --git a/codegen-sdk-ast-generator/src/generator.rs b/codegen-sdk-ast-generator/src/generator.rs index 0b8bf1e..f105d0f 100644 --- a/codegen-sdk-ast-generator/src/generator.rs +++ b/codegen-sdk-ast-generator/src/generator.rs @@ -71,22 +71,24 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { #[id] pub id: codegen_sdk_common::FileNodeId<'db>, } + impl<'db> codegen_sdk_resolution::Parse<'db> for #language_struct_name<'db> { + fn parse(db: &'db dyn codegen_sdk_resolution::Db, input: codegen_sdk_common::FileNodeId<'db>) -> &'db Self { + parse(db, input) + } + } // impl<'db> File for {language_struct_name}File<'db> {{ // fn path(&self) -> &PathBuf {{ // &self.path(db) // }} // }} - pub fn parse(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File) -> #language_struct_name<'_> { + #[salsa::tracked(return_ref)] + pub fn parse<'db>(db: &'db dyn codegen_sdk_resolution::Db, input: codegen_sdk_common::FileNodeId<'db>) -> #language_struct_name<'db> { + let input = db.input(input.path(db)).unwrap(); log::debug!("Parsing {} file: {}", input.path(db).display(), #language_name_str); - let ast = crate::cst::parse_program_raw(db, input.contents(db), input.path(db).clone()); + let ast = crate::cst::parse_program_raw(db, input); let file_id = codegen_sdk_common::FileNodeId::new(db, input.path(db).clone()); #language_struct_name::new(db, ast, file_id) } - #[salsa::tracked(return_ref)] - pub fn parse_query(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File) -> #language_struct_name<'_> { - parse(db, input) - } - impl<'db> #language_struct_name<'db> { pub fn tree(&self, db: &'db dyn salsa::Database) -> &'db codegen_sdk_common::Tree> { self.node(db).unwrap().tree(db) diff --git a/codegen-sdk-ast-generator/src/lib.rs b/codegen-sdk-ast-generator/src/lib.rs index 4e92ede..9023a56 100644 --- a/codegen-sdk-ast-generator/src/lib.rs +++ b/codegen-sdk-ast-generator/src/lib.rs @@ -10,11 +10,12 @@ use syn::parse_quote; pub fn generate_ast(language: &Language) -> anyhow::Result<()> { let db = CSTDatabase::default(); let imports = quote! { - use codegen_sdk_common::*; - use std::path::PathBuf; - use codegen_sdk_cst::CSTLanguage; - use std::collections::BTreeMap; - use std::sync::mpsc::Sender; + use codegen_sdk_common::*; + use std::path::PathBuf; + use codegen_sdk_cst::CSTLanguage; + use std::collections::BTreeMap; + use codegen_sdk_resolution::HasFile; + use codegen_sdk_resolution::Parse; }; let ast = generator::generate_ast(language)?; let definition_visitor = visitor::generate_visitor(&db, language, "definition"); diff --git a/codegen-sdk-ast-generator/src/query.rs b/codegen-sdk-ast-generator/src/query.rs index 319344d..b1a70b8 100644 --- a/codegen-sdk-ast-generator/src/query.rs +++ b/codegen-sdk-ast-generator/src/query.rs @@ -1,4 +1,7 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, +}; use codegen_sdk_common::{ CSTNode, HasChildren, Language, Tree, @@ -8,10 +11,22 @@ use codegen_sdk_cst::CSTLanguage; use codegen_sdk_cst_generator::{Config, Field, State}; use codegen_sdk_ts_query::cst as ts_query; use derive_more::Debug; +use indextree::NodeId; use log::{debug, info, warn}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; +use syn::parse_quote; use ts_query::NodeTypes; +fn name_for_capture<'a>(capture: &'a ts_query::Capture<'a>) -> String { + full_name_for_capture(capture) + .split(".") + .last() + .unwrap() + .to_string() +} +fn full_name_for_capture<'a>(capture: &'a ts_query::Capture<'a>) -> String { + capture.source().split_off(1) +} fn captures_for_field_definition<'a>( node: &'a ts_query::FieldDefinition<'a>, tree: &'a Tree>, @@ -25,6 +40,54 @@ fn captures_for_field_definition<'a>( ts_query::FieldDefinitionChildrenRef::FieldDefinition(field) => { captures.extend(captures_for_field_definition(&field, tree)); } + ts_query::FieldDefinitionChildrenRef::Grouping(grouping) => { + captures.extend(captures_for_grouping(&grouping, tree)); + } + ts_query::FieldDefinitionChildrenRef::List(list) => { + captures.extend(captures_for_list(&list, tree)); + } + _ => {} + } + } + captures.into_iter() +} +fn captures_for_list<'a>( + list: &'a ts_query::List<'a>, + tree: &'a Tree>, +) -> impl Iterator> { + let mut captures = Vec::new(); + for child in list.children(tree) { + match child { + ts_query::ListChildrenRef::NamedNode(named) => { + captures.extend(captures_for_named_node(&named, tree)); + } + ts_query::ListChildrenRef::List(list) => { + captures.extend(captures_for_list(&list, tree)); + } + ts_query::ListChildrenRef::FieldDefinition(field) => { + captures.extend(captures_for_field_definition(&field, tree)); + } + _ => {} + } + } + captures.into_iter() +} +fn captures_for_grouping<'a>( + grouping: &'a ts_query::Grouping<'a>, + tree: &'a Tree>, +) -> impl Iterator> { + let mut captures = Vec::new(); + for child in grouping.children(tree) { + match child { + ts_query::GroupingChildrenRef::NamedNode(named) => { + captures.extend(captures_for_named_node(&named, tree)); + } + ts_query::GroupingChildrenRef::Grouping(grouping) => { + captures.extend(captures_for_grouping(&grouping, tree)); + } + ts_query::GroupingChildrenRef::FieldDefinition(field) => { + captures.extend(captures_for_field_definition(&field, tree)); + } _ => {} } } @@ -44,6 +107,12 @@ fn captures_for_named_node<'a>( ts_query::NamedNodeChildrenRef::FieldDefinition(field) => { captures.extend(captures_for_field_definition(&field, tree)); } + ts_query::NamedNodeChildrenRef::Grouping(grouping) => { + captures.extend(captures_for_grouping(&grouping, tree)); + } + ts_query::NamedNodeChildrenRef::List(list) => { + captures.extend(captures_for_list(&list, tree)); + } _ => {} } } @@ -54,6 +123,7 @@ pub struct Query<'a> { node: &'a ts_query::NamedNode<'a>, language: &'a Language, tree: &'a Tree>, + root_id: NodeId, pub(crate) state: Arc>, } impl<'a> Query<'a> { @@ -63,14 +133,15 @@ impl<'a> Query<'a> { language: &'a Language, ) -> BTreeMap { let result = ts_query::Query::parse(db, source.to_string()).unwrap(); - let (parsed, tree) = result; + let (parsed, tree, program_id) = result; let config = Config::default(); let state = Arc::new(State::new(language, config)); let mut queries = BTreeMap::new(); for node in parsed.children(tree) { match node { ts_query::ProgramChildrenRef::NamedNode(named) => { - let query = Self::from_named_node(&named, language, state.clone(), tree); + let query = + Self::from_named_node(&named, language, state.clone(), tree, program_id); queries.insert(query.name(), query); } node => { @@ -102,12 +173,14 @@ impl<'a> Query<'a> { language: &'a Language, state: Arc>, tree: &'a Tree>, + root_id: indextree::NodeId, ) -> Self { Query { node: named, language: language, state: state, tree: tree, + root_id: root_id, } } /// Get the kind of the query (the node to be matched) @@ -187,6 +260,17 @@ impl<'a> Query<'a> { format_ident!("{}s", name) } } + pub fn symbol_name(&self) -> Ident { + let raw_name = self.name(); + let name = raw_name.split(".").last().unwrap(); + let symbol = format_ident!("{}", normalize_type_name(name, true)); + // References can produce duplicate names. We can be reasonably sure that there is no @definition.call. + if raw_name.starts_with("reference") && !["call"].contains(&name) { + format_ident!("{}Ref", symbol) + } else { + symbol + } + } fn get_field_for_field_name(&self, field_name: &str, struct_name: &str) -> Option<&Field> { debug!( "Getting field for: {:#?} on node: {:#?}", @@ -208,7 +292,8 @@ impl<'a> Query<'a> { field: &ts_query::FieldDefinition, struct_name: &str, current_node: &Ident, - name_value: Option, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, + query_values: &mut HashMap, ) -> TokenStream { let other_child: ts_query::NodeTypesRef = field .children(self.tree) @@ -219,6 +304,7 @@ impl<'a> Query<'a> { .into(); for name in &field.name(self.tree) { if let ts_query::FieldDefinitionNameRef::Identifier(identifier) = name { + let doc = format!("Code for field: {}", field.source()); let name = normalize_field_name(&identifier.source()); if let Some(field) = self.get_field_for_field_name(&name, struct_name) { let field_name = format_ident!("{}", name); @@ -227,7 +313,8 @@ impl<'a> Query<'a> { &normalized_struct_name, other_child.clone(), &field_name, - name_value, + existing, + query_values, ); // assert!( // wrapped.to_string().len() > 0, @@ -236,13 +323,22 @@ impl<'a> Query<'a> { // other_child.source(), // other_child.kind() // ); - if !field.is_optional() { + if field.is_multiple() { return quote! { + #[doc = #doc] + for #field_name in #current_node.#field_name(tree) { + #wrapped + } + }; + } else if !field.is_optional() { + return quote! { + #[doc = #doc] let #field_name = #current_node.#field_name(tree); #wrapped }; } else { return quote! { + #[doc = #doc] if let Some(#field_name) = #current_node.#field_name(tree) { #wrapped } @@ -270,7 +366,8 @@ impl<'a> Query<'a> { node: &ts_query::Grouping, struct_name: &str, current_node: &Ident, - name_value: Option, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, + query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); for group in node.children(self.tree) { @@ -278,7 +375,8 @@ impl<'a> Query<'a> { struct_name, group.into(), current_node, - name_value.clone(), + existing, + query_values, ); matchers.extend_one(result); } @@ -291,10 +389,10 @@ impl<'a> Query<'a> { target_kind: &str, current_node: &Ident, remaining_nodes: Vec>, - name_value: Option, + query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); - let mut field_matchers = TokenStream::new(); + let mut field_matchers = Vec::new(); let mut comment_variant = None; let variants = self .state @@ -311,19 +409,15 @@ impl<'a> Query<'a> { } for child in remaining_nodes { - if child.kind_name() == "field_definition" { - field_matchers.extend_one(self.get_matcher_for_definition( - &target_name, - child.into(), - current_node, - name_value.clone(), - )); + if let ts_query::NamedNodeChildrenRef::FieldDefinition(_) = child { + field_matchers.push((child.into(), target_name, current_node)); } else { let result = self.get_matcher_for_definition( &target_name, child.into(), &format_ident!("child"), - name_value.clone(), + &mut Vec::new(), + query_values, ); if let Some(ref variant) = comment_variant { @@ -355,6 +449,17 @@ impl<'a> Query<'a> { "Code for query: {}", &self.node().source().replace("\n", " ") // Newlines mess with quote's doc comments ); + let field_matchers = if let Some(prev) = field_matchers.pop() { + self.get_matcher_for_definition( + &prev.1, + prev.0, + &prev.2, + &mut field_matchers, + query_values, + ) + } else { + quote! {} + }; if matchers.is_empty() && field_matchers.is_empty() { return quote! {}; } @@ -382,21 +487,23 @@ impl<'a> Query<'a> { &'b self, node: &'b ts_query::NamedNode<'b>, first_node: &ts_query::NamedNodeChildrenRef<'_>, - mut name_value: Option, + query_values: &mut HashMap, current_node: &Ident, - ) -> (Option, Vec>) { + ) -> Vec> { let mut prev = first_node.clone(); let mut remaining_nodes = Vec::new(); + log::info!( + "Grouping children for: {:#?} of kind: {:#?}", + node.source(), + node.kind_name() + ); for child in node.children(self.tree).into_iter().skip(1) { if child.kind_name() == "capture" { - if child.source() == "@name" { - log::info!( - "Found @name! prev: {:#?}, {:#?}", - prev.source(), - prev.kind_name() - ); + let capture_name = name_for_capture(child.try_into().unwrap()); + if self.target_capture_names().contains(&capture_name) { match prev { ts_query::NamedNodeChildrenRef::FieldDefinition(field) => { + log::info!("Found @{}! on field: {:#?}", capture_name, field.source(),); let field_name = field .name(self.tree) .iter() @@ -404,24 +511,40 @@ impl<'a> Query<'a> { .map(|c| format_ident!("{}", c.source())) .next() .unwrap(); - name_value = Some(quote! { - #current_node.#field_name.source() - }); + query_values.insert( + capture_name, + quote! { + + #current_node.#field_name + }, + ); } ts_query::NamedNodeChildrenRef::Identifier(named) => { log::info!( - "Found @name! prev: {:#?}, {:#?}", + "Found @{}! prev: {:#?}, {:#?}", + capture_name, named.source(), named.kind_name() ); - name_value = Some(quote! { - #current_node.source() - }); + query_values.insert( + capture_name, + quote! { + #current_node + }, + ); } ts_query::NamedNodeChildrenRef::AnonymousUnderscore(_) => { - name_value = Some(quote! { - #current_node.source() - }); + log::info!( + "Found @{}! on anonymous underscore: {:#?}", + capture_name, + node.source() + ); + query_values.insert( + capture_name, + quote! { + #current_node + }, + ); } _ => panic!( "Unexpected prev: {:#?}, source: {:#?}. Query: {:#?}", @@ -430,33 +553,33 @@ impl<'a> Query<'a> { self.node().source() ), } - break; } continue; } prev = child.clone(); remaining_nodes.push(child); } - (name_value, remaining_nodes) + remaining_nodes } fn get_matcher_for_named_node( &self, node: &ts_query::NamedNode, struct_name: &str, current_node: &Ident, - name_value: Option, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, + query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); let first_node = node.children(self.tree).into_iter().next().unwrap(); - let (name_value, remaining_nodes) = - self.group_children(node, &first_node, name_value, current_node); + let remaining_nodes = self.group_children(node, &first_node, query_values, current_node); if remaining_nodes.len() == 0 { log::info!("single node, {}", first_node.source()); return self.get_matcher_for_definition( struct_name, first_node.into(), current_node, - name_value, + existing, + query_values, ); } @@ -469,7 +592,7 @@ impl<'a> Query<'a> { name_node.kind(), current_node, remaining_nodes, - name_value, + query_values, ); matchers.extend_one(matcher); } else { @@ -489,7 +612,7 @@ impl<'a> Query<'a> { variant.kind(), current_node, remaining_nodes.clone(), - name_value.clone(), + query_values, ); matchers.extend_one(matcher); } @@ -498,27 +621,64 @@ impl<'a> Query<'a> { #matchers } } - fn get_default_matcher(&self, name_value: Option) -> TokenStream { + fn get_default_matcher( + &self, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, + query_values: &mut HashMap, + ) -> TokenStream { + if let Some(prev) = existing.pop() { + log::info!( + "Executing previous matcher on: {:#?} with {:#?}", + prev.0.source(), + query_values + ); + return self.get_matcher_for_definition( + &prev.1, + prev.0, + &prev.2, + existing, + query_values, + ); + } + let to_append = self.executor_id(); - if let Some(name_value) = name_value { - return quote! { - #to_append.entry(#name_value).or_default().push(id); - }; + let mut args = Vec::new(); + for target in self.target_capture_names() { + if let Some(value) = query_values.get(&target) { + args.push(value); + } else { + log::warn!("No value found for: {} on {}", target, self.node().source()); + return quote! {}; + } } - log::warn!("No name value found for: {}", self.node().source()); - quote! {} + let name = query_values.get("name").unwrap_or_else(|| { + panic!( + "No name found for: {}. Query_values: {:#?} Target Capture Names: {:#?}", + self.node().source(), + query_values, + self.target_capture_names() + ); + }); + let symbol_name = self.symbol_name(); + return quote! { + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new(db, node.file_id(),#name.source()); + let symbol = #symbol_name::new(db, fully_qualified_name, id, #(#args.clone().into()),*); + #to_append.entry(#name.source()).or_default().push(symbol); + }; } fn get_matcher_for_identifier( &self, identifier: &ts_query::Identifier, struct_name: &str, current_node: &Ident, - name_value: Option, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, + query_values: &mut HashMap, ) -> TokenStream { // We have 2 nodes, the parent node and the identifier node - let to_append = self.get_default_matcher(name_value); + let to_append = self.get_default_matcher(existing, query_values); // Case 1: The identifier is the same as the struct name (IE: we know this is the corrent node) - if normalize_type_name(&identifier.source(), true) == struct_name { + let target_name = normalize_type_name(&identifier.source(), true); + if target_name == struct_name { return to_append; } // Case 2: We have a node for the parent struct @@ -539,7 +699,13 @@ impl<'a> Query<'a> { } else { // Case 3: This is a subenum // If this is a field, we may be dealing with multiple types and can't operate over all of them - return to_append; // TODO: Handle this case + let target_name = format_ident!("{}", target_name); + let struct_name = format_ident!("{}Ref", struct_name); + return quote! { + if let crate::cst::#struct_name::#target_name(#current_node) = #current_node { + #to_append + } + }; // TODO: Handle this case } } fn get_matcher_for_definition( @@ -547,22 +713,31 @@ impl<'a> Query<'a> { struct_name: &str, node: ts_query::NodeTypesRef, current_node: &Ident, - name_value: Option, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, + query_values: &mut HashMap, ) -> TokenStream { if !node.is_named() { - return self.get_default_matcher(name_value); + return self.get_default_matcher(existing, query_values); } match node { - ts_query::NodeTypesRef::FieldDefinition(field) => { - self.get_matcher_for_field(&field, struct_name, current_node, name_value) - } + ts_query::NodeTypesRef::FieldDefinition(field) => self.get_matcher_for_field( + &field, + struct_name, + current_node, + existing, + query_values, + ), ts_query::NodeTypesRef::Capture(named) => { info!("Capture: {:#?}", named.source()); quote! {} } - ts_query::NodeTypesRef::NamedNode(named) => { - self.get_matcher_for_named_node(&named, struct_name, current_node, name_value) - } + ts_query::NodeTypesRef::NamedNode(named) => self.get_matcher_for_named_node( + &named, + struct_name, + current_node, + existing, + query_values, + ), ts_query::NodeTypesRef::Comment(_) => { quote! {} } @@ -572,19 +747,28 @@ impl<'a> Query<'a> { struct_name, child.into(), current_node, - name_value.clone(), + existing, + query_values, ); // Currently just returns the first child return result; // TODO: properly handle list } quote! {} } - ts_query::NodeTypesRef::Grouping(grouping) => { - self.get_matchers_for_grouping(&grouping, struct_name, current_node, name_value) - } - ts_query::NodeTypesRef::Identifier(identifier) => { - self.get_matcher_for_identifier(&identifier, struct_name, current_node, name_value) - } + ts_query::NodeTypesRef::Grouping(grouping) => self.get_matchers_for_grouping( + &grouping, + struct_name, + current_node, + existing, + query_values, + ), + ts_query::NodeTypesRef::Identifier(identifier) => self.get_matcher_for_identifier( + &identifier, + struct_name, + current_node, + existing, + query_values, + ), unhandled => { log::warn!( "Unhandled definition in language {}: {:#?}, {:#?}", @@ -592,7 +776,7 @@ impl<'a> Query<'a> { unhandled.kind_name(), unhandled.source() ); - self.get_default_matcher(name_value) + self.get_default_matcher(existing, query_values) } } } @@ -605,10 +789,11 @@ impl<'a> Query<'a> { struct_name }; let starting_node = format_ident!("node"); - let (name_value, remaining_nodes) = self.group_children( + let mut query_values = HashMap::new(); + let remaining_nodes = self.group_children( &self.node(), &self.node().children(self.tree).into_iter().next().unwrap(), - None, + &mut query_values, &starting_node, ); return self._get_matcher_for_named_node( @@ -617,9 +802,124 @@ impl<'a> Query<'a> { kind, &starting_node, remaining_nodes, - name_value, + &mut query_values, ); } + fn target_captures(&self) -> Vec<&ts_query::Capture> { + let mut captures: Vec<&ts_query::Capture> = self + .captures() + .into_iter() + .filter(|c| !full_name_for_capture(c).contains(".")) + .collect(); + captures.sort_by_key(|c| full_name_for_capture(c)); + captures.dedup_by_key(|c| full_name_for_capture(c)); + captures + } + fn target_capture_names(&self) -> Vec { + self.target_captures() + .into_iter() + .map(|c| name_for_capture(c)) + .collect() + } + fn get_type_for_field( + &self, + parent: &ts_query::NamedNode, + field: &ts_query::FieldDefinition, + ) -> String { + let parent_name = normalize_type_name(&parent.name(self.tree).source(), true); + let field_name = normalize_field_name( + &field + .name(self.tree) + .into_iter() + .filter(|n| n.is_named()) + .next() + .unwrap() + .source(), + ); + let parsed_field = self.get_field_for_field_name(&field_name, &parent_name); + if let Some(parsed_field) = parsed_field { + parsed_field.type_name() + } else { + panic!( + "No field found for: {:#?}, {:#?}, {:#?}", + field, field_name, parent_name + ); + } + } + pub fn struct_fields(&self) -> Vec { + let mut fields = Vec::new(); + for capture in self.target_captures() { + let name = name_for_capture(capture); + let mut type_name = format_ident!("NodeTypes"); + for (node, id) in self.tree.descendants(&self.root_id) { + if let ts_query::NodeTypesRef::Capture(other) = node.as_ref() { + if other == capture { + let mut preceding_siblings = + id.preceding_siblings(self.tree.arena()).skip(1); + while let Some(prev) = preceding_siblings.next() { + if let Some(prev_capture) = self.tree.arena().get(prev) { + match prev_capture.get().as_ref() { + ts_query::NodeTypesRef::NamedNode(prev_capture) => { + type_name = format_ident!( + "{}", + normalize_type_name(&prev_capture.source(), true) + ); + break; + } + ts_query::NodeTypesRef::Identifier(prev_capture) => { + type_name = format_ident!( + "{}", + normalize_type_name(&prev_capture.source(), true) + ); + break; + } + ts_query::NodeTypesRef::AnonymousUnderscore(_) => { + let mut ancestors = id.ancestors(self.tree.arena()).skip(2); + if let Some(field) = ancestors.next() { + if let Some(parent) = ancestors.next() { + // Field definitions (example) + // (new_expression\n constructor: (_) @name) @reference.class + let parent: &ts_query::NamedNode = self + .tree + .get(&parent) + .unwrap() + .as_ref() + .try_into() + .unwrap(); + + let field: &ts_query::FieldDefinition = self + .tree + .get(&field) + .unwrap() + .as_ref() + .try_into() + .unwrap(); + type_name = format_ident!( + "{}", + self.get_type_for_field(parent, field) + ); + } + } + break; // Could be any type + } + _ => { + panic!("Unexpected capture: {:#?}", prev_capture); + } + } + } + } + } + } + } + let name_ident = format_ident!("{}", name); + fields.push(parse_quote!( + #[tracked] + #[return_ref] + pub #name_ident: crate::cst::#type_name<'db> + )); + } + fields + } } pub trait HasQuery { diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap index f655a51..bf961cf 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap @@ -3,34 +3,284 @@ source: codegen-sdk-ast-generator/src/visitor.rs expression: "codegen_sdk_common::generator::format_code_string(&visitor.to_string()).unwrap()" --- #[salsa::tracked] +pub struct Class<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub name: crate::cst::Identifier<'db>, +} +impl<'db> Class<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::ClassDefinition<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Class<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Class<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[salsa::tracked] +pub struct Constant<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub name: crate::cst::Identifier<'db>, +} +impl<'db> Constant<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::Module<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Constant<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Constant<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[salsa::tracked] +pub struct Function<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub name: crate::cst::Identifier<'db>, +} +impl<'db> Function<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::FunctionDefinition<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Function<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Function<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[salsa::tracked] +pub struct Import<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub module: crate::cst::DottedName<'db>, + #[tracked] + #[return_ref] + pub name: crate::cst::DottedName<'db>, +} +impl<'db> Import<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::ImportFromStatement<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Import<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Import<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] +pub enum Symbol<'db> { + Class(Class<'db>), + Constant(Constant<'db>), + Function(Function<'db>), + Import(Import<'db>), +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Symbol<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + match self { + Self::Class(symbol) => symbol.file(db), + Self::Constant(symbol) => symbol.file(db), + Self::Function(symbol) => symbol.file(db), + Self::Import(symbol) => symbol.file(db), + } + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + match self { + Self::Class(symbol) => symbol.root_path(db), + Self::Constant(symbol) => symbol.root_path(db), + Self::Function(symbol) => symbol.root_path(db), + Self::Import(symbol) => symbol.root_path(db), + } + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Symbol<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + match self { + Self::Class(symbol) => symbol.fully_qualified_name(db), + Self::Constant(symbol) => symbol.fully_qualified_name(db), + Self::Function(symbol) => symbol.fully_qualified_name(db), + Self::Import(symbol) => symbol.fully_qualified_name(db), + } + } +} +#[salsa::tracked] pub struct Definitions<'db> { #[return_ref] - pub _classes: BTreeMap>, + pub classes: BTreeMap>>, + #[return_ref] + pub constants: BTreeMap>>, #[return_ref] - pub _constants: BTreeMap>, + pub functions: BTreeMap>>, #[return_ref] - pub _functions: BTreeMap>, + pub imports: BTreeMap>>, } impl<'db> Definitions<'db> { pub fn visit( db: &'db dyn salsa::Database, root: &'db crate::cst::Parsed<'db>, ) -> Self { - let mut classes: BTreeMap> = BTreeMap::new(); - let mut constants: BTreeMap> = BTreeMap::new(); - let mut functions: BTreeMap> = BTreeMap::new(); + let mut classes: BTreeMap>> = BTreeMap::new(); + let mut constants: BTreeMap>> = BTreeMap::new(); + let mut functions: BTreeMap>> = BTreeMap::new(); + let mut imports: BTreeMap>> = BTreeMap::new(); let tree = root.tree(db); for (node, id) in tree.descendants(&root.program(db)) { match node { crate::cst::NodeTypes::ClassDefinition(node) => { ///Code for query: (class_definition name: (identifier) @name) @definition.class + ///Code for field: name: (identifier) @name let name = node.name(tree); - classes.entry(name.source()).or_default().push(id); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Class::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); + classes.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::FunctionDefinition(node) => { ///Code for query: (function_definition name: (identifier) @name) @definition.function + ///Code for field: name: (identifier) @name let name = node.name(tree); - functions.entry(name.source()).or_default().push(id); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Function::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); + functions.entry(name.source()).or_default().push(symbol); + } + crate::cst::NodeTypes::ImportFromStatement(node) => { + ///Code for query: (import_from_statement module_name: (dotted_name) @module name: (dotted_name) @name) @definition.import + ///Code for field: name: (dotted_name) @name + for name in node.name(tree) { + if let crate::cst::ImportFromStatementNameRef::DottedName( + name, + ) = name { + ///Code for field: module_name: (dotted_name) @module + let module_name = node.module_name(tree); + if let crate::cst::ImportFromStatementModuleNameRef::DottedName( + module_name, + ) = module_name { + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Import::new( + db, + fully_qualified_name, + id, + module_name.clone().into(), + name.clone().into(), + ); + imports.entry(name.source()).or_default().push(symbol); + } + } + } } crate::cst::NodeTypes::Module(node) => { ///Code for query: (module (expression_statement (assignment left: (identifier) @name) @definition.constant)) @@ -44,8 +294,22 @@ impl<'db> Definitions<'db> { child, ) = child { ///Code for query: (module (expression_statement (assignment left: (identifier) @name) @definition.constant)) + ///Code for field: left: (identifier) @name let left = child.left(tree); - constants.entry(left.source()).or_default().push(id); + if let crate::cst::AssignmentLeftRef::Identifier(left) = left { + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + left.source(), + ); + let symbol = Constant::new( + db, + fully_qualified_name, + id, + left.clone().into(), + ); + constants.entry(left.source()).or_default().push(symbol); + } } break; } @@ -56,60 +320,13 @@ impl<'db> Definitions<'db> { _ => {} } } - Self::new(db, classes, constants, functions) + Self::new(db, classes, constants, functions, imports) } pub fn default(db: &'db dyn salsa::Database) -> Self { - let mut classes: BTreeMap> = BTreeMap::new(); - let mut constants: BTreeMap> = BTreeMap::new(); - let mut functions: BTreeMap> = BTreeMap::new(); - Self::new(db, classes, constants, functions) - } - pub fn classes( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._classes(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn constants( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._constants(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn functions( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._functions(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() + let mut classes: BTreeMap>> = BTreeMap::new(); + let mut constants: BTreeMap>> = BTreeMap::new(); + let mut functions: BTreeMap>> = BTreeMap::new(); + let mut imports: BTreeMap>> = BTreeMap::new(); + Self::new(db, classes, constants, functions, imports) } } diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap index b3bee56..a059f4b 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap @@ -3,55 +3,354 @@ source: codegen-sdk-ast-generator/src/visitor.rs expression: "codegen_sdk_common::generator::format_code_string(&visitor.to_string()).unwrap()" --- #[salsa::tracked] +pub struct Class<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub name: crate::cst::TypeIdentifier<'db>, +} +impl<'db> Class<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::AbstractClassDeclaration<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Class<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Class<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[salsa::tracked] +pub struct Function<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub name: crate::cst::Identifier<'db>, +} +impl<'db> Function<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::FunctionSignature<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Function<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Function<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[salsa::tracked] +pub struct Interface<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub name: crate::cst::TypeIdentifier<'db>, +} +impl<'db> Interface<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::InterfaceDeclaration<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Interface<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Interface<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[salsa::tracked] +pub struct Method<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub name: crate::cst::PropertyIdentifier<'db>, +} +impl<'db> Method<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::AbstractMethodSignature<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Method<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Method<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[salsa::tracked] +pub struct Module<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub name: crate::cst::Identifier<'db>, +} +impl<'db> Module<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::Module<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Module<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Module<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} +#[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] +pub enum Symbol<'db> { + Class(Class<'db>), + Function(Function<'db>), + Interface(Interface<'db>), + Method(Method<'db>), + Module(Module<'db>), +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Symbol<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + match self { + Self::Class(symbol) => symbol.file(db), + Self::Function(symbol) => symbol.file(db), + Self::Interface(symbol) => symbol.file(db), + Self::Method(symbol) => symbol.file(db), + Self::Module(symbol) => symbol.file(db), + } + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + match self { + Self::Class(symbol) => symbol.root_path(db), + Self::Function(symbol) => symbol.root_path(db), + Self::Interface(symbol) => symbol.root_path(db), + Self::Method(symbol) => symbol.root_path(db), + Self::Module(symbol) => symbol.root_path(db), + } + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Symbol<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + match self { + Self::Class(symbol) => symbol.fully_qualified_name(db), + Self::Function(symbol) => symbol.fully_qualified_name(db), + Self::Interface(symbol) => symbol.fully_qualified_name(db), + Self::Method(symbol) => symbol.fully_qualified_name(db), + Self::Module(symbol) => symbol.fully_qualified_name(db), + } + } +} +#[salsa::tracked] pub struct Definitions<'db> { #[return_ref] - pub _classes: BTreeMap>, + pub classes: BTreeMap>>, #[return_ref] - pub _functions: BTreeMap>, + pub functions: BTreeMap>>, #[return_ref] - pub _interfaces: BTreeMap>, + pub interfaces: BTreeMap>>, #[return_ref] - pub _methods: BTreeMap>, + pub methods: BTreeMap>>, #[return_ref] - pub _modules: BTreeMap>, + pub modules: BTreeMap>>, } impl<'db> Definitions<'db> { pub fn visit( db: &'db dyn salsa::Database, root: &'db crate::cst::Parsed<'db>, ) -> Self { - let mut classes: BTreeMap> = BTreeMap::new(); - let mut functions: BTreeMap> = BTreeMap::new(); - let mut interfaces: BTreeMap> = BTreeMap::new(); - let mut methods: BTreeMap> = BTreeMap::new(); - let mut modules: BTreeMap> = BTreeMap::new(); + let mut classes: BTreeMap>> = BTreeMap::new(); + let mut functions: BTreeMap>> = BTreeMap::new(); + let mut interfaces: BTreeMap>> = BTreeMap::new(); + let mut methods: BTreeMap>> = BTreeMap::new(); + let mut modules: BTreeMap>> = BTreeMap::new(); let tree = root.tree(db); for (node, id) in tree.descendants(&root.program(db)) { match node { crate::cst::NodeTypes::AbstractClassDeclaration(node) => { ///Code for query: (abstract_class_declaration name: (type_identifier) @name) @definition.class + ///Code for field: name: (type_identifier) @name let name = node.name(tree); - classes.entry(name.source()).or_default().push(id); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Class::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); + classes.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::AbstractMethodSignature(node) => { ///Code for query: (abstract_method_signature name: (property_identifier) @name) @definition.method + ///Code for field: name: (property_identifier) @name let name = node.name(tree); - methods.entry(name.source()).or_default().push(id); + if let crate::cst::AbstractMethodSignatureNameRef::PropertyIdentifier( + name, + ) = name { + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Method::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); + methods.entry(name.source()).or_default().push(symbol); + } } crate::cst::NodeTypes::FunctionSignature(node) => { ///Code for query: (function_signature name: (identifier) @name) @definition.function + ///Code for field: name: (identifier) @name let name = node.name(tree); - functions.entry(name.source()).or_default().push(id); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Function::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); + functions.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::InterfaceDeclaration(node) => { ///Code for query: (interface_declaration name: (type_identifier) @name) @definition.interface + ///Code for field: name: (type_identifier) @name let name = node.name(tree); - interfaces.entry(name.source()).or_default().push(id); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Interface::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); + interfaces.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::Module(node) => { ///Code for query: (module name: (identifier) @name) @definition.module + ///Code for field: name: (identifier) @name let name = node.name(tree); - modules.entry(name.source()).or_default().push(id); + if let crate::cst::ModuleNameRef::Identifier(name) = name { + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Module::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); + modules.entry(name.source()).or_default().push(symbol); + } } _ => {} } @@ -59,91 +358,11 @@ impl<'db> Definitions<'db> { Self::new(db, classes, functions, interfaces, methods, modules) } pub fn default(db: &'db dyn salsa::Database) -> Self { - let mut classes: BTreeMap> = BTreeMap::new(); - let mut functions: BTreeMap> = BTreeMap::new(); - let mut interfaces: BTreeMap> = BTreeMap::new(); - let mut methods: BTreeMap> = BTreeMap::new(); - let mut modules: BTreeMap> = BTreeMap::new(); + let mut classes: BTreeMap>> = BTreeMap::new(); + let mut functions: BTreeMap>> = BTreeMap::new(); + let mut interfaces: BTreeMap>> = BTreeMap::new(); + let mut methods: BTreeMap>> = BTreeMap::new(); + let mut modules: BTreeMap>> = BTreeMap::new(); Self::new(db, classes, functions, interfaces, methods, modules) } - pub fn classes( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._classes(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn functions( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._functions(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn interfaces( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._interfaces(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn methods( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._methods(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn modules( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._modules(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } } diff --git a/codegen-sdk-ast-generator/src/visitor.rs b/codegen-sdk-ast-generator/src/visitor.rs index 7c3826a..f4e1554 100644 --- a/codegen-sdk-ast-generator/src/visitor.rs +++ b/codegen-sdk-ast-generator/src/visitor.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet}; use codegen_sdk_common::Language; use convert_case::{Case, Casing}; use proc_macro2::{Span, TokenStream}; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, quote_spanned}; use syn::parse_quote_spanned; use super::query::Query; @@ -24,9 +24,11 @@ pub fn generate_visitor<'db>( let mut types = Vec::new(); let mut variants = BTreeSet::new(); let mut enter_methods = BTreeMap::new(); + let mut symbol_names = Vec::new(); for query in queries { names.push(query.executor_id()); types.push(format_ident!("{}", query.struct_name())); + symbol_names.push(query.symbol_name()); for variant in query.struct_variants() { variants.insert(format_ident!("{}", variant)); enter_methods @@ -36,7 +38,7 @@ pub fn generate_visitor<'db>( } } let mut methods: Vec = Vec::new(); - for (variant, queries) in enter_methods { + for (variant, queries) in enter_methods.iter() { let mut matchers = TokenStream::new(); let struct_name = format_ident!("{}", variant); for query in queries { @@ -49,15 +51,108 @@ pub fn generate_visitor<'db>( } }); } + + let symbol_name = if name == "definition" { + format_ident!("Symbol") + } else { + format_ident!("Reference") + }; let maps = quote! { #( - let mut #names: BTreeMap> = BTreeMap::new(); + let mut #names: BTreeMap>> = BTreeMap::new(); )* }; let constructor = quote! { Self::new(db, #(#names),*) }; + let mut defs = Vec::new(); + let language_struct = format_ident!("{}File", language.struct_name()); + for (variant, type_name) in symbol_names.iter().zip(types.iter()) { + let query = enter_methods + .get(&type_name.to_string()) + .unwrap() + .first() + .unwrap(); + let fields = query.struct_fields(); + let span = Span::mixed_site(); + defs.push(quote_spanned! { + span => + #[salsa::tracked] + pub struct #variant<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, + #[id] + node_id: indextree::NodeId, + // #[tracked] + // #[return_ref] + // pub node: crate::cst::#type_name<'db>, + #(#fields),* + } + impl<'db> #variant<'db> { + pub fn node(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db crate::cst::#type_name<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } + } + impl<'db> codegen_sdk_resolution::HasFile<'db> for #variant<'db> { + type File<'db1> = #language_struct<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } + } + impl<'db> codegen_sdk_resolution::HasId<'db> for #variant<'db> { + fn fully_qualified_name(&self, db: &'db dyn salsa::Database) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } + } + }); + } + let symbol = if defs.len() > 0 { + quote! { + #( + #defs + )* + #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] + pub enum #symbol_name<'db> { + #( + #symbol_names(#symbol_names<'db>), + )* + } + impl<'db> codegen_sdk_resolution::HasFile<'db> for #symbol_name<'db> { + type File<'db1> = #language_struct<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + match self { + #(Self::#symbol_names(symbol) => symbol.file(db),)* + } + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + match self { + #(Self::#symbol_names(symbol) => symbol.root_path(db),)* + } + } + } + impl<'db> codegen_sdk_resolution::HasId<'db> for #symbol_name<'db> { + fn fully_qualified_name(&self, db: &'db dyn salsa::Database) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + match self { + #(Self::#symbol_names(symbol) => symbol.fully_qualified_name(db),)* + } + } + } + } + } else { + quote! { + #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] + pub enum #symbol_name<'db> { + _Phantom(std::marker::PhantomData<&'db ()>) + } + } + }; let name = format_ident!("{}s", name.to_case(Case::Pascal)); let output_constructor = quote! { pub fn visit(db: &'db dyn salsa::Database, root: &'db crate::cst::Parsed<'db>) -> Self { @@ -76,11 +171,9 @@ pub fn generate_visitor<'db>( #constructor } }; - let underscored_names = names - .iter() - .map(|name| format_ident!("_{}", name)) - .collect::>(); + quote! { + #symbol // Three lifetimes: // db: the lifetime of the database // db1: the lifetime of the visitor executing per-node @@ -89,17 +182,11 @@ pub fn generate_visitor<'db>( pub struct #name<'db> { #( #[return_ref] - pub #underscored_names: BTreeMap>, + pub #names: BTreeMap>>, )* } impl<'db> #name<'db> { #output_constructor - #( - pub fn #names(&self, db: &'db dyn salsa::Database, tree: &'db codegen_sdk_common::tree::Tree>) -> BTreeMap>> { - self.#underscored_names(db).iter().map(|(k, v)| - (k.clone(), v.iter().map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()).collect())).collect() - } - )* } } } diff --git a/codegen-sdk-ast/src/input.rs b/codegen-sdk-ast/src/input.rs deleted file mode 100644 index 9e97d42..0000000 --- a/codegen-sdk-ast/src/input.rs +++ /dev/null @@ -1,10 +0,0 @@ -use std::path::PathBuf; - -use codegen_sdk_cst::Input; -#[salsa::input] -pub struct File { - #[id] - pub path: PathBuf, - // #[return_ref] - pub contents: Input, -} diff --git a/codegen-sdk-ast/src/lib.rs b/codegen-sdk-ast/src/lib.rs index cb9958f..4cc0d27 100644 --- a/codegen-sdk-ast/src/lib.rs +++ b/codegen-sdk-ast/src/lib.rs @@ -1,17 +1,16 @@ #![recursion_limit = "512"] -pub mod input; use ambassador::delegatable_trait; -use codegen_sdk_common::File; +// use codegen_sdk_common::File; pub use codegen_sdk_common::language::LANGUAGES; pub use codegen_sdk_cst::*; -pub trait Named { - fn name(&self) -> &str; -} -impl Named for T { - fn name(&self) -> &str { - self.path().file_name().unwrap().to_str().unwrap() - } -} +// pub trait Named { +// fn name(&self) -> &str; +// } +// impl Named for T { +// fn name(&self) -> &str { +// self.path().file_name().unwrap().to_str().unwrap() +// } +// } #[delegatable_trait] pub trait Definitions<'db> { type Definitions; diff --git a/codegen-sdk-common/Cargo.toml b/codegen-sdk-common/Cargo.toml index a1b418b..d8fcd0e 100644 --- a/codegen-sdk-common/Cargo.toml +++ b/codegen-sdk-common/Cargo.toml @@ -39,6 +39,9 @@ syn = { workspace = true } prettyplease = { workspace = true } salsa = { workspace = true } indextree = { workspace = true } +indexmap = { workspace = true } +rustc-hash = "2.1.1" +hashbrown = { version = "0.15.2", features = ["rayon"] } [dev-dependencies] test-log = { workspace = true } [features] diff --git a/codegen-sdk-common/src/hash.rs b/codegen-sdk-common/src/hash.rs new file mode 100644 index 0000000..ba9e4d9 --- /dev/null +++ b/codegen-sdk-common/src/hash.rs @@ -0,0 +1,13 @@ +// Taken from https://github.com/salsa-rs/salsa/blob/9d2a9786c45000f5fa396ad2872391e302a2836a/src/hash.rs#L1 +use std::hash::{BuildHasher, Hash}; + +pub type FxHasher = std::hash::BuildHasherDefault; +pub type FxIndexSet = indexmap::IndexSet; +pub type FxIndexMap = indexmap::IndexMap; +// pub type FxDashMap = dashmap::DashMap; +// pub type FxLinkedHashSet = hashlink::LinkedHashSet; +pub type FxHashSet = hashbrown::HashSet; +pub type FxHashMap = hashbrown::HashMap; +pub fn hash(t: &T) -> u64 { + FxHasher::default().hash_one(t) +} diff --git a/codegen-sdk-common/src/language/python.rs b/codegen-sdk-common/src/language/python.rs index 62188a1..efb1534 100644 --- a/codegen-sdk-common/src/language/python.rs +++ b/codegen-sdk-common/src/language/python.rs @@ -1,12 +1,17 @@ use super::Language; +const PYTHON_TAGS_QUERY: &'static str = tree_sitter_python::TAGS_QUERY; +const EXTRA_TAGS_QUERY: &'static str = " + (import_from_statement module_name: (dotted_name) @module name: (dotted_name) @name) @definition.import + "; lazy_static! { + static ref TAGS_QUERY: String = [PYTHON_TAGS_QUERY, EXTRA_TAGS_QUERY].join("\n"); pub static ref Python: Language = Language::new( "python", "Python", tree_sitter_python::NODE_TYPES, &["py"], tree_sitter_python::LANGUAGE.into(), - tree_sitter_python::TAGS_QUERY, + &TAGS_QUERY, ) .unwrap(); } diff --git a/codegen-sdk-common/src/lib.rs b/codegen-sdk-common/src/lib.rs index 67d9a91..594d23d 100644 --- a/codegen-sdk-common/src/lib.rs +++ b/codegen-sdk-common/src/lib.rs @@ -1,6 +1,7 @@ #![feature(error_generic_member_access)] #![feature(trivial_bounds)] mod errors; +pub mod hash; pub mod language; pub mod traits; pub mod utils; diff --git a/codegen-sdk-common/src/tree/context.rs b/codegen-sdk-common/src/tree/context.rs index 1cf94dc..9210b85 100644 --- a/codegen-sdk-common/src/tree/context.rs +++ b/codegen-sdk-common/src/tree/context.rs @@ -6,15 +6,17 @@ use crate::tree::{FileNodeId, Tree, TreeNode}; pub struct ParseContext<'db, T: TreeNode> { pub db: &'db dyn salsa::Database, pub file_id: FileNodeId<'db>, + pub root: FileNodeId<'db>, pub buffer: Arc, pub tree: Tree, } impl<'db, T: TreeNode> ParseContext<'db, T> { - pub fn new(db: &'db dyn salsa::Database, path: PathBuf, content: Bytes) -> Self { + pub fn new(db: &'db dyn salsa::Database, path: PathBuf, root: PathBuf, content: Bytes) -> Self { let file_id = FileNodeId::new(db, path); Self { db, file_id, + root: FileNodeId::new(db, root), buffer: Arc::new(content), tree: Tree::default(), } diff --git a/codegen-sdk-common/src/tree/id.rs b/codegen-sdk-common/src/tree/id.rs index b665416..057689b 100644 --- a/codegen-sdk-common/src/tree/id.rs +++ b/codegen-sdk-common/src/tree/id.rs @@ -2,11 +2,12 @@ use std::path::PathBuf; #[salsa::interned] pub struct FileNodeId<'db> { - file_path: PathBuf, + pub path: PathBuf, } #[salsa::interned] pub struct CSTNodeId<'db> { - file_id: FileNodeId<'db>, + pub file: FileNodeId<'db>, node_id: usize, + pub root: FileNodeId<'db>, // TODO: add a marker for tree-sitter generation } diff --git a/codegen-sdk-common/src/tree/tree.rs b/codegen-sdk-common/src/tree/tree.rs index ae884cd..01518f8 100644 --- a/codegen-sdk-common/src/tree/tree.rs +++ b/codegen-sdk-common/src/tree/tree.rs @@ -39,6 +39,9 @@ impl Tree { id.children(&self.ids) .map(|id| (self.get(&id).unwrap(), id)) } + pub fn arena(&self) -> &Arena { + &self.ids + } } unsafe impl Update for Tree where diff --git a/codegen-sdk-cst-generator/Cargo.toml b/codegen-sdk-cst-generator/Cargo.toml index aecadca..2581f23 100644 --- a/codegen-sdk-cst-generator/Cargo.toml +++ b/codegen-sdk-cst-generator/Cargo.toml @@ -2,7 +2,7 @@ name = "codegen-sdk-cst-generator" version = "0.1.0" edition = "2024" - +description = "Generates the CST for the given language from the tree-sitter node-types.json" [dependencies] syn = { workspace = true } tree-sitter = { workspace = true } diff --git a/codegen-sdk-cst-generator/src/generator.rs b/codegen-sdk-cst-generator/src/generator.rs index 0827658..48c6c66 100644 --- a/codegen-sdk-cst-generator/src/generator.rs +++ b/codegen-sdk-cst-generator/src/generator.rs @@ -53,10 +53,11 @@ fn get_parser(language: &Language) -> TokenStream { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } - pub fn parse_program_raw<'db>(db: &'db dyn salsa::Database, input: codegen_sdk_cst::Input, path: PathBuf) -> Option> { + pub fn parse_program_raw<'db>(db: &'db dyn salsa::Database, input: codegen_sdk_cst::File) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::#language_name::#language_struct_name.parse_tree_sitter(&input.content(db)); match tree { @@ -65,7 +66,7 @@ fn get_parser(language: &Language) -> TokenStream { ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new(db, input.path(db), input.root(db), buffer); let root_id = #program_id::orphaned(&mut context, tree.root_node()) .map_or_else(|e| { e.report(db); @@ -74,7 +75,7 @@ fn get_parser(language: &Language) -> TokenStream { Some(program) }); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some(Parsed::new(db, context.file_id, Arc::new(context.tree), program)) } else { None } @@ -87,8 +88,8 @@ fn get_parser(language: &Language) -> TokenStream { } } #[salsa::tracked(return_ref)] - pub fn parse_program(db: &dyn salsa::Database, input: codegen_sdk_cst::Input) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + pub fn parse_program(db: &dyn salsa::Database, input: codegen_sdk_cst::File) -> Parsed<'_> { + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -102,13 +103,13 @@ fn get_parser(language: &Language) -> TokenStream { fn language() -> &'static codegen_sdk_common::language::Language { &codegen_sdk_common::language::#language_name::#language_struct_name } - fn parse<'db>(db: &'db dyn salsa::Database, content: std::string::String) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + fn parse<'db>(db: &'db dyn salsa::Database, content: std::string::String) -> Option<(&'db Self::Program<'db>, &'db Tree>, indextree::NodeId)> { + let input = codegen_sdk_cst::File::new(db, std::path::PathBuf::new(), content, std::path::PathBuf::new()); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } } diff --git a/codegen-sdk-cst-generator/src/generator/node.rs b/codegen-sdk-cst-generator/src/generator/node.rs index 5e64a3a..cd080aa 100644 --- a/codegen-sdk-cst-generator/src/generator/node.rs +++ b/codegen-sdk-cst-generator/src/generator/node.rs @@ -224,7 +224,7 @@ impl<'a> Node<'a> { fn from_node(context: &mut ParseContext<'db, NodeTypes<'db>>, node: tree_sitter::Node) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); #(#constructor_fields)* Ok((Self { diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_complex.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_complex.snap index a5dfaca..0c3eab2 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_complex.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_complex.snap @@ -24,7 +24,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let multiple_field = get_multiple_children_by_field_name::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_simple.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_simple.snap index 3488e0a..8dacc1e 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_simple.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_simple.snap @@ -21,7 +21,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_children.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_children.snap index 71c16c8..5ca94b1 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_children.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_children.snap @@ -22,7 +22,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let _children = named_children_without_field_names::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_fields.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_fields.snap index f2c05f6..7a52fb5 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_fields.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_fields.snap @@ -22,7 +22,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let test_field = get_child_by_field_name::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_single_child_type.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_single_child_type.snap index 71c16c8..5ca94b1 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_single_child_type.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_single_child_type.snap @@ -22,7 +22,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let _children = named_children_without_field_names::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__add_field_subenums_missing_node-2.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__add_field_subenums_missing_node-2.snap index a6c909e..def0094 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__add_field_subenums_missing_node-2.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__add_field_subenums_missing_node-2.snap @@ -21,7 +21,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for AnonymousNodeA<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -142,7 +142,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for NodeB<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -264,7 +264,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for NodeC<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let field = get_multiple_children_by_field_name::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__get_structs.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__get_structs.snap index 5c0d90f..cc9210e 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__get_structs.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__get_structs.snap @@ -21,7 +21,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for Test<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { diff --git a/codegen-sdk-cst-generator/src/generator/utils.rs b/codegen-sdk-cst-generator/src/generator/utils.rs index 1dbf87f..6efd477 100644 --- a/codegen-sdk-cst-generator/src/generator/utils.rs +++ b/codegen-sdk-cst-generator/src/generator/utils.rs @@ -65,6 +65,20 @@ pub fn get_from_enum_to_ref(enum_name: &str, variant_names: &Vec) -> Toke node.as_ref().into() } } + impl<'db3> From<#name_ref<'db3>> for #name<'db3> { + fn from(node: #name_ref<'db3>) -> Self { + match node { + #(#name_ref::#variant_names(data) => Self::#variant_names((*data).clone()),)* + } + } + } + impl<'db3> From<&'db3 #name_ref<'db3>> for #name<'db3> { + fn from(node: &'db3 #name_ref<'db3>) -> Self { + match node { + #(#name_ref::#variant_names(data) => Self::#variant_names((*data).clone()),)* + } + } + } #( impl<'db3> TryFrom<#name_ref<'db3>> for &'db3 #variant_names<'db3> { type Error = codegen_sdk_cst::ConversionError; diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__basic_subtypes.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__basic_subtypes.snap index 487887b..0abe8b5 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__basic_subtypes.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__basic_subtypes.snap @@ -93,6 +93,26 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::UnaryExpression(data) => Self::UnaryExpression((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::UnaryExpression(data) => Self::UnaryExpression((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -144,6 +164,30 @@ impl<'db3> From<&'db3 Expression<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Expression<'db3> { + fn from(node: ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::UnaryExpression(data) => { + Self::UnaryExpression((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 ExpressionRef<'db3>> for Expression<'db3> { + fn from(node: &'db3 ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::UnaryExpression(data) => { + Self::UnaryExpression((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: ExpressionRef<'db3>) -> Result { @@ -191,7 +235,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for BinaryExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -312,7 +356,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for UnaryExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -422,13 +466,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -439,7 +483,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -449,7 +498,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -464,9 +515,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -483,12 +534,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__deeply_nested_subtypes.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__deeply_nested_subtypes.snap index 55f053b..4ca160f 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__deeply_nested_subtypes.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__deeply_nested_subtypes.snap @@ -147,6 +147,36 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + NodeTypesRef::ExpressionStatement(data) => { + Self::ExpressionStatement((*data).clone()) + } + NodeTypesRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + NodeTypesRef::ExpressionStatement(data) => { + Self::ExpressionStatement((*data).clone()) + } + NodeTypesRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 ClassDeclaration<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -212,6 +242,30 @@ impl<'db3> From<&'db3 Declaration<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Declaration<'db3> { + fn from(node: DeclarationRef<'db3>) -> Self { + match node { + DeclarationRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + DeclarationRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 DeclarationRef<'db3>> for Declaration<'db3> { + fn from(node: &'db3 DeclarationRef<'db3>) -> Self { + match node { + DeclarationRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + DeclarationRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 ClassDeclaration<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: DeclarationRef<'db3>) -> Result { @@ -263,6 +317,24 @@ impl<'db3> From<&'db3 FunctionDeclaration<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for FunctionDeclaration<'db3> { + fn from(node: FunctionDeclarationRef<'db3>) -> Self { + match node { + FunctionDeclarationRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 FunctionDeclarationRef<'db3>> for FunctionDeclaration<'db3> { + fn from(node: &'db3 FunctionDeclarationRef<'db3>) -> Self { + match node { + FunctionDeclarationRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 MethodDeclaration<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: FunctionDeclarationRef<'db3>) -> Result { @@ -302,6 +374,36 @@ impl<'db3> From<&'db3 Statement<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Statement<'db3> { + fn from(node: StatementRef<'db3>) -> Self { + match node { + StatementRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + StatementRef::ExpressionStatement(data) => { + Self::ExpressionStatement((*data).clone()) + } + StatementRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 StatementRef<'db3>> for Statement<'db3> { + fn from(node: &'db3 StatementRef<'db3>) -> Self { + match node { + StatementRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + StatementRef::ExpressionStatement(data) => { + Self::ExpressionStatement((*data).clone()) + } + StatementRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 ClassDeclaration<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: StatementRef<'db3>) -> Result { @@ -363,7 +465,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for ClassDeclaration<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -484,7 +586,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for ExpressionStatement<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -605,7 +707,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for MethodDeclaration<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -715,13 +817,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -732,7 +834,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -742,7 +849,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -757,9 +866,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -776,12 +885,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__subtypes_with_fields.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__subtypes_with_fields.snap index 30c03a1..4f98e87 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__subtypes_with_fields.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__subtypes_with_fields.snap @@ -112,6 +112,26 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -165,6 +185,27 @@ impl<'db3> From<&'db3 BinaryExpressionChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for BinaryExpressionChildren<'db3> { + fn from(node: BinaryExpressionChildrenRef<'db3>) -> Self { + match node { + BinaryExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + BinaryExpressionChildrenRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} +impl<'db3> From<&'db3 BinaryExpressionChildrenRef<'db3>> +for BinaryExpressionChildren<'db3> { + fn from(node: &'db3 BinaryExpressionChildrenRef<'db3>) -> Self { + match node { + BinaryExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + BinaryExpressionChildrenRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: BinaryExpressionChildrenRef<'db3>) -> Result { @@ -216,6 +257,26 @@ impl<'db3> From<&'db3 Expression<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Expression<'db3> { + fn from(node: ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} +impl<'db3> From<&'db3 ExpressionRef<'db3>> for Expression<'db3> { + fn from(node: &'db3 ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: ExpressionRef<'db3>) -> Result { @@ -265,7 +326,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for BinaryExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let left = get_child_by_field_name::< NodeTypes<'db>, @@ -416,7 +477,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for Literal<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -526,13 +587,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -543,7 +604,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -553,7 +619,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -568,9 +636,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -587,12 +655,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_children__subtypes_with_children.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_children__subtypes_with_children.snap index f57a980..1bd8f1b 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_children__subtypes_with_children.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_children__subtypes_with_children.snap @@ -147,6 +147,24 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::Block(data) => Self::Block((*data).clone()), + NodeTypesRef::IfStatement(data) => Self::IfStatement((*data).clone()), + NodeTypesRef::ReturnStatement(data) => Self::ReturnStatement((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::Block(data) => Self::Block((*data).clone()), + NodeTypesRef::IfStatement(data) => Self::IfStatement((*data).clone()), + NodeTypesRef::ReturnStatement(data) => Self::ReturnStatement((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 Block<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -212,6 +230,26 @@ impl<'db3> From<&'db3 BlockChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for BlockChildren<'db3> { + fn from(node: BlockChildrenRef<'db3>) -> Self { + match node { + BlockChildrenRef::IfStatement(data) => Self::IfStatement((*data).clone()), + BlockChildrenRef::ReturnStatement(data) => { + Self::ReturnStatement((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 BlockChildrenRef<'db3>> for BlockChildren<'db3> { + fn from(node: &'db3 BlockChildrenRef<'db3>) -> Self { + match node { + BlockChildrenRef::IfStatement(data) => Self::IfStatement((*data).clone()), + BlockChildrenRef::ReturnStatement(data) => { + Self::ReturnStatement((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 IfStatement<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: BlockChildrenRef<'db3>) -> Result { @@ -261,6 +299,20 @@ impl<'db3> From<&'db3 IfStatementChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for IfStatementChildren<'db3> { + fn from(node: IfStatementChildrenRef<'db3>) -> Self { + match node { + IfStatementChildrenRef::Block(data) => Self::Block((*data).clone()), + } + } +} +impl<'db3> From<&'db3 IfStatementChildrenRef<'db3>> for IfStatementChildren<'db3> { + fn from(node: &'db3 IfStatementChildrenRef<'db3>) -> Self { + match node { + IfStatementChildrenRef::Block(data) => Self::Block((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 Block<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: IfStatementChildrenRef<'db3>) -> Result { @@ -298,6 +350,22 @@ impl<'db3> From<&'db3 Statement<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Statement<'db3> { + fn from(node: StatementRef<'db3>) -> Self { + match node { + StatementRef::IfStatement(data) => Self::IfStatement((*data).clone()), + StatementRef::ReturnStatement(data) => Self::ReturnStatement((*data).clone()), + } + } +} +impl<'db3> From<&'db3 StatementRef<'db3>> for Statement<'db3> { + fn from(node: &'db3 StatementRef<'db3>) -> Self { + match node { + StatementRef::IfStatement(data) => Self::IfStatement((*data).clone()), + StatementRef::ReturnStatement(data) => Self::ReturnStatement((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 IfStatement<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: StatementRef<'db3>) -> Result { @@ -346,7 +414,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for Block<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let _children = named_children_without_field_names::< NodeTypes<'db>, @@ -482,7 +550,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for IfStatement<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let _children = named_children_without_field_names::< NodeTypes<'db>, @@ -617,7 +685,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for ReturnStatement<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -727,13 +795,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -744,7 +812,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -754,7 +827,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -769,9 +844,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -788,12 +863,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_multiple_inheritance__multiple_inheritance.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_multiple_inheritance__multiple_inheritance.snap index e7e00f0..e3c2c6a 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_multiple_inheritance__multiple_inheritance.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_multiple_inheritance__multiple_inheritance.snap @@ -96,6 +96,20 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 ClassMethod<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -131,6 +145,20 @@ impl<'db3> From<&'db3 ClassMember<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for ClassMember<'db3> { + fn from(node: ClassMemberRef<'db3>) -> Self { + match node { + ClassMemberRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} +impl<'db3> From<&'db3 ClassMemberRef<'db3>> for ClassMember<'db3> { + fn from(node: &'db3 ClassMemberRef<'db3>) -> Self { + match node { + ClassMemberRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 ClassMethod<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: ClassMemberRef<'db3>) -> Result { @@ -166,6 +194,20 @@ impl<'db3> From<&'db3 Declaration<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Declaration<'db3> { + fn from(node: DeclarationRef<'db3>) -> Self { + match node { + DeclarationRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} +impl<'db3> From<&'db3 DeclarationRef<'db3>> for Declaration<'db3> { + fn from(node: &'db3 DeclarationRef<'db3>) -> Self { + match node { + DeclarationRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 ClassMethod<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: DeclarationRef<'db3>) -> Result { @@ -199,7 +241,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for ClassMethod<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -309,13 +351,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -326,7 +368,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -336,7 +383,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -351,9 +400,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -370,12 +419,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_recursive__recursive_subtypes.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_recursive__recursive_subtypes.snap index 5a68711..a2bb6f9 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_recursive__recursive_subtypes.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_recursive__recursive_subtypes.snap @@ -131,6 +131,26 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::CallExpression(data) => Self::CallExpression((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::CallExpression(data) => Self::CallExpression((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -186,6 +206,31 @@ impl<'db3> From<&'db3 BinaryExpressionChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for BinaryExpressionChildren<'db3> { + fn from(node: BinaryExpressionChildrenRef<'db3>) -> Self { + match node { + BinaryExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + BinaryExpressionChildrenRef::CallExpression(data) => { + Self::CallExpression((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 BinaryExpressionChildrenRef<'db3>> +for BinaryExpressionChildren<'db3> { + fn from(node: &'db3 BinaryExpressionChildrenRef<'db3>) -> Self { + match node { + BinaryExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + BinaryExpressionChildrenRef::CallExpression(data) => { + Self::CallExpression((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: BinaryExpressionChildrenRef<'db3>) -> Result { @@ -239,6 +284,30 @@ impl<'db3> From<&'db3 CallExpressionChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for CallExpressionChildren<'db3> { + fn from(node: CallExpressionChildrenRef<'db3>) -> Self { + match node { + CallExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + CallExpressionChildrenRef::CallExpression(data) => { + Self::CallExpression((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 CallExpressionChildrenRef<'db3>> for CallExpressionChildren<'db3> { + fn from(node: &'db3 CallExpressionChildrenRef<'db3>) -> Self { + match node { + CallExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + CallExpressionChildrenRef::CallExpression(data) => { + Self::CallExpression((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: CallExpressionChildrenRef<'db3>) -> Result { @@ -290,6 +359,26 @@ impl<'db3> From<&'db3 Expression<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Expression<'db3> { + fn from(node: ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::CallExpression(data) => Self::CallExpression((*data).clone()), + } + } +} +impl<'db3> From<&'db3 ExpressionRef<'db3>> for Expression<'db3> { + fn from(node: &'db3 ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::CallExpression(data) => Self::CallExpression((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: ExpressionRef<'db3>) -> Result { @@ -339,7 +428,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for BinaryExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let left = get_child_by_field_name::< NodeTypes<'db>, @@ -492,7 +581,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for CallExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let callee = get_child_by_field_name::< NodeTypes<'db>, @@ -632,13 +721,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -649,7 +738,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -659,7 +753,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -674,9 +770,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -693,12 +789,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst/Cargo.toml b/codegen-sdk-cst/Cargo.toml index d9d952b..e4c1340 100644 --- a/codegen-sdk-cst/Cargo.toml +++ b/codegen-sdk-cst/Cargo.toml @@ -13,6 +13,7 @@ log = { workspace = true } salsa = { workspace = true } dashmap = "6.1.0" thiserror = {workspace = true} +indextree = { workspace = true } [dev-dependencies] tempfile = { workspace = true } test-log = { workspace = true } diff --git a/codegen-sdk-cst/src/database.rs b/codegen-sdk-cst/src/database.rs index de19d04..e4c3b17 100644 --- a/codegen-sdk-cst/src/database.rs +++ b/codegen-sdk-cst/src/database.rs @@ -2,10 +2,10 @@ use std::{any::Any, path::PathBuf, sync::Arc}; use dashmap::{DashMap, mapref::entry::Entry}; -use crate::Input; +use crate::File; #[salsa::db] #[derive(Default, Clone)] -// Basic Database implementation for Query generation. This is not used for anything else. +// Basic Database implementation for Query generation and testing. This is not used for anything else. pub struct CSTDatabase { storage: salsa::Storage, } diff --git a/codegen-sdk-cst/src/input.rs b/codegen-sdk-cst/src/input.rs index 658de06..d668f02 100644 --- a/codegen-sdk-cst/src/input.rs +++ b/codegen-sdk-cst/src/input.rs @@ -1,5 +1,10 @@ use std::path::PathBuf; #[salsa::input] -pub struct Input { +pub struct File { + #[id] + pub path: PathBuf, + #[return_ref] pub content: String, + #[id] + pub root: PathBuf, } diff --git a/codegen-sdk-cst/src/language.rs b/codegen-sdk-cst/src/language.rs index 8fe82c0..de3d990 100644 --- a/codegen-sdk-cst/src/language.rs +++ b/codegen-sdk-cst/src/language.rs @@ -15,12 +15,23 @@ pub trait CSTLanguage { fn parse<'db>( db: &'db dyn salsa::Database, content: String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)>; + ) -> Option<( + &'db Self::Program<'db>, + &'db Tree>, + indextree::NodeId, + )>; fn parse_file_from_cache<'db>( db: &'db dyn salsa::Database, file_path: &PathBuf, #[cfg(feature = "serialization")] cache: &'db codegen_sdk_common::serialize::Cache, - ) -> Result, &'db Tree>)>, ParseError> { + ) -> Result< + Option<( + &'db Self::Program<'db>, + &'db Tree>, + indextree::NodeId, + )>, + ParseError, + > { #[cfg(feature = "serialization")] { let serialized_path = cache.get_path(file_path); @@ -35,7 +46,14 @@ pub trait CSTLanguage { db: &'db dyn salsa::Database, file_path: &PathBuf, #[cfg(feature = "serialization")] cache: &'db codegen_sdk_common::serialize::Cache, - ) -> Result, &'db Tree>)>, ParseError> { + ) -> Result< + Option<( + &'db Self::Program<'db>, + &'db Tree>, + indextree::NodeId, + )>, + ParseError, + > { if let Some(parsed) = Self::parse_file_from_cache( db, file_path, diff --git a/codegen-sdk-cst/src/lib.rs b/codegen-sdk-cst/src/lib.rs index 700938d..11407fc 100644 --- a/codegen-sdk-cst/src/lib.rs +++ b/codegen-sdk-cst/src/lib.rs @@ -10,7 +10,7 @@ use dashmap::{DashMap, mapref::entry::Entry}; mod database; use codegen_sdk_common::{ParseError, traits::CSTNode}; pub use database::CSTDatabase; -pub use input::Input; +pub use input::File; mod language; pub use codegen_sdk_common::language::LANGUAGES; pub use language::CSTLanguage; diff --git a/codegen-sdk-macros/src/lib.rs b/codegen-sdk-macros/src/lib.rs index cb168d4..57d9ccb 100644 --- a/codegen-sdk-macros/src/lib.rs +++ b/codegen-sdk-macros/src/lib.rs @@ -101,7 +101,7 @@ pub fn parse_language(_item: TokenStream) -> TokenStream { let variant: proc_macro2::TokenStream = quote! { #[cfg(feature = #name)] if #package_name::cst::#struct_name::should_parse(&file.path(db)).unwrap_or(false) { - let parsed = #package_name::ast::parse(db, file); + let parsed = #package_name::ast::parse(db, file).clone(); return Parsed::new( db, FileNodeId::new(db, file.path(db)), diff --git a/codegen-sdk-resolution/Cargo.toml b/codegen-sdk-resolution/Cargo.toml index fba1a9c..130b35d 100644 --- a/codegen-sdk-resolution/Cargo.toml +++ b/codegen-sdk-resolution/Cargo.toml @@ -6,3 +6,11 @@ edition = "2024" [dependencies] salsa = { workspace = true } +log = {workspace = true} +codegen-sdk-ast = { workspace = true } +codegen-sdk-common = { workspace = true } +anyhow = { workspace = true } +indicatif = { workspace = true } +ambassador = { workspace = true } +codegen-sdk-cst = { workspace = true } +smallvec = { workspace = true } diff --git a/codegen-sdk-resolution/src/codebase.rs b/codegen-sdk-resolution/src/codebase.rs index 35b3e6e..5cd112f 100644 --- a/codegen-sdk-resolution/src/codebase.rs +++ b/codegen-sdk-resolution/src/codebase.rs @@ -1,6 +1,8 @@ use std::path::PathBuf; -use salsa::Database; +use codegen_sdk_common::FileNodeId; + +use crate::Db; // Not sure what to name this // Equivalent to CodebaseGraph/CodebaseContext in the SDK pub trait CodebaseContext { @@ -8,6 +10,20 @@ pub trait CodebaseContext { where Self: 'a; fn files<'a>(&'a self) -> Vec<&'a Self::File<'a>>; - fn db(&self) -> &dyn Database; + fn db(&self) -> &dyn Db; fn get_file<'a>(&'a self, path: PathBuf) -> Option<&'a Self::File<'a>>; + fn get_file_for_id<'a>(&'a self, id: FileNodeId) -> Option<&'a Self::File<'a>> { + self.get_file(id.path(self.db())) + } + fn get_raw_file_for_id<'a>(&'a self, id: FileNodeId) -> Option { + self.get_raw_file(id.path(self.db())) + } + fn get_raw_file<'a>(&'a self, path: PathBuf) -> Option { + if let Ok(path) = path.canonicalize() { + self.db().get_file(path) + } else { + None + } + } + fn root_path(&self) -> PathBuf; } diff --git a/codegen-sdk-resolution/src/database.rs b/codegen-sdk-resolution/src/database.rs new file mode 100644 index 0000000..69aa8c2 --- /dev/null +++ b/codegen-sdk-resolution/src/database.rs @@ -0,0 +1,52 @@ +use std::path::PathBuf; + +use codegen_sdk_cst::File; +use indicatif::MultiProgress; +#[salsa::db] +pub trait Db: salsa::Database + Send { + fn input(&self, path: PathBuf) -> anyhow::Result; + fn get_file(&self, path: PathBuf) -> Option; + fn multi_progress(&self) -> &MultiProgress; + fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()>; + fn files(&self) -> codegen_sdk_common::hash::FxHashSet>; + fn get_file_for_id(&self, id: codegen_sdk_common::FileNodeId<'_>) -> Option { + self.get_file(id.path(self)) + } +} +#[salsa::tracked] +pub fn files<'db>( + db: &'db dyn Db, +) -> codegen_sdk_common::hash::FxHashSet> { + db.files() +} +#[salsa::db] +impl Db for codegen_sdk_cst::CSTDatabase { + fn input(&self, path: PathBuf) -> anyhow::Result { + let content = std::fs::read_to_string(&path)?; + let file = + codegen_sdk_cst::File::new(self, path.canonicalize().unwrap(), content, PathBuf::new()); + Ok(file) + } + fn multi_progress(&self) -> &MultiProgress { + unimplemented!() + } + fn watch_dir(&mut self, _path: PathBuf) -> anyhow::Result<()> { + unimplemented!() + } + fn files(&self) -> codegen_sdk_common::hash::FxHashSet> { + let path = PathBuf::from("."); + let files = std::fs::read_dir(path).unwrap(); + let mut set = codegen_sdk_common::hash::FxHashSet::default(); + for file in files { + let file = file.unwrap(); + let path = file.path().canonicalize().unwrap(); + if path.is_file() { + set.insert(codegen_sdk_common::FileNodeId::new(self, path)); + } + } + set + } + fn get_file(&self, path: PathBuf) -> Option { + self.input(path).ok() + } +} diff --git a/codegen-sdk-resolution/src/lib.rs b/codegen-sdk-resolution/src/lib.rs index 7a3184e..9d79bec 100644 --- a/codegen-sdk-resolution/src/lib.rs +++ b/codegen-sdk-resolution/src/lib.rs @@ -1,4 +1,7 @@ mod scope; +use std::path::PathBuf; + +use ambassador::delegatable_trait; pub use scope::Scope; mod resolve_type; pub use resolve_type::ResolveType; @@ -6,3 +9,16 @@ mod references; pub use references::References; mod codebase; pub use codebase::CodebaseContext; +mod database; +mod parse; +pub use database::{Db, files}; +pub use parse::Parse; +pub use scope::Dependencies; +mod name; +pub use name::{FullyQualifiedName, HasId}; +#[delegatable_trait] +pub trait HasFile<'db> { + type File<'db1>; + fn file(&self, db: &'db dyn Db) -> &'db Self::File<'db>; + fn root_path(&self, db: &'db dyn Db) -> PathBuf; +} diff --git a/codegen-sdk-resolution/src/name.rs b/codegen-sdk-resolution/src/name.rs new file mode 100644 index 0000000..c81e0a4 --- /dev/null +++ b/codegen-sdk-resolution/src/name.rs @@ -0,0 +1,13 @@ +use codegen_sdk_common::FileNodeId; + +#[salsa::interned] +pub struct FullyQualifiedName<'db> { + #[id] + pub path: FileNodeId<'db>, + #[return_ref] + pub name: String, +} + +pub trait HasId<'db> { + fn fully_qualified_name(&self, db: &'db dyn salsa::Database) -> FullyQualifiedName<'db>; +} diff --git a/codegen-sdk-resolution/src/parse.rs b/codegen-sdk-resolution/src/parse.rs new file mode 100644 index 0000000..ca3ffb3 --- /dev/null +++ b/codegen-sdk-resolution/src/parse.rs @@ -0,0 +1,5 @@ +use crate::Db; + +pub trait Parse<'db> { + fn parse(db: &'db dyn Db, input: codegen_sdk_common::FileNodeId<'db>) -> &'db Self; +} diff --git a/codegen-sdk-resolution/src/references.rs b/codegen-sdk-resolution/src/references.rs index 5ee2062..db21c56 100644 --- a/codegen-sdk-resolution/src/references.rs +++ b/codegen-sdk-resolution/src/references.rs @@ -1,30 +1,34 @@ -use crate::{CodebaseContext, ResolveType}; +use std::hash::Hash; + +use crate::{Db, Dependencies, FullyQualifiedName, HasFile, HasId, Parse, ResolveType}; pub trait References< 'db, - ReferenceType: ResolveType<'db, Scope, Type = Self> + Clone, // References must resolve to this type - Scope: crate::Scope<'db, Type = Self, ReferenceType = ReferenceType> + Clone, ->: Eq + PartialEq where Self:'db + Dep: Dependencies<'db, FullyQualifiedName<'db>, ReferenceType> + 'db, + ReferenceType: ResolveType<'db, Type = Self> + Eq + Hash + Clone + 'db, // References must resolve to this type + Scope: crate::Scope<'db, Type = Self, ReferenceType = ReferenceType, Dependencies = Dep> + + Clone + 'db, +>: Eq + PartialEq + Hash + HasFile<'db, File<'db> = Scope> + HasId<'db> + Sized + 'db where Self:'db { - fn references + Clone + 'db, T>(&self, codebase: &'db T, scope: &Scope) -> Vec + fn references(self, db: &'db dyn Db) -> Vec where Self: Sized, - for<'b> T: CodebaseContext = F> + 'static, - { - let scopes: Vec = codebase.files().into_iter().filter_map(|file| file.clone().try_into().ok()).collect(); - return self.references_for_scopes(codebase.db(), scopes, scope); - } - fn references_for_scopes(&self, db: &'db dyn salsa::Database, scopes: Vec, scope: &Scope) -> Vec - where - Self: Sized + 'db, - { - let mut results = Vec::new(); - for reference in scope.clone().resolvables(db) { - let resolved = reference.clone().resolve_type(db, scope.clone(), scopes.clone()); - if resolved.iter().any(|result| *result == *self) { - results.push(reference); - } - } - results - } + Scope: Parse<'db>; + // { + // // let files = files(db); + // // log::info!(target: "resolution", "Finding references across {:?} files", files.len()); + // // let mut results = Vec::new(); + // // for input in files { + // // // if !self.filter(db, &input) { + // // // continue; + // // // } + // // let file = Scope::parse(db, input.clone()); + // // let dependencies = file.clone().compute_dependencies_query(db); + // // if let Some(references) = dependencies.get(db, &self.fully_qualified_name(db)) { + // // results.extend(references.iter().cloned()); + // // } + // // } + // results + // } + fn filter(&self, db: &'db dyn Db, input: &codegen_sdk_cst::File) -> bool; } diff --git a/codegen-sdk-resolution/src/resolve_type.rs b/codegen-sdk-resolution/src/resolve_type.rs index 62d6a85..e5fdf84 100644 --- a/codegen-sdk-resolution/src/resolve_type.rs +++ b/codegen-sdk-resolution/src/resolve_type.rs @@ -1,12 +1,6 @@ -use crate::Scope; - +use crate::Db; // Get definitions for a given type -pub trait ResolveType<'db, T: Scope<'db>> { +pub trait ResolveType<'db> { type Type; // Possible types this trait can be defined as - fn resolve_type( - self, - db: &'db dyn salsa::Database, - scope: T, - scopes: Vec, - ) -> &'db Vec; + fn resolve_type(self, db: &'db dyn Db) -> &'db Vec; } diff --git a/codegen-sdk-resolution/src/scope.rs b/codegen-sdk-resolution/src/scope.rs index ec7baa9..84c2727 100644 --- a/codegen-sdk-resolution/src/scope.rs +++ b/codegen-sdk-resolution/src/scope.rs @@ -1,10 +1,45 @@ -use crate::ResolveType; +use std::hash::Hash; +use crate::{Db, FullyQualifiedName, HasId, ResolveType}; +pub trait Dependencies<'db, Type, ReferenceType>: Eq + Hash + Clone { + fn get( + &'db self, + db: &'db dyn Db, + key: &Type, + ) -> Option<&'db codegen_sdk_common::hash::FxIndexSet>; +} // Resolve a given string name in a scope to a given type pub trait Scope<'db>: Sized { - type Type; - type ReferenceType: ResolveType<'db, Self, Type = Self::Type>; - fn resolve(self, db: &'db dyn salsa::Database, name: String) -> &'db Vec; + type Type: Eq + Hash + Clone + HasId<'db>; + type Dependencies: Dependencies<'db, FullyQualifiedName<'db>, Self::ReferenceType>; + type ReferenceType: ResolveType<'db, Type = Self::Type> + Eq + Hash + Clone; + fn resolve(self, db: &'db dyn Db, name: String) -> &'db Vec; /// Get all the resolvables (IE: function_calls) in the scope - fn resolvables(self, db: &'db dyn salsa::Database) -> Vec; + fn resolvables(self, db: &'db dyn Db) -> Vec; + fn compute_dependencies_query(self, db: &'db dyn Db) -> &'db Self::Dependencies; + fn compute_dependencies( + self, + db: &'db dyn Db, + ) -> codegen_sdk_common::hash::FxHashMap< + FullyQualifiedName<'db>, + codegen_sdk_common::hash::FxIndexSet, + > + where + Self: 'db, + { + let mut dependencies: codegen_sdk_common::hash::FxHashMap< + FullyQualifiedName<'db>, + codegen_sdk_common::hash::FxIndexSet, + > = codegen_sdk_common::hash::FxHashMap::default(); + for reference in self.resolvables(db) { + let resolved = reference.clone().resolve_type(db); + for resolved in resolved { + dependencies + .entry(resolved.fully_qualified_name(db)) + .or_default() + .insert(reference.clone()); + } + } + dependencies + } } diff --git a/languages/codegen-sdk-go/Cargo.toml b/languages/codegen-sdk-go/Cargo.toml index 0d8c919..611be6d 100644 --- a/languages/codegen-sdk-go/Cargo.toml +++ b/languages/codegen-sdk-go/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-java/Cargo.toml b/languages/codegen-sdk-java/Cargo.toml index b49e867..8a041d3 100644 --- a/languages/codegen-sdk-java/Cargo.toml +++ b/languages/codegen-sdk-java/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-javascript/Cargo.toml b/languages/codegen-sdk-javascript/Cargo.toml index 9b96f6a..9c6f8b8 100644 --- a/languages/codegen-sdk-javascript/Cargo.toml +++ b/languages/codegen-sdk-javascript/Cargo.toml @@ -14,6 +14,7 @@ indextree ={ workspace = true } subenum = {workspace = true} bytes = { workspace = true } codegen-sdk-cst = { workspace = true } +codegen-sdk-resolution = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } [build-dependencies] diff --git a/languages/codegen-sdk-json/Cargo.toml b/languages/codegen-sdk-json/Cargo.toml index f2bb976..c56732c 100644 --- a/languages/codegen-sdk-json/Cargo.toml +++ b/languages/codegen-sdk-json/Cargo.toml @@ -14,6 +14,7 @@ subenum = {workspace = true} bytes = { workspace = true } codegen-sdk-cst = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } log = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } diff --git a/languages/codegen-sdk-json/src/lib.rs b/languages/codegen-sdk-json/src/lib.rs index 81485e6..5fb5f2c 100644 --- a/languages/codegen-sdk-json/src/lib.rs +++ b/languages/codegen-sdk-json/src/lib.rs @@ -20,7 +20,7 @@ mod tests { "; let db = codegen_sdk_cst::CSTDatabase::default(); let module = crate::cst::JSON::parse(&db, content.to_string()).unwrap(); - let (root, tree) = module; + let (root, tree, _) = module; assert!(root.children(tree).len() > 0); } } diff --git a/languages/codegen-sdk-jsx/Cargo.toml b/languages/codegen-sdk-jsx/Cargo.toml index bd203ae..5422b2d 100644 --- a/languages/codegen-sdk-jsx/Cargo.toml +++ b/languages/codegen-sdk-jsx/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-markdown/Cargo.toml b/languages/codegen-sdk-markdown/Cargo.toml index 420a66c..3f5246d 100644 --- a/languages/codegen-sdk-markdown/Cargo.toml +++ b/languages/codegen-sdk-markdown/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-python/Cargo.toml b/languages/codegen-sdk-python/Cargo.toml index b14c697..ddec4aa 100644 --- a/languages/codegen-sdk-python/Cargo.toml +++ b/languages/codegen-sdk-python/Cargo.toml @@ -17,13 +17,14 @@ codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } codegen-sdk-resolution = { workspace = true } - +memchr = { version = "2" } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } codegen-sdk-common = { workspace = true, features = ["python"] } env_logger = { workspace = true } log = { workspace = true } + [dev-dependencies] test-log = { workspace = true } tempfile = {workspace = true} diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index 5e3dfba..d9a3d74 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -9,47 +9,276 @@ pub mod ast { use codegen_sdk_resolution::{ResolveType, Scope}; include!(concat!(env!("OUT_DIR"), "/python-ast.rs")); #[salsa::tracked] + impl<'db> Import<'db> { + #[salsa::tracked] + fn resolve_import( + self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> Option> { + let root_path = self.root_path(db); + let module = self.module(db).source().replace(".", "/"); + let target_path = root_path.join(module).with_extension("py"); + log::info!(target: "resolution", "Resolving import to path: {:?}", target_path); + target_path + .canonicalize() + .ok() + .map(|path| codegen_sdk_common::FileNodeId::new(db, path)) + } + } + #[salsa::tracked] + pub struct PythonDependencies<'db> { + #[id] + id: codegen_sdk_common::FileNodeId<'db>, + #[return_ref] + #[tracked] + #[no_eq] + pub dependencies: codegen_sdk_common::hash::FxHashMap< + codegen_sdk_resolution::FullyQualifiedName<'db>, + codegen_sdk_common::hash::FxIndexSet>, + >, + } + impl<'db> + codegen_sdk_resolution::Dependencies< + 'db, + codegen_sdk_resolution::FullyQualifiedName<'db>, + crate::ast::Call<'db>, + > for PythonDependencies<'db> + { + fn get( + &'db self, + db: &'db dyn codegen_sdk_resolution::Db, + key: &codegen_sdk_resolution::FullyQualifiedName<'db>, + ) -> Option<&'db codegen_sdk_common::hash::FxIndexSet>> { + self.dependencies(db).get(key) + } + } + #[salsa::tracked(return_ref, no_eq)] + pub fn dependencies<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + input: codegen_sdk_common::FileNodeId<'db>, + ) -> PythonDependencies<'db> { + let file = parse(db, input); + PythonDependencies::new(db, file.id(db), file.compute_dependencies(db)) + } + #[salsa::tracked(return_ref, no_eq)] + pub fn dependency_keys<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + input: codegen_sdk_common::FileNodeId<'db>, + ) -> codegen_sdk_common::hash::FxHashSet> { + let dependencies = dependencies(db, input); + dependencies.dependencies(db).keys().cloned().collect() + } + #[salsa::tracked] + struct UsagesInput<'db> { + #[id] + input: codegen_sdk_common::FileNodeId<'db>, + name: codegen_sdk_resolution::FullyQualifiedName<'db>, + } + #[salsa::tracked(return_ref)] + pub fn usages<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + input: UsagesInput<'db>, + ) -> codegen_sdk_common::hash::FxIndexSet> { + let file = parse(db, input.input(db)); + let mut results = codegen_sdk_common::hash::FxIndexSet::default(); + for reference in file.resolvables(db) { + let resolved = reference.clone().resolve_type(db); + for resolved in resolved { + if resolved.fully_qualified_name(db) == input.name(db) { + results.insert(reference); + continue; + } + } + } + results + } + #[salsa::tracked] impl<'db> Scope<'db> for PythonFile<'db> { - type Type = crate::cst::FunctionDefinition<'db>; - type ReferenceType = crate::cst::Call<'db>; + type Type = crate::ast::Symbol<'db>; + type Dependencies = PythonDependencies<'db>; + type ReferenceType = crate::ast::Call<'db>; #[salsa::tracked(return_ref)] - fn resolve(self, db: &'db dyn salsa::Database, name: String) -> Vec { + fn resolve(self, db: &'db dyn codegen_sdk_resolution::Db, name: String) -> Vec { let tree = self.node(db).unwrap().tree(db); let mut results = Vec::new(); - for (def_name, defs) in self.definitions(db).functions(db, &tree).into_iter() { - if def_name == name { - results.extend(defs.into_iter().cloned()); + for (def_name, defs) in self.definitions(db).functions(db).into_iter() { + if *def_name == name { + results.extend( + defs.into_iter() + .cloned() + .map(|def| crate::ast::Symbol::Function(def)), + ); + } + } + for (def_name, defs) in self.definitions(db).imports(db).into_iter() { + if *def_name == name { + for def in defs { + results.push(crate::ast::Symbol::Import(def.clone())); + for resolved in def.resolve_type(db) { + results.push(resolved.clone()); + } + } } } results } #[salsa::tracked] - fn resolvables(self, db: &'db dyn salsa::Database) -> Vec { + fn resolvables(self, db: &'db dyn codegen_sdk_resolution::Db) -> Vec { let mut results = Vec::new(); - let tree = self.node(db).unwrap().tree(db); - for (_, refs) in self.references(db).calls(db, &tree).into_iter() { + for (_, refs) in self.references(db).calls(db).into_iter() { results.extend(refs.into_iter().cloned()); } results } + fn compute_dependencies( + self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> codegen_sdk_common::hash::FxHashMap< + codegen_sdk_resolution::FullyQualifiedName<'db>, + codegen_sdk_common::hash::FxIndexSet, + > + where + Self: 'db, + { + let mut dependencies: codegen_sdk_common::hash::FxHashMap< + codegen_sdk_resolution::FullyQualifiedName<'db>, + codegen_sdk_common::hash::FxIndexSet, + > = codegen_sdk_common::hash::FxHashMap::default(); + for reference in self.resolvables(db) { + let resolved = reference.clone().resolve_type(db); + for resolved in resolved { + dependencies + .entry(resolved.fully_qualified_name(db)) + .or_default() + .insert(reference.clone()); + } + } + dependencies + } + + #[salsa::tracked(return_ref)] + fn compute_dependencies_query( + self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> PythonDependencies<'db> { + PythonDependencies::new(db, self.id(db), self.compute_dependencies(db)) + } } #[salsa::tracked] - impl<'db> ResolveType<'db, PythonFile<'db>> for crate::cst::Call<'db> { - type Type = crate::cst::FunctionDefinition<'db>; + impl<'db> ResolveType<'db> for crate::ast::Import<'db> { + type Type = crate::ast::Symbol<'db>; #[salsa::tracked(return_ref)] - fn resolve_type( - self, - db: &'db dyn salsa::Database, - scope: PythonFile<'db>, - _scopes: Vec>, - ) -> Vec { + fn resolve_type(self, db: &'db dyn codegen_sdk_resolution::Db) -> Vec { + let target_path = self.resolve_import(db); + if let Some(target_path) = target_path { + if let Some(_) = db.get_file_for_id(target_path) { + return PythonFile::parse(db, target_path) + .resolve(db, self.name(db).source()) + .to_vec(); + } + } + Vec::new() + } + } + #[salsa::tracked] + impl<'db> ResolveType<'db> for crate::ast::Call<'db> { + type Type = crate::ast::Symbol<'db>; + #[salsa::tracked(return_ref)] + fn resolve_type(self, db: &'db dyn codegen_sdk_resolution::Db) -> Vec { + let scope = self.file(db); let tree = scope.node(db).unwrap().tree(db); - scope.resolve(db, self.function(tree).source()).clone() + scope + .resolve(db, self.node(db).function(tree).source()) + .clone() + } + } + use codegen_sdk_resolution::{Db, Dependencies, HasId}; + // #[salsa::tracked(return_ref)] + pub fn references_for_file<'db>( + db: &'db dyn Db, + file: codegen_sdk_common::FileNodeId<'db>, + ) -> usize { + let parsed = parse(db, file); + let definitions = parsed.definitions(db); + let functions = definitions.functions(db); + let mut total_references = 0; + let total_functions = functions.len(); + let functions = functions + .into_iter() + .map(|(_, functions)| functions) + .flatten() + .map(|function| function.fully_qualified_name(db)) + .collect::>(); + let files = codegen_sdk_resolution::files(db); + log::info!(target: "resolution", "Finding references across {:?} files", files.len()); + let mut results = 0; + for input in files.into_iter() { + let keys = dependency_keys(db, input.clone()); + if keys.is_disjoint(&functions) { + continue; + } + // if !self.filter(db, &input) { + // continue; + // } + // let input = UsagesInput::new(db, input.clone(), name.clone()); + // results.extend(usages(db, input)); + let dependencies = dependencies(db, input.clone()); + for function in functions.iter() { + if let Some(references) = dependencies.get(db, function) { + results += references.len(); + } + } + } + results + } + pub fn references_impl<'db>( + db: &'db dyn Db, + name: codegen_sdk_resolution::FullyQualifiedName<'db>, + ) -> Vec> { + let files = codegen_sdk_resolution::files(db); + log::info!(target: "resolution", "Finding references across {:?} files", files.len()); + let mut results = Vec::new(); + for input in files.into_iter() { + let keys = dependency_keys(db, input.clone()); + if keys.contains(&name) { + // if !self.filter(db, &input) { + // continue; + // } + // let input = UsagesInput::new(db, input.clone(), name.clone()); + // results.extend(usages(db, input)); + let dependencies = dependencies(db, input.clone()); + if let Some(references) = dependencies.get(db, &name) { + results.extend(references.iter().cloned()); + } + } } + results } #[salsa::tracked] - impl<'db> codegen_sdk_resolution::References<'db, crate::cst::Call<'db>, PythonFile<'db>> - for crate::cst::FunctionDefinition<'db> + impl<'db> + codegen_sdk_resolution::References< + 'db, + PythonDependencies<'db>, + crate::ast::Call<'db>, + PythonFile<'db>, + > for crate::ast::Symbol<'db> { + fn references(self, db: &'db dyn Db) -> Vec> { + references_impl(db, self.fully_qualified_name(db)) + } + fn filter( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + input: &codegen_sdk_cst::File, + ) -> bool { + match self { + crate::ast::Symbol::Function(function) => { + let content = input.content(db); + let target = function.name(db).text(); + memchr::memmem::find(&content.as_bytes(), &target).is_some() + } + _ => true, + } + } } } diff --git a/languages/codegen-sdk-python/tests/test_python.rs b/languages/codegen-sdk-python/tests/test_python.rs index 5d86b67..9e3768f 100644 --- a/languages/codegen-sdk-python/tests/test_python.rs +++ b/languages/codegen-sdk-python/tests/test_python.rs @@ -1,13 +1,29 @@ #![recursion_limit = "512"] -use std::path::PathBuf; +use std::{env, path::PathBuf}; use codegen_sdk_ast::{Definitions, References}; use codegen_sdk_resolution::References as _; -fn write_to_temp_file(content: &str, temp_dir: &tempfile::TempDir) -> PathBuf { - let file_path = temp_dir.path().join("test.ts"); +fn write_to_temp_file_with_name( + content: &str, + temp_dir: &tempfile::TempDir, + name: &str, +) -> PathBuf { + let file_path = temp_dir.path().join(name); std::fs::write(&file_path, content).unwrap(); file_path } +fn parse_file<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + content: &str, + temp_dir: &tempfile::TempDir, + name: &str, +) -> &'db codegen_sdk_python::ast::PythonFile<'db> { + let file_path = write_to_temp_file_with_name(content, temp_dir, name); + db.input(file_path.clone()).unwrap(); + let file_node_id = codegen_sdk_common::FileNodeId::new(db, file_path); + let file = codegen_sdk_python::ast::parse(db, file_node_id); + file +} // TODO: Fix queries for classes and functions // #[test_log::test] // fn test_typescript_ast_class() { @@ -31,13 +47,9 @@ fn test_python_ast_class() { let content = " class Test: pass"; - let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); - let file = codegen_sdk_python::ast::parse_query(&db, input); - let tree = file.tree(&db); - assert_eq!(file.definitions(&db).classes(&db, &tree).len(), 1); + let file = parse_file(&db, content, &temp_dir, "filea.py"); + assert_eq!(file.definitions(&db).classes(&db).len(), 1); } #[test_log::test] fn test_python_ast_function() { @@ -45,36 +57,61 @@ fn test_python_ast_function() { let content = " def test(): pass"; - let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); - let file = codegen_sdk_python::ast::parse_query(&db, input); - let tree = file.tree(&db); - assert_eq!(file.definitions(&db).functions(&db, &tree).len(), 1); + let file = parse_file(&db, content, &temp_dir, "filea.py"); + assert_eq!(file.definitions(&db).functions(&db).len(), 1); } +// +// for function in codebase.functions(): +// function.rename("test2") +// codebase.commit() + +// 3 bounds +// 1. Codebase updated, everything else is invalidated +// 2. Files + codebase updated, everything else is invalidated +// 3. Everything updated, nothing is invalidated + #[test_log::test] fn test_python_ast_function_usages() { let temp_dir = tempfile::tempdir().unwrap(); + assert!(env::set_current_dir(&temp_dir).is_ok()); + let content = " +def test(): + pass + +test()"; + let db = codegen_sdk_cst::CSTDatabase::default(); + let file = parse_file(&db, content, &temp_dir, "filea.py"); + assert_eq!(file.references(&db).calls(&db).len(), 1); + let definitions = file.definitions(&db); + let functions = definitions.functions(&db); + let function = functions.get("test").unwrap().first().unwrap(); + let function = codegen_sdk_python::ast::Symbol::Function(function.clone().clone()); + assert_eq!(function.references(&db).len(), 1); +} +#[test_log::test] +fn test_python_ast_function_usages_cross_file() { + let temp_dir = tempfile::tempdir().unwrap(); + assert!(env::set_current_dir(&temp_dir).is_ok()); let content = " def test(): pass +"; + let usage_file_content = " +from filea import test test()"; - let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); - let file = codegen_sdk_python::ast::parse_query(&db, input); - let tree = file.tree(&db); - assert_eq!(file.references(&db).calls(&db, &tree).len(), 1); + let file = parse_file(&db, content, &temp_dir, "filea.py"); + let usage_file = parse_file(&db, usage_file_content, &temp_dir, "fileb.py"); + assert_eq!(usage_file.references(&db).calls(&db).len(), 1); let definitions = file.definitions(&db); - let functions = definitions.functions(&db, &tree); + let functions = definitions.functions(&db); let function = functions.get("test").unwrap().first().unwrap(); - assert_eq!( - function - .references_for_scopes(&db, vec![*file], &file) - .len(), - 1 - ); + let function = codegen_sdk_python::ast::Symbol::Function(function.clone().clone()); + let imports = usage_file.definitions(&db).imports(&db); + let import = imports.get("test").unwrap().first().unwrap(); + let import = codegen_sdk_python::ast::Symbol::Import(import.clone().clone()); + assert_eq!(import.references(&db,).len(), 1); + assert_eq!(function.references(&db,).len(), 1); } diff --git a/languages/codegen-sdk-ruby/Cargo.toml b/languages/codegen-sdk-ruby/Cargo.toml index bb55830..5f89054 100644 --- a/languages/codegen-sdk-ruby/Cargo.toml +++ b/languages/codegen-sdk-ruby/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-rust/Cargo.toml b/languages/codegen-sdk-rust/Cargo.toml index af96d5f..f432491 100644 --- a/languages/codegen-sdk-rust/Cargo.toml +++ b/languages/codegen-sdk-rust/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-toml/Cargo.toml b/languages/codegen-sdk-toml/Cargo.toml index 2fb48ba..8fe907a 100644 --- a/languages/codegen-sdk-toml/Cargo.toml +++ b/languages/codegen-sdk-toml/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-tsx/Cargo.toml b/languages/codegen-sdk-tsx/Cargo.toml index d26ebc6..3dd20cd 100644 --- a/languages/codegen-sdk-tsx/Cargo.toml +++ b/languages/codegen-sdk-tsx/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-typescript/Cargo.toml b/languages/codegen-sdk-typescript/Cargo.toml index 56445c6..17b34f7 100644 --- a/languages/codegen-sdk-typescript/Cargo.toml +++ b/languages/codegen-sdk-typescript/Cargo.toml @@ -15,6 +15,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } indextree = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } diff --git a/languages/codegen-sdk-typescript/tests/test_typescript.rs b/languages/codegen-sdk-typescript/tests/test_typescript.rs index c7a5920..2a2b11b 100644 --- a/languages/codegen-sdk-typescript/tests/test_typescript.rs +++ b/languages/codegen-sdk-typescript/tests/test_typescript.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use codegen_sdk_ast::Definitions; +use codegen_sdk_resolution::Db; fn write_to_temp_file(content: &str, temp_dir: &tempfile::TempDir) -> PathBuf { let file_path = temp_dir.path().join("test.ts"); std::fs::write(&file_path, content).unwrap(); @@ -27,12 +28,11 @@ fn write_to_temp_file(content: &str, temp_dir: &tempfile::TempDir) -> PathBuf { #[test_log::test] fn test_typescript_ast_interface() { let temp_dir = tempfile::tempdir().unwrap(); - let content = "interface Test { }"; - let file_path = write_to_temp_file(content, &temp_dir); + let content = "interface Test { }".to_string(); + let file_path = write_to_temp_file(&content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); - let file = codegen_sdk_typescript::ast::parse_query(&db, input); - let tree = file.node(&db).unwrap().tree(&db); - assert_eq!(file.definitions(&db).interfaces(&db, &tree).len(), 1); + db.input(file_path.clone()).unwrap(); + let input = codegen_sdk_common::FileNodeId::new(&db, file_path); + let file = codegen_sdk_typescript::ast::parse(&db, input); + assert_eq!(file.definitions(&db).interfaces(&db).len(), 1); } diff --git a/languages/codegen-sdk-yaml/Cargo.toml b/languages/codegen-sdk-yaml/Cargo.toml index 2efebd0..1b306a3 100644 --- a/languages/codegen-sdk-yaml/Cargo.toml +++ b/languages/codegen-sdk-yaml/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/src/main.rs b/src/main.rs index c387676..28ef18d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,64 +2,54 @@ use std::{path::PathBuf, time::Instant}; use clap::Parser; -use codegen_sdk_analyzer::{Codebase, ParsedFile}; +use codegen_sdk_analyzer::{Codebase, ParsedFile, parse_file}; use codegen_sdk_ast::Definitions; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; use codegen_sdk_core::system::get_memory; -use codegen_sdk_resolution::{CodebaseContext, References}; +use codegen_sdk_resolution::CodebaseContext; #[derive(Debug, Parser)] struct Args { input: String, } +// #[salsa::tracked] +fn get_definitions<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + file: codegen_sdk_common::FileNodeId<'db>, +) -> (usize, usize, usize, usize, usize, usize) { + if let Some(parsed) = parse_file(db, file).file(db) { + #[cfg(feature = "typescript")] + if let ParsedFile::Typescript(file) = parsed { + let definitions = file.definitions(db); + return ( + definitions.classes(db).len(), + definitions.functions(db).len(), + definitions.interfaces(db).len(), + definitions.methods(db).len(), + definitions.modules(db).len(), + 0, + ); + } + #[cfg(feature = "python")] + if let ParsedFile::Python(file) = parsed { + let definitions = file.definitions(db); + let functions = definitions.functions(db); + let total_references = codegen_sdk_python::ast::references_for_file(db, file.id(db)); + return ( + definitions.classes(db).len(), + functions.len(), + 0, + 0, + 0, + total_references, + ); + } + } + (0, 0, 0, 0, 0, 0) +} fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize, usize, usize)> { - codebase - .files() - .into_iter() - .map(|parsed| { - #[cfg(feature = "typescript")] - if let ParsedFile::Typescript(file) = parsed { - let definitions = file.definitions(codebase.db()); - if let Some(node) = file.node(codebase.db()) { - let tree = node.tree(codebase.db()); - return ( - definitions.classes(codebase.db(), &tree).len(), - definitions.functions(codebase.db(), &tree).len(), - definitions.interfaces(codebase.db(), &tree).len(), - definitions.methods(codebase.db(), &tree).len(), - definitions.modules(codebase.db(), &tree).len(), - 0, - ); - } - } - #[cfg(feature = "python")] - if let ParsedFile::Python(file) = parsed { - let definitions = file.definitions(codebase.db()); - let tree = file.node(codebase.db()).unwrap().tree(codebase.db()); - let functions = definitions.functions(codebase.db(), &tree); - let mut total_references = 0; - let total_functions = functions.len(); - for function in functions - .into_iter() - .map(|(_, functions)| functions) - .flatten() - { - total_references += function - .references_for_scopes(codebase.db(), vec![*file], &file) - .len(); - } - return ( - definitions.classes(codebase.db(), &tree).len(), - total_functions, - 0, - 0, - 0, - total_references, - ); - } - (0, 0, 0, 0, 0, 0) - }) - .collect() + log::info!("Getting total definitions"); + codebase.execute_op_with_progress("Getting Usages", true, |db, file| get_definitions(db, file)) } fn print_definitions(codebase: &Codebase) { let mut total_classes = 0;