From 95764ff3b5883252060d27d53e00826ed4f02cca Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 6 Mar 2025 09:25:38 -0800 Subject: [PATCH 01/16] WIP: ast --- Cargo.lock | 2 + codegen-sdk-ast-generator/Cargo.toml | 1 + codegen-sdk-ast-generator/src/query.rs | 306 ++++++++++++++---- ...ast_generator__visitor__tests__python.snap | 148 +++++---- ...generator__visitor__tests__typescript.snap | 184 +++++------ codegen-sdk-ast-generator/src/visitor.rs | 61 +++- codegen-sdk-common/src/language/python.rs | 7 +- codegen-sdk-common/src/tree/tree.rs | 3 + codegen-sdk-cst-generator/src/generator.rs | 8 +- codegen-sdk-cst/Cargo.toml | 1 + codegen-sdk-cst/src/language.rs | 24 +- languages/codegen-sdk-python/src/lib.rs | 38 ++- .../codegen-sdk-python/tests/test_python.rs | 63 +++- 13 files changed, 582 insertions(+), 264 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 860f9b1a..443fe155 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -364,6 +364,7 @@ dependencies = [ "codegen-sdk-ts_query", "convert_case", "derive_more", + "indextree", "insta", "log", "proc-macro2", @@ -446,6 +447,7 @@ dependencies = [ "codegen-sdk-common", "convert_case", "dashmap", + "indextree", "log", "rkyv", "salsa", diff --git a/codegen-sdk-ast-generator/Cargo.toml b/codegen-sdk-ast-generator/Cargo.toml index 574d4b30..ddf4a180 100644 --- a/codegen-sdk-ast-generator/Cargo.toml +++ b/codegen-sdk-ast-generator/Cargo.toml @@ -16,6 +16,7 @@ codegen-sdk-ts_query = { workspace = true } convert_case = { workspace = true } salsa = { workspace = true } syn = { workspace = true } +indextree = { workspace = true } [dev-dependencies] test-log = { workspace = true } insta = { workspace = true } diff --git a/codegen-sdk-ast-generator/src/query.rs b/codegen-sdk-ast-generator/src/query.rs index 319344d3..a0fd32fa 100644 --- a/codegen-sdk-ast-generator/src/query.rs +++ b/codegen-sdk-ast-generator/src/query.rs @@ -1,4 +1,7 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, +}; use codegen_sdk_common::{ CSTNode, HasChildren, Language, Tree, @@ -8,10 +11,22 @@ use codegen_sdk_cst::CSTLanguage; use codegen_sdk_cst_generator::{Config, Field, State}; use codegen_sdk_ts_query::cst as ts_query; use derive_more::Debug; +use indextree::NodeId; use log::{debug, info, warn}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; +use syn::parse_quote; use ts_query::NodeTypes; +fn name_for_capture<'a>(capture: &'a ts_query::Capture<'a>) -> String { + full_name_for_capture(capture) + .split(".") + .last() + .unwrap() + .to_string() +} +fn full_name_for_capture<'a>(capture: &'a ts_query::Capture<'a>) -> String { + capture.source().split_off(1) +} fn captures_for_field_definition<'a>( node: &'a ts_query::FieldDefinition<'a>, tree: &'a Tree>, @@ -25,6 +40,54 @@ fn captures_for_field_definition<'a>( ts_query::FieldDefinitionChildrenRef::FieldDefinition(field) => { captures.extend(captures_for_field_definition(&field, tree)); } + ts_query::FieldDefinitionChildrenRef::Grouping(grouping) => { + captures.extend(captures_for_grouping(&grouping, tree)); + } + ts_query::FieldDefinitionChildrenRef::List(list) => { + captures.extend(captures_for_list(&list, tree)); + } + _ => {} + } + } + captures.into_iter() +} +fn captures_for_list<'a>( + list: &'a ts_query::List<'a>, + tree: &'a Tree>, +) -> impl Iterator> { + let mut captures = Vec::new(); + for child in list.children(tree) { + match child { + ts_query::ListChildrenRef::NamedNode(named) => { + captures.extend(captures_for_named_node(&named, tree)); + } + ts_query::ListChildrenRef::List(list) => { + captures.extend(captures_for_list(&list, tree)); + } + ts_query::ListChildrenRef::FieldDefinition(field) => { + captures.extend(captures_for_field_definition(&field, tree)); + } + _ => {} + } + } + captures.into_iter() +} +fn captures_for_grouping<'a>( + grouping: &'a ts_query::Grouping<'a>, + tree: &'a Tree>, +) -> impl Iterator> { + let mut captures = Vec::new(); + for child in grouping.children(tree) { + match child { + ts_query::GroupingChildrenRef::NamedNode(named) => { + captures.extend(captures_for_named_node(&named, tree)); + } + ts_query::GroupingChildrenRef::Grouping(grouping) => { + captures.extend(captures_for_grouping(&grouping, tree)); + } + ts_query::GroupingChildrenRef::FieldDefinition(field) => { + captures.extend(captures_for_field_definition(&field, tree)); + } _ => {} } } @@ -44,6 +107,12 @@ fn captures_for_named_node<'a>( ts_query::NamedNodeChildrenRef::FieldDefinition(field) => { captures.extend(captures_for_field_definition(&field, tree)); } + ts_query::NamedNodeChildrenRef::Grouping(grouping) => { + captures.extend(captures_for_grouping(&grouping, tree)); + } + ts_query::NamedNodeChildrenRef::List(list) => { + captures.extend(captures_for_list(&list, tree)); + } _ => {} } } @@ -54,6 +123,7 @@ pub struct Query<'a> { node: &'a ts_query::NamedNode<'a>, language: &'a Language, tree: &'a Tree>, + root_id: NodeId, pub(crate) state: Arc>, } impl<'a> Query<'a> { @@ -63,14 +133,15 @@ impl<'a> Query<'a> { language: &'a Language, ) -> BTreeMap { let result = ts_query::Query::parse(db, source.to_string()).unwrap(); - let (parsed, tree) = result; + let (parsed, tree, program_id) = result; let config = Config::default(); let state = Arc::new(State::new(language, config)); let mut queries = BTreeMap::new(); for node in parsed.children(tree) { match node { ts_query::ProgramChildrenRef::NamedNode(named) => { - let query = Self::from_named_node(&named, language, state.clone(), tree); + let query = + Self::from_named_node(&named, language, state.clone(), tree, program_id); queries.insert(query.name(), query); } node => { @@ -102,12 +173,14 @@ impl<'a> Query<'a> { language: &'a Language, state: Arc>, tree: &'a Tree>, + root_id: indextree::NodeId, ) -> Self { Query { node: named, language: language, state: state, tree: tree, + root_id: root_id, } } /// Get the kind of the query (the node to be matched) @@ -187,6 +260,11 @@ impl<'a> Query<'a> { format_ident!("{}s", name) } } + pub fn symbol_name(&self) -> Ident { + let raw_name = self.name(); + let name = raw_name.split(".").last().unwrap(); + format_ident!("{}", normalize_type_name(name, true)) + } fn get_field_for_field_name(&self, field_name: &str, struct_name: &str) -> Option<&Field> { debug!( "Getting field for: {:#?} on node: {:#?}", @@ -208,7 +286,7 @@ impl<'a> Query<'a> { field: &ts_query::FieldDefinition, struct_name: &str, current_node: &Ident, - name_value: Option, + query_values: &mut HashMap, ) -> TokenStream { let other_child: ts_query::NodeTypesRef = field .children(self.tree) @@ -227,7 +305,7 @@ impl<'a> Query<'a> { &normalized_struct_name, other_child.clone(), &field_name, - name_value, + query_values, ); // assert!( // wrapped.to_string().len() > 0, @@ -236,7 +314,13 @@ impl<'a> Query<'a> { // other_child.source(), // other_child.kind() // ); - if !field.is_optional() { + if field.is_multiple() { + return quote! { + for #field_name in #current_node.#field_name(tree) { + #wrapped + } + }; + } else if !field.is_optional() { return quote! { let #field_name = #current_node.#field_name(tree); #wrapped @@ -270,7 +354,7 @@ impl<'a> Query<'a> { node: &ts_query::Grouping, struct_name: &str, current_node: &Ident, - name_value: Option, + query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); for group in node.children(self.tree) { @@ -278,7 +362,7 @@ impl<'a> Query<'a> { struct_name, group.into(), current_node, - name_value.clone(), + query_values, ); matchers.extend_one(result); } @@ -291,7 +375,7 @@ impl<'a> Query<'a> { target_kind: &str, current_node: &Ident, remaining_nodes: Vec>, - name_value: Option, + query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); let mut field_matchers = TokenStream::new(); @@ -316,14 +400,14 @@ impl<'a> Query<'a> { &target_name, child.into(), current_node, - name_value.clone(), + query_values, )); } else { let result = self.get_matcher_for_definition( &target_name, child.into(), &format_ident!("child"), - name_value.clone(), + query_values, ); if let Some(ref variant) = comment_variant { @@ -382,21 +466,19 @@ impl<'a> Query<'a> { &'b self, node: &'b ts_query::NamedNode<'b>, first_node: &ts_query::NamedNodeChildrenRef<'_>, - mut name_value: Option, + query_values: &mut HashMap, current_node: &Ident, - ) -> (Option, Vec>) { + ) -> Vec> { let mut prev = first_node.clone(); let mut remaining_nodes = Vec::new(); + log::info!("Grouping children for: {:#?}", node.source()); for child in node.children(self.tree).into_iter().skip(1) { if child.kind_name() == "capture" { - if child.source() == "@name" { - log::info!( - "Found @name! prev: {:#?}, {:#?}", - prev.source(), - prev.kind_name() - ); + let capture_name = name_for_capture(child.try_into().unwrap()); + if self.target_capture_names().contains(&capture_name) { match prev { ts_query::NamedNodeChildrenRef::FieldDefinition(field) => { + log::info!("Found @{}! on field: {:#?}", capture_name, field.source(),); let field_name = field .name(self.tree) .iter() @@ -404,24 +486,40 @@ impl<'a> Query<'a> { .map(|c| format_ident!("{}", c.source())) .next() .unwrap(); - name_value = Some(quote! { - #current_node.#field_name.source() - }); + query_values.insert( + capture_name, + quote! { + + #current_node.#field_name + }, + ); } ts_query::NamedNodeChildrenRef::Identifier(named) => { log::info!( - "Found @name! prev: {:#?}, {:#?}", + "Found @{}! prev: {:#?}, {:#?}", + capture_name, named.source(), named.kind_name() ); - name_value = Some(quote! { - #current_node.source() - }); + query_values.insert( + capture_name, + quote! { + #current_node + }, + ); } ts_query::NamedNodeChildrenRef::AnonymousUnderscore(_) => { - name_value = Some(quote! { - #current_node.source() - }); + log::info!( + "Found @{}! on anonymous underscore: {:#?}", + capture_name, + node.source() + ); + query_values.insert( + capture_name, + quote! { + #current_node + }, + ); } _ => panic!( "Unexpected prev: {:#?}, source: {:#?}. Query: {:#?}", @@ -430,33 +528,31 @@ impl<'a> Query<'a> { self.node().source() ), } - break; } continue; } prev = child.clone(); remaining_nodes.push(child); } - (name_value, remaining_nodes) + remaining_nodes } fn get_matcher_for_named_node( &self, node: &ts_query::NamedNode, struct_name: &str, current_node: &Ident, - name_value: Option, + query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); let first_node = node.children(self.tree).into_iter().next().unwrap(); - let (name_value, remaining_nodes) = - self.group_children(node, &first_node, name_value, current_node); + let remaining_nodes = self.group_children(node, &first_node, query_values, current_node); if remaining_nodes.len() == 0 { log::info!("single node, {}", first_node.source()); return self.get_matcher_for_definition( struct_name, first_node.into(), current_node, - name_value, + query_values, ); } @@ -469,7 +565,7 @@ impl<'a> Query<'a> { name_node.kind(), current_node, remaining_nodes, - name_value, + query_values, ); matchers.extend_one(matcher); } else { @@ -489,7 +585,7 @@ impl<'a> Query<'a> { variant.kind(), current_node, remaining_nodes.clone(), - name_value.clone(), + query_values, ); matchers.extend_one(matcher); } @@ -498,27 +594,43 @@ impl<'a> Query<'a> { #matchers } } - fn get_default_matcher(&self, name_value: Option) -> TokenStream { + fn get_default_matcher(&self, query_values: &mut HashMap) -> TokenStream { let to_append = self.executor_id(); - if let Some(name_value) = name_value { - return quote! { - #to_append.entry(#name_value).or_default().push(id); - }; + let mut args = Vec::new(); + for target in self.target_capture_names() { + if let Some(value) = query_values.get(&target) { + args.push(value); + } else { + log::warn!("No value found for: {} on {}", target, self.node().source()); + return quote! {}; + } } - log::warn!("No name value found for: {}", self.node().source()); - quote! {} + let name = query_values.get("name").unwrap_or_else(|| { + panic!( + "No name found for: {}. Query_values: {:#?} Target Capture Names: {:#?}", + self.node().source(), + query_values, + self.target_capture_names() + ); + }); + let symbol_name = self.symbol_name(); + return quote! { + let symbol = #symbol_name::new(db, id, node.clone(), #(#args.clone()),*); + #to_append.entry(#name.source()).or_default().push(symbol); + }; } fn get_matcher_for_identifier( &self, identifier: &ts_query::Identifier, struct_name: &str, current_node: &Ident, - name_value: Option, + query_values: &mut HashMap, ) -> TokenStream { // We have 2 nodes, the parent node and the identifier node - let to_append = self.get_default_matcher(name_value); + let to_append = self.get_default_matcher(query_values); // Case 1: The identifier is the same as the struct name (IE: we know this is the corrent node) - if normalize_type_name(&identifier.source(), true) == struct_name { + let target_name = normalize_type_name(&identifier.source(), true); + if target_name == struct_name { return to_append; } // Case 2: We have a node for the parent struct @@ -539,7 +651,13 @@ impl<'a> Query<'a> { } else { // Case 3: This is a subenum // If this is a field, we may be dealing with multiple types and can't operate over all of them - return to_append; // TODO: Handle this case + let target_name = format_ident!("{}", target_name); + let struct_name = format_ident!("{}Ref", struct_name); + return quote! { + if let crate::cst::#struct_name::#target_name(#current_node) = #current_node { + #to_append + } + }; // TODO: Handle this case } } fn get_matcher_for_definition( @@ -547,21 +665,21 @@ impl<'a> Query<'a> { struct_name: &str, node: ts_query::NodeTypesRef, current_node: &Ident, - name_value: Option, + query_values: &mut HashMap, ) -> TokenStream { if !node.is_named() { - return self.get_default_matcher(name_value); + return self.get_default_matcher(query_values); } match node { ts_query::NodeTypesRef::FieldDefinition(field) => { - self.get_matcher_for_field(&field, struct_name, current_node, name_value) + self.get_matcher_for_field(&field, struct_name, current_node, query_values) } ts_query::NodeTypesRef::Capture(named) => { info!("Capture: {:#?}", named.source()); quote! {} } ts_query::NodeTypesRef::NamedNode(named) => { - self.get_matcher_for_named_node(&named, struct_name, current_node, name_value) + self.get_matcher_for_named_node(&named, struct_name, current_node, query_values) } ts_query::NodeTypesRef::Comment(_) => { quote! {} @@ -572,7 +690,7 @@ impl<'a> Query<'a> { struct_name, child.into(), current_node, - name_value.clone(), + query_values, ); // Currently just returns the first child return result; // TODO: properly handle list @@ -580,11 +698,14 @@ impl<'a> Query<'a> { quote! {} } ts_query::NodeTypesRef::Grouping(grouping) => { - self.get_matchers_for_grouping(&grouping, struct_name, current_node, name_value) - } - ts_query::NodeTypesRef::Identifier(identifier) => { - self.get_matcher_for_identifier(&identifier, struct_name, current_node, name_value) + self.get_matchers_for_grouping(&grouping, struct_name, current_node, query_values) } + ts_query::NodeTypesRef::Identifier(identifier) => self.get_matcher_for_identifier( + &identifier, + struct_name, + current_node, + query_values, + ), unhandled => { log::warn!( "Unhandled definition in language {}: {:#?}, {:#?}", @@ -592,7 +713,7 @@ impl<'a> Query<'a> { unhandled.kind_name(), unhandled.source() ); - self.get_default_matcher(name_value) + self.get_default_matcher(query_values) } } } @@ -605,10 +726,11 @@ impl<'a> Query<'a> { struct_name }; let starting_node = format_ident!("node"); - let (name_value, remaining_nodes) = self.group_children( + let mut query_values = HashMap::new(); + let remaining_nodes = self.group_children( &self.node(), &self.node().children(self.tree).into_iter().next().unwrap(), - None, + &mut query_values, &starting_node, ); return self._get_matcher_for_named_node( @@ -617,9 +739,71 @@ impl<'a> Query<'a> { kind, &starting_node, remaining_nodes, - name_value, + &mut query_values, ); } + fn target_captures(&self) -> Vec<&ts_query::Capture> { + let mut captures: Vec<&ts_query::Capture> = self + .captures() + .into_iter() + .filter(|c| !full_name_for_capture(c).contains(".")) + .collect(); + captures.sort_by_key(|c| full_name_for_capture(c)); + captures.dedup_by_key(|c| full_name_for_capture(c)); + captures + } + fn target_capture_names(&self) -> Vec { + self.target_captures() + .into_iter() + .map(|c| name_for_capture(c)) + .collect() + } + pub fn struct_fields(&self) -> Vec { + let mut fields = Vec::new(); + for capture in self.target_captures() { + let name = name_for_capture(capture); + let mut type_name = format_ident!("NodeTypes"); + for (node, id) in self.tree.descendants(&self.root_id) { + if let ts_query::NodeTypesRef::Capture(other) = node.as_ref() { + if other == capture { + let mut preceding_siblings = + id.preceding_siblings(self.tree.arena()).skip(1); + while let Some(prev) = preceding_siblings.next() { + if let Some(prev_capture) = self.tree.arena().get(prev) { + match prev_capture.get().as_ref() { + ts_query::NodeTypesRef::NamedNode(prev_capture) => { + type_name = format_ident!( + "{}", + normalize_type_name(&prev_capture.source(), true) + ); + break; + } + ts_query::NodeTypesRef::Identifier(prev_capture) => { + type_name = format_ident!( + "{}", + normalize_type_name(&prev_capture.source(), true) + ); + break; + } + ts_query::NodeTypesRef::AnonymousUnderscore(_) => { + break; // Could be any type + } + _ => { + panic!("Unexpected capture: {:#?}", prev_capture); + } + } + } + } + } + } + } + let name_ident = format_ident!("{}", name); + fields.push(parse_quote!( + pub #name_ident: crate::cst::#type_name<'db> + )); + } + fields + } } pub trait HasQuery { diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap index f655a511..e2901c3f 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap @@ -3,34 +3,97 @@ source: codegen-sdk-ast-generator/src/visitor.rs expression: "codegen_sdk_common::generator::format_code_string(&visitor.to_string()).unwrap()" --- #[salsa::tracked] +pub struct Class<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::ClassDefinition<'db>, + pub name: crate::cst::Identifier<'db>, +} +#[salsa::tracked] +pub struct Constant<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::Module<'db>, + pub name: crate::cst::Identifier<'db>, +} +#[salsa::tracked] +pub struct Function<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::FunctionDefinition<'db>, + pub name: crate::cst::Identifier<'db>, +} +#[salsa::tracked] +pub struct Import<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::ImportFromStatement<'db>, + pub module: crate::cst::NodeTypes<'db>, + pub name: crate::cst::NodeTypes<'db>, +} +#[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] +pub enum Symbol<'db> { + Class(Class<'db>), + Constant(Constant<'db>), + Function(Function<'db>), + Import(Import<'db>), +} +#[salsa::tracked] pub struct Definitions<'db> { #[return_ref] - pub _classes: BTreeMap>, + pub classes: BTreeMap>>, + #[return_ref] + pub constants: BTreeMap>>, #[return_ref] - pub _constants: BTreeMap>, + pub functions: BTreeMap>>, #[return_ref] - pub _functions: BTreeMap>, + pub imports: BTreeMap>>, } impl<'db> Definitions<'db> { pub fn visit( db: &'db dyn salsa::Database, root: &'db crate::cst::Parsed<'db>, ) -> Self { - let mut classes: BTreeMap> = BTreeMap::new(); - let mut constants: BTreeMap> = BTreeMap::new(); - let mut functions: BTreeMap> = BTreeMap::new(); + let mut classes: BTreeMap>> = BTreeMap::new(); + let mut constants: BTreeMap>> = BTreeMap::new(); + let mut functions: BTreeMap>> = BTreeMap::new(); + let mut imports: BTreeMap>> = BTreeMap::new(); let tree = root.tree(db); for (node, id) in tree.descendants(&root.program(db)) { match node { crate::cst::NodeTypes::ClassDefinition(node) => { ///Code for query: (class_definition name: (identifier) @name) @definition.class let name = node.name(tree); - classes.entry(name.source()).or_default().push(id); + let symbol = Class::new(db, id, node.clone(), name.clone()); + classes.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::FunctionDefinition(node) => { ///Code for query: (function_definition name: (identifier) @name) @definition.function let name = node.name(tree); - functions.entry(name.source()).or_default().push(id); + let symbol = Function::new(db, id, node.clone(), name.clone()); + functions.entry(name.source()).or_default().push(symbol); + } + crate::cst::NodeTypes::ImportFromStatement(node) => { + ///Code for query: (import_from_statement module_name: (_) @module name: (_) @name) @definition.import + let module_name = node.module_name(tree); + for name in node.name(tree) { + let symbol = Import::new( + db, + id, + node.clone(), + module_name.clone(), + name.clone(), + ); + imports.entry(name.source()).or_default().push(symbol); + } } crate::cst::NodeTypes::Module(node) => { ///Code for query: (module (expression_statement (assignment left: (identifier) @name) @definition.constant)) @@ -45,7 +108,15 @@ impl<'db> Definitions<'db> { ) = child { ///Code for query: (module (expression_statement (assignment left: (identifier) @name) @definition.constant)) let left = child.left(tree); - constants.entry(left.source()).or_default().push(id); + if let crate::cst::AssignmentLeftRef::Identifier(left) = left { + let symbol = Constant::new( + db, + id, + node.clone(), + left.clone(), + ); + constants.entry(left.source()).or_default().push(symbol); + } } break; } @@ -56,60 +127,13 @@ impl<'db> Definitions<'db> { _ => {} } } - Self::new(db, classes, constants, functions) + Self::new(db, classes, constants, functions, imports) } pub fn default(db: &'db dyn salsa::Database) -> Self { - let mut classes: BTreeMap> = BTreeMap::new(); - let mut constants: BTreeMap> = BTreeMap::new(); - let mut functions: BTreeMap> = BTreeMap::new(); - Self::new(db, classes, constants, functions) - } - pub fn classes( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._classes(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn constants( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._constants(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn functions( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._functions(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() + let mut classes: BTreeMap>> = BTreeMap::new(); + let mut constants: BTreeMap>> = BTreeMap::new(); + let mut functions: BTreeMap>> = BTreeMap::new(); + let mut imports: BTreeMap>> = BTreeMap::new(); + Self::new(db, classes, constants, functions, imports) } } diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap index b3bee56b..451434dc 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap @@ -3,55 +3,119 @@ source: codegen-sdk-ast-generator/src/visitor.rs expression: "codegen_sdk_common::generator::format_code_string(&visitor.to_string()).unwrap()" --- #[salsa::tracked] +pub struct Class<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::AbstractClassDeclaration<'db>, + pub name: crate::cst::TypeIdentifier<'db>, +} +#[salsa::tracked] +pub struct Function<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::FunctionSignature<'db>, + pub name: crate::cst::Identifier<'db>, +} +#[salsa::tracked] +pub struct Interface<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::InterfaceDeclaration<'db>, + pub name: crate::cst::TypeIdentifier<'db>, +} +#[salsa::tracked] +pub struct Method<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::AbstractMethodSignature<'db>, + pub name: crate::cst::PropertyIdentifier<'db>, +} +#[salsa::tracked] +pub struct Module<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::Module<'db>, + pub name: crate::cst::Identifier<'db>, +} +#[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] +pub enum Symbol<'db> { + Class(Class<'db>), + Function(Function<'db>), + Interface(Interface<'db>), + Method(Method<'db>), + Module(Module<'db>), +} +#[salsa::tracked] pub struct Definitions<'db> { #[return_ref] - pub _classes: BTreeMap>, + pub classes: BTreeMap>>, #[return_ref] - pub _functions: BTreeMap>, + pub functions: BTreeMap>>, #[return_ref] - pub _interfaces: BTreeMap>, + pub interfaces: BTreeMap>>, #[return_ref] - pub _methods: BTreeMap>, + pub methods: BTreeMap>>, #[return_ref] - pub _modules: BTreeMap>, + pub modules: BTreeMap>>, } impl<'db> Definitions<'db> { pub fn visit( db: &'db dyn salsa::Database, root: &'db crate::cst::Parsed<'db>, ) -> Self { - let mut classes: BTreeMap> = BTreeMap::new(); - let mut functions: BTreeMap> = BTreeMap::new(); - let mut interfaces: BTreeMap> = BTreeMap::new(); - let mut methods: BTreeMap> = BTreeMap::new(); - let mut modules: BTreeMap> = BTreeMap::new(); + let mut classes: BTreeMap>> = BTreeMap::new(); + let mut functions: BTreeMap>> = BTreeMap::new(); + let mut interfaces: BTreeMap>> = BTreeMap::new(); + let mut methods: BTreeMap>> = BTreeMap::new(); + let mut modules: BTreeMap>> = BTreeMap::new(); let tree = root.tree(db); for (node, id) in tree.descendants(&root.program(db)) { match node { crate::cst::NodeTypes::AbstractClassDeclaration(node) => { ///Code for query: (abstract_class_declaration name: (type_identifier) @name) @definition.class let name = node.name(tree); - classes.entry(name.source()).or_default().push(id); + let symbol = Class::new(db, id, node.clone(), name.clone()); + classes.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::AbstractMethodSignature(node) => { ///Code for query: (abstract_method_signature name: (property_identifier) @name) @definition.method let name = node.name(tree); - methods.entry(name.source()).or_default().push(id); + if let crate::cst::AbstractMethodSignatureNameRef::PropertyIdentifier( + name, + ) = name { + let symbol = Method::new(db, id, node.clone(), name.clone()); + methods.entry(name.source()).or_default().push(symbol); + } } crate::cst::NodeTypes::FunctionSignature(node) => { ///Code for query: (function_signature name: (identifier) @name) @definition.function let name = node.name(tree); - functions.entry(name.source()).or_default().push(id); + let symbol = Function::new(db, id, node.clone(), name.clone()); + functions.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::InterfaceDeclaration(node) => { ///Code for query: (interface_declaration name: (type_identifier) @name) @definition.interface let name = node.name(tree); - interfaces.entry(name.source()).or_default().push(id); + let symbol = Interface::new(db, id, node.clone(), name.clone()); + interfaces.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::Module(node) => { ///Code for query: (module name: (identifier) @name) @definition.module let name = node.name(tree); - modules.entry(name.source()).or_default().push(id); + if let crate::cst::ModuleNameRef::Identifier(name) = name { + let symbol = Module::new(db, id, node.clone(), name.clone()); + modules.entry(name.source()).or_default().push(symbol); + } } _ => {} } @@ -59,91 +123,11 @@ impl<'db> Definitions<'db> { Self::new(db, classes, functions, interfaces, methods, modules) } pub fn default(db: &'db dyn salsa::Database) -> Self { - let mut classes: BTreeMap> = BTreeMap::new(); - let mut functions: BTreeMap> = BTreeMap::new(); - let mut interfaces: BTreeMap> = BTreeMap::new(); - let mut methods: BTreeMap> = BTreeMap::new(); - let mut modules: BTreeMap> = BTreeMap::new(); + let mut classes: BTreeMap>> = BTreeMap::new(); + let mut functions: BTreeMap>> = BTreeMap::new(); + let mut interfaces: BTreeMap>> = BTreeMap::new(); + let mut methods: BTreeMap>> = BTreeMap::new(); + let mut modules: BTreeMap>> = BTreeMap::new(); Self::new(db, classes, functions, interfaces, methods, modules) } - pub fn classes( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._classes(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn functions( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._functions(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn interfaces( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._interfaces(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn methods( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._methods(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } - pub fn modules( - &self, - db: &'db dyn salsa::Database, - tree: &'db codegen_sdk_common::tree::Tree>, - ) -> BTreeMap>> { - self._modules(db) - .iter() - .map(|(k, v)| ( - k.clone(), - v - .iter() - .map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()) - .collect(), - )) - .collect() - } } diff --git a/codegen-sdk-ast-generator/src/visitor.rs b/codegen-sdk-ast-generator/src/visitor.rs index 7c3826a1..2cca0de4 100644 --- a/codegen-sdk-ast-generator/src/visitor.rs +++ b/codegen-sdk-ast-generator/src/visitor.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet}; use codegen_sdk_common::Language; use convert_case::{Case, Casing}; use proc_macro2::{Span, TokenStream}; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, quote_spanned}; use syn::parse_quote_spanned; use super::query::Query; @@ -24,9 +24,11 @@ pub fn generate_visitor<'db>( let mut types = Vec::new(); let mut variants = BTreeSet::new(); let mut enter_methods = BTreeMap::new(); + let mut symbol_names = Vec::new(); for query in queries { names.push(query.executor_id()); types.push(format_ident!("{}", query.struct_name())); + symbol_names.push(query.symbol_name()); for variant in query.struct_variants() { variants.insert(format_ident!("{}", variant)); enter_methods @@ -36,7 +38,7 @@ pub fn generate_visitor<'db>( } } let mut methods: Vec = Vec::new(); - for (variant, queries) in enter_methods { + for (variant, queries) in enter_methods.iter() { let mut matchers = TokenStream::new(); let struct_name = format_ident!("{}", variant); for query in queries { @@ -49,15 +51,54 @@ pub fn generate_visitor<'db>( } }); } + + let symbol_name = if name == "definition" { + format_ident!("Symbol") + } else { + format_ident!("Reference") + }; let maps = quote! { #( - let mut #names: BTreeMap> = BTreeMap::new(); + let mut #names: BTreeMap>> = BTreeMap::new(); )* }; let constructor = quote! { Self::new(db, #(#names),*) }; + let mut defs = Vec::new(); + for (variant, type_name) in symbol_names.iter().zip(types.iter()) { + let query = enter_methods + .get(&type_name.to_string()) + .unwrap() + .first() + .unwrap(); + let fields = query.struct_fields(); + let span = Span::mixed_site(); + defs.push(quote_spanned! { + span => + #[salsa::tracked] + pub struct #variant<'db> { + #[id] + node_id: indextree::NodeId, + #[tracked] + #[return_ref] + pub node: crate::cst::#type_name<'db>, + #(#fields),* + } + }); + } + let symbol = quote! { + #( + #defs + )* + #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] + pub enum #symbol_name<'db> { + #( + #symbol_names(#symbol_names<'db>), + )* + } + }; let name = format_ident!("{}s", name.to_case(Case::Pascal)); let output_constructor = quote! { pub fn visit(db: &'db dyn salsa::Database, root: &'db crate::cst::Parsed<'db>) -> Self { @@ -76,11 +117,9 @@ pub fn generate_visitor<'db>( #constructor } }; - let underscored_names = names - .iter() - .map(|name| format_ident!("_{}", name)) - .collect::>(); + quote! { + #symbol // Three lifetimes: // db: the lifetime of the database // db1: the lifetime of the visitor executing per-node @@ -89,17 +128,11 @@ pub fn generate_visitor<'db>( pub struct #name<'db> { #( #[return_ref] - pub #underscored_names: BTreeMap>, + pub #names: BTreeMap>>, )* } impl<'db> #name<'db> { #output_constructor - #( - pub fn #names(&self, db: &'db dyn salsa::Database, tree: &'db codegen_sdk_common::tree::Tree>) -> BTreeMap>> { - self.#underscored_names(db).iter().map(|(k, v)| - (k.clone(), v.iter().map(|id| tree.get(id).unwrap().as_ref().try_into().unwrap()).collect())).collect() - } - )* } } } diff --git a/codegen-sdk-common/src/language/python.rs b/codegen-sdk-common/src/language/python.rs index 62188a17..efb15341 100644 --- a/codegen-sdk-common/src/language/python.rs +++ b/codegen-sdk-common/src/language/python.rs @@ -1,12 +1,17 @@ use super::Language; +const PYTHON_TAGS_QUERY: &'static str = tree_sitter_python::TAGS_QUERY; +const EXTRA_TAGS_QUERY: &'static str = " + (import_from_statement module_name: (dotted_name) @module name: (dotted_name) @name) @definition.import + "; lazy_static! { + static ref TAGS_QUERY: String = [PYTHON_TAGS_QUERY, EXTRA_TAGS_QUERY].join("\n"); pub static ref Python: Language = Language::new( "python", "Python", tree_sitter_python::NODE_TYPES, &["py"], tree_sitter_python::LANGUAGE.into(), - tree_sitter_python::TAGS_QUERY, + &TAGS_QUERY, ) .unwrap(); } diff --git a/codegen-sdk-common/src/tree/tree.rs b/codegen-sdk-common/src/tree/tree.rs index ae884cd6..01518f8b 100644 --- a/codegen-sdk-common/src/tree/tree.rs +++ b/codegen-sdk-common/src/tree/tree.rs @@ -39,6 +39,9 @@ impl Tree { id.children(&self.ids) .map(|id| (self.get(&id).unwrap(), id)) } + pub fn arena(&self) -> &Arena { + &self.ids + } } unsafe impl Update for Tree where diff --git a/codegen-sdk-cst-generator/src/generator.rs b/codegen-sdk-cst-generator/src/generator.rs index 08276586..252f010f 100644 --- a/codegen-sdk-cst-generator/src/generator.rs +++ b/codegen-sdk-cst-generator/src/generator.rs @@ -102,13 +102,13 @@ fn get_parser(language: &Language) -> TokenStream { fn language() -> &'static codegen_sdk_common::language::Language { &codegen_sdk_common::language::#language_name::#language_struct_name } - fn parse<'db>(db: &'db dyn salsa::Database, content: std::string::String) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { + fn parse<'db>(db: &'db dyn salsa::Database, content: std::string::String) -> Option<(&'db Self::Program<'db>, &'db Tree>, indextree::NodeId)> { let input = codegen_sdk_cst::Input::new(db, content); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } } diff --git a/codegen-sdk-cst/Cargo.toml b/codegen-sdk-cst/Cargo.toml index d9d952b9..e4c13402 100644 --- a/codegen-sdk-cst/Cargo.toml +++ b/codegen-sdk-cst/Cargo.toml @@ -13,6 +13,7 @@ log = { workspace = true } salsa = { workspace = true } dashmap = "6.1.0" thiserror = {workspace = true} +indextree = { workspace = true } [dev-dependencies] tempfile = { workspace = true } test-log = { workspace = true } diff --git a/codegen-sdk-cst/src/language.rs b/codegen-sdk-cst/src/language.rs index 8fe82c04..de3d9905 100644 --- a/codegen-sdk-cst/src/language.rs +++ b/codegen-sdk-cst/src/language.rs @@ -15,12 +15,23 @@ pub trait CSTLanguage { fn parse<'db>( db: &'db dyn salsa::Database, content: String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)>; + ) -> Option<( + &'db Self::Program<'db>, + &'db Tree>, + indextree::NodeId, + )>; fn parse_file_from_cache<'db>( db: &'db dyn salsa::Database, file_path: &PathBuf, #[cfg(feature = "serialization")] cache: &'db codegen_sdk_common::serialize::Cache, - ) -> Result, &'db Tree>)>, ParseError> { + ) -> Result< + Option<( + &'db Self::Program<'db>, + &'db Tree>, + indextree::NodeId, + )>, + ParseError, + > { #[cfg(feature = "serialization")] { let serialized_path = cache.get_path(file_path); @@ -35,7 +46,14 @@ pub trait CSTLanguage { db: &'db dyn salsa::Database, file_path: &PathBuf, #[cfg(feature = "serialization")] cache: &'db codegen_sdk_common::serialize::Cache, - ) -> Result, &'db Tree>)>, ParseError> { + ) -> Result< + Option<( + &'db Self::Program<'db>, + &'db Tree>, + indextree::NodeId, + )>, + ParseError, + > { if let Some(parsed) = Self::parse_file_from_cache( db, file_path, diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index 5e3dfba4..0535cb13 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -10,15 +10,28 @@ pub mod ast { include!(concat!(env!("OUT_DIR"), "/python-ast.rs")); #[salsa::tracked] impl<'db> Scope<'db> for PythonFile<'db> { - type Type = crate::cst::FunctionDefinition<'db>; - type ReferenceType = crate::cst::Call<'db>; + type Type = crate::ast::Symbol<'db>; + type ReferenceType = crate::ast::Call<'db>; #[salsa::tracked(return_ref)] fn resolve(self, db: &'db dyn salsa::Database, name: String) -> Vec { let tree = self.node(db).unwrap().tree(db); let mut results = Vec::new(); - for (def_name, defs) in self.definitions(db).functions(db, &tree).into_iter() { - if def_name == name { - results.extend(defs.into_iter().cloned()); + for (def_name, defs) in self.definitions(db).functions(db).into_iter() { + if *def_name == name { + results.extend( + defs.into_iter() + .cloned() + .map(|def| crate::ast::Symbol::Function(def)), + ); + } + } + for (def_name, defs) in self.definitions(db).imports(db).into_iter() { + if *def_name == name { + results.extend( + defs.into_iter() + .cloned() + .map(|def| crate::ast::Symbol::Import(def)), + ); } } results @@ -26,16 +39,15 @@ pub mod ast { #[salsa::tracked] fn resolvables(self, db: &'db dyn salsa::Database) -> Vec { let mut results = Vec::new(); - let tree = self.node(db).unwrap().tree(db); - for (_, refs) in self.references(db).calls(db, &tree).into_iter() { + for (_, refs) in self.references(db).calls(db).into_iter() { results.extend(refs.into_iter().cloned()); } results } } #[salsa::tracked] - impl<'db> ResolveType<'db, PythonFile<'db>> for crate::cst::Call<'db> { - type Type = crate::cst::FunctionDefinition<'db>; + impl<'db> ResolveType<'db, PythonFile<'db>> for crate::ast::Call<'db> { + type Type = crate::ast::Symbol<'db>; #[salsa::tracked(return_ref)] fn resolve_type( self, @@ -44,12 +56,14 @@ pub mod ast { _scopes: Vec>, ) -> Vec { let tree = scope.node(db).unwrap().tree(db); - scope.resolve(db, self.function(tree).source()).clone() + scope + .resolve(db, self.node(db).function(tree).source()) + .clone() } } #[salsa::tracked] - impl<'db> codegen_sdk_resolution::References<'db, crate::cst::Call<'db>, PythonFile<'db>> - for crate::cst::FunctionDefinition<'db> + impl<'db> codegen_sdk_resolution::References<'db, crate::ast::Call<'db>, PythonFile<'db>> + for crate::ast::Symbol<'db> { } } diff --git a/languages/codegen-sdk-python/tests/test_python.rs b/languages/codegen-sdk-python/tests/test_python.rs index 5d86b675..8a0635fd 100644 --- a/languages/codegen-sdk-python/tests/test_python.rs +++ b/languages/codegen-sdk-python/tests/test_python.rs @@ -4,7 +4,14 @@ use std::path::PathBuf; use codegen_sdk_ast::{Definitions, References}; use codegen_sdk_resolution::References as _; fn write_to_temp_file(content: &str, temp_dir: &tempfile::TempDir) -> PathBuf { - let file_path = temp_dir.path().join("test.ts"); + write_to_temp_file_with_name(content, temp_dir, "test.py") +} +fn write_to_temp_file_with_name( + content: &str, + temp_dir: &tempfile::TempDir, + name: &str, +) -> PathBuf { + let file_path = temp_dir.path().join(name); std::fs::write(&file_path, content).unwrap(); file_path } @@ -36,8 +43,7 @@ class Test: let content = codegen_sdk_cst::Input::new(&db, content.to_string()); let input = codegen_sdk_ast::input::File::new(&db, file_path, content); let file = codegen_sdk_python::ast::parse_query(&db, input); - let tree = file.tree(&db); - assert_eq!(file.definitions(&db).classes(&db, &tree).len(), 1); + assert_eq!(file.definitions(&db).classes(&db).len(), 1); } #[test_log::test] fn test_python_ast_function() { @@ -50,8 +56,7 @@ def test(): let content = codegen_sdk_cst::Input::new(&db, content.to_string()); let input = codegen_sdk_ast::input::File::new(&db, file_path, content); let file = codegen_sdk_python::ast::parse_query(&db, input); - let tree = file.tree(&db); - assert_eq!(file.definitions(&db).functions(&db, &tree).len(), 1); + assert_eq!(file.definitions(&db).functions(&db).len(), 1); } #[test_log::test] fn test_python_ast_function_usages() { @@ -67,10 +72,11 @@ test()"; let input = codegen_sdk_ast::input::File::new(&db, file_path, content); let file = codegen_sdk_python::ast::parse_query(&db, input); let tree = file.tree(&db); - assert_eq!(file.references(&db).calls(&db, &tree).len(), 1); + assert_eq!(file.references(&db).calls(&db).len(), 1); let definitions = file.definitions(&db); - let functions = definitions.functions(&db, &tree); + let functions = definitions.functions(&db); let function = functions.get("test").unwrap().first().unwrap(); + let function = codegen_sdk_python::ast::Symbol::Function(function.clone().clone()); assert_eq!( function .references_for_scopes(&db, vec![*file], &file) @@ -78,3 +84,46 @@ test()"; 1 ); } +#[test_log::test] +fn test_python_ast_function_usages_cross_file() { + let temp_dir = tempfile::tempdir().unwrap(); + let content = " +def test(): + pass + +"; + let usage_file_content = " +from filea import test +test()"; + let file_path = write_to_temp_file_with_name(content, &temp_dir, "filea.py"); + let usage_file_path = write_to_temp_file_with_name(usage_file_content, &temp_dir, "fileb.py"); + let db = codegen_sdk_cst::CSTDatabase::default(); + let content = codegen_sdk_cst::Input::new(&db, content.to_string()); + let usage_content = codegen_sdk_cst::Input::new(&db, usage_file_content.to_string()); + let input = codegen_sdk_ast::input::File::new(&db, file_path, content); + let usage_input = codegen_sdk_ast::input::File::new(&db, usage_file_path, usage_content); + let file = codegen_sdk_python::ast::parse_query(&db, input); + let usage_file = codegen_sdk_python::ast::parse_query(&db, usage_input); + let tree = file.tree(&db); + let usage_tree = usage_file.tree(&db); + assert_eq!(usage_file.references(&db).calls(&db).len(), 1); + let definitions = file.definitions(&db); + let functions = definitions.functions(&db); + let function = functions.get("test").unwrap().first().unwrap(); + let function = codegen_sdk_python::ast::Symbol::Function(function.clone().clone()); + let imports = usage_file.definitions(&db).imports(&db); + let import = imports.get("test").unwrap().first().unwrap(); + let import = codegen_sdk_python::ast::Symbol::Import(import.clone().clone()); + assert_eq!( + import + .references_for_scopes(&db, vec![*usage_file], &usage_file) + .len(), + 1 + ); + assert_eq!( + function + .references_for_scopes(&db, vec![*file, *usage_file], &file) + .len(), + 1 + ); +} From ac40f8552a935452713c87139d2b8fdc475796a1 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 6 Mar 2025 10:25:07 -0800 Subject: [PATCH 02/16] Redo visitor generation --- codegen-sdk-ast-generator/src/query.rs | 100 ++++++++++++++---- ...ast_generator__visitor__tests__python.snap | 37 ++++--- ...generator__visitor__tests__typescript.snap | 5 + 3 files changed, 108 insertions(+), 34 deletions(-) diff --git a/codegen-sdk-ast-generator/src/query.rs b/codegen-sdk-ast-generator/src/query.rs index a0fd32fa..32d48bf6 100644 --- a/codegen-sdk-ast-generator/src/query.rs +++ b/codegen-sdk-ast-generator/src/query.rs @@ -286,6 +286,7 @@ impl<'a> Query<'a> { field: &ts_query::FieldDefinition, struct_name: &str, current_node: &Ident, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, query_values: &mut HashMap, ) -> TokenStream { let other_child: ts_query::NodeTypesRef = field @@ -297,6 +298,7 @@ impl<'a> Query<'a> { .into(); for name in &field.name(self.tree) { if let ts_query::FieldDefinitionNameRef::Identifier(identifier) = name { + let doc = format!("Code for field: {}", field.source()); let name = normalize_field_name(&identifier.source()); if let Some(field) = self.get_field_for_field_name(&name, struct_name) { let field_name = format_ident!("{}", name); @@ -305,6 +307,7 @@ impl<'a> Query<'a> { &normalized_struct_name, other_child.clone(), &field_name, + existing, query_values, ); // assert!( @@ -316,17 +319,20 @@ impl<'a> Query<'a> { // ); if field.is_multiple() { return quote! { + #[doc = #doc] for #field_name in #current_node.#field_name(tree) { #wrapped } }; } else if !field.is_optional() { return quote! { + #[doc = #doc] let #field_name = #current_node.#field_name(tree); #wrapped }; } else { return quote! { + #[doc = #doc] if let Some(#field_name) = #current_node.#field_name(tree) { #wrapped } @@ -354,6 +360,7 @@ impl<'a> Query<'a> { node: &ts_query::Grouping, struct_name: &str, current_node: &Ident, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); @@ -362,6 +369,7 @@ impl<'a> Query<'a> { struct_name, group.into(), current_node, + existing, query_values, ); matchers.extend_one(result); @@ -378,7 +386,7 @@ impl<'a> Query<'a> { query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); - let mut field_matchers = TokenStream::new(); + let mut field_matchers = Vec::new(); let mut comment_variant = None; let variants = self .state @@ -395,18 +403,14 @@ impl<'a> Query<'a> { } for child in remaining_nodes { - if child.kind_name() == "field_definition" { - field_matchers.extend_one(self.get_matcher_for_definition( - &target_name, - child.into(), - current_node, - query_values, - )); + if let ts_query::NamedNodeChildrenRef::FieldDefinition(_) = child { + field_matchers.push((child.into(), target_name, current_node)); } else { let result = self.get_matcher_for_definition( &target_name, child.into(), &format_ident!("child"), + &mut Vec::new(), query_values, ); @@ -439,6 +443,17 @@ impl<'a> Query<'a> { "Code for query: {}", &self.node().source().replace("\n", " ") // Newlines mess with quote's doc comments ); + let field_matchers = if let Some(prev) = field_matchers.pop() { + self.get_matcher_for_definition( + &prev.1, + prev.0, + &prev.2, + &mut field_matchers, + query_values, + ) + } else { + quote! {} + }; if matchers.is_empty() && field_matchers.is_empty() { return quote! {}; } @@ -471,7 +486,11 @@ impl<'a> Query<'a> { ) -> Vec> { let mut prev = first_node.clone(); let mut remaining_nodes = Vec::new(); - log::info!("Grouping children for: {:#?}", node.source()); + log::info!( + "Grouping children for: {:#?} of kind: {:#?}", + node.source(), + node.kind_name() + ); for child in node.children(self.tree).into_iter().skip(1) { if child.kind_name() == "capture" { let capture_name = name_for_capture(child.try_into().unwrap()); @@ -541,6 +560,7 @@ impl<'a> Query<'a> { node: &ts_query::NamedNode, struct_name: &str, current_node: &Ident, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, query_values: &mut HashMap, ) -> TokenStream { let mut matchers = TokenStream::new(); @@ -552,6 +572,7 @@ impl<'a> Query<'a> { struct_name, first_node.into(), current_node, + existing, query_values, ); } @@ -594,7 +615,26 @@ impl<'a> Query<'a> { #matchers } } - fn get_default_matcher(&self, query_values: &mut HashMap) -> TokenStream { + fn get_default_matcher( + &self, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, + query_values: &mut HashMap, + ) -> TokenStream { + if let Some(prev) = existing.pop() { + log::info!( + "Executing previous matcher on: {:#?} with {:#?}", + prev.0.source(), + query_values + ); + return self.get_matcher_for_definition( + &prev.1, + prev.0, + &prev.2, + existing, + query_values, + ); + } + let to_append = self.executor_id(); let mut args = Vec::new(); for target in self.target_capture_names() { @@ -624,10 +664,11 @@ impl<'a> Query<'a> { identifier: &ts_query::Identifier, struct_name: &str, current_node: &Ident, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, query_values: &mut HashMap, ) -> TokenStream { // We have 2 nodes, the parent node and the identifier node - let to_append = self.get_default_matcher(query_values); + let to_append = self.get_default_matcher(existing, query_values); // Case 1: The identifier is the same as the struct name (IE: we know this is the corrent node) let target_name = normalize_type_name(&identifier.source(), true); if target_name == struct_name { @@ -665,22 +706,31 @@ impl<'a> Query<'a> { struct_name: &str, node: ts_query::NodeTypesRef, current_node: &Ident, + existing: &mut Vec<(ts_query::NodeTypesRef, &str, &Ident)>, query_values: &mut HashMap, ) -> TokenStream { if !node.is_named() { - return self.get_default_matcher(query_values); + return self.get_default_matcher(existing, query_values); } match node { - ts_query::NodeTypesRef::FieldDefinition(field) => { - self.get_matcher_for_field(&field, struct_name, current_node, query_values) - } + ts_query::NodeTypesRef::FieldDefinition(field) => self.get_matcher_for_field( + &field, + struct_name, + current_node, + existing, + query_values, + ), ts_query::NodeTypesRef::Capture(named) => { info!("Capture: {:#?}", named.source()); quote! {} } - ts_query::NodeTypesRef::NamedNode(named) => { - self.get_matcher_for_named_node(&named, struct_name, current_node, query_values) - } + ts_query::NodeTypesRef::NamedNode(named) => self.get_matcher_for_named_node( + &named, + struct_name, + current_node, + existing, + query_values, + ), ts_query::NodeTypesRef::Comment(_) => { quote! {} } @@ -690,6 +740,7 @@ impl<'a> Query<'a> { struct_name, child.into(), current_node, + existing, query_values, ); // Currently just returns the first child @@ -697,13 +748,18 @@ impl<'a> Query<'a> { } quote! {} } - ts_query::NodeTypesRef::Grouping(grouping) => { - self.get_matchers_for_grouping(&grouping, struct_name, current_node, query_values) - } + ts_query::NodeTypesRef::Grouping(grouping) => self.get_matchers_for_grouping( + &grouping, + struct_name, + current_node, + existing, + query_values, + ), ts_query::NodeTypesRef::Identifier(identifier) => self.get_matcher_for_identifier( &identifier, struct_name, current_node, + existing, query_values, ), unhandled => { @@ -713,7 +769,7 @@ impl<'a> Query<'a> { unhandled.kind_name(), unhandled.source() ); - self.get_default_matcher(query_values) + self.get_default_matcher(existing, query_values) } } } diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap index e2901c3f..d776675b 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap @@ -36,8 +36,8 @@ pub struct Import<'db> { #[tracked] #[return_ref] pub node: crate::cst::ImportFromStatement<'db>, - pub module: crate::cst::NodeTypes<'db>, - pub name: crate::cst::NodeTypes<'db>, + pub module: crate::cst::DottedName<'db>, + pub name: crate::cst::DottedName<'db>, } #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] pub enum Symbol<'db> { @@ -71,28 +71,40 @@ impl<'db> Definitions<'db> { match node { crate::cst::NodeTypes::ClassDefinition(node) => { ///Code for query: (class_definition name: (identifier) @name) @definition.class + ///Code for field: name: (identifier) @name let name = node.name(tree); let symbol = Class::new(db, id, node.clone(), name.clone()); classes.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::FunctionDefinition(node) => { ///Code for query: (function_definition name: (identifier) @name) @definition.function + ///Code for field: name: (identifier) @name let name = node.name(tree); let symbol = Function::new(db, id, node.clone(), name.clone()); functions.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::ImportFromStatement(node) => { - ///Code for query: (import_from_statement module_name: (_) @module name: (_) @name) @definition.import - let module_name = node.module_name(tree); + ///Code for query: (import_from_statement module_name: (dotted_name) @module name: (dotted_name) @name) @definition.import + ///Code for field: name: (dotted_name) @name for name in node.name(tree) { - let symbol = Import::new( - db, - id, - node.clone(), - module_name.clone(), - name.clone(), - ); - imports.entry(name.source()).or_default().push(symbol); + if let crate::cst::ImportFromStatementNameRef::DottedName( + name, + ) = name { + ///Code for field: module_name: (dotted_name) @module + let module_name = node.module_name(tree); + if let crate::cst::ImportFromStatementModuleNameRef::DottedName( + module_name, + ) = module_name { + let symbol = Import::new( + db, + id, + node.clone(), + module_name.clone(), + name.clone(), + ); + imports.entry(name.source()).or_default().push(symbol); + } + } } } crate::cst::NodeTypes::Module(node) => { @@ -107,6 +119,7 @@ impl<'db> Definitions<'db> { child, ) = child { ///Code for query: (module (expression_statement (assignment left: (identifier) @name) @definition.constant)) + ///Code for field: left: (identifier) @name let left = child.left(tree); if let crate::cst::AssignmentLeftRef::Identifier(left) = left { let symbol = Constant::new( diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap index 451434dc..87717345 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap @@ -83,12 +83,14 @@ impl<'db> Definitions<'db> { match node { crate::cst::NodeTypes::AbstractClassDeclaration(node) => { ///Code for query: (abstract_class_declaration name: (type_identifier) @name) @definition.class + ///Code for field: name: (type_identifier) @name let name = node.name(tree); let symbol = Class::new(db, id, node.clone(), name.clone()); classes.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::AbstractMethodSignature(node) => { ///Code for query: (abstract_method_signature name: (property_identifier) @name) @definition.method + ///Code for field: name: (property_identifier) @name let name = node.name(tree); if let crate::cst::AbstractMethodSignatureNameRef::PropertyIdentifier( name, @@ -99,18 +101,21 @@ impl<'db> Definitions<'db> { } crate::cst::NodeTypes::FunctionSignature(node) => { ///Code for query: (function_signature name: (identifier) @name) @definition.function + ///Code for field: name: (identifier) @name let name = node.name(tree); let symbol = Function::new(db, id, node.clone(), name.clone()); functions.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::InterfaceDeclaration(node) => { ///Code for query: (interface_declaration name: (type_identifier) @name) @definition.interface + ///Code for field: name: (type_identifier) @name let name = node.name(tree); let symbol = Interface::new(db, id, node.clone(), name.clone()); interfaces.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::Module(node) => { ///Code for query: (module name: (identifier) @name) @definition.module + ///Code for field: name: (identifier) @name let name = node.name(tree); if let crate::cst::ModuleNameRef::Identifier(name) = name { let symbol = Module::new(db, id, node.clone(), name.clone()); From 001e500ff11bd59522fc57ade21454d4c648e4fe Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 6 Mar 2025 10:32:14 -0800 Subject: [PATCH 03/16] Fix bugs --- codegen-sdk-ast-generator/src/query.rs | 8 ++++++- codegen-sdk-ast-generator/src/visitor.rs | 27 ++++++++++++++++-------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/codegen-sdk-ast-generator/src/query.rs b/codegen-sdk-ast-generator/src/query.rs index 32d48bf6..e1ddb635 100644 --- a/codegen-sdk-ast-generator/src/query.rs +++ b/codegen-sdk-ast-generator/src/query.rs @@ -263,7 +263,13 @@ impl<'a> Query<'a> { pub fn symbol_name(&self) -> Ident { let raw_name = self.name(); let name = raw_name.split(".").last().unwrap(); - format_ident!("{}", normalize_type_name(name, true)) + let symbol = format_ident!("{}", normalize_type_name(name, true)); + // References can produce duplicate names. We can be reasonably sure that there is no @definition.call. + if raw_name.starts_with("reference") && !["call"].contains(&name) { + format_ident!("{}Ref", symbol) + } else { + symbol + } } fn get_field_for_field_name(&self, field_name: &str, struct_name: &str) -> Option<&Field> { debug!( diff --git a/codegen-sdk-ast-generator/src/visitor.rs b/codegen-sdk-ast-generator/src/visitor.rs index 2cca0de4..61e8985d 100644 --- a/codegen-sdk-ast-generator/src/visitor.rs +++ b/codegen-sdk-ast-generator/src/visitor.rs @@ -88,15 +88,24 @@ pub fn generate_visitor<'db>( } }); } - let symbol = quote! { - #( - #defs - )* - #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] - pub enum #symbol_name<'db> { - #( - #symbol_names(#symbol_names<'db>), - )* + let symbol = if defs.len() > 0 { + quote! { + #( + #defs + )* + #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] + pub enum #symbol_name<'db> { + #( + #symbol_names(#symbol_names<'db>), + )* + } + } + } else { + quote! { + #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] + pub enum #symbol_name<'db> { + _Phantom(std::marker::PhantomData<&'db ()>) + } } }; let name = format_ident!("{}s", name.to_case(Case::Pascal)); From 11a17b85708a73f73bfeb23e6be8c92d2b2a924d Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 6 Mar 2025 11:17:06 -0800 Subject: [PATCH 04/16] Misc fixes --- Cargo.lock | 1 + Cargo.toml | 9 ++-- codegen-sdk-ast-generator/src/query.rs | 53 ++++++++++++++++++- .../src/generator/utils.rs | 14 +++++ languages/codegen-sdk-json/src/lib.rs | 2 +- src/main.rs | 6 +-- 6 files changed, 76 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 443fe155..f54a79e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -426,6 +426,7 @@ dependencies = [ "codegen-sdk-analyzer", "codegen-sdk-ast", "codegen-sdk-common", + "codegen-sdk-python", "codegen-sdk-resolution", "codegen-sdk-typescript", "criterion", diff --git a/Cargo.toml b/Cargo.toml index f42a448c..085ca6d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,20 +7,21 @@ edition = "2024" [dependencies] clap = { version = "4.5.28", features = ["derive"] } -codegen-sdk-analyzer = { path = "codegen-sdk-analyzer" } +codegen-sdk-analyzer = { path = "codegen-sdk-analyzer", default-features = false } codegen-sdk-ast = { workspace = true} codegen-sdk-common = { workspace = true} anyhow = { workspace = true} salsa = { workspace = true} -codegen-sdk-typescript = { workspace = true} +codegen-sdk-typescript = { workspace = true, optional = true } +codegen-sdk-python = { workspace = true, optional = true } env_logger = { workspace = true } log = { workspace = true } codegen-sdk-resolution = { workspace = true} sysinfo = "0.33.1" rkyv.workspace = true [features] -python = [ "codegen-sdk-analyzer/python"] # TODO: Add python support -typescript = [ "codegen-sdk-analyzer/typescript"] +python = [ "codegen-sdk-analyzer/python", "codegen-sdk-python"] # TODO: Add python support +typescript = [ "codegen-sdk-analyzer/typescript", "codegen-sdk-typescript"] tsx = [ "codegen-sdk-analyzer/tsx"] jsx = [ "codegen-sdk-analyzer/jsx"] javascript = [ "codegen-sdk-analyzer/javascript"] diff --git a/codegen-sdk-ast-generator/src/query.rs b/codegen-sdk-ast-generator/src/query.rs index e1ddb635..f9d4b614 100644 --- a/codegen-sdk-ast-generator/src/query.rs +++ b/codegen-sdk-ast-generator/src/query.rs @@ -661,7 +661,7 @@ impl<'a> Query<'a> { }); let symbol_name = self.symbol_name(); return quote! { - let symbol = #symbol_name::new(db, id, node.clone(), #(#args.clone()),*); + let symbol = #symbol_name::new(db, id, node.clone(), #(#args.clone().into()),*); #to_append.entry(#name.source()).or_default().push(symbol); }; } @@ -820,6 +820,31 @@ impl<'a> Query<'a> { .map(|c| name_for_capture(c)) .collect() } + fn get_type_for_field( + &self, + parent: &ts_query::NamedNode, + field: &ts_query::FieldDefinition, + ) -> String { + let parent_name = normalize_type_name(&parent.name(self.tree).source(), true); + let field_name = normalize_field_name( + &field + .name(self.tree) + .into_iter() + .filter(|n| n.is_named()) + .next() + .unwrap() + .source(), + ); + let parsed_field = self.get_field_for_field_name(&field_name, &parent_name); + if let Some(parsed_field) = parsed_field { + parsed_field.type_name() + } else { + panic!( + "No field found for: {:#?}, {:#?}, {:#?}", + field, field_name, parent_name + ); + } + } pub fn struct_fields(&self) -> Vec { let mut fields = Vec::new(); for capture in self.target_captures() { @@ -848,6 +873,32 @@ impl<'a> Query<'a> { break; } ts_query::NodeTypesRef::AnonymousUnderscore(_) => { + let mut ancestors = id.ancestors(self.tree.arena()).skip(2); + if let Some(field) = ancestors.next() { + if let Some(parent) = ancestors.next() { + // Field definitions (example) + // (new_expression\n constructor: (_) @name) @reference.class + let parent: &ts_query::NamedNode = self + .tree + .get(&parent) + .unwrap() + .as_ref() + .try_into() + .unwrap(); + + let field: &ts_query::FieldDefinition = self + .tree + .get(&field) + .unwrap() + .as_ref() + .try_into() + .unwrap(); + type_name = format_ident!( + "{}", + self.get_type_for_field(parent, field) + ); + } + } break; // Could be any type } _ => { diff --git a/codegen-sdk-cst-generator/src/generator/utils.rs b/codegen-sdk-cst-generator/src/generator/utils.rs index 1dbf87fd..6efd477e 100644 --- a/codegen-sdk-cst-generator/src/generator/utils.rs +++ b/codegen-sdk-cst-generator/src/generator/utils.rs @@ -65,6 +65,20 @@ pub fn get_from_enum_to_ref(enum_name: &str, variant_names: &Vec) -> Toke node.as_ref().into() } } + impl<'db3> From<#name_ref<'db3>> for #name<'db3> { + fn from(node: #name_ref<'db3>) -> Self { + match node { + #(#name_ref::#variant_names(data) => Self::#variant_names((*data).clone()),)* + } + } + } + impl<'db3> From<&'db3 #name_ref<'db3>> for #name<'db3> { + fn from(node: &'db3 #name_ref<'db3>) -> Self { + match node { + #(#name_ref::#variant_names(data) => Self::#variant_names((*data).clone()),)* + } + } + } #( impl<'db3> TryFrom<#name_ref<'db3>> for &'db3 #variant_names<'db3> { type Error = codegen_sdk_cst::ConversionError; diff --git a/languages/codegen-sdk-json/src/lib.rs b/languages/codegen-sdk-json/src/lib.rs index 81485e68..5fb5f2c4 100644 --- a/languages/codegen-sdk-json/src/lib.rs +++ b/languages/codegen-sdk-json/src/lib.rs @@ -20,7 +20,7 @@ mod tests { "; let db = codegen_sdk_cst::CSTDatabase::default(); let module = crate::cst::JSON::parse(&db, content.to_string()).unwrap(); - let (root, tree) = module; + let (root, tree, _) = module; assert!(root.children(tree).len() > 0); } } diff --git a/src/main.rs b/src/main.rs index c387676f..da9e2799 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,21 +35,21 @@ fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize #[cfg(feature = "python")] if let ParsedFile::Python(file) = parsed { let definitions = file.definitions(codebase.db()); - let tree = file.node(codebase.db()).unwrap().tree(codebase.db()); - let functions = definitions.functions(codebase.db(), &tree); + let functions = definitions.functions(codebase.db()); let mut total_references = 0; let total_functions = functions.len(); for function in functions .into_iter() .map(|(_, functions)| functions) .flatten() + .map(|function| codegen_sdk_python::ast::Symbol::Function(function.clone())) { total_references += function .references_for_scopes(codebase.db(), vec![*file], &file) .len(); } return ( - definitions.classes(codebase.db(), &tree).len(), + definitions.classes(codebase.db()).len(), total_functions, 0, 0, From 250c88a7e0bda53b8e49a9f9609163b40bf424fe Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 6 Mar 2025 11:48:00 -0800 Subject: [PATCH 05/16] holy it works --- Cargo.lock | 1 + codegen-sdk-resolution/Cargo.toml | 1 + codegen-sdk-resolution/src/codebase.rs | 1 + codegen-sdk-resolution/src/references.rs | 9 ++-- codegen-sdk-resolution/src/resolve_type.rs | 4 +- codegen-sdk-resolution/src/scope.rs | 11 +++- languages/codegen-sdk-python/src/lib.rs | 51 ++++++++++++++++--- .../codegen-sdk-python/tests/test_python.rs | 19 ++++--- 8 files changed, 77 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f54a79e2..b6f6394a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -633,6 +633,7 @@ dependencies = [ name = "codegen-sdk-resolution" version = "0.1.0" dependencies = [ + "log", "salsa", ] diff --git a/codegen-sdk-resolution/Cargo.toml b/codegen-sdk-resolution/Cargo.toml index fba1a9c2..bbbca93a 100644 --- a/codegen-sdk-resolution/Cargo.toml +++ b/codegen-sdk-resolution/Cargo.toml @@ -6,3 +6,4 @@ edition = "2024" [dependencies] salsa = { workspace = true } +log = {workspace = true} diff --git a/codegen-sdk-resolution/src/codebase.rs b/codegen-sdk-resolution/src/codebase.rs index 35b3e6e1..7e28b5bc 100644 --- a/codegen-sdk-resolution/src/codebase.rs +++ b/codegen-sdk-resolution/src/codebase.rs @@ -10,4 +10,5 @@ pub trait CodebaseContext { fn files<'a>(&'a self) -> Vec<&'a Self::File<'a>>; fn db(&self) -> &dyn Database; fn get_file<'a>(&'a self, path: PathBuf) -> Option<&'a Self::File<'a>>; + fn root_path(&self) -> PathBuf; } diff --git a/codegen-sdk-resolution/src/references.rs b/codegen-sdk-resolution/src/references.rs index 5ee20621..db5b35f2 100644 --- a/codegen-sdk-resolution/src/references.rs +++ b/codegen-sdk-resolution/src/references.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use crate::{CodebaseContext, ResolveType}; pub trait References< @@ -12,15 +14,16 @@ pub trait References< for<'b> T: CodebaseContext = F> + 'static, { let scopes: Vec = codebase.files().into_iter().filter_map(|file| file.clone().try_into().ok()).collect(); - return self.references_for_scopes(codebase.db(), scopes, scope); + return self.references_for_scopes(codebase.db(), codebase.root_path(), scopes, scope); } - fn references_for_scopes(&self, db: &'db dyn salsa::Database, scopes: Vec, scope: &Scope) -> Vec + fn references_for_scopes(&self, db: &'db dyn salsa::Database, root_path: PathBuf, scopes: Vec, scope: &Scope) -> Vec where Self: Sized + 'db, { + log::info!("Finding references across {:?} scopes", scopes.len()); let mut results = Vec::new(); for reference in scope.clone().resolvables(db) { - let resolved = reference.clone().resolve_type(db, scope.clone(), scopes.clone()); + let resolved = reference.clone().resolve_type(db, scope.clone(), root_path.clone(), scopes.clone()); if resolved.iter().any(|result| *result == *self) { results.push(reference); } diff --git a/codegen-sdk-resolution/src/resolve_type.rs b/codegen-sdk-resolution/src/resolve_type.rs index 62d6a85a..0917471d 100644 --- a/codegen-sdk-resolution/src/resolve_type.rs +++ b/codegen-sdk-resolution/src/resolve_type.rs @@ -1,5 +1,6 @@ -use crate::Scope; +use std::path::PathBuf; +use crate::Scope; // Get definitions for a given type pub trait ResolveType<'db, T: Scope<'db>> { type Type; // Possible types this trait can be defined as @@ -7,6 +8,7 @@ pub trait ResolveType<'db, T: Scope<'db>> { self, db: &'db dyn salsa::Database, scope: T, + root_path: PathBuf, scopes: Vec, ) -> &'db Vec; } diff --git a/codegen-sdk-resolution/src/scope.rs b/codegen-sdk-resolution/src/scope.rs index ec7baa9d..30175c31 100644 --- a/codegen-sdk-resolution/src/scope.rs +++ b/codegen-sdk-resolution/src/scope.rs @@ -1,10 +1,17 @@ -use crate::ResolveType; +use std::path::PathBuf; +use crate::ResolveType; // Resolve a given string name in a scope to a given type pub trait Scope<'db>: Sized { type Type; type ReferenceType: ResolveType<'db, Self, Type = Self::Type>; - fn resolve(self, db: &'db dyn salsa::Database, name: String) -> &'db Vec; + fn resolve( + self, + db: &'db dyn salsa::Database, + name: String, + root_path: PathBuf, + scopes: Vec, + ) -> &'db Vec; /// Get all the resolvables (IE: function_calls) in the scope fn resolvables(self, db: &'db dyn salsa::Database) -> Vec; } diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index 0535cb13..c5d909aa 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -13,7 +13,13 @@ pub mod ast { type Type = crate::ast::Symbol<'db>; type ReferenceType = crate::ast::Call<'db>; #[salsa::tracked(return_ref)] - fn resolve(self, db: &'db dyn salsa::Database, name: String) -> Vec { + fn resolve( + self, + db: &'db dyn salsa::Database, + name: String, + root_path: PathBuf, + scopes: Vec>, + ) -> Vec { let tree = self.node(db).unwrap().tree(db); let mut results = Vec::new(); for (def_name, defs) in self.definitions(db).functions(db).into_iter() { @@ -27,11 +33,14 @@ pub mod ast { } for (def_name, defs) in self.definitions(db).imports(db).into_iter() { if *def_name == name { - results.extend( - defs.into_iter() - .cloned() - .map(|def| crate::ast::Symbol::Import(def)), - ); + for def in defs { + results.push(crate::ast::Symbol::Import(def.clone())); + for resolved in + def.resolve_type(db, self, root_path.clone(), scopes.clone()) + { + results.push(resolved.clone()); + } + } } } results @@ -46,6 +55,30 @@ pub mod ast { } } #[salsa::tracked] + impl<'db> ResolveType<'db, PythonFile<'db>> for crate::ast::Import<'db> { + type Type = crate::ast::Symbol<'db>; + #[salsa::tracked(return_ref)] + fn resolve_type( + self, + db: &'db dyn salsa::Database, + scope: PythonFile<'db>, + root_path: PathBuf, + scopes: Vec>, + ) -> Vec { + let module = self.module(db).source().replace(".", "/"); + let target_path = FileNodeId::new(db, root_path.join(module).with_extension("py")); + log::info!("Target path: {:?}", target_path); + let name = self.name(db).source(); + for scope in &scopes { + log::info!("Checking scope {:?}", scope.id(db)); + if scope.id(db) == target_path { + return scope.resolve(db, name, root_path, scopes).to_vec(); + } + } + Vec::new() + } + } + #[salsa::tracked] impl<'db> ResolveType<'db, PythonFile<'db>> for crate::ast::Call<'db> { type Type = crate::ast::Symbol<'db>; #[salsa::tracked(return_ref)] @@ -53,11 +86,13 @@ pub mod ast { self, db: &'db dyn salsa::Database, scope: PythonFile<'db>, - _scopes: Vec>, + root_path: PathBuf, + scopes: Vec>, ) -> Vec { + log::info!("Resolving call with {:?} scopes", scopes.len()); let tree = scope.node(db).unwrap().tree(db); scope - .resolve(db, self.node(db).function(tree).source()) + .resolve(db, self.node(db).function(tree).source(), root_path, scopes) .clone() } } diff --git a/languages/codegen-sdk-python/tests/test_python.rs b/languages/codegen-sdk-python/tests/test_python.rs index 8a0635fd..77dc9b22 100644 --- a/languages/codegen-sdk-python/tests/test_python.rs +++ b/languages/codegen-sdk-python/tests/test_python.rs @@ -71,7 +71,6 @@ test()"; let content = codegen_sdk_cst::Input::new(&db, content.to_string()); let input = codegen_sdk_ast::input::File::new(&db, file_path, content); let file = codegen_sdk_python::ast::parse_query(&db, input); - let tree = file.tree(&db); assert_eq!(file.references(&db).calls(&db).len(), 1); let definitions = file.definitions(&db); let functions = definitions.functions(&db); @@ -79,7 +78,7 @@ test()"; let function = codegen_sdk_python::ast::Symbol::Function(function.clone().clone()); assert_eq!( function - .references_for_scopes(&db, vec![*file], &file) + .references_for_scopes(&db, temp_dir.path().to_path_buf(), vec![*file], &file) .len(), 1 ); @@ -104,8 +103,6 @@ test()"; let usage_input = codegen_sdk_ast::input::File::new(&db, usage_file_path, usage_content); let file = codegen_sdk_python::ast::parse_query(&db, input); let usage_file = codegen_sdk_python::ast::parse_query(&db, usage_input); - let tree = file.tree(&db); - let usage_tree = usage_file.tree(&db); assert_eq!(usage_file.references(&db).calls(&db).len(), 1); let definitions = file.definitions(&db); let functions = definitions.functions(&db); @@ -116,13 +113,23 @@ test()"; let import = codegen_sdk_python::ast::Symbol::Import(import.clone().clone()); assert_eq!( import - .references_for_scopes(&db, vec![*usage_file], &usage_file) + .references_for_scopes( + &db, + temp_dir.path().to_path_buf(), + vec![*usage_file], + &usage_file + ) .len(), 1 ); assert_eq!( function - .references_for_scopes(&db, vec![*file, *usage_file], &file) + .references_for_scopes( + &db, + temp_dir.path().to_path_buf(), + vec![*file, *usage_file], + &usage_file + ) .len(), 1 ); From a04622fd62fd31d9f9b2b810b1284cb815f0186a Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 6 Mar 2025 13:29:39 -0800 Subject: [PATCH 06/16] Parallel import resolution --- Cargo.lock | 7 ++ codegen-sdk-analyzer/src/codebase.rs | 17 ++-- .../src/codebase/discovery.rs | 6 +- codegen-sdk-analyzer/src/codebase/parser.rs | 21 ++--- codegen-sdk-analyzer/src/database.rs | 13 ++-- codegen-sdk-analyzer/src/parser.rs | 8 +- codegen-sdk-ast-generator/src/generator.rs | 13 +++- codegen-sdk-ast-generator/src/lib.rs | 12 +-- codegen-sdk-ast-generator/src/visitor.rs | 26 +++++++ codegen-sdk-common/src/tree/context.rs | 4 +- codegen-sdk-common/src/tree/id.rs | 5 +- codegen-sdk-cst-generator/src/generator.rs | 6 +- .../src/generator/node.rs | 2 +- codegen-sdk-macros/src/lib.rs | 2 +- codegen-sdk-resolution/Cargo.toml | 5 ++ codegen-sdk-resolution/src/codebase.rs | 15 +++- codegen-sdk-resolution/src/database.rs | 11 +++ codegen-sdk-resolution/src/lib.rs | 13 ++++ codegen-sdk-resolution/src/parse.rs | 11 +++ codegen-sdk-resolution/src/references.rs | 37 +++++---- codegen-sdk-resolution/src/resolve_type.rs | 14 +--- codegen-sdk-resolution/src/scope.rs | 24 +++--- languages/codegen-sdk-javascript/Cargo.toml | 1 + languages/codegen-sdk-json/Cargo.toml | 1 + languages/codegen-sdk-python/src/lib.rs | 77 ++++++++++--------- src/main.rs | 4 +- 26 files changed, 226 insertions(+), 129 deletions(-) create mode 100644 codegen-sdk-resolution/src/database.rs create mode 100644 codegen-sdk-resolution/src/parse.rs diff --git a/Cargo.lock b/Cargo.lock index b6f6394a..a335a39a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -527,6 +527,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -547,6 +548,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -633,6 +635,11 @@ dependencies = [ name = "codegen-sdk-resolution" version = "0.1.0" dependencies = [ + "ambassador", + "anyhow", + "codegen-sdk-ast", + "codegen-sdk-common", + "indicatif", "log", "salsa", ] diff --git a/codegen-sdk-analyzer/src/codebase.rs b/codegen-sdk-analyzer/src/codebase.rs index 01b56e61..c6db6e67 100644 --- a/codegen-sdk-analyzer/src/codebase.rs +++ b/codegen-sdk-analyzer/src/codebase.rs @@ -4,16 +4,12 @@ use anyhow::Context; use codegen_sdk_ast::Input; #[cfg(feature = "serialization")] use codegen_sdk_common::serialization::Cache; -use codegen_sdk_resolution::CodebaseContext; +use codegen_sdk_resolution::{CodebaseContext, Db}; use discovery::FilesToParse; use notify_debouncer_mini::DebounceEventResult; use salsa::Setter; -use crate::{ - ParsedFile, - database::{CodegenDatabase, Db}, - parser::parse_file, -}; +use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; mod discovery; mod parser; pub struct Codebase { @@ -87,6 +83,9 @@ impl Codebase { } impl CodebaseContext for Codebase { type File<'a> = ParsedFile<'a>; + fn root_path(&self) -> PathBuf { + self.root.clone() + } fn files<'a>(&'a self) -> Vec<&'a Self::File<'a>> { let mut files = Vec::new(); for file in self.discover().files(&self.db) { @@ -96,13 +95,15 @@ impl CodebaseContext for Codebase { } files } - fn db(&self) -> &dyn salsa::Database { + fn db(&self) -> &dyn Db { &self.db } fn get_file<'a>(&'a self, path: PathBuf) -> Option<&'a Self::File<'a>> { let file = self.db.files.get(&path); if let Some(file) = file { - return parse_file(&self.db, file.clone()).file(&self.db).as_ref(); + return parse_file(&self.db, file.clone(), self.root.clone()) + .file(&self.db) + .as_ref(); } None } diff --git a/codegen-sdk-analyzer/src/codebase/discovery.rs b/codegen-sdk-analyzer/src/codebase/discovery.rs index 825a5086..040e017d 100644 --- a/codegen-sdk-analyzer/src/codebase/discovery.rs +++ b/codegen-sdk-analyzer/src/codebase/discovery.rs @@ -3,12 +3,14 @@ use std::path::PathBuf; use codegen_sdk_ast::*; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; +use codegen_sdk_resolution::Db; use glob::glob; -use crate::database::{CodegenDatabase, Db}; +use crate::database::CodegenDatabase; #[salsa::input] pub struct FilesToParse { pub files: Vec, + pub root: PathBuf, } pub fn log_languages() { for language in LANGUAGES.iter() { @@ -42,5 +44,5 @@ pub fn collect_files(db: &CodegenDatabase, dir: &PathBuf) -> FilesToParse { .filter(|file| !file.is_dir() && !file.is_symlink()) .map(|file| db.input(file).unwrap()) .collect(); - FilesToParse::new(db, files) + FilesToParse::new(db, files, dir) } diff --git a/codegen-sdk-analyzer/src/codebase/parser.rs b/codegen-sdk-analyzer/src/codebase/parser.rs index a1f86e0f..7cb62603 100644 --- a/codegen-sdk-analyzer/src/codebase/parser.rs +++ b/codegen-sdk-analyzer/src/codebase/parser.rs @@ -1,19 +1,18 @@ +use std::path::PathBuf; + use codegen_sdk_ast::{Definitions, References, input::File}; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; +use codegen_sdk_resolution::{Db, Scope}; use indicatif::{ProgressBar, ProgressStyle}; use super::discovery::{FilesToParse, log_languages}; -use crate::{ - ParsedFile, - database::{CodegenDatabase, Db}, - parser::parse_file, -}; +use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; fn execute_op_with_progress( db: &Database, files: FilesToParse, name: &str, - op: fn(&Database, File) -> T, + op: fn(&Database, File, PathBuf) -> T, ) -> Vec { let multi = db.multi_progress(); let style = ProgressStyle::with_template( @@ -28,15 +27,16 @@ fn execute_op_with_progress( let inputs = files .files(db) .into_iter() - .map(|file| (&pg, file, op)) + .map(|file| (&pg, file, files.root(db).clone(), op)) .collect::>(); let results: Vec = salsa::par_map(db, inputs, move |db, input| { - let (pg, file, op) = input; + let (pg, file, root, op) = input; let res = op( db, #[cfg(feature = "serialization")] &cache, file, + root, ); pg.inc(1); res @@ -53,8 +53,8 @@ fn execute_op_with_progress( // } #[salsa::tracked] fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { - let _: Vec<_> = execute_op_with_progress(db, files, "Parsing Files", |db, file| { - let file = parse_file(db, file); + let _: Vec<_> = execute_op_with_progress(db, files, "Parsing Files", |db, file, root| { + let file = parse_file(db, file, root); if let Some(parsed) = file.file(db) { #[cfg(feature = "typescript")] if let ParsedFile::Typescript(parsed) = parsed { @@ -65,6 +65,7 @@ fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { if let ParsedFile::Python(parsed) = parsed { parsed.definitions(db); parsed.references(db); + parsed.compute_dependencies(db); } } () diff --git a/codegen-sdk-analyzer/src/database.rs b/codegen-sdk-analyzer/src/database.rs index 588c8028..f6593176 100644 --- a/codegen-sdk-analyzer/src/database.rs +++ b/codegen-sdk-analyzer/src/database.rs @@ -7,6 +7,7 @@ use std::{ use anyhow::Context; use codegen_sdk_ast::input::File; use codegen_sdk_cst::Input; +use codegen_sdk_resolution::Db; use dashmap::{DashMap, mapref::entry::Entry}; use indicatif::MultiProgress; use notify_debouncer_mini::{ @@ -16,12 +17,6 @@ use notify_debouncer_mini::{ use crate::progress::get_multi_progress; #[salsa::db] -pub trait Db: salsa::Database + Send { - fn input(&self, path: PathBuf) -> anyhow::Result; - fn multi_progress(&self) -> &MultiProgress; - fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()>; -} -#[salsa::db] #[derive(Clone)] // Basic Database implementation for Query generation. This is not used for anything else. pub struct CodegenDatabase { @@ -76,6 +71,12 @@ impl salsa::Database for CodegenDatabase { } #[salsa::db] impl Db for CodegenDatabase { + fn files(&self) -> Vec { + self.files + .iter() + .map(|entry| entry.value().clone()) + .collect() + } fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()> { let path = path.canonicalize()?; let watcher = &mut *self.file_watcher.lock().unwrap(); diff --git a/codegen-sdk-analyzer/src/parser.rs b/codegen-sdk-analyzer/src/parser.rs index b3cf807b..1c3b2033 100644 --- a/codegen-sdk-analyzer/src/parser.rs +++ b/codegen-sdk-analyzer/src/parser.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use codegen_sdk_common::FileNodeId; use codegen_sdk_cst::CSTLanguage; use codegen_sdk_macros::{languages_ast, parse_language}; @@ -12,7 +14,11 @@ pub struct Parsed<'db> { pub file: Option>, } #[salsa::tracked(return_ref)] -pub fn parse_file(db: &dyn salsa::Database, file: codegen_sdk_ast::input::File) -> Parsed<'_> { +pub fn parse_file( + db: &dyn salsa::Database, + file: codegen_sdk_ast::input::File, + root: PathBuf, +) -> Parsed<'_> { parse_language!(); Parsed::new(db, FileNodeId::new(db, file.path(db)), None) } diff --git a/codegen-sdk-ast-generator/src/generator.rs b/codegen-sdk-ast-generator/src/generator.rs index 0b8bf1e6..e8f6da94 100644 --- a/codegen-sdk-ast-generator/src/generator.rs +++ b/codegen-sdk-ast-generator/src/generator.rs @@ -71,20 +71,25 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { #[id] pub id: codegen_sdk_common::FileNodeId<'db>, } + impl<'db> codegen_sdk_resolution::Parse<'db> for #language_struct_name<'db> { + fn parse(db: &'db dyn salsa::Database, input: codegen_sdk_ast::input::File, root: PathBuf) -> &'db Self { + parse_query(db, input, root) + } + } // impl<'db> File for {language_struct_name}File<'db> {{ // fn path(&self) -> &PathBuf {{ // &self.path(db) // }} // }} - pub fn parse(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File) -> #language_struct_name<'_> { + pub fn parse(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File, root: PathBuf) -> #language_struct_name<'_> { log::debug!("Parsing {} file: {}", input.path(db).display(), #language_name_str); - let ast = crate::cst::parse_program_raw(db, input.contents(db), input.path(db).clone()); + let ast = crate::cst::parse_program_raw(db, input.contents(db), input.path(db).clone(), root); let file_id = codegen_sdk_common::FileNodeId::new(db, input.path(db).clone()); #language_struct_name::new(db, ast, file_id) } #[salsa::tracked(return_ref)] - pub fn parse_query(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File) -> #language_struct_name<'_> { - parse(db, input) + pub fn parse_query(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File, root: PathBuf) -> #language_struct_name<'_> { + parse(db, input, root) } impl<'db> #language_struct_name<'db> { diff --git a/codegen-sdk-ast-generator/src/lib.rs b/codegen-sdk-ast-generator/src/lib.rs index 4e92edec..bf97d8ef 100644 --- a/codegen-sdk-ast-generator/src/lib.rs +++ b/codegen-sdk-ast-generator/src/lib.rs @@ -10,11 +10,13 @@ use syn::parse_quote; pub fn generate_ast(language: &Language) -> anyhow::Result<()> { let db = CSTDatabase::default(); let imports = quote! { - use codegen_sdk_common::*; - use std::path::PathBuf; - use codegen_sdk_cst::CSTLanguage; - use std::collections::BTreeMap; - use std::sync::mpsc::Sender; + use codegen_sdk_common::*; + use std::path::PathBuf; + use codegen_sdk_cst::CSTLanguage; + use std::collections::BTreeMap; + use std::sync::mpsc::Sender; + use codegen_sdk_resolution::HasFile; + use codegen_sdk_resolution::Parse; }; let ast = generator::generate_ast(language)?; let definition_visitor = visitor::generate_visitor(&db, language, "definition"); diff --git a/codegen-sdk-ast-generator/src/visitor.rs b/codegen-sdk-ast-generator/src/visitor.rs index 61e8985d..5d7d15f9 100644 --- a/codegen-sdk-ast-generator/src/visitor.rs +++ b/codegen-sdk-ast-generator/src/visitor.rs @@ -67,6 +67,7 @@ pub fn generate_visitor<'db>( }; let mut defs = Vec::new(); + let language_struct = format_ident!("{}File", language.struct_name()); for (variant, type_name) in symbol_names.iter().zip(types.iter()) { let query = enter_methods .get(&type_name.to_string()) @@ -86,6 +87,18 @@ pub fn generate_visitor<'db>( pub node: crate::cst::#type_name<'db>, #(#fields),* } + impl<'db> codegen_sdk_resolution::HasFile<'db> for #variant<'db> { + type File<'db1> = #language_struct<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self.node(db).id().file(db).path(db); + let root = self.root_path(db); + let input = db.input(path).unwrap(); + parse_query(db, input, root) + } + fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf { + self.node(db).id().root(db).path(db) + } + } }); } let symbol = if defs.len() > 0 { @@ -99,6 +112,19 @@ pub fn generate_visitor<'db>( #symbol_names(#symbol_names<'db>), )* } + impl<'db> codegen_sdk_resolution::HasFile<'db> for #symbol_name<'db> { + type File<'db1> = #language_struct<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + match self { + #(Self::#symbol_names(symbol) => symbol.file(db),)* + } + } + fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf { + match self { + #(Self::#symbol_names(symbol) => symbol.root_path(db),)* + } + } + } } } else { quote! { diff --git a/codegen-sdk-common/src/tree/context.rs b/codegen-sdk-common/src/tree/context.rs index 1cf94dcc..9210b85e 100644 --- a/codegen-sdk-common/src/tree/context.rs +++ b/codegen-sdk-common/src/tree/context.rs @@ -6,15 +6,17 @@ use crate::tree::{FileNodeId, Tree, TreeNode}; pub struct ParseContext<'db, T: TreeNode> { pub db: &'db dyn salsa::Database, pub file_id: FileNodeId<'db>, + pub root: FileNodeId<'db>, pub buffer: Arc, pub tree: Tree, } impl<'db, T: TreeNode> ParseContext<'db, T> { - pub fn new(db: &'db dyn salsa::Database, path: PathBuf, content: Bytes) -> Self { + pub fn new(db: &'db dyn salsa::Database, path: PathBuf, root: PathBuf, content: Bytes) -> Self { let file_id = FileNodeId::new(db, path); Self { db, file_id, + root: FileNodeId::new(db, root), buffer: Arc::new(content), tree: Tree::default(), } diff --git a/codegen-sdk-common/src/tree/id.rs b/codegen-sdk-common/src/tree/id.rs index b665416a..057689ba 100644 --- a/codegen-sdk-common/src/tree/id.rs +++ b/codegen-sdk-common/src/tree/id.rs @@ -2,11 +2,12 @@ use std::path::PathBuf; #[salsa::interned] pub struct FileNodeId<'db> { - file_path: PathBuf, + pub path: PathBuf, } #[salsa::interned] pub struct CSTNodeId<'db> { - file_id: FileNodeId<'db>, + pub file: FileNodeId<'db>, node_id: usize, + pub root: FileNodeId<'db>, // TODO: add a marker for tree-sitter generation } diff --git a/codegen-sdk-cst-generator/src/generator.rs b/codegen-sdk-cst-generator/src/generator.rs index 252f010f..36004570 100644 --- a/codegen-sdk-cst-generator/src/generator.rs +++ b/codegen-sdk-cst-generator/src/generator.rs @@ -56,7 +56,7 @@ fn get_parser(language: &Language) -> TokenStream { pub tree: Tree>, pub program: indextree::NodeId, } - pub fn parse_program_raw<'db>(db: &'db dyn salsa::Database, input: codegen_sdk_cst::Input, path: PathBuf) -> Option> { + pub fn parse_program_raw<'db>(db: &'db dyn salsa::Database, input: codegen_sdk_cst::Input, path: PathBuf, root: PathBuf) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::#language_name::#language_struct_name.parse_tree_sitter(&input.content(db)); match tree { @@ -65,7 +65,7 @@ fn get_parser(language: &Language) -> TokenStream { ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new(db, path, root, buffer); let root_id = #program_id::orphaned(&mut context, tree.root_node()) .map_or_else(|e| { e.report(db); @@ -88,7 +88,7 @@ fn get_parser(language: &Language) -> TokenStream { } #[salsa::tracked(return_ref)] pub fn parse_program(db: &dyn salsa::Database, input: codegen_sdk_cst::Input) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input, std::path::PathBuf::new(), std::path::PathBuf::new()); if let Some(parsed) = raw { parsed } else { diff --git a/codegen-sdk-cst-generator/src/generator/node.rs b/codegen-sdk-cst-generator/src/generator/node.rs index 5e64a3a4..cd080aa3 100644 --- a/codegen-sdk-cst-generator/src/generator/node.rs +++ b/codegen-sdk-cst-generator/src/generator/node.rs @@ -224,7 +224,7 @@ impl<'a> Node<'a> { fn from_node(context: &mut ParseContext<'db, NodeTypes<'db>>, node: tree_sitter::Node) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); #(#constructor_fields)* Ok((Self { diff --git a/codegen-sdk-macros/src/lib.rs b/codegen-sdk-macros/src/lib.rs index cb168d40..2496dce6 100644 --- a/codegen-sdk-macros/src/lib.rs +++ b/codegen-sdk-macros/src/lib.rs @@ -101,7 +101,7 @@ pub fn parse_language(_item: TokenStream) -> TokenStream { let variant: proc_macro2::TokenStream = quote! { #[cfg(feature = #name)] if #package_name::cst::#struct_name::should_parse(&file.path(db)).unwrap_or(false) { - let parsed = #package_name::ast::parse(db, file); + let parsed = #package_name::ast::parse(db, file, root); return Parsed::new( db, FileNodeId::new(db, file.path(db)), diff --git a/codegen-sdk-resolution/Cargo.toml b/codegen-sdk-resolution/Cargo.toml index bbbca93a..3d10b2a3 100644 --- a/codegen-sdk-resolution/Cargo.toml +++ b/codegen-sdk-resolution/Cargo.toml @@ -7,3 +7,8 @@ edition = "2024" [dependencies] salsa = { workspace = true } log = {workspace = true} +codegen-sdk-ast = { workspace = true } +codegen-sdk-common = { workspace = true } +anyhow = { workspace = true } +indicatif = { workspace = true } +ambassador = { workspace = true } diff --git a/codegen-sdk-resolution/src/codebase.rs b/codegen-sdk-resolution/src/codebase.rs index 7e28b5bc..49e3a8d9 100644 --- a/codegen-sdk-resolution/src/codebase.rs +++ b/codegen-sdk-resolution/src/codebase.rs @@ -1,6 +1,8 @@ use std::path::PathBuf; -use salsa::Database; +use codegen_sdk_common::FileNodeId; + +use crate::Db; // Not sure what to name this // Equivalent to CodebaseGraph/CodebaseContext in the SDK pub trait CodebaseContext { @@ -8,7 +10,16 @@ pub trait CodebaseContext { where Self: 'a; fn files<'a>(&'a self) -> Vec<&'a Self::File<'a>>; - fn db(&self) -> &dyn Database; + fn db(&self) -> &dyn Db; fn get_file<'a>(&'a self, path: PathBuf) -> Option<&'a Self::File<'a>>; + fn get_file_for_id<'a>(&'a self, id: FileNodeId) -> Option<&'a Self::File<'a>> { + self.get_file(id.path(self.db())) + } + fn get_raw_file_for_id<'a>(&'a self, id: FileNodeId) -> Option { + self.get_raw_file(id.path(self.db())) + } + fn get_raw_file<'a>(&'a self, path: PathBuf) -> Option { + self.db().input(path).ok() + } fn root_path(&self) -> PathBuf; } diff --git a/codegen-sdk-resolution/src/database.rs b/codegen-sdk-resolution/src/database.rs new file mode 100644 index 00000000..171f8374 --- /dev/null +++ b/codegen-sdk-resolution/src/database.rs @@ -0,0 +1,11 @@ +use std::path::PathBuf; + +use codegen_sdk_ast::input::File; +use indicatif::MultiProgress; +#[salsa::db] +pub trait Db: salsa::Database + Send { + fn input(&self, path: PathBuf) -> anyhow::Result; + fn multi_progress(&self) -> &MultiProgress; + fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()>; + fn files(&self) -> Vec; +} diff --git a/codegen-sdk-resolution/src/lib.rs b/codegen-sdk-resolution/src/lib.rs index 7a3184ef..600dc93c 100644 --- a/codegen-sdk-resolution/src/lib.rs +++ b/codegen-sdk-resolution/src/lib.rs @@ -1,4 +1,7 @@ mod scope; +use std::path::PathBuf; + +use ambassador::delegatable_trait; pub use scope::Scope; mod resolve_type; pub use resolve_type::ResolveType; @@ -6,3 +9,13 @@ mod references; pub use references::References; mod codebase; pub use codebase::CodebaseContext; +mod database; +mod parse; +pub use database::Db; +pub use parse::Parse; +#[delegatable_trait] +pub trait HasFile<'db> { + type File<'db1>; + fn file(&self, db: &'db dyn Db) -> &'db Self::File<'db>; + fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf; +} diff --git a/codegen-sdk-resolution/src/parse.rs b/codegen-sdk-resolution/src/parse.rs new file mode 100644 index 00000000..d173c14e --- /dev/null +++ b/codegen-sdk-resolution/src/parse.rs @@ -0,0 +1,11 @@ +use std::path::PathBuf; + +use salsa::Database; + +pub trait Parse<'db> { + fn parse( + db: &'db dyn Database, + input: codegen_sdk_ast::input::File, + root: PathBuf, + ) -> &'db Self; +} diff --git a/codegen-sdk-resolution/src/references.rs b/codegen-sdk-resolution/src/references.rs index db5b35f2..543621cd 100644 --- a/codegen-sdk-resolution/src/references.rs +++ b/codegen-sdk-resolution/src/references.rs @@ -1,33 +1,32 @@ -use std::path::PathBuf; - -use crate::{CodebaseContext, ResolveType}; +use crate::{Db, HasFile, Parse, ResolveType}; pub trait References< 'db, - ReferenceType: ResolveType<'db, Scope, Type = Self> + Clone, // References must resolve to this type - Scope: crate::Scope<'db, Type = Self, ReferenceType = ReferenceType> + Clone, ->: Eq + PartialEq where Self:'db + ReferenceType: ResolveType<'db, Type = Self> + Clone, // References must resolve to this type + Scope: crate::Scope<'db, Type = Self, ReferenceType = ReferenceType> + Clone + 'db, +>: Eq + PartialEq + HasFile<'db, File<'db> = Scope> + 'db where Self:'db { - fn references + Clone + 'db, T>(&self, codebase: &'db T, scope: &Scope) -> Vec + fn references(&self, db: &'db dyn Db) -> Vec where Self: Sized, - for<'b> T: CodebaseContext = F> + 'static, - { - let scopes: Vec = codebase.files().into_iter().filter_map(|file| file.clone().try_into().ok()).collect(); - return self.references_for_scopes(codebase.db(), codebase.root_path(), scopes, scope); - } - fn references_for_scopes(&self, db: &'db dyn salsa::Database, root_path: PathBuf, scopes: Vec, scope: &Scope) -> Vec - where - Self: Sized + 'db, + Scope: Parse<'db>, { - log::info!("Finding references across {:?} scopes", scopes.len()); + let files = db.files(); + let root_path = self.root_path(db); + log::info!("Finding references across {:?} files", files.len()); let mut results = Vec::new(); - for reference in scope.clone().resolvables(db) { - let resolved = reference.clone().resolve_type(db, scope.clone(), root_path.clone(), scopes.clone()); - if resolved.iter().any(|result| *result == *self) { + for input in files { + if !self.filter(db, &input) { + continue; + } + let file = Scope::parse(db, input, root_path.clone()); + for reference in file.clone().resolvables(db) { + if reference.clone().resolve_type(db).iter().any(|result| *result == *self) { results.push(reference); } + } } results } + fn filter(&self, db: &'db dyn Db, input: &codegen_sdk_ast::input::File) -> bool; } diff --git a/codegen-sdk-resolution/src/resolve_type.rs b/codegen-sdk-resolution/src/resolve_type.rs index 0917471d..e5fdf842 100644 --- a/codegen-sdk-resolution/src/resolve_type.rs +++ b/codegen-sdk-resolution/src/resolve_type.rs @@ -1,14 +1,6 @@ -use std::path::PathBuf; - -use crate::Scope; +use crate::Db; // Get definitions for a given type -pub trait ResolveType<'db, T: Scope<'db>> { +pub trait ResolveType<'db> { type Type; // Possible types this trait can be defined as - fn resolve_type( - self, - db: &'db dyn salsa::Database, - scope: T, - root_path: PathBuf, - scopes: Vec, - ) -> &'db Vec; + fn resolve_type(self, db: &'db dyn Db) -> &'db Vec; } diff --git a/codegen-sdk-resolution/src/scope.rs b/codegen-sdk-resolution/src/scope.rs index 30175c31..7ac3a8a3 100644 --- a/codegen-sdk-resolution/src/scope.rs +++ b/codegen-sdk-resolution/src/scope.rs @@ -1,17 +1,17 @@ -use std::path::PathBuf; - -use crate::ResolveType; +use crate::{Db, ResolveType}; // Resolve a given string name in a scope to a given type pub trait Scope<'db>: Sized { type Type; - type ReferenceType: ResolveType<'db, Self, Type = Self::Type>; - fn resolve( - self, - db: &'db dyn salsa::Database, - name: String, - root_path: PathBuf, - scopes: Vec, - ) -> &'db Vec; + type ReferenceType: ResolveType<'db>; + fn resolve(self, db: &'db dyn Db, name: String) -> &'db Vec; /// Get all the resolvables (IE: function_calls) in the scope - fn resolvables(self, db: &'db dyn salsa::Database) -> Vec; + fn resolvables(self, db: &'db dyn Db) -> Vec; + fn compute_dependencies(self, db: &'db dyn Db) + where + Self: 'db, + { + for reference in self.resolvables(db) { + reference.resolve_type(db); + } + } } diff --git a/languages/codegen-sdk-javascript/Cargo.toml b/languages/codegen-sdk-javascript/Cargo.toml index 9b96f6ac..9c6f8b8c 100644 --- a/languages/codegen-sdk-javascript/Cargo.toml +++ b/languages/codegen-sdk-javascript/Cargo.toml @@ -14,6 +14,7 @@ indextree ={ workspace = true } subenum = {workspace = true} bytes = { workspace = true } codegen-sdk-cst = { workspace = true } +codegen-sdk-resolution = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } [build-dependencies] diff --git a/languages/codegen-sdk-json/Cargo.toml b/languages/codegen-sdk-json/Cargo.toml index f2bb9761..c56732c2 100644 --- a/languages/codegen-sdk-json/Cargo.toml +++ b/languages/codegen-sdk-json/Cargo.toml @@ -14,6 +14,7 @@ subenum = {workspace = true} bytes = { workspace = true } codegen-sdk-cst = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } log = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index c5d909aa..e4c66fb9 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -9,17 +9,22 @@ pub mod ast { use codegen_sdk_resolution::{ResolveType, Scope}; include!(concat!(env!("OUT_DIR"), "/python-ast.rs")); #[salsa::tracked] + impl<'db> Import<'db> { + #[salsa::tracked] + fn resolve_import(self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + let root_path = self.root_path(db); + let module = self.module(db).source().replace(".", "/"); + let target_path = root_path.join(module).with_extension("py"); + log::info!(target: "resolution", "Resolving import to path: {:?}", target_path); + target_path + } + } + #[salsa::tracked] impl<'db> Scope<'db> for PythonFile<'db> { type Type = crate::ast::Symbol<'db>; type ReferenceType = crate::ast::Call<'db>; #[salsa::tracked(return_ref)] - fn resolve( - self, - db: &'db dyn salsa::Database, - name: String, - root_path: PathBuf, - scopes: Vec>, - ) -> Vec { + fn resolve(self, db: &'db dyn codegen_sdk_resolution::Db, name: String) -> Vec { let tree = self.node(db).unwrap().tree(db); let mut results = Vec::new(); for (def_name, defs) in self.definitions(db).functions(db).into_iter() { @@ -35,9 +40,7 @@ pub mod ast { if *def_name == name { for def in defs { results.push(crate::ast::Symbol::Import(def.clone())); - for resolved in - def.resolve_type(db, self, root_path.clone(), scopes.clone()) - { + for resolved in def.resolve_type(db) { results.push(resolved.clone()); } } @@ -46,7 +49,7 @@ pub mod ast { results } #[salsa::tracked] - fn resolvables(self, db: &'db dyn salsa::Database) -> Vec { + fn resolvables(self, db: &'db dyn codegen_sdk_resolution::Db) -> Vec { let mut results = Vec::new(); for (_, refs) in self.references(db).calls(db).into_iter() { results.extend(refs.into_iter().cloned()); @@ -55,44 +58,28 @@ pub mod ast { } } #[salsa::tracked] - impl<'db> ResolveType<'db, PythonFile<'db>> for crate::ast::Import<'db> { + impl<'db> ResolveType<'db> for crate::ast::Import<'db> { type Type = crate::ast::Symbol<'db>; #[salsa::tracked(return_ref)] - fn resolve_type( - self, - db: &'db dyn salsa::Database, - scope: PythonFile<'db>, - root_path: PathBuf, - scopes: Vec>, - ) -> Vec { - let module = self.module(db).source().replace(".", "/"); - let target_path = FileNodeId::new(db, root_path.join(module).with_extension("py")); - log::info!("Target path: {:?}", target_path); - let name = self.name(db).source(); - for scope in &scopes { - log::info!("Checking scope {:?}", scope.id(db)); - if scope.id(db) == target_path { - return scope.resolve(db, name, root_path, scopes).to_vec(); - } + fn resolve_type(self, db: &'db dyn codegen_sdk_resolution::Db) -> Vec { + let target_path = self.resolve_import(db); + if let Ok(input) = db.input(target_path) { + return PythonFile::parse(db, input, self.root_path(db)) + .resolve(db, self.name(db).source()) + .to_vec(); } Vec::new() } } #[salsa::tracked] - impl<'db> ResolveType<'db, PythonFile<'db>> for crate::ast::Call<'db> { + impl<'db> ResolveType<'db> for crate::ast::Call<'db> { type Type = crate::ast::Symbol<'db>; #[salsa::tracked(return_ref)] - fn resolve_type( - self, - db: &'db dyn salsa::Database, - scope: PythonFile<'db>, - root_path: PathBuf, - scopes: Vec>, - ) -> Vec { - log::info!("Resolving call with {:?} scopes", scopes.len()); + fn resolve_type(self, db: &'db dyn codegen_sdk_resolution::Db) -> Vec { + let scope = self.file(db); let tree = scope.node(db).unwrap().tree(db); scope - .resolve(db, self.node(db).function(tree).source(), root_path, scopes) + .resolve(db, self.node(db).function(tree).source()) .clone() } } @@ -100,5 +87,19 @@ pub mod ast { impl<'db> codegen_sdk_resolution::References<'db, crate::ast::Call<'db>, PythonFile<'db>> for crate::ast::Symbol<'db> { + fn filter( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + input: &codegen_sdk_ast::input::File, + ) -> bool { + input.path(db).extension().unwrap() == "py" + && match self { + crate::ast::Symbol::Function(function) => input + .contents(db) + .content(db) + .contains(&function.name(db).source()), + _ => false, + } + } } } diff --git a/src/main.rs b/src/main.rs index da9e2799..5ddeecfa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,9 +44,7 @@ fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize .flatten() .map(|function| codegen_sdk_python::ast::Symbol::Function(function.clone())) { - total_references += function - .references_for_scopes(codebase.db(), vec![*file], &file) - .len(); + total_references += function.references(codebase.db()).len(); } return ( definitions.classes(codebase.db()).len(), From 4ac3b7fb3504757dc06e4659f1f91aaece6b2883 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 6 Mar 2025 15:18:31 -0800 Subject: [PATCH 07/16] Progress bars --- .cargo/config.toml | 2 +- .gitignore | 1 + Cargo.lock | 74 ++++++++++++++----- Cargo.toml | 4 +- codegen-sdk-analyzer/src/codebase.rs | 28 +++++-- .../src/codebase/discovery.rs | 1 + codegen-sdk-analyzer/src/codebase/parser.rs | 2 +- codegen-sdk-analyzer/src/database.rs | 6 +- codegen-sdk-analyzer/src/lib.rs | 2 +- codegen-sdk-macros/src/lib.rs | 2 +- codegen-sdk-resolution/src/codebase.rs | 6 +- codegen-sdk-resolution/src/database.rs | 1 + codegen-sdk-resolution/src/references.rs | 6 +- languages/codegen-sdk-python/Cargo.toml | 2 + languages/codegen-sdk-python/src/lib.rs | 28 +++---- src/main.rs | 38 +++++----- 16 files changed, 132 insertions(+), 71 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index d958320a..4b2316cb 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,2 @@ [build] -rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zthreads=16"] +rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zthreads=16", "-Ctarget-cpu=native"] diff --git a/.gitignore b/.gitignore index 11cbd479..dff150c9 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,4 @@ tests/integration/verified_codemods/codemod_data/repo_commits.json target/* .benchmarks/* **.snap.new +flamegraph.svg diff --git a/Cargo.lock b/Cargo.lock index a335a39a..5a9ebaa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -109,6 +121,12 @@ dependencies = [ "backtrace", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "autocfg" version = "1.4.0" @@ -624,6 +642,7 @@ dependencies = [ "env_logger", "indextree", "log", + "memchr", "salsa", "subenum", "tempfile", @@ -1053,12 +1072,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "foldhash" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" - [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1203,25 +1216,24 @@ name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" -dependencies = [ - "allocator-api2", - "equivalent", - "foldhash", -] [[package]] name = "hashlink" -version = "0.10.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.14.5", ] [[package]] @@ -2327,12 +2339,14 @@ checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "salsa" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#ceb9b083b3c0f6a1634e5a0b75b7bb5c7ca7b33f" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e59d074084ce0a89693f021d8317cbc53d23d6502d3b3e2a3d1a7db1ceb13b" dependencies = [ + "arc-swap", "boxcar", "crossbeam-queue", "dashmap", - "hashbrown 0.15.2", + "hashbrown 0.14.5", "hashlink", "indexmap", "parking_lot", @@ -2346,13 +2360,15 @@ dependencies = [ [[package]] name = "salsa-macro-rules" -version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#ceb9b083b3c0f6a1634e5a0b75b7bb5c7ca7b33f" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e354e0bdf1a23d822161e2b0f95c07846535a0e81deba77248a6ac22d19bc97" [[package]] name = "salsa-macros" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#ceb9b083b3c0f6a1634e5a0b75b7bb5c7ca7b33f" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b061c51d6c6d5d8e4459bcaa11ef18d268286c68263615d65e983071b357fd9" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -3397,6 +3413,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", +] + [[package]] name = "zerofrom" version = "0.1.6" diff --git a/Cargo.toml b/Cargo.toml index 085ca6d6..47872435 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,7 +117,7 @@ insta = "1.42.1" prettyplease = "0.2.29" syn = { version = "2.0.98", features = ["proc-macro", "full"] } derive_more = { version = "2.0.1", features = ["debug", "display"] } -salsa = {git = "https://github.com/salsa-rs/salsa", branch = "master"} +salsa = {version = "0.18.0"} subenum = {git = "https://github.com/mrenow/subenum", branch = "main"} indicatif-log-bridge = "0.2.3" indicatif = { version = "0.17.11", features = ["rayon"] } @@ -158,3 +158,5 @@ lto = false name = "parse" harness = false required-features = ["stable"] +[profile.release] +debug = true diff --git a/codegen-sdk-analyzer/src/codebase.rs b/codegen-sdk-analyzer/src/codebase.rs index c6db6e67..aa1c2cdf 100644 --- a/codegen-sdk-analyzer/src/codebase.rs +++ b/codegen-sdk-analyzer/src/codebase.rs @@ -1,17 +1,19 @@ use std::path::PathBuf; use anyhow::Context; -use codegen_sdk_ast::Input; +use codegen_sdk_ast::{Input, input::File}; #[cfg(feature = "serialization")] use codegen_sdk_common::serialization::Cache; use codegen_sdk_resolution::{CodebaseContext, Db}; use discovery::FilesToParse; use notify_debouncer_mini::DebounceEventResult; -use salsa::Setter; +use salsa::{AsDynDatabase, Database, Setter}; use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; mod discovery; mod parser; +use parser::execute_op_with_progress; + pub struct Codebase { db: CodegenDatabase, root: PathBuf, @@ -80,6 +82,16 @@ impl Codebase { files, ) } + fn _db(&self) -> &dyn Db { + &self.db + } + pub fn execute_op_with_progress( + &self, + name: &str, + op: fn(&dyn Db, File, PathBuf) -> T, + ) -> Vec { + execute_op_with_progress(self._db(), self.discover(), name, op) + } } impl CodebaseContext for Codebase { type File<'a> = ParsedFile<'a>; @@ -99,11 +111,13 @@ impl CodebaseContext for Codebase { &self.db } fn get_file<'a>(&'a self, path: PathBuf) -> Option<&'a Self::File<'a>> { - let file = self.db.files.get(&path); - if let Some(file) = file { - return parse_file(&self.db, file.clone(), self.root.clone()) - .file(&self.db) - .as_ref(); + if let Ok(path) = path.canonicalize() { + let file = self.db.files.get(&path); + if let Some(file) = file { + return parse_file(&self.db, file.clone(), self.root.clone()) + .file(&self.db) + .as_ref(); + } } None } diff --git a/codegen-sdk-analyzer/src/codebase/discovery.rs b/codegen-sdk-analyzer/src/codebase/discovery.rs index 040e017d..169e84eb 100644 --- a/codegen-sdk-analyzer/src/codebase/discovery.rs +++ b/codegen-sdk-analyzer/src/codebase/discovery.rs @@ -42,6 +42,7 @@ pub fn collect_files(db: &CodegenDatabase, dir: &PathBuf) -> FilesToParse { .into_iter() .filter_map(|file| file.ok()) .filter(|file| !file.is_dir() && !file.is_symlink()) + .filter_map(|file| file.canonicalize().ok()) .map(|file| db.input(file).unwrap()) .collect(); FilesToParse::new(db, files, dir) diff --git a/codegen-sdk-analyzer/src/codebase/parser.rs b/codegen-sdk-analyzer/src/codebase/parser.rs index 7cb62603..d21598ca 100644 --- a/codegen-sdk-analyzer/src/codebase/parser.rs +++ b/codegen-sdk-analyzer/src/codebase/parser.rs @@ -8,7 +8,7 @@ use indicatif::{ProgressBar, ProgressStyle}; use super::discovery::{FilesToParse, log_languages}; use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; -fn execute_op_with_progress( +pub fn execute_op_with_progress( db: &Database, files: FilesToParse, name: &str, diff --git a/codegen-sdk-analyzer/src/database.rs b/codegen-sdk-analyzer/src/database.rs index f6593176..de0acb1b 100644 --- a/codegen-sdk-analyzer/src/database.rs +++ b/codegen-sdk-analyzer/src/database.rs @@ -87,10 +87,10 @@ impl Db for CodegenDatabase { self.dirs.push(path); Ok(()) } + fn get_file(&self, path: PathBuf) -> Option { + self.files.get(&path).map(|entry| entry.value().clone()) + } fn input(&self, path: PathBuf) -> anyhow::Result { - let path = path - .canonicalize() - .with_context(|| format!("Failed to read {}", path.display()))?; Ok(match self.files.entry(path.clone()) { // If the file already exists in our cache then just return it. Entry::Occupied(entry) => *entry.get(), diff --git a/codegen-sdk-analyzer/src/lib.rs b/codegen-sdk-analyzer/src/lib.rs index f6e05d95..22004e5e 100644 --- a/codegen-sdk-analyzer/src/lib.rs +++ b/codegen-sdk-analyzer/src/lib.rs @@ -2,6 +2,6 @@ mod database; mod parser; mod progress; -pub use parser::{Parsed, ParsedFile}; +pub use parser::{Parsed, ParsedFile, parse_file}; mod codebase; pub use codebase::Codebase; diff --git a/codegen-sdk-macros/src/lib.rs b/codegen-sdk-macros/src/lib.rs index 2496dce6..e36dd64f 100644 --- a/codegen-sdk-macros/src/lib.rs +++ b/codegen-sdk-macros/src/lib.rs @@ -101,7 +101,7 @@ pub fn parse_language(_item: TokenStream) -> TokenStream { let variant: proc_macro2::TokenStream = quote! { #[cfg(feature = #name)] if #package_name::cst::#struct_name::should_parse(&file.path(db)).unwrap_or(false) { - let parsed = #package_name::ast::parse(db, file, root); + let parsed = #package_name::ast::parse_query(db, file, root).clone(); return Parsed::new( db, FileNodeId::new(db, file.path(db)), diff --git a/codegen-sdk-resolution/src/codebase.rs b/codegen-sdk-resolution/src/codebase.rs index 49e3a8d9..65678ec9 100644 --- a/codegen-sdk-resolution/src/codebase.rs +++ b/codegen-sdk-resolution/src/codebase.rs @@ -19,7 +19,11 @@ pub trait CodebaseContext { self.get_raw_file(id.path(self.db())) } fn get_raw_file<'a>(&'a self, path: PathBuf) -> Option { - self.db().input(path).ok() + if let Ok(path) = path.canonicalize() { + self.db().get_file(path) + } else { + None + } } fn root_path(&self) -> PathBuf; } diff --git a/codegen-sdk-resolution/src/database.rs b/codegen-sdk-resolution/src/database.rs index 171f8374..31a72a21 100644 --- a/codegen-sdk-resolution/src/database.rs +++ b/codegen-sdk-resolution/src/database.rs @@ -5,6 +5,7 @@ use indicatif::MultiProgress; #[salsa::db] pub trait Db: salsa::Database + Send { fn input(&self, path: PathBuf) -> anyhow::Result; + fn get_file(&self, path: PathBuf) -> Option; fn multi_progress(&self) -> &MultiProgress; fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()>; fn files(&self) -> Vec; diff --git a/codegen-sdk-resolution/src/references.rs b/codegen-sdk-resolution/src/references.rs index 543621cd..ae8ac3cd 100644 --- a/codegen-sdk-resolution/src/references.rs +++ b/codegen-sdk-resolution/src/references.rs @@ -16,9 +16,9 @@ pub trait References< log::info!("Finding references across {:?} files", files.len()); let mut results = Vec::new(); for input in files { - if !self.filter(db, &input) { - continue; - } + // if !self.filter(db, &input) { + // continue; + // } let file = Scope::parse(db, input, root_path.clone()); for reference in file.clone().resolvables(db) { if reference.clone().resolve_type(db).iter().any(|result| *result == *self) { diff --git a/languages/codegen-sdk-python/Cargo.toml b/languages/codegen-sdk-python/Cargo.toml index b14c6976..1372e409 100644 --- a/languages/codegen-sdk-python/Cargo.toml +++ b/languages/codegen-sdk-python/Cargo.toml @@ -17,6 +17,7 @@ codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } codegen-sdk-resolution = { workspace = true } +memchr = { version = "2" } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } @@ -24,6 +25,7 @@ codegen-sdk-ast-generator = { workspace = true } codegen-sdk-common = { workspace = true, features = ["python"] } env_logger = { workspace = true } log = { workspace = true } + [dev-dependencies] test-log = { workspace = true } tempfile = {workspace = true} diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index e4c66fb9..77666728 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -11,12 +11,12 @@ pub mod ast { #[salsa::tracked] impl<'db> Import<'db> { #[salsa::tracked] - fn resolve_import(self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + fn resolve_import(self, db: &'db dyn codegen_sdk_resolution::Db) -> Option { let root_path = self.root_path(db); let module = self.module(db).source().replace(".", "/"); let target_path = root_path.join(module).with_extension("py"); log::info!(target: "resolution", "Resolving import to path: {:?}", target_path); - target_path + target_path.canonicalize().ok() } } #[salsa::tracked] @@ -63,10 +63,12 @@ pub mod ast { #[salsa::tracked(return_ref)] fn resolve_type(self, db: &'db dyn codegen_sdk_resolution::Db) -> Vec { let target_path = self.resolve_import(db); - if let Ok(input) = db.input(target_path) { - return PythonFile::parse(db, input, self.root_path(db)) - .resolve(db, self.name(db).source()) - .to_vec(); + if let Some(target_path) = target_path { + if let Some(input) = db.get_file(target_path) { + return PythonFile::parse(db, input, self.root_path(db)) + .resolve(db, self.name(db).source()) + .to_vec(); + } } Vec::new() } @@ -92,14 +94,14 @@ pub mod ast { db: &'db dyn codegen_sdk_resolution::Db, input: &codegen_sdk_ast::input::File, ) -> bool { - input.path(db).extension().unwrap() == "py" - && match self { - crate::ast::Symbol::Function(function) => input - .contents(db) - .content(db) - .contains(&function.name(db).source()), - _ => false, + match self { + crate::ast::Symbol::Function(function) => { + let content = input.contents(db).content(db); + let target = function.name(db).text(); + memchr::memmem::find(&content.as_bytes(), &target).is_some() } + _ => true, + } } } } diff --git a/src/main.rs b/src/main.rs index 5ddeecfa..505d7424 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use std::{path::PathBuf, time::Instant}; use clap::Parser; -use codegen_sdk_analyzer::{Codebase, ParsedFile}; +use codegen_sdk_analyzer::{Codebase, ParsedFile, parse_file}; use codegen_sdk_ast::Definitions; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; @@ -13,29 +13,27 @@ struct Args { input: String, } fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize, usize, usize)> { - codebase - .files() - .into_iter() - .map(|parsed| { + codebase.execute_op_with_progress("Getting Usages", |db, file, root| { + if let Some(parsed) = parse_file(db, file, root).file(db) { #[cfg(feature = "typescript")] if let ParsedFile::Typescript(file) = parsed { - let definitions = file.definitions(codebase.db()); - if let Some(node) = file.node(codebase.db()) { - let tree = node.tree(codebase.db()); + let definitions = file.definitions(db); + if let Some(node) = file.node(db) { + let tree = node.tree(db); return ( - definitions.classes(codebase.db(), &tree).len(), - definitions.functions(codebase.db(), &tree).len(), - definitions.interfaces(codebase.db(), &tree).len(), - definitions.methods(codebase.db(), &tree).len(), - definitions.modules(codebase.db(), &tree).len(), + definitions.classes(db, &tree).len(), + definitions.functions(db, &tree).len(), + definitions.interfaces(db, &tree).len(), + definitions.methods(db, &tree).len(), + definitions.modules(db, &tree).len(), 0, ); } } #[cfg(feature = "python")] if let ParsedFile::Python(file) = parsed { - let definitions = file.definitions(codebase.db()); - let functions = definitions.functions(codebase.db()); + let definitions = file.definitions(db); + let functions = definitions.functions(db); let mut total_references = 0; let total_functions = functions.len(); for function in functions @@ -44,10 +42,10 @@ fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize .flatten() .map(|function| codegen_sdk_python::ast::Symbol::Function(function.clone())) { - total_references += function.references(codebase.db()).len(); + total_references += function.references(db).len(); } return ( - definitions.classes(codebase.db()).len(), + definitions.classes(db).len(), total_functions, 0, 0, @@ -55,9 +53,9 @@ fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize total_references, ); } - (0, 0, 0, 0, 0, 0) - }) - .collect() + } + (0, 0, 0, 0, 0, 0) + }) } fn print_definitions(codebase: &Codebase) { let mut total_classes = 0; From 776153f4fc619f6027608cf39686cec5820320b8 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Thu, 6 Mar 2025 16:00:47 -0800 Subject: [PATCH 08/16] perf improvements --- Cargo.lock | 74 ++++++--------------- Cargo.toml | 4 +- codegen-sdk-analyzer/src/codebase.rs | 11 ++- codegen-sdk-analyzer/src/codebase/parser.rs | 13 ++-- codegen-sdk-analyzer/src/database.rs | 6 +- codegen-sdk-analyzer/src/parser.rs | 6 +- codegen-sdk-ast-generator/src/generator.rs | 12 ++-- codegen-sdk-ast-generator/src/visitor.rs | 3 +- codegen-sdk-cst-generator/src/generator.rs | 8 +-- codegen-sdk-cst/src/input.rs | 1 + codegen-sdk-macros/src/lib.rs | 2 +- codegen-sdk-resolution/Cargo.toml | 1 + codegen-sdk-resolution/src/lib.rs | 1 + codegen-sdk-resolution/src/parse.rs | 8 +-- codegen-sdk-resolution/src/references.rs | 16 ++--- codegen-sdk-resolution/src/scope.rs | 24 +++++-- languages/codegen-sdk-python/Cargo.toml | 1 - languages/codegen-sdk-python/src/lib.rs | 10 ++- src/main.rs | 4 +- 19 files changed, 93 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a9ebaa2..26643d53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,18 +17,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" -[[package]] -name = "ahash" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" -dependencies = [ - "cfg-if", - "once_cell", - "version_check", - "zerocopy", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -121,12 +109,6 @@ dependencies = [ "backtrace", ] -[[package]] -name = "arc-swap" -version = "1.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" - [[package]] name = "autocfg" version = "1.4.0" @@ -658,6 +640,7 @@ dependencies = [ "anyhow", "codegen-sdk-ast", "codegen-sdk-common", + "indexmap", "indicatif", "log", "salsa", @@ -1072,6 +1055,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1216,24 +1205,25 @@ name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", - "allocator-api2", -] [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "hashlink" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown 0.14.5", + "hashbrown 0.15.2", ] [[package]] @@ -2339,14 +2329,12 @@ checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "salsa" version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e59d074084ce0a89693f021d8317cbc53d23d6502d3b3e2a3d1a7db1ceb13b" +source = "git+https://github.com/salsa-rs/salsa?branch=master#9d2a9786c45000f5fa396ad2872391e302a2836a" dependencies = [ - "arc-swap", "boxcar", "crossbeam-queue", "dashmap", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "hashlink", "indexmap", "parking_lot", @@ -2360,15 +2348,13 @@ dependencies = [ [[package]] name = "salsa-macro-rules" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e354e0bdf1a23d822161e2b0f95c07846535a0e81deba77248a6ac22d19bc97" +version = "0.18.0" +source = "git+https://github.com/salsa-rs/salsa?branch=master#9d2a9786c45000f5fa396ad2872391e302a2836a" [[package]] name = "salsa-macros" version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b061c51d6c6d5d8e4459bcaa11ef18d268286c68263615d65e983071b357fd9" +source = "git+https://github.com/salsa-rs/salsa?branch=master#9d2a9786c45000f5fa396ad2872391e302a2836a" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -3413,26 +3399,6 @@ dependencies = [ "synstructure", ] -[[package]] -name = "zerocopy" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" -dependencies = [ - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.98", -] - [[package]] name = "zerofrom" version = "0.1.6" diff --git a/Cargo.toml b/Cargo.toml index 47872435..fc56769d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,7 +117,7 @@ insta = "1.42.1" prettyplease = "0.2.29" syn = { version = "2.0.98", features = ["proc-macro", "full"] } derive_more = { version = "2.0.1", features = ["debug", "display"] } -salsa = {version = "0.18.0"} +salsa = {git = "https://github.com/salsa-rs/salsa", branch = "master"} subenum = {git = "https://github.com/mrenow/subenum", branch = "main"} indicatif-log-bridge = "0.2.3" indicatif = { version = "0.17.11", features = ["rayon"] } @@ -125,7 +125,7 @@ crossbeam-channel = "0.5.11" rstest = "0.25.0" indextree = "4.7.3" thiserror = "2.0.11" - +indexmap = "2" [profile.dev] # codegen-backend = "cranelift" # split-debuginfo = "unpacked" diff --git a/codegen-sdk-analyzer/src/codebase.rs b/codegen-sdk-analyzer/src/codebase.rs index aa1c2cdf..d90d3a9c 100644 --- a/codegen-sdk-analyzer/src/codebase.rs +++ b/codegen-sdk-analyzer/src/codebase.rs @@ -24,8 +24,9 @@ pub struct Codebase { impl Codebase { pub fn new(root: PathBuf) -> Self { + let root = root.canonicalize().unwrap(); let (tx, rx) = crossbeam_channel::unbounded(); - let mut db = CodegenDatabase::new(tx); + let mut db = CodegenDatabase::new(tx, root.clone()); db.watch_dir(PathBuf::from(&root)).unwrap(); let codebase = Self { db, root, rx }; codebase.sync(); @@ -45,7 +46,7 @@ impl Codebase { // to kick in, just like any other update to a salsa input. let contents = std::fs::read_to_string(path) .with_context(|| format!("Failed to read file {}", event.path.display()))?; - let input = Input::new(&self.db, contents); + let input = Input::new(&self.db, contents, self.root.clone()); file.set_contents(&mut self.db).to(input); } Err(e) => { @@ -88,7 +89,7 @@ impl Codebase { pub fn execute_op_with_progress( &self, name: &str, - op: fn(&dyn Db, File, PathBuf) -> T, + op: fn(&dyn Db, File) -> T, ) -> Vec { execute_op_with_progress(self._db(), self.discover(), name, op) } @@ -114,9 +115,7 @@ impl CodebaseContext for Codebase { if let Ok(path) = path.canonicalize() { let file = self.db.files.get(&path); if let Some(file) = file { - return parse_file(&self.db, file.clone(), self.root.clone()) - .file(&self.db) - .as_ref(); + return parse_file(&self.db, file.clone()).file(&self.db).as_ref(); } } None diff --git a/codegen-sdk-analyzer/src/codebase/parser.rs b/codegen-sdk-analyzer/src/codebase/parser.rs index d21598ca..e3c48acf 100644 --- a/codegen-sdk-analyzer/src/codebase/parser.rs +++ b/codegen-sdk-analyzer/src/codebase/parser.rs @@ -12,7 +12,7 @@ pub fn execute_op_with_progress db: &Database, files: FilesToParse, name: &str, - op: fn(&Database, File, PathBuf) -> T, + op: fn(&Database, File) -> T, ) -> Vec { let multi = db.multi_progress(); let style = ProgressStyle::with_template( @@ -27,16 +27,15 @@ pub fn execute_op_with_progress let inputs = files .files(db) .into_iter() - .map(|file| (&pg, file, files.root(db).clone(), op)) + .map(|file| (&pg, file, op)) .collect::>(); let results: Vec = salsa::par_map(db, inputs, move |db, input| { - let (pg, file, root, op) = input; + let (pg, file, op) = input; let res = op( db, #[cfg(feature = "serialization")] &cache, file, - root, ); pg.inc(1); res @@ -53,8 +52,8 @@ pub fn execute_op_with_progress // } #[salsa::tracked] fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { - let _: Vec<_> = execute_op_with_progress(db, files, "Parsing Files", |db, file, root| { - let file = parse_file(db, file, root); + let _: Vec<_> = execute_op_with_progress(db, files, "Parsing Files", |db, file| { + let file = parse_file(db, file); if let Some(parsed) = file.file(db) { #[cfg(feature = "typescript")] if let ParsedFile::Typescript(parsed) = parsed { @@ -65,7 +64,7 @@ fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { if let ParsedFile::Python(parsed) = parsed { parsed.definitions(db); parsed.references(db); - parsed.compute_dependencies(db); + parsed.compute_dependencies_query(db); } } () diff --git a/codegen-sdk-analyzer/src/database.rs b/codegen-sdk-analyzer/src/database.rs index de0acb1b..3c89d07e 100644 --- a/codegen-sdk-analyzer/src/database.rs +++ b/codegen-sdk-analyzer/src/database.rs @@ -25,6 +25,7 @@ pub struct CodegenDatabase { dirs: Vec, multi_progress: MultiProgress, file_watcher: Arc>>, + root: PathBuf, } fn get_watcher( tx: crossbeam_channel::Sender, @@ -35,7 +36,7 @@ fn get_watcher( Arc::new(Mutex::new(new_debouncer_opt(config, tx).unwrap())) } impl CodegenDatabase { - pub fn new(tx: crossbeam_channel::Sender) -> Self { + pub fn new(tx: crossbeam_channel::Sender, root: PathBuf) -> Self { let multi_progress = get_multi_progress(); Self { file_watcher: get_watcher(tx), @@ -43,6 +44,7 @@ impl CodegenDatabase { multi_progress, files: DashMap::new(), dirs: Vec::new(), + root, } } fn _watch_file(&self, path: &PathBuf) -> anyhow::Result<()> { @@ -102,7 +104,7 @@ impl Db for CodegenDatabase { self._watch_file(&path)?; let contents = std::fs::read_to_string(&path) .with_context(|| format!("Failed to read {}", path.display()))?; - let input = Input::new(self, contents); + let input = Input::new(self, contents, self.root.clone()); *entry.insert(File::new(self, path, input)) } }) diff --git a/codegen-sdk-analyzer/src/parser.rs b/codegen-sdk-analyzer/src/parser.rs index 1c3b2033..f152cc20 100644 --- a/codegen-sdk-analyzer/src/parser.rs +++ b/codegen-sdk-analyzer/src/parser.rs @@ -14,11 +14,7 @@ pub struct Parsed<'db> { pub file: Option>, } #[salsa::tracked(return_ref)] -pub fn parse_file( - db: &dyn salsa::Database, - file: codegen_sdk_ast::input::File, - root: PathBuf, -) -> Parsed<'_> { +pub fn parse_file(db: &dyn salsa::Database, file: codegen_sdk_ast::input::File) -> Parsed<'_> { parse_language!(); Parsed::new(db, FileNodeId::new(db, file.path(db)), None) } diff --git a/codegen-sdk-ast-generator/src/generator.rs b/codegen-sdk-ast-generator/src/generator.rs index e8f6da94..3fe50533 100644 --- a/codegen-sdk-ast-generator/src/generator.rs +++ b/codegen-sdk-ast-generator/src/generator.rs @@ -72,8 +72,8 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { pub id: codegen_sdk_common::FileNodeId<'db>, } impl<'db> codegen_sdk_resolution::Parse<'db> for #language_struct_name<'db> { - fn parse(db: &'db dyn salsa::Database, input: codegen_sdk_ast::input::File, root: PathBuf) -> &'db Self { - parse_query(db, input, root) + fn parse(db: &'db dyn salsa::Database, input: codegen_sdk_ast::input::File) -> &'db Self { + parse_query(db, input) } } // impl<'db> File for {language_struct_name}File<'db> {{ @@ -81,15 +81,15 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { // &self.path(db) // }} // }} - pub fn parse(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File, root: PathBuf) -> #language_struct_name<'_> { + pub fn parse(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File) -> #language_struct_name<'_> { log::debug!("Parsing {} file: {}", input.path(db).display(), #language_name_str); - let ast = crate::cst::parse_program_raw(db, input.contents(db), input.path(db).clone(), root); + let ast = crate::cst::parse_program_raw(db, input.contents(db), input.path(db).clone()); let file_id = codegen_sdk_common::FileNodeId::new(db, input.path(db).clone()); #language_struct_name::new(db, ast, file_id) } #[salsa::tracked(return_ref)] - pub fn parse_query(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File, root: PathBuf) -> #language_struct_name<'_> { - parse(db, input, root) + pub fn parse_query(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File) -> #language_struct_name<'_> { + parse(db, input) } impl<'db> #language_struct_name<'db> { diff --git a/codegen-sdk-ast-generator/src/visitor.rs b/codegen-sdk-ast-generator/src/visitor.rs index 5d7d15f9..72080de2 100644 --- a/codegen-sdk-ast-generator/src/visitor.rs +++ b/codegen-sdk-ast-generator/src/visitor.rs @@ -91,9 +91,8 @@ pub fn generate_visitor<'db>( type File<'db1> = #language_struct<'db1>; fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { let path = self.node(db).id().file(db).path(db); - let root = self.root_path(db); let input = db.input(path).unwrap(); - parse_query(db, input, root) + parse_query(db, input) } fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf { self.node(db).id().root(db).path(db) diff --git a/codegen-sdk-cst-generator/src/generator.rs b/codegen-sdk-cst-generator/src/generator.rs index 36004570..3f5d3a48 100644 --- a/codegen-sdk-cst-generator/src/generator.rs +++ b/codegen-sdk-cst-generator/src/generator.rs @@ -56,7 +56,7 @@ fn get_parser(language: &Language) -> TokenStream { pub tree: Tree>, pub program: indextree::NodeId, } - pub fn parse_program_raw<'db>(db: &'db dyn salsa::Database, input: codegen_sdk_cst::Input, path: PathBuf, root: PathBuf) -> Option> { + pub fn parse_program_raw<'db>(db: &'db dyn salsa::Database, input: codegen_sdk_cst::Input, path: PathBuf) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::#language_name::#language_struct_name.parse_tree_sitter(&input.content(db)); match tree { @@ -65,7 +65,7 @@ fn get_parser(language: &Language) -> TokenStream { ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, root, buffer); + let mut context = ParseContext::new(db, path, input.root(db), buffer); let root_id = #program_id::orphaned(&mut context, tree.root_node()) .map_or_else(|e| { e.report(db); @@ -88,7 +88,7 @@ fn get_parser(language: &Language) -> TokenStream { } #[salsa::tracked(return_ref)] pub fn parse_program(db: &dyn salsa::Database, input: codegen_sdk_cst::Input) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new(), std::path::PathBuf::new()); + let raw = parse_program_raw(db, input, std::path::PathBuf::new()); if let Some(parsed) = raw { parsed } else { @@ -103,7 +103,7 @@ fn get_parser(language: &Language) -> TokenStream { &codegen_sdk_common::language::#language_name::#language_struct_name } fn parse<'db>(db: &'db dyn salsa::Database, content: std::string::String) -> Option<(&'db Self::Program<'db>, &'db Tree>, indextree::NodeId)> { - let input = codegen_sdk_cst::Input::new(db, content); + let input = codegen_sdk_cst::Input::new(db, content, std::path::PathBuf::new()); let parsed = parse_program(db, input); let program_id = parsed.program(db); let tree = parsed.tree(db); diff --git a/codegen-sdk-cst/src/input.rs b/codegen-sdk-cst/src/input.rs index 658de06a..8e87676f 100644 --- a/codegen-sdk-cst/src/input.rs +++ b/codegen-sdk-cst/src/input.rs @@ -2,4 +2,5 @@ use std::path::PathBuf; #[salsa::input] pub struct Input { pub content: String, + pub root: PathBuf, } diff --git a/codegen-sdk-macros/src/lib.rs b/codegen-sdk-macros/src/lib.rs index e36dd64f..916cbae6 100644 --- a/codegen-sdk-macros/src/lib.rs +++ b/codegen-sdk-macros/src/lib.rs @@ -101,7 +101,7 @@ pub fn parse_language(_item: TokenStream) -> TokenStream { let variant: proc_macro2::TokenStream = quote! { #[cfg(feature = #name)] if #package_name::cst::#struct_name::should_parse(&file.path(db)).unwrap_or(false) { - let parsed = #package_name::ast::parse_query(db, file, root).clone(); + let parsed = #package_name::ast::parse_query(db, file).clone(); return Parsed::new( db, FileNodeId::new(db, file.path(db)), diff --git a/codegen-sdk-resolution/Cargo.toml b/codegen-sdk-resolution/Cargo.toml index 3d10b2a3..118b2a6a 100644 --- a/codegen-sdk-resolution/Cargo.toml +++ b/codegen-sdk-resolution/Cargo.toml @@ -12,3 +12,4 @@ codegen-sdk-common = { workspace = true } anyhow = { workspace = true } indicatif = { workspace = true } ambassador = { workspace = true } +indexmap = "2" diff --git a/codegen-sdk-resolution/src/lib.rs b/codegen-sdk-resolution/src/lib.rs index 600dc93c..0de8a6f8 100644 --- a/codegen-sdk-resolution/src/lib.rs +++ b/codegen-sdk-resolution/src/lib.rs @@ -19,3 +19,4 @@ pub trait HasFile<'db> { fn file(&self, db: &'db dyn Db) -> &'db Self::File<'db>; fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf; } +pub use indexmap; diff --git a/codegen-sdk-resolution/src/parse.rs b/codegen-sdk-resolution/src/parse.rs index d173c14e..3d973766 100644 --- a/codegen-sdk-resolution/src/parse.rs +++ b/codegen-sdk-resolution/src/parse.rs @@ -1,11 +1,5 @@ -use std::path::PathBuf; - use salsa::Database; pub trait Parse<'db> { - fn parse( - db: &'db dyn Database, - input: codegen_sdk_ast::input::File, - root: PathBuf, - ) -> &'db Self; + fn parse(db: &'db dyn Database, input: codegen_sdk_ast::input::File) -> &'db Self; } diff --git a/codegen-sdk-resolution/src/references.rs b/codegen-sdk-resolution/src/references.rs index ae8ac3cd..20f68608 100644 --- a/codegen-sdk-resolution/src/references.rs +++ b/codegen-sdk-resolution/src/references.rs @@ -1,10 +1,12 @@ +use std::hash::Hash; + use crate::{Db, HasFile, Parse, ResolveType}; pub trait References< 'db, - ReferenceType: ResolveType<'db, Type = Self> + Clone, // References must resolve to this type + ReferenceType: ResolveType<'db, Type = Self> + Eq + Hash + Clone + 'db, // References must resolve to this type Scope: crate::Scope<'db, Type = Self, ReferenceType = ReferenceType> + Clone + 'db, ->: Eq + PartialEq + HasFile<'db, File<'db> = Scope> + 'db where Self:'db +>: Eq + PartialEq + Hash + HasFile<'db, File<'db> = Scope> + 'db where Self:'db { fn references(&self, db: &'db dyn Db) -> Vec where @@ -12,18 +14,16 @@ pub trait References< Scope: Parse<'db>, { let files = db.files(); - let root_path = self.root_path(db); log::info!("Finding references across {:?} files", files.len()); let mut results = Vec::new(); for input in files { // if !self.filter(db, &input) { // continue; // } - let file = Scope::parse(db, input, root_path.clone()); - for reference in file.clone().resolvables(db) { - if reference.clone().resolve_type(db).iter().any(|result| *result == *self) { - results.push(reference); - } + let file = Scope::parse(db, input); + let dependencies = file.clone().compute_dependencies_query(db); + if let Some(references) = dependencies.get(self) { + results.extend(references.iter().cloned()); } } results diff --git a/codegen-sdk-resolution/src/scope.rs b/codegen-sdk-resolution/src/scope.rs index 7ac3a8a3..6174cb3f 100644 --- a/codegen-sdk-resolution/src/scope.rs +++ b/codegen-sdk-resolution/src/scope.rs @@ -1,17 +1,33 @@ +use std::hash::Hash; + +use indexmap::IndexMap; + use crate::{Db, ResolveType}; // Resolve a given string name in a scope to a given type pub trait Scope<'db>: Sized { - type Type; - type ReferenceType: ResolveType<'db>; + type Type: Eq + Hash + Clone; + type ReferenceType: ResolveType<'db, Type = Self::Type> + Eq + Hash + Clone; fn resolve(self, db: &'db dyn Db, name: String) -> &'db Vec; /// Get all the resolvables (IE: function_calls) in the scope fn resolvables(self, db: &'db dyn Db) -> Vec; - fn compute_dependencies(self, db: &'db dyn Db) + fn compute_dependencies_query( + self, + db: &'db dyn Db, + ) -> &'db IndexMap>; + fn compute_dependencies(self, db: &'db dyn Db) -> IndexMap> where Self: 'db, { + let mut dependencies: IndexMap> = IndexMap::new(); for reference in self.resolvables(db) { - reference.resolve_type(db); + let resolved = reference.clone().resolve_type(db); + for resolved in resolved { + dependencies + .entry(resolved.clone()) + .or_default() + .push(reference.clone()); + } } + dependencies } } diff --git a/languages/codegen-sdk-python/Cargo.toml b/languages/codegen-sdk-python/Cargo.toml index 1372e409..ddec4aab 100644 --- a/languages/codegen-sdk-python/Cargo.toml +++ b/languages/codegen-sdk-python/Cargo.toml @@ -18,7 +18,6 @@ log = { workspace = true } codegen-sdk-ast = { workspace = true } codegen-sdk-resolution = { workspace = true } memchr = { version = "2" } - [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index 77666728..777309a4 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -56,6 +56,14 @@ pub mod ast { } results } + #[salsa::tracked(return_ref)] + fn compute_dependencies_query( + self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> codegen_sdk_resolution::indexmap::IndexMap> + { + self.compute_dependencies(db) + } } #[salsa::tracked] impl<'db> ResolveType<'db> for crate::ast::Import<'db> { @@ -65,7 +73,7 @@ pub mod ast { let target_path = self.resolve_import(db); if let Some(target_path) = target_path { if let Some(input) = db.get_file(target_path) { - return PythonFile::parse(db, input, self.root_path(db)) + return PythonFile::parse(db, input) .resolve(db, self.name(db).source()) .to_vec(); } diff --git a/src/main.rs b/src/main.rs index 505d7424..2bb2c58c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,8 +13,8 @@ struct Args { input: String, } fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize, usize, usize)> { - codebase.execute_op_with_progress("Getting Usages", |db, file, root| { - if let Some(parsed) = parse_file(db, file, root).file(db) { + codebase.execute_op_with_progress("Getting Usages", |db, file| { + if let Some(parsed) = parse_file(db, file).file(db) { #[cfg(feature = "typescript")] if let ParsedFile::Typescript(file) = parsed { let definitions = file.definitions(db); From 57660a6c8392338ac3fb0e7f556a43605a24d46b Mon Sep 17 00:00:00 2001 From: bagel897 Date: Fri, 7 Mar 2025 10:33:27 -0800 Subject: [PATCH 09/16] Perf --- Cargo.lock | 12 +++ codegen-sdk-analyzer/Cargo.toml | 1 + codegen-sdk-analyzer/src/codebase.rs | 14 ++- .../src/codebase/discovery.rs | 2 +- codegen-sdk-analyzer/src/codebase/parser.rs | 17 ++-- codegen-sdk-analyzer/src/database.rs | 23 +++-- codegen-sdk-analyzer/src/parser.rs | 2 +- codegen-sdk-ast-generator/src/generator.rs | 14 +-- codegen-sdk-ast-generator/src/lib.rs | 1 - codegen-sdk-ast-generator/src/query.rs | 3 +- codegen-sdk-ast-generator/src/visitor.rs | 18 +++- codegen-sdk-ast/src/input.rs | 10 -- codegen-sdk-ast/src/lib.rs | 19 ++-- codegen-sdk-cst-generator/src/generator.rs | 14 +-- codegen-sdk-cst/src/database.rs | 2 +- codegen-sdk-cst/src/input.rs | 6 +- codegen-sdk-cst/src/lib.rs | 2 +- codegen-sdk-macros/src/lib.rs | 2 +- codegen-sdk-resolution/Cargo.toml | 1 + codegen-sdk-resolution/src/codebase.rs | 4 +- codegen-sdk-resolution/src/database.rs | 8 +- codegen-sdk-resolution/src/lib.rs | 6 +- codegen-sdk-resolution/src/name.rs | 14 +++ codegen-sdk-resolution/src/parse.rs | 2 +- codegen-sdk-resolution/src/references.rs | 18 ++-- codegen-sdk-resolution/src/scope.rs | 23 +++-- languages/codegen-sdk-go/Cargo.toml | 1 + languages/codegen-sdk-java/Cargo.toml | 1 + languages/codegen-sdk-jsx/Cargo.toml | 1 + languages/codegen-sdk-markdown/Cargo.toml | 1 + languages/codegen-sdk-python/src/lib.rs | 97 +++++++++++++++++-- .../codegen-sdk-python/tests/test_python.rs | 44 +++------ languages/codegen-sdk-ruby/Cargo.toml | 1 + languages/codegen-sdk-rust/Cargo.toml | 1 + languages/codegen-sdk-toml/Cargo.toml | 1 + languages/codegen-sdk-tsx/Cargo.toml | 1 + languages/codegen-sdk-typescript/Cargo.toml | 1 + .../tests/test_typescript.rs | 11 +-- languages/codegen-sdk-yaml/Cargo.toml | 1 + 39 files changed, 262 insertions(+), 138 deletions(-) delete mode 100644 codegen-sdk-ast/src/input.rs create mode 100644 codegen-sdk-resolution/src/name.rs diff --git a/Cargo.lock b/Cargo.lock index 26643d53..00c7cc09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,6 +331,7 @@ dependencies = [ "dashmap", "env_logger", "glob", + "indexmap", "indicatif", "indicatif-log-bridge", "log", @@ -487,6 +488,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -507,6 +509,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -570,6 +573,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -599,6 +603,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -640,6 +645,7 @@ dependencies = [ "anyhow", "codegen-sdk-ast", "codegen-sdk-common", + "codegen-sdk-cst", "indexmap", "indicatif", "log", @@ -657,6 +663,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -677,6 +684,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -697,6 +705,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -735,6 +744,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -755,6 +765,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", @@ -777,6 +788,7 @@ dependencies = [ "codegen-sdk-common", "codegen-sdk-cst", "codegen-sdk-cst-generator", + "codegen-sdk-resolution", "derive_more", "env_logger", "indextree", diff --git a/codegen-sdk-analyzer/Cargo.toml b/codegen-sdk-analyzer/Cargo.toml index 76f61720..086fd0cd 100644 --- a/codegen-sdk-analyzer/Cargo.toml +++ b/codegen-sdk-analyzer/Cargo.toml @@ -34,6 +34,7 @@ crossbeam-channel = { workspace = true } glob = "0.3.2" rayon = { workspace = true } ambassador = { workspace = true } +indexmap = {workspace = true} [features] python = [ "codegen-sdk-python"] # TODO: Add python support typescript = [ "codegen-sdk-typescript"] diff --git a/codegen-sdk-analyzer/src/codebase.rs b/codegen-sdk-analyzer/src/codebase.rs index d90d3a9c..07ae9a52 100644 --- a/codegen-sdk-analyzer/src/codebase.rs +++ b/codegen-sdk-analyzer/src/codebase.rs @@ -1,13 +1,13 @@ use std::path::PathBuf; use anyhow::Context; -use codegen_sdk_ast::{Input, input::File}; #[cfg(feature = "serialization")] use codegen_sdk_common::serialization::Cache; +use codegen_sdk_cst::File; use codegen_sdk_resolution::{CodebaseContext, Db}; use discovery::FilesToParse; use notify_debouncer_mini::DebounceEventResult; -use salsa::{AsDynDatabase, Database, Setter}; +use salsa::Setter; use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; mod discovery; @@ -46,8 +46,7 @@ impl Codebase { // to kick in, just like any other update to a salsa input. let contents = std::fs::read_to_string(path) .with_context(|| format!("Failed to read file {}", event.path.display()))?; - let input = Input::new(&self.db, contents, self.root.clone()); - file.set_contents(&mut self.db).to(input); + file.set_content(&mut self.db).to(contents); } Err(e) => { log::error!( @@ -91,7 +90,12 @@ impl Codebase { name: &str, op: fn(&dyn Db, File) -> T, ) -> Vec { - execute_op_with_progress(self._db(), self.discover(), name, op) + execute_op_with_progress( + self._db(), + codegen_sdk_resolution::files(self._db()), + name, + op, + ) } } impl CodebaseContext for Codebase { diff --git a/codegen-sdk-analyzer/src/codebase/discovery.rs b/codegen-sdk-analyzer/src/codebase/discovery.rs index 169e84eb..1a29f4d4 100644 --- a/codegen-sdk-analyzer/src/codebase/discovery.rs +++ b/codegen-sdk-analyzer/src/codebase/discovery.rs @@ -9,7 +9,7 @@ use glob::glob; use crate::database::CodegenDatabase; #[salsa::input] pub struct FilesToParse { - pub files: Vec, + pub files: indexmap::IndexSet, pub root: PathBuf, } pub fn log_languages() { diff --git a/codegen-sdk-analyzer/src/codebase/parser.rs b/codegen-sdk-analyzer/src/codebase/parser.rs index e3c48acf..ac9a47be 100644 --- a/codegen-sdk-analyzer/src/codebase/parser.rs +++ b/codegen-sdk-analyzer/src/codebase/parser.rs @@ -1,16 +1,16 @@ -use std::path::PathBuf; - -use codegen_sdk_ast::{Definitions, References, input::File}; +use codegen_sdk_ast::{Definitions, References}; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; +use codegen_sdk_cst::File; use codegen_sdk_resolution::{Db, Scope}; +use indexmap::IndexSet; use indicatif::{ProgressBar, ProgressStyle}; use super::discovery::{FilesToParse, log_languages}; use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; pub fn execute_op_with_progress( db: &Database, - files: FilesToParse, + files: IndexSet, name: &str, op: fn(&Database, File) -> T, ) -> Vec { @@ -20,12 +20,11 @@ pub fn execute_op_with_progress ) .unwrap(); let pg = multi.add( - ProgressBar::new(files.files(db).len() as u64) + ProgressBar::new(files.len() as u64) .with_style(style) .with_message(name.to_string()), ); let inputs = files - .files(db) .into_iter() .map(|file| (&pg, file, op)) .collect::>(); @@ -52,8 +51,8 @@ pub fn execute_op_with_progress // } #[salsa::tracked] fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { - let _: Vec<_> = execute_op_with_progress(db, files, "Parsing Files", |db, file| { - let file = parse_file(db, file); + let _: Vec<_> = execute_op_with_progress(db, files.files(db), "Parsing Files", |db, input| { + let file = parse_file(db, input.clone()); if let Some(parsed) = file.file(db) { #[cfg(feature = "typescript")] if let ParsedFile::Typescript(parsed) = parsed { @@ -64,7 +63,7 @@ fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { if let ParsedFile::Python(parsed) = parsed { parsed.definitions(db); parsed.references(db); - parsed.compute_dependencies_query(db); + codegen_sdk_python::ast::dependencies(db, input); } } () diff --git a/codegen-sdk-analyzer/src/database.rs b/codegen-sdk-analyzer/src/database.rs index 3c89d07e..99a2df18 100644 --- a/codegen-sdk-analyzer/src/database.rs +++ b/codegen-sdk-analyzer/src/database.rs @@ -5,8 +5,7 @@ use std::{ }; use anyhow::Context; -use codegen_sdk_ast::input::File; -use codegen_sdk_cst::Input; +use codegen_sdk_cst::File; use codegen_sdk_resolution::Db; use dashmap::{DashMap, mapref::entry::Entry}; use indicatif::MultiProgress; @@ -63,17 +62,17 @@ impl CodegenDatabase { } #[salsa::db] impl salsa::Database for CodegenDatabase { - fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) { // don't log boring events - let event = event(); - if let salsa::EventKind::WillExecute { .. } = event.kind { - log::debug!("{:?}", event); - } + // let event = event(); + // if let salsa::EventKind::WillExecute { .. } = event.kind { + // log::debug!("{:?}", event); + // } } } #[salsa::db] impl Db for CodegenDatabase { - fn files(&self) -> Vec { + fn files(&self) -> indexmap::IndexSet { self.files .iter() .map(|entry| entry.value().clone()) @@ -93,6 +92,7 @@ impl Db for CodegenDatabase { self.files.get(&path).map(|entry| entry.value().clone()) } fn input(&self, path: PathBuf) -> anyhow::Result { + let path = path.canonicalize()?; Ok(match self.files.entry(path.clone()) { // If the file already exists in our cache then just return it. Entry::Occupied(entry) => *entry.get(), @@ -104,8 +104,11 @@ impl Db for CodegenDatabase { self._watch_file(&path)?; let contents = std::fs::read_to_string(&path) .with_context(|| format!("Failed to read {}", path.display()))?; - let input = Input::new(self, contents, self.root.clone()); - *entry.insert(File::new(self, path, input)) + *entry.insert( + File::builder(path, contents, self.root.clone()) + .root_durability(salsa::Durability::HIGH) + .new(self), + ) } }) } diff --git a/codegen-sdk-analyzer/src/parser.rs b/codegen-sdk-analyzer/src/parser.rs index f152cc20..80177f77 100644 --- a/codegen-sdk-analyzer/src/parser.rs +++ b/codegen-sdk-analyzer/src/parser.rs @@ -14,7 +14,7 @@ pub struct Parsed<'db> { pub file: Option>, } #[salsa::tracked(return_ref)] -pub fn parse_file(db: &dyn salsa::Database, file: codegen_sdk_ast::input::File) -> Parsed<'_> { +pub fn parse_file(db: &dyn salsa::Database, file: codegen_sdk_cst::File) -> Parsed<'_> { parse_language!(); Parsed::new(db, FileNodeId::new(db, file.path(db)), None) } diff --git a/codegen-sdk-ast-generator/src/generator.rs b/codegen-sdk-ast-generator/src/generator.rs index 3fe50533..40d61033 100644 --- a/codegen-sdk-ast-generator/src/generator.rs +++ b/codegen-sdk-ast-generator/src/generator.rs @@ -72,8 +72,8 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { pub id: codegen_sdk_common::FileNodeId<'db>, } impl<'db> codegen_sdk_resolution::Parse<'db> for #language_struct_name<'db> { - fn parse(db: &'db dyn salsa::Database, input: codegen_sdk_ast::input::File) -> &'db Self { - parse_query(db, input) + fn parse(db: &'db dyn salsa::Database, input: codegen_sdk_cst::File) -> &'db Self { + parse(db, input) } } // impl<'db> File for {language_struct_name}File<'db> {{ @@ -81,17 +81,13 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { // &self.path(db) // }} // }} - pub fn parse(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File) -> #language_struct_name<'_> { + #[salsa::tracked(return_ref)] + pub fn parse(db: &dyn salsa::Database, input: codegen_sdk_cst::File) -> #language_struct_name<'_> { log::debug!("Parsing {} file: {}", input.path(db).display(), #language_name_str); - let ast = crate::cst::parse_program_raw(db, input.contents(db), input.path(db).clone()); + let ast = crate::cst::parse_program_raw(db, input); let file_id = codegen_sdk_common::FileNodeId::new(db, input.path(db).clone()); #language_struct_name::new(db, ast, file_id) } - #[salsa::tracked(return_ref)] - pub fn parse_query(db: &dyn salsa::Database, input: codegen_sdk_ast::input::File) -> #language_struct_name<'_> { - parse(db, input) - } - impl<'db> #language_struct_name<'db> { pub fn tree(&self, db: &'db dyn salsa::Database) -> &'db codegen_sdk_common::Tree> { self.node(db).unwrap().tree(db) diff --git a/codegen-sdk-ast-generator/src/lib.rs b/codegen-sdk-ast-generator/src/lib.rs index bf97d8ef..9023a56c 100644 --- a/codegen-sdk-ast-generator/src/lib.rs +++ b/codegen-sdk-ast-generator/src/lib.rs @@ -14,7 +14,6 @@ pub fn generate_ast(language: &Language) -> anyhow::Result<()> { use std::path::PathBuf; use codegen_sdk_cst::CSTLanguage; use std::collections::BTreeMap; - use std::sync::mpsc::Sender; use codegen_sdk_resolution::HasFile; use codegen_sdk_resolution::Parse; }; diff --git a/codegen-sdk-ast-generator/src/query.rs b/codegen-sdk-ast-generator/src/query.rs index f9d4b614..883f00cd 100644 --- a/codegen-sdk-ast-generator/src/query.rs +++ b/codegen-sdk-ast-generator/src/query.rs @@ -661,7 +661,8 @@ impl<'a> Query<'a> { }); let symbol_name = self.symbol_name(); return quote! { - let symbol = #symbol_name::new(db, id, node.clone(), #(#args.clone().into()),*); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new(db, node.file_id(),#name.source()); + let symbol = #symbol_name::new(db, fully_qualified_name, id, node.clone(), #(#args.clone().into()),*); #to_append.entry(#name.source()).or_default().push(symbol); }; } diff --git a/codegen-sdk-ast-generator/src/visitor.rs b/codegen-sdk-ast-generator/src/visitor.rs index 72080de2..35e1aca6 100644 --- a/codegen-sdk-ast-generator/src/visitor.rs +++ b/codegen-sdk-ast-generator/src/visitor.rs @@ -80,6 +80,8 @@ pub fn generate_visitor<'db>( span => #[salsa::tracked] pub struct #variant<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] @@ -91,13 +93,18 @@ pub fn generate_visitor<'db>( type File<'db1> = #language_struct<'db1>; fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { let path = self.node(db).id().file(db).path(db); - let input = db.input(path).unwrap(); - parse_query(db, input) + let input = db.get_file(path).unwrap(); + parse(db, input) } fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf { self.node(db).id().root(db).path(db) } } + impl<'db> codegen_sdk_resolution::HasId<'db> for #variant<'db> { + fn fully_qualified_name(&self, db: &'db dyn codegen_sdk_resolution::Db) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } + } }); } let symbol = if defs.len() > 0 { @@ -124,6 +131,13 @@ pub fn generate_visitor<'db>( } } } + impl<'db> codegen_sdk_resolution::HasId<'db> for #symbol_name<'db> { + fn fully_qualified_name(&self, db: &'db dyn codegen_sdk_resolution::Db) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + match self { + #(Self::#symbol_names(symbol) => symbol.fully_qualified_name(db),)* + } + } + } } } else { quote! { diff --git a/codegen-sdk-ast/src/input.rs b/codegen-sdk-ast/src/input.rs deleted file mode 100644 index 9e97d421..00000000 --- a/codegen-sdk-ast/src/input.rs +++ /dev/null @@ -1,10 +0,0 @@ -use std::path::PathBuf; - -use codegen_sdk_cst::Input; -#[salsa::input] -pub struct File { - #[id] - pub path: PathBuf, - // #[return_ref] - pub contents: Input, -} diff --git a/codegen-sdk-ast/src/lib.rs b/codegen-sdk-ast/src/lib.rs index cb9958ff..4cc0d279 100644 --- a/codegen-sdk-ast/src/lib.rs +++ b/codegen-sdk-ast/src/lib.rs @@ -1,17 +1,16 @@ #![recursion_limit = "512"] -pub mod input; use ambassador::delegatable_trait; -use codegen_sdk_common::File; +// use codegen_sdk_common::File; pub use codegen_sdk_common::language::LANGUAGES; pub use codegen_sdk_cst::*; -pub trait Named { - fn name(&self) -> &str; -} -impl Named for T { - fn name(&self) -> &str { - self.path().file_name().unwrap().to_str().unwrap() - } -} +// pub trait Named { +// fn name(&self) -> &str; +// } +// impl Named for T { +// fn name(&self) -> &str { +// self.path().file_name().unwrap().to_str().unwrap() +// } +// } #[delegatable_trait] pub trait Definitions<'db> { type Definitions; diff --git a/codegen-sdk-cst-generator/src/generator.rs b/codegen-sdk-cst-generator/src/generator.rs index 3f5d3a48..c9ffdf29 100644 --- a/codegen-sdk-cst-generator/src/generator.rs +++ b/codegen-sdk-cst-generator/src/generator.rs @@ -53,10 +53,10 @@ fn get_parser(language: &Language) -> TokenStream { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + pub tree: Arc>>, pub program: indextree::NodeId, } - pub fn parse_program_raw<'db>(db: &'db dyn salsa::Database, input: codegen_sdk_cst::Input, path: PathBuf) -> Option> { + pub fn parse_program_raw<'db>(db: &'db dyn salsa::Database, input: codegen_sdk_cst::File) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::#language_name::#language_struct_name.parse_tree_sitter(&input.content(db)); match tree { @@ -65,7 +65,7 @@ fn get_parser(language: &Language) -> TokenStream { ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, input.root(db), buffer); + let mut context = ParseContext::new(db, input.path(db), input.root(db), buffer); let root_id = #program_id::orphaned(&mut context, tree.root_node()) .map_or_else(|e| { e.report(db); @@ -74,7 +74,7 @@ fn get_parser(language: &Language) -> TokenStream { Some(program) }); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some(Parsed::new(db, context.file_id, Arc::new(context.tree), program)) } else { None } @@ -87,8 +87,8 @@ fn get_parser(language: &Language) -> TokenStream { } } #[salsa::tracked(return_ref)] - pub fn parse_program(db: &dyn salsa::Database, input: codegen_sdk_cst::Input) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + pub fn parse_program(db: &dyn salsa::Database, input: codegen_sdk_cst::File) -> Parsed<'_> { + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -103,7 +103,7 @@ fn get_parser(language: &Language) -> TokenStream { &codegen_sdk_common::language::#language_name::#language_struct_name } fn parse<'db>(db: &'db dyn salsa::Database, content: std::string::String) -> Option<(&'db Self::Program<'db>, &'db Tree>, indextree::NodeId)> { - let input = codegen_sdk_cst::Input::new(db, content, std::path::PathBuf::new()); + let input = codegen_sdk_cst::File::new(db, std::path::PathBuf::new(), content, std::path::PathBuf::new()); let parsed = parse_program(db, input); let program_id = parsed.program(db); let tree = parsed.tree(db); diff --git a/codegen-sdk-cst/src/database.rs b/codegen-sdk-cst/src/database.rs index de19d043..f62cb369 100644 --- a/codegen-sdk-cst/src/database.rs +++ b/codegen-sdk-cst/src/database.rs @@ -2,7 +2,7 @@ use std::{any::Any, path::PathBuf, sync::Arc}; use dashmap::{DashMap, mapref::entry::Entry}; -use crate::Input; +use crate::File; #[salsa::db] #[derive(Default, Clone)] // Basic Database implementation for Query generation. This is not used for anything else. diff --git a/codegen-sdk-cst/src/input.rs b/codegen-sdk-cst/src/input.rs index 8e87676f..d668f025 100644 --- a/codegen-sdk-cst/src/input.rs +++ b/codegen-sdk-cst/src/input.rs @@ -1,6 +1,10 @@ use std::path::PathBuf; #[salsa::input] -pub struct Input { +pub struct File { + #[id] + pub path: PathBuf, + #[return_ref] pub content: String, + #[id] pub root: PathBuf, } diff --git a/codegen-sdk-cst/src/lib.rs b/codegen-sdk-cst/src/lib.rs index 700938dc..11407fcc 100644 --- a/codegen-sdk-cst/src/lib.rs +++ b/codegen-sdk-cst/src/lib.rs @@ -10,7 +10,7 @@ use dashmap::{DashMap, mapref::entry::Entry}; mod database; use codegen_sdk_common::{ParseError, traits::CSTNode}; pub use database::CSTDatabase; -pub use input::Input; +pub use input::File; mod language; pub use codegen_sdk_common::language::LANGUAGES; pub use language::CSTLanguage; diff --git a/codegen-sdk-macros/src/lib.rs b/codegen-sdk-macros/src/lib.rs index 916cbae6..57d9ccba 100644 --- a/codegen-sdk-macros/src/lib.rs +++ b/codegen-sdk-macros/src/lib.rs @@ -101,7 +101,7 @@ pub fn parse_language(_item: TokenStream) -> TokenStream { let variant: proc_macro2::TokenStream = quote! { #[cfg(feature = #name)] if #package_name::cst::#struct_name::should_parse(&file.path(db)).unwrap_or(false) { - let parsed = #package_name::ast::parse_query(db, file).clone(); + let parsed = #package_name::ast::parse(db, file).clone(); return Parsed::new( db, FileNodeId::new(db, file.path(db)), diff --git a/codegen-sdk-resolution/Cargo.toml b/codegen-sdk-resolution/Cargo.toml index 118b2a6a..de745b95 100644 --- a/codegen-sdk-resolution/Cargo.toml +++ b/codegen-sdk-resolution/Cargo.toml @@ -12,4 +12,5 @@ codegen-sdk-common = { workspace = true } anyhow = { workspace = true } indicatif = { workspace = true } ambassador = { workspace = true } +codegen-sdk-cst = { workspace = true } indexmap = "2" diff --git a/codegen-sdk-resolution/src/codebase.rs b/codegen-sdk-resolution/src/codebase.rs index 65678ec9..5cd112f0 100644 --- a/codegen-sdk-resolution/src/codebase.rs +++ b/codegen-sdk-resolution/src/codebase.rs @@ -15,10 +15,10 @@ pub trait CodebaseContext { fn get_file_for_id<'a>(&'a self, id: FileNodeId) -> Option<&'a Self::File<'a>> { self.get_file(id.path(self.db())) } - fn get_raw_file_for_id<'a>(&'a self, id: FileNodeId) -> Option { + fn get_raw_file_for_id<'a>(&'a self, id: FileNodeId) -> Option { self.get_raw_file(id.path(self.db())) } - fn get_raw_file<'a>(&'a self, path: PathBuf) -> Option { + fn get_raw_file<'a>(&'a self, path: PathBuf) -> Option { if let Ok(path) = path.canonicalize() { self.db().get_file(path) } else { diff --git a/codegen-sdk-resolution/src/database.rs b/codegen-sdk-resolution/src/database.rs index 31a72a21..7e4020ea 100644 --- a/codegen-sdk-resolution/src/database.rs +++ b/codegen-sdk-resolution/src/database.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use codegen_sdk_ast::input::File; +use codegen_sdk_cst::File; use indicatif::MultiProgress; #[salsa::db] pub trait Db: salsa::Database + Send { @@ -8,5 +8,9 @@ pub trait Db: salsa::Database + Send { fn get_file(&self, path: PathBuf) -> Option; fn multi_progress(&self) -> &MultiProgress; fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()>; - fn files(&self) -> Vec; + fn files(&self) -> indexmap::IndexSet; +} +#[salsa::tracked] +pub fn files(db: &dyn Db) -> indexmap::IndexSet { + db.files() } diff --git a/codegen-sdk-resolution/src/lib.rs b/codegen-sdk-resolution/src/lib.rs index 0de8a6f8..7862928e 100644 --- a/codegen-sdk-resolution/src/lib.rs +++ b/codegen-sdk-resolution/src/lib.rs @@ -11,12 +11,16 @@ mod codebase; pub use codebase::CodebaseContext; mod database; mod parse; -pub use database::Db; +pub use database::{Db, files}; pub use parse::Parse; +pub use scope::Dependencies; +mod name; +pub use name::{FullyQualifiedName, HasId}; #[delegatable_trait] pub trait HasFile<'db> { type File<'db1>; fn file(&self, db: &'db dyn Db) -> &'db Self::File<'db>; fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf; } + pub use indexmap; diff --git a/codegen-sdk-resolution/src/name.rs b/codegen-sdk-resolution/src/name.rs new file mode 100644 index 00000000..23557b2a --- /dev/null +++ b/codegen-sdk-resolution/src/name.rs @@ -0,0 +1,14 @@ +use codegen_sdk_common::FileNodeId; + +use crate::Db; +#[salsa::interned] +pub struct FullyQualifiedName<'db> { + #[id] + path: FileNodeId<'db>, + #[return_ref] + name: String, +} + +pub trait HasId<'db> { + fn fully_qualified_name(&self, db: &'db dyn Db) -> FullyQualifiedName<'db>; +} diff --git a/codegen-sdk-resolution/src/parse.rs b/codegen-sdk-resolution/src/parse.rs index 3d973766..d9614c4a 100644 --- a/codegen-sdk-resolution/src/parse.rs +++ b/codegen-sdk-resolution/src/parse.rs @@ -1,5 +1,5 @@ use salsa::Database; pub trait Parse<'db> { - fn parse(db: &'db dyn Database, input: codegen_sdk_ast::input::File) -> &'db Self; + fn parse(db: &'db dyn Database, input: codegen_sdk_cst::File) -> &'db Self; } diff --git a/codegen-sdk-resolution/src/references.rs b/codegen-sdk-resolution/src/references.rs index 20f68608..4437df41 100644 --- a/codegen-sdk-resolution/src/references.rs +++ b/codegen-sdk-resolution/src/references.rs @@ -1,32 +1,34 @@ use std::hash::Hash; -use crate::{Db, HasFile, Parse, ResolveType}; +use crate::{Db, Dependencies, FullyQualifiedName, HasFile, HasId, Parse, ResolveType, files}; pub trait References< 'db, + Dep: Dependencies<'db, FullyQualifiedName<'db>, ReferenceType> + 'db, ReferenceType: ResolveType<'db, Type = Self> + Eq + Hash + Clone + 'db, // References must resolve to this type - Scope: crate::Scope<'db, Type = Self, ReferenceType = ReferenceType> + Clone + 'db, ->: Eq + PartialEq + Hash + HasFile<'db, File<'db> = Scope> + 'db where Self:'db + Scope: crate::Scope<'db, Type = Self, ReferenceType = ReferenceType, Dependencies = Dep> + + Clone + 'db, +>: Eq + PartialEq + Hash + HasFile<'db, File<'db> = Scope> + HasId<'db> + Sized + 'db where Self:'db { fn references(&self, db: &'db dyn Db) -> Vec where Self: Sized, Scope: Parse<'db>, { - let files = db.files(); - log::info!("Finding references across {:?} files", files.len()); + let files = files(db); + log::info!(target: "resolution", "Finding references across {:?} files", files.len()); let mut results = Vec::new(); for input in files { // if !self.filter(db, &input) { // continue; // } - let file = Scope::parse(db, input); + let file = Scope::parse(db, input.clone()); let dependencies = file.clone().compute_dependencies_query(db); - if let Some(references) = dependencies.get(self) { + if let Some(references) = dependencies.get(db, &self.fully_qualified_name(db)) { results.extend(references.iter().cloned()); } } results } - fn filter(&self, db: &'db dyn Db, input: &codegen_sdk_ast::input::File) -> bool; + fn filter(&self, db: &'db dyn Db, input: &codegen_sdk_cst::File) -> bool; } diff --git a/codegen-sdk-resolution/src/scope.rs b/codegen-sdk-resolution/src/scope.rs index 6174cb3f..98a1862a 100644 --- a/codegen-sdk-resolution/src/scope.rs +++ b/codegen-sdk-resolution/src/scope.rs @@ -1,31 +1,36 @@ use std::hash::Hash; -use indexmap::IndexMap; +use indexmap::{IndexMap, IndexSet}; -use crate::{Db, ResolveType}; +use crate::{Db, FullyQualifiedName, HasId, ResolveType}; +pub trait Dependencies<'db, Type, ReferenceType>: Eq + Hash + Clone { + fn get(&'db self, db: &'db dyn Db, key: &Type) -> Option<&'db IndexSet>; +} // Resolve a given string name in a scope to a given type pub trait Scope<'db>: Sized { - type Type: Eq + Hash + Clone; + type Type: Eq + Hash + Clone + HasId<'db>; + type Dependencies: Dependencies<'db, FullyQualifiedName<'db>, Self::ReferenceType>; type ReferenceType: ResolveType<'db, Type = Self::Type> + Eq + Hash + Clone; fn resolve(self, db: &'db dyn Db, name: String) -> &'db Vec; /// Get all the resolvables (IE: function_calls) in the scope fn resolvables(self, db: &'db dyn Db) -> Vec; - fn compute_dependencies_query( + fn compute_dependencies_query(self, db: &'db dyn Db) -> &'db Self::Dependencies; + fn compute_dependencies( self, db: &'db dyn Db, - ) -> &'db IndexMap>; - fn compute_dependencies(self, db: &'db dyn Db) -> IndexMap> + ) -> IndexMap, IndexSet> where Self: 'db, { - let mut dependencies: IndexMap> = IndexMap::new(); + let mut dependencies: IndexMap, IndexSet> = + IndexMap::new(); for reference in self.resolvables(db) { let resolved = reference.clone().resolve_type(db); for resolved in resolved { dependencies - .entry(resolved.clone()) + .entry(resolved.fully_qualified_name(db)) .or_default() - .push(reference.clone()); + .insert(reference.clone()); } } dependencies diff --git a/languages/codegen-sdk-go/Cargo.toml b/languages/codegen-sdk-go/Cargo.toml index 0d8c919c..611be6db 100644 --- a/languages/codegen-sdk-go/Cargo.toml +++ b/languages/codegen-sdk-go/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-java/Cargo.toml b/languages/codegen-sdk-java/Cargo.toml index b49e8677..8a041d3d 100644 --- a/languages/codegen-sdk-java/Cargo.toml +++ b/languages/codegen-sdk-java/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-jsx/Cargo.toml b/languages/codegen-sdk-jsx/Cargo.toml index bd203ae2..5422b2d0 100644 --- a/languages/codegen-sdk-jsx/Cargo.toml +++ b/languages/codegen-sdk-jsx/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-markdown/Cargo.toml b/languages/codegen-sdk-markdown/Cargo.toml index 420a66c9..3f5246d0 100644 --- a/languages/codegen-sdk-markdown/Cargo.toml +++ b/languages/codegen-sdk-markdown/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index 777309a4..8f08073a 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -20,8 +20,44 @@ pub mod ast { } } #[salsa::tracked] + pub struct PythonDependencies<'db> { + #[id] + id: codegen_sdk_common::FileNodeId<'db>, + #[return_ref] + #[tracked] + dependencies: codegen_sdk_resolution::indexmap::IndexMap< + codegen_sdk_resolution::FullyQualifiedName<'db>, + codegen_sdk_resolution::indexmap::IndexSet>, + >, + } + impl<'db> + codegen_sdk_resolution::Dependencies< + 'db, + codegen_sdk_resolution::FullyQualifiedName<'db>, + crate::ast::Call<'db>, + > for PythonDependencies<'db> + { + fn get( + &'db self, + db: &'db dyn codegen_sdk_resolution::Db, + key: &codegen_sdk_resolution::FullyQualifiedName<'db>, + ) -> Option<&'db codegen_sdk_resolution::indexmap::IndexSet>> + { + self.dependencies(db).get(key) + } + } + #[salsa::tracked(return_ref)] + pub fn dependencies( + db: &dyn codegen_sdk_resolution::Db, + input: codegen_sdk_cst::File, + ) -> PythonDependencies<'_> { + let file = parse(db, input); + PythonDependencies::new(db, file.id(db), file.compute_dependencies(db)) + } + #[salsa::tracked] impl<'db> Scope<'db> for PythonFile<'db> { type Type = crate::ast::Symbol<'db>; + type Dependencies = PythonDependencies<'db>; type ReferenceType = crate::ast::Call<'db>; #[salsa::tracked(return_ref)] fn resolve(self, db: &'db dyn codegen_sdk_resolution::Db, name: String) -> Vec { @@ -56,13 +92,38 @@ pub mod ast { } results } + fn compute_dependencies( + self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> codegen_sdk_resolution::indexmap::IndexMap< + codegen_sdk_resolution::FullyQualifiedName<'db>, + codegen_sdk_resolution::indexmap::IndexSet, + > + where + Self: 'db, + { + let mut dependencies: codegen_sdk_resolution::indexmap::IndexMap< + codegen_sdk_resolution::FullyQualifiedName<'db>, + codegen_sdk_resolution::indexmap::IndexSet, + > = codegen_sdk_resolution::indexmap::IndexMap::new(); + for reference in self.resolvables(db) { + let resolved = reference.clone().resolve_type(db); + for resolved in resolved { + dependencies + .entry(resolved.fully_qualified_name(db)) + .or_default() + .insert(reference.clone()); + } + } + dependencies + } + #[salsa::tracked(return_ref)] fn compute_dependencies_query( self, db: &'db dyn codegen_sdk_resolution::Db, - ) -> codegen_sdk_resolution::indexmap::IndexMap> - { - self.compute_dependencies(db) + ) -> PythonDependencies<'db> { + PythonDependencies::new(db, self.id(db), self.compute_dependencies(db)) } } #[salsa::tracked] @@ -93,18 +154,40 @@ pub mod ast { .clone() } } + use codegen_sdk_resolution::{Db, Dependencies, HasId}; #[salsa::tracked] - impl<'db> codegen_sdk_resolution::References<'db, crate::ast::Call<'db>, PythonFile<'db>> - for crate::ast::Symbol<'db> + impl<'db> + codegen_sdk_resolution::References< + 'db, + PythonDependencies<'db>, + crate::ast::Call<'db>, + PythonFile<'db>, + > for crate::ast::Symbol<'db> { + fn references(&self, db: &'db dyn Db) -> Vec> { + let files = codegen_sdk_resolution::files(db); + log::info!(target: "resolution", "Finding references across {:?} files", files.len()); + let mut results = Vec::new(); + let name = self.fully_qualified_name(db); + for input in files { + // if !self.filter(db, &input) { + // continue; + // } + let dependencies = dependencies(db, input.clone()); + if let Some(references) = dependencies.get(db, &name) { + results.extend(references.iter().cloned()); + } + } + results + } fn filter( &self, db: &'db dyn codegen_sdk_resolution::Db, - input: &codegen_sdk_ast::input::File, + input: &codegen_sdk_cst::File, ) -> bool { match self { crate::ast::Symbol::Function(function) => { - let content = input.contents(db).content(db); + let content = input.content(db); let target = function.name(db).text(); memchr::memmem::find(&content.as_bytes(), &target).is_some() } diff --git a/languages/codegen-sdk-python/tests/test_python.rs b/languages/codegen-sdk-python/tests/test_python.rs index 77dc9b22..69fe0ed2 100644 --- a/languages/codegen-sdk-python/tests/test_python.rs +++ b/languages/codegen-sdk-python/tests/test_python.rs @@ -40,9 +40,9 @@ class Test: pass"; let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); - let file = codegen_sdk_python::ast::parse_query(&db, input); + let root_path = temp_dir.path().to_path_buf(); + let input = codegen_sdk_cst::File::new(&db, file_path, content, root_path); + let file = codegen_sdk_python::ast::parse(&db, input); assert_eq!(file.definitions(&db).classes(&db).len(), 1); } #[test_log::test] @@ -54,7 +54,7 @@ def test(): let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); + let input = codegen_sdk_cst::File::new(&db, file_path, content); let file = codegen_sdk_python::ast::parse_query(&db, input); assert_eq!(file.definitions(&db).functions(&db).len(), 1); } @@ -69,7 +69,7 @@ test()"; let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); + let input = codegen_sdk_cst::File::new(&db, file_path, content); let file = codegen_sdk_python::ast::parse_query(&db, input); assert_eq!(file.references(&db).calls(&db).len(), 1); let definitions = file.definitions(&db); @@ -78,7 +78,7 @@ test()"; let function = codegen_sdk_python::ast::Symbol::Function(function.clone().clone()); assert_eq!( function - .references_for_scopes(&db, temp_dir.path().to_path_buf(), vec![*file], &file) + .references(&db, temp_dir.path().to_path_buf(), vec![*file], &file) .len(), 1 ); @@ -94,13 +94,13 @@ def test(): let usage_file_content = " from filea import test test()"; + let root_path = temp_dir.path().to_path_buf(); let file_path = write_to_temp_file_with_name(content, &temp_dir, "filea.py"); let usage_file_path = write_to_temp_file_with_name(usage_file_content, &temp_dir, "fileb.py"); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let usage_content = codegen_sdk_cst::Input::new(&db, usage_file_content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); - let usage_input = codegen_sdk_ast::input::File::new(&db, usage_file_path, usage_content); + let input = codegen_sdk_cst::File::new(&db, file_path, content, root_path.clone()); + let usage_input = + codegen_sdk_cst::File::new(&db, usage_file_path, usage_file_content, root_path.clone()); let file = codegen_sdk_python::ast::parse_query(&db, input); let usage_file = codegen_sdk_python::ast::parse_query(&db, usage_input); assert_eq!(usage_file.references(&db).calls(&db).len(), 1); @@ -111,26 +111,6 @@ test()"; let imports = usage_file.definitions(&db).imports(&db); let import = imports.get("test").unwrap().first().unwrap(); let import = codegen_sdk_python::ast::Symbol::Import(import.clone().clone()); - assert_eq!( - import - .references_for_scopes( - &db, - temp_dir.path().to_path_buf(), - vec![*usage_file], - &usage_file - ) - .len(), - 1 - ); - assert_eq!( - function - .references_for_scopes( - &db, - temp_dir.path().to_path_buf(), - vec![*file, *usage_file], - &usage_file - ) - .len(), - 1 - ); + assert_eq!(import.references(&db,).len(), 1); + assert_eq!(function.references(&db,).len(), 1); } diff --git a/languages/codegen-sdk-ruby/Cargo.toml b/languages/codegen-sdk-ruby/Cargo.toml index bb558304..5f890541 100644 --- a/languages/codegen-sdk-ruby/Cargo.toml +++ b/languages/codegen-sdk-ruby/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-rust/Cargo.toml b/languages/codegen-sdk-rust/Cargo.toml index af96d5f0..f432491e 100644 --- a/languages/codegen-sdk-rust/Cargo.toml +++ b/languages/codegen-sdk-rust/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-toml/Cargo.toml b/languages/codegen-sdk-toml/Cargo.toml index 2fb48bae..8fe907ae 100644 --- a/languages/codegen-sdk-toml/Cargo.toml +++ b/languages/codegen-sdk-toml/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-tsx/Cargo.toml b/languages/codegen-sdk-tsx/Cargo.toml index d26ebc63..3dd20cdf 100644 --- a/languages/codegen-sdk-tsx/Cargo.toml +++ b/languages/codegen-sdk-tsx/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } diff --git a/languages/codegen-sdk-typescript/Cargo.toml b/languages/codegen-sdk-typescript/Cargo.toml index 56445c6c..17b34f73 100644 --- a/languages/codegen-sdk-typescript/Cargo.toml +++ b/languages/codegen-sdk-typescript/Cargo.toml @@ -15,6 +15,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } indextree = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } diff --git a/languages/codegen-sdk-typescript/tests/test_typescript.rs b/languages/codegen-sdk-typescript/tests/test_typescript.rs index c7a59201..8d7a54e8 100644 --- a/languages/codegen-sdk-typescript/tests/test_typescript.rs +++ b/languages/codegen-sdk-typescript/tests/test_typescript.rs @@ -27,12 +27,11 @@ fn write_to_temp_file(content: &str, temp_dir: &tempfile::TempDir) -> PathBuf { #[test_log::test] fn test_typescript_ast_interface() { let temp_dir = tempfile::tempdir().unwrap(); - let content = "interface Test { }"; + let content = "interface Test { }".to_string(); let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_ast::input::File::new(&db, file_path, content); - let file = codegen_sdk_typescript::ast::parse_query(&db, input); - let tree = file.node(&db).unwrap().tree(&db); - assert_eq!(file.definitions(&db).interfaces(&db, &tree).len(), 1); + let root_path = temp_dir.path().to_path_buf(); + let input = codegen_sdk_cst::File::new(&db, file_path, content, root_path); + let file = codegen_sdk_typescript::ast::parse(&db, input); + assert_eq!(file.definitions(&db).interfaces(&db).len(), 1); } diff --git a/languages/codegen-sdk-yaml/Cargo.toml b/languages/codegen-sdk-yaml/Cargo.toml index 2efebd0a..1b306a3c 100644 --- a/languages/codegen-sdk-yaml/Cargo.toml +++ b/languages/codegen-sdk-yaml/Cargo.toml @@ -16,6 +16,7 @@ bytes = { workspace = true } codegen-sdk-cst = { workspace = true } log = { workspace = true } codegen-sdk-ast = { workspace = true } +codegen-sdk-resolution = { workspace = true } [build-dependencies] codegen-sdk-cst-generator = { workspace = true } codegen-sdk-ast-generator = { workspace = true } From 6d95909c9f152b8b16aa79af0c757bfd38118c9e Mon Sep 17 00:00:00 2001 From: bagel897 Date: Fri, 7 Mar 2025 11:10:09 -0800 Subject: [PATCH 10/16] wip: perf --- Cargo.lock | 5 +++-- Cargo.toml | 5 ++++- codegen-sdk-analyzer/Cargo.toml | 1 - .../src/codebase/discovery.rs | 2 +- codegen-sdk-analyzer/src/codebase/parser.rs | 3 +-- codegen-sdk-analyzer/src/database.rs | 2 +- codegen-sdk-common/Cargo.toml | 2 ++ codegen-sdk-common/src/hash.rs | 12 ++++++++++++ codegen-sdk-common/src/lib.rs | 1 + codegen-sdk-resolution/Cargo.toml | 2 +- codegen-sdk-resolution/src/database.rs | 4 ++-- codegen-sdk-resolution/src/lib.rs | 2 -- codegen-sdk-resolution/src/scope.rs | 19 +++++++++++++------ languages/codegen-sdk-python/src/lib.rs | 17 ++++++++--------- 14 files changed, 49 insertions(+), 28 deletions(-) create mode 100644 codegen-sdk-common/src/hash.rs diff --git a/Cargo.lock b/Cargo.lock index 00c7cc09..169b7bb7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,7 +331,6 @@ dependencies = [ "dashmap", "env_logger", "glob", - "indexmap", "indicatif", "indicatif-log-bridge", "log", @@ -386,6 +385,7 @@ dependencies = [ "buildid", "bytes", "convert_case", + "indexmap", "indextree", "lazy_static", "mockall", @@ -394,6 +394,7 @@ dependencies = [ "proc-macro2", "quote", "rkyv", + "rustc-hash", "salsa", "serde", "serde_json", @@ -646,10 +647,10 @@ dependencies = [ "codegen-sdk-ast", "codegen-sdk-common", "codegen-sdk-cst", - "indexmap", "indicatif", "log", "salsa", + "smallvec", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index fc56769d..017b4bed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -126,6 +126,7 @@ rstest = "0.25.0" indextree = "4.7.3" thiserror = "2.0.11" indexmap = "2" +smallvec = "1.11.0" [profile.dev] # codegen-backend = "cranelift" # split-debuginfo = "unpacked" @@ -158,5 +159,7 @@ lto = false name = "parse" harness = false required-features = ["stable"] -[profile.release] + +[profile.profiling] +inherits = "release" debug = true diff --git a/codegen-sdk-analyzer/Cargo.toml b/codegen-sdk-analyzer/Cargo.toml index 086fd0cd..76f61720 100644 --- a/codegen-sdk-analyzer/Cargo.toml +++ b/codegen-sdk-analyzer/Cargo.toml @@ -34,7 +34,6 @@ crossbeam-channel = { workspace = true } glob = "0.3.2" rayon = { workspace = true } ambassador = { workspace = true } -indexmap = {workspace = true} [features] python = [ "codegen-sdk-python"] # TODO: Add python support typescript = [ "codegen-sdk-typescript"] diff --git a/codegen-sdk-analyzer/src/codebase/discovery.rs b/codegen-sdk-analyzer/src/codebase/discovery.rs index 1a29f4d4..97fa0e2d 100644 --- a/codegen-sdk-analyzer/src/codebase/discovery.rs +++ b/codegen-sdk-analyzer/src/codebase/discovery.rs @@ -9,7 +9,7 @@ use glob::glob; use crate::database::CodegenDatabase; #[salsa::input] pub struct FilesToParse { - pub files: indexmap::IndexSet, + pub files: codegen_sdk_common::hash::FxHashSet, pub root: PathBuf, } pub fn log_languages() { diff --git a/codegen-sdk-analyzer/src/codebase/parser.rs b/codegen-sdk-analyzer/src/codebase/parser.rs index ac9a47be..6a5abd9e 100644 --- a/codegen-sdk-analyzer/src/codebase/parser.rs +++ b/codegen-sdk-analyzer/src/codebase/parser.rs @@ -3,14 +3,13 @@ use codegen_sdk_ast::{Definitions, References}; use codegen_sdk_common::serialize::Cache; use codegen_sdk_cst::File; use codegen_sdk_resolution::{Db, Scope}; -use indexmap::IndexSet; use indicatif::{ProgressBar, ProgressStyle}; use super::discovery::{FilesToParse, log_languages}; use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; pub fn execute_op_with_progress( db: &Database, - files: IndexSet, + files: codegen_sdk_common::hash::FxHashSet, name: &str, op: fn(&Database, File) -> T, ) -> Vec { diff --git a/codegen-sdk-analyzer/src/database.rs b/codegen-sdk-analyzer/src/database.rs index 99a2df18..da38d9e2 100644 --- a/codegen-sdk-analyzer/src/database.rs +++ b/codegen-sdk-analyzer/src/database.rs @@ -72,7 +72,7 @@ impl salsa::Database for CodegenDatabase { } #[salsa::db] impl Db for CodegenDatabase { - fn files(&self) -> indexmap::IndexSet { + fn files(&self) -> codegen_sdk_common::hash::FxHashSet { self.files .iter() .map(|entry| entry.value().clone()) diff --git a/codegen-sdk-common/Cargo.toml b/codegen-sdk-common/Cargo.toml index a1b418b4..1347b552 100644 --- a/codegen-sdk-common/Cargo.toml +++ b/codegen-sdk-common/Cargo.toml @@ -39,6 +39,8 @@ syn = { workspace = true } prettyplease = { workspace = true } salsa = { workspace = true } indextree = { workspace = true } +indexmap = { workspace = true } +rustc-hash = "2.1.1" [dev-dependencies] test-log = { workspace = true } [features] diff --git a/codegen-sdk-common/src/hash.rs b/codegen-sdk-common/src/hash.rs new file mode 100644 index 00000000..dd928145 --- /dev/null +++ b/codegen-sdk-common/src/hash.rs @@ -0,0 +1,12 @@ +// Taken from https://github.com/salsa-rs/salsa/blob/9d2a9786c45000f5fa396ad2872391e302a2836a/src/hash.rs#L1 +use std::hash::{BuildHasher, Hash}; + +pub type FxHasher = std::hash::BuildHasherDefault; +pub type FxIndexSet = indexmap::IndexSet; +pub type FxIndexMap = indexmap::IndexMap; +// pub type FxDashMap = dashmap::DashMap; +// pub type FxLinkedHashSet = hashlink::LinkedHashSet; +pub type FxHashSet = std::collections::HashSet; +pub fn hash(t: &T) -> u64 { + FxHasher::default().hash_one(t) +} diff --git a/codegen-sdk-common/src/lib.rs b/codegen-sdk-common/src/lib.rs index 67d9a919..594d23d5 100644 --- a/codegen-sdk-common/src/lib.rs +++ b/codegen-sdk-common/src/lib.rs @@ -1,6 +1,7 @@ #![feature(error_generic_member_access)] #![feature(trivial_bounds)] mod errors; +pub mod hash; pub mod language; pub mod traits; pub mod utils; diff --git a/codegen-sdk-resolution/Cargo.toml b/codegen-sdk-resolution/Cargo.toml index de745b95..130b35de 100644 --- a/codegen-sdk-resolution/Cargo.toml +++ b/codegen-sdk-resolution/Cargo.toml @@ -13,4 +13,4 @@ anyhow = { workspace = true } indicatif = { workspace = true } ambassador = { workspace = true } codegen-sdk-cst = { workspace = true } -indexmap = "2" +smallvec = { workspace = true } diff --git a/codegen-sdk-resolution/src/database.rs b/codegen-sdk-resolution/src/database.rs index 7e4020ea..fe07b272 100644 --- a/codegen-sdk-resolution/src/database.rs +++ b/codegen-sdk-resolution/src/database.rs @@ -8,9 +8,9 @@ pub trait Db: salsa::Database + Send { fn get_file(&self, path: PathBuf) -> Option; fn multi_progress(&self) -> &MultiProgress; fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()>; - fn files(&self) -> indexmap::IndexSet; + fn files(&self) -> codegen_sdk_common::hash::FxHashSet; } #[salsa::tracked] -pub fn files(db: &dyn Db) -> indexmap::IndexSet { +pub fn files(db: &dyn Db) -> codegen_sdk_common::hash::FxHashSet { db.files() } diff --git a/codegen-sdk-resolution/src/lib.rs b/codegen-sdk-resolution/src/lib.rs index 7862928e..3bb34662 100644 --- a/codegen-sdk-resolution/src/lib.rs +++ b/codegen-sdk-resolution/src/lib.rs @@ -22,5 +22,3 @@ pub trait HasFile<'db> { fn file(&self, db: &'db dyn Db) -> &'db Self::File<'db>; fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf; } - -pub use indexmap; diff --git a/codegen-sdk-resolution/src/scope.rs b/codegen-sdk-resolution/src/scope.rs index 98a1862a..aed6899e 100644 --- a/codegen-sdk-resolution/src/scope.rs +++ b/codegen-sdk-resolution/src/scope.rs @@ -1,10 +1,12 @@ use std::hash::Hash; -use indexmap::{IndexMap, IndexSet}; - use crate::{Db, FullyQualifiedName, HasId, ResolveType}; pub trait Dependencies<'db, Type, ReferenceType>: Eq + Hash + Clone { - fn get(&'db self, db: &'db dyn Db, key: &Type) -> Option<&'db IndexSet>; + fn get( + &'db self, + db: &'db dyn Db, + key: &Type, + ) -> Option<&'db codegen_sdk_common::hash::FxIndexSet>; } // Resolve a given string name in a scope to a given type pub trait Scope<'db>: Sized { @@ -18,12 +20,17 @@ pub trait Scope<'db>: Sized { fn compute_dependencies( self, db: &'db dyn Db, - ) -> IndexMap, IndexSet> + ) -> codegen_sdk_common::hash::FxIndexMap< + FullyQualifiedName<'db>, + codegen_sdk_common::hash::FxIndexSet, + > where Self: 'db, { - let mut dependencies: IndexMap, IndexSet> = - IndexMap::new(); + let mut dependencies: codegen_sdk_common::hash::FxIndexMap< + FullyQualifiedName<'db>, + codegen_sdk_common::hash::FxIndexSet, + > = codegen_sdk_common::hash::FxIndexMap::default(); for reference in self.resolvables(db) { let resolved = reference.clone().resolve_type(db); for resolved in resolved { diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index 8f08073a..27717653 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -25,9 +25,9 @@ pub mod ast { id: codegen_sdk_common::FileNodeId<'db>, #[return_ref] #[tracked] - dependencies: codegen_sdk_resolution::indexmap::IndexMap< + dependencies: codegen_sdk_common::hash::FxIndexMap< codegen_sdk_resolution::FullyQualifiedName<'db>, - codegen_sdk_resolution::indexmap::IndexSet>, + codegen_sdk_common::hash::FxIndexSet>, >, } impl<'db> @@ -41,8 +41,7 @@ pub mod ast { &'db self, db: &'db dyn codegen_sdk_resolution::Db, key: &codegen_sdk_resolution::FullyQualifiedName<'db>, - ) -> Option<&'db codegen_sdk_resolution::indexmap::IndexSet>> - { + ) -> Option<&'db codegen_sdk_common::hash::FxIndexSet>> { self.dependencies(db).get(key) } } @@ -95,17 +94,17 @@ pub mod ast { fn compute_dependencies( self, db: &'db dyn codegen_sdk_resolution::Db, - ) -> codegen_sdk_resolution::indexmap::IndexMap< + ) -> codegen_sdk_common::hash::FxIndexMap< codegen_sdk_resolution::FullyQualifiedName<'db>, - codegen_sdk_resolution::indexmap::IndexSet, + codegen_sdk_common::hash::FxIndexSet, > where Self: 'db, { - let mut dependencies: codegen_sdk_resolution::indexmap::IndexMap< + let mut dependencies: codegen_sdk_common::hash::FxIndexMap< codegen_sdk_resolution::FullyQualifiedName<'db>, - codegen_sdk_resolution::indexmap::IndexSet, - > = codegen_sdk_resolution::indexmap::IndexMap::new(); + codegen_sdk_common::hash::FxIndexSet, + > = codegen_sdk_common::hash::FxIndexMap::default(); for reference in self.resolvables(db) { let resolved = reference.clone().resolve_type(db); for resolved in resolved { From e1af66464f6412a544e2ba1b45b0747b70930fbe Mon Sep 17 00:00:00 2001 From: bagel897 Date: Fri, 7 Mar 2025 15:37:58 -0800 Subject: [PATCH 11/16] Fast references impl --- Cargo.lock | 2 + Cargo.toml | 2 +- codegen-sdk-analyzer/src/codebase.rs | 9 +- codegen-sdk-analyzer/src/codebase/parser.rs | 104 ++++++++++++--- codegen-sdk-analyzer/src/database.rs | 8 +- codegen-sdk-analyzer/src/parser.rs | 7 +- codegen-sdk-ast-generator/src/generator.rs | 5 +- codegen-sdk-ast-generator/src/query.rs | 4 +- codegen-sdk-ast-generator/src/visitor.rs | 26 ++-- codegen-sdk-common/Cargo.toml | 1 + codegen-sdk-common/src/hash.rs | 3 +- codegen-sdk-cst-generator/src/generator.rs | 1 + codegen-sdk-resolution/src/database.rs | 9 +- codegen-sdk-resolution/src/lib.rs | 2 +- codegen-sdk-resolution/src/name.rs | 6 +- codegen-sdk-resolution/src/parse.rs | 4 +- codegen-sdk-resolution/src/references.rs | 36 ++--- codegen-sdk-resolution/src/scope.rs | 6 +- languages/codegen-sdk-python/src/lib.rs | 140 ++++++++++++++++---- src/main.rs | 96 +++++++------- 20 files changed, 324 insertions(+), 147 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 169b7bb7..ec378de8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -385,6 +385,7 @@ dependencies = [ "buildid", "bytes", "convert_case", + "hashbrown 0.15.2", "indexmap", "indextree", "lazy_static", @@ -1228,6 +1229,7 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash", + "rayon", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 017b4bed..4f0e5724 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ codegen-sdk-resolution = { workspace = true} sysinfo = "0.33.1" rkyv.workspace = true [features] -python = [ "codegen-sdk-analyzer/python", "codegen-sdk-python"] # TODO: Add python support +python = [ "codegen-sdk-analyzer/python", "codegen-sdk-python"] typescript = [ "codegen-sdk-analyzer/typescript", "codegen-sdk-typescript"] tsx = [ "codegen-sdk-analyzer/tsx"] jsx = [ "codegen-sdk-analyzer/jsx"] diff --git a/codegen-sdk-analyzer/src/codebase.rs b/codegen-sdk-analyzer/src/codebase.rs index 07ae9a52..540c20c9 100644 --- a/codegen-sdk-analyzer/src/codebase.rs +++ b/codegen-sdk-analyzer/src/codebase.rs @@ -88,12 +88,14 @@ impl Codebase { pub fn execute_op_with_progress( &self, name: &str, - op: fn(&dyn Db, File) -> T, + parallel: bool, + op: fn(&dyn Db, codegen_sdk_common::FileNodeId<'_>) -> T, ) -> Vec { execute_op_with_progress( self._db(), codegen_sdk_resolution::files(self._db()), name, + parallel, op, ) } @@ -118,8 +120,9 @@ impl CodebaseContext for Codebase { fn get_file<'a>(&'a self, path: PathBuf) -> Option<&'a Self::File<'a>> { if let Ok(path) = path.canonicalize() { let file = self.db.files.get(&path); - if let Some(file) = file { - return parse_file(&self.db, file.clone()).file(&self.db).as_ref(); + if let Some(_) = file { + let file_id = codegen_sdk_common::FileNodeId::new(&self.db, path); + return parse_file(&self.db, file_id).file(&self.db).as_ref(); } } None diff --git a/codegen-sdk-analyzer/src/codebase/parser.rs b/codegen-sdk-analyzer/src/codebase/parser.rs index 6a5abd9e..0754902a 100644 --- a/codegen-sdk-analyzer/src/codebase/parser.rs +++ b/codegen-sdk-analyzer/src/codebase/parser.rs @@ -1,17 +1,23 @@ +use std::path::PathBuf; + use codegen_sdk_ast::{Definitions, References}; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; -use codegen_sdk_cst::File; use codegen_sdk_resolution::{Db, Scope}; use indicatif::{ProgressBar, ProgressStyle}; use super::discovery::{FilesToParse, log_languages}; use crate::{ParsedFile, database::CodegenDatabase, parser::parse_file}; -pub fn execute_op_with_progress( +pub fn execute_op_with_progress< + Database: Db + ?Sized + 'static, + Input: Send + Sync, + T: Send + Sync, +>( db: &Database, - files: codegen_sdk_common::hash::FxHashSet, + files: codegen_sdk_common::hash::FxHashSet, name: &str, - op: fn(&Database, File) -> T, + parallel: bool, + op: fn(&Database, Input) -> T, ) -> Vec { let multi = db.multi_progress(); let style = ProgressStyle::with_template( @@ -27,19 +33,35 @@ pub fn execute_op_with_progress .into_iter() .map(|file| (&pg, file, op)) .collect::>(); - let results: Vec = salsa::par_map(db, inputs, move |db, input| { - let (pg, file, op) = input; - let res = op( - db, - #[cfg(feature = "serialization")] - &cache, - file, - ); - pg.inc(1); - res - }); + let results: Vec<_> = if parallel { + salsa::par_map(db, inputs, move |db, input| { + let (pg, file, op) = input; + let res = op( + db, + #[cfg(feature = "serialization")] + &cache, + file, + ); + pg.inc(1); + res + }) + } else { + inputs + .into_iter() + .map(|input| { + let (pg, file, op) = input; + let res = op( + db, + #[cfg(feature = "serialization")] + &cache, + file, + ); + pg.inc(1); + res + }) + .collect() + }; pg.finish(); - multi.remove(&pg); results } // #[salsa::tracked] @@ -50,7 +72,12 @@ pub fn execute_op_with_progress // } #[salsa::tracked] fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { - let _: Vec<_> = execute_op_with_progress(db, files.files(db), "Parsing Files", |db, input| { + let ids = files + .files(db) + .iter() + .map(|input| codegen_sdk_common::FileNodeId::new(db, input.path(db))) + .collect::>(); + let _: Vec<_> = execute_op_with_progress(db, ids, "Parsing Files", true, |db, input| { let file = parse_file(db, input.clone()); if let Some(parsed) = file.file(db) { #[cfg(feature = "typescript")] @@ -62,12 +89,47 @@ fn parse_files_definitions_par(db: &dyn Db, files: FilesToParse) { if let ParsedFile::Python(parsed) = parsed { parsed.definitions(db); parsed.references(db); - codegen_sdk_python::ast::dependencies(db, input); + // let deps = codegen_sdk_python::ast::dependencies(db, input); + // for dep in deps.dependencies(db).keys() { + // codegen_sdk_resolution::ast::references_impl(db, dep); + // } } } () }); } +#[salsa::tracked] +fn compute_dependencies_par(db: &dyn Db, files: FilesToParse) { + let ids = files + .files(db) + .iter() + .map(|input| codegen_sdk_common::FileNodeId::new(db, input.path(db))) + .collect::>(); + let targets: codegen_sdk_common::hash::FxHashSet<(PathBuf, String)> = + execute_op_with_progress(db, ids, "Computing Dependencies", true, |db, input| { + let file = parse_file(db, input.clone()); + if let Some(parsed) = file.file(db) { + #[cfg(feature = "python")] + if let ParsedFile::Python(parsed) = parsed { + let deps = codegen_sdk_python::ast::dependency_keys(db, input); + return deps + .iter() + .map(|dep| (dep.path(db).path(db).clone(), dep.name(db).clone())) + .collect::>(); + } + } + Vec::new() + }) + .into_iter() + .flatten() + .collect(); + // let _: Vec<_> = execute_op_with_progress(db, targets, "Finding Usages", true, |db, input: (PathBuf, String)| { + // let file_node_id = codegen_sdk_common::FileNodeId::new(db, input.0); + // let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new(db, file_node_id, input.1); + // codegen_sdk_python::ast::references_impl(db, fully_qualified_name); + // }); +} + pub fn parse_files<'db>( db: &'db CodegenDatabase, #[cfg(feature = "serialization")] cache: &'db Cache, @@ -89,6 +151,12 @@ pub fn parse_files<'db>( &cache, files_to_parse, ); + compute_dependencies_par( + db, + #[cfg(feature = "serialization")] + &cache, + files_to_parse, + ); #[cfg(feature = "serialization")] report_cached_count(cached, &files_to_parse.files(db)); } diff --git a/codegen-sdk-analyzer/src/database.rs b/codegen-sdk-analyzer/src/database.rs index da38d9e2..7d004e11 100644 --- a/codegen-sdk-analyzer/src/database.rs +++ b/codegen-sdk-analyzer/src/database.rs @@ -20,7 +20,7 @@ use crate::progress::get_multi_progress; // Basic Database implementation for Query generation. This is not used for anything else. pub struct CodegenDatabase { storage: salsa::Storage, - pub files: DashMap, + pub files: Arc>, dirs: Vec, multi_progress: MultiProgress, file_watcher: Arc>>, @@ -41,7 +41,7 @@ impl CodegenDatabase { file_watcher: get_watcher(tx), storage: salsa::Storage::default(), multi_progress, - files: DashMap::new(), + files: Arc::new(DashMap::new()), dirs: Vec::new(), root, } @@ -72,10 +72,10 @@ impl salsa::Database for CodegenDatabase { } #[salsa::db] impl Db for CodegenDatabase { - fn files(&self) -> codegen_sdk_common::hash::FxHashSet { + fn files(&self) -> codegen_sdk_common::hash::FxHashSet> { self.files .iter() - .map(|entry| entry.value().clone()) + .map(|entry| codegen_sdk_common::FileNodeId::new(self, entry.key().clone())) .collect() } fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()> { diff --git a/codegen-sdk-analyzer/src/parser.rs b/codegen-sdk-analyzer/src/parser.rs index 80177f77..caf91a29 100644 --- a/codegen-sdk-analyzer/src/parser.rs +++ b/codegen-sdk-analyzer/src/parser.rs @@ -14,7 +14,10 @@ pub struct Parsed<'db> { pub file: Option>, } #[salsa::tracked(return_ref)] -pub fn parse_file(db: &dyn salsa::Database, file: codegen_sdk_cst::File) -> Parsed<'_> { +pub fn parse_file<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + file: codegen_sdk_common::FileNodeId<'db>, +) -> Parsed<'db> { parse_language!(); - Parsed::new(db, FileNodeId::new(db, file.path(db)), None) + Parsed::new(db, file, None) } diff --git a/codegen-sdk-ast-generator/src/generator.rs b/codegen-sdk-ast-generator/src/generator.rs index 40d61033..f105d0fe 100644 --- a/codegen-sdk-ast-generator/src/generator.rs +++ b/codegen-sdk-ast-generator/src/generator.rs @@ -72,7 +72,7 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { pub id: codegen_sdk_common::FileNodeId<'db>, } impl<'db> codegen_sdk_resolution::Parse<'db> for #language_struct_name<'db> { - fn parse(db: &'db dyn salsa::Database, input: codegen_sdk_cst::File) -> &'db Self { + fn parse(db: &'db dyn codegen_sdk_resolution::Db, input: codegen_sdk_common::FileNodeId<'db>) -> &'db Self { parse(db, input) } } @@ -82,7 +82,8 @@ pub fn generate_ast(language: &Language) -> anyhow::Result { // }} // }} #[salsa::tracked(return_ref)] - pub fn parse(db: &dyn salsa::Database, input: codegen_sdk_cst::File) -> #language_struct_name<'_> { + pub fn parse<'db>(db: &'db dyn codegen_sdk_resolution::Db, input: codegen_sdk_common::FileNodeId<'db>) -> #language_struct_name<'db> { + let input = db.input(input.path(db)).unwrap(); log::debug!("Parsing {} file: {}", input.path(db).display(), #language_name_str); let ast = crate::cst::parse_program_raw(db, input); let file_id = codegen_sdk_common::FileNodeId::new(db, input.path(db).clone()); diff --git a/codegen-sdk-ast-generator/src/query.rs b/codegen-sdk-ast-generator/src/query.rs index 883f00cd..b1a70b88 100644 --- a/codegen-sdk-ast-generator/src/query.rs +++ b/codegen-sdk-ast-generator/src/query.rs @@ -662,7 +662,7 @@ impl<'a> Query<'a> { let symbol_name = self.symbol_name(); return quote! { let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new(db, node.file_id(),#name.source()); - let symbol = #symbol_name::new(db, fully_qualified_name, id, node.clone(), #(#args.clone().into()),*); + let symbol = #symbol_name::new(db, fully_qualified_name, id, #(#args.clone().into()),*); #to_append.entry(#name.source()).or_default().push(symbol); }; } @@ -913,6 +913,8 @@ impl<'a> Query<'a> { } let name_ident = format_ident!("{}", name); fields.push(parse_quote!( + #[tracked] + #[return_ref] pub #name_ident: crate::cst::#type_name<'db> )); } diff --git a/codegen-sdk-ast-generator/src/visitor.rs b/codegen-sdk-ast-generator/src/visitor.rs index 35e1aca6..f4e15541 100644 --- a/codegen-sdk-ast-generator/src/visitor.rs +++ b/codegen-sdk-ast-generator/src/visitor.rs @@ -84,24 +84,30 @@ pub fn generate_visitor<'db>( _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, - #[tracked] - #[return_ref] - pub node: crate::cst::#type_name<'db>, + // #[tracked] + // #[return_ref] + // pub node: crate::cst::#type_name<'db>, #(#fields),* } + impl<'db> #variant<'db> { + pub fn node(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db crate::cst::#type_name<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } + } impl<'db> codegen_sdk_resolution::HasFile<'db> for #variant<'db> { type File<'db1> = #language_struct<'db1>; fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { - let path = self.node(db).id().file(db).path(db); - let input = db.get_file(path).unwrap(); - parse(db, input) + let path = self._fully_qualified_name(db).path(db); + parse(db, path) } - fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf { + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { self.node(db).id().root(db).path(db) } } impl<'db> codegen_sdk_resolution::HasId<'db> for #variant<'db> { - fn fully_qualified_name(&self, db: &'db dyn codegen_sdk_resolution::Db) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + fn fully_qualified_name(&self, db: &'db dyn salsa::Database) -> codegen_sdk_resolution::FullyQualifiedName<'db> { self._fully_qualified_name(db) } } @@ -125,14 +131,14 @@ pub fn generate_visitor<'db>( #(Self::#symbol_names(symbol) => symbol.file(db),)* } } - fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf { + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { match self { #(Self::#symbol_names(symbol) => symbol.root_path(db),)* } } } impl<'db> codegen_sdk_resolution::HasId<'db> for #symbol_name<'db> { - fn fully_qualified_name(&self, db: &'db dyn codegen_sdk_resolution::Db) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + fn fully_qualified_name(&self, db: &'db dyn salsa::Database) -> codegen_sdk_resolution::FullyQualifiedName<'db> { match self { #(Self::#symbol_names(symbol) => symbol.fully_qualified_name(db),)* } diff --git a/codegen-sdk-common/Cargo.toml b/codegen-sdk-common/Cargo.toml index 1347b552..d8fcd0e1 100644 --- a/codegen-sdk-common/Cargo.toml +++ b/codegen-sdk-common/Cargo.toml @@ -41,6 +41,7 @@ salsa = { workspace = true } indextree = { workspace = true } indexmap = { workspace = true } rustc-hash = "2.1.1" +hashbrown = { version = "0.15.2", features = ["rayon"] } [dev-dependencies] test-log = { workspace = true } [features] diff --git a/codegen-sdk-common/src/hash.rs b/codegen-sdk-common/src/hash.rs index dd928145..ba9e4d93 100644 --- a/codegen-sdk-common/src/hash.rs +++ b/codegen-sdk-common/src/hash.rs @@ -6,7 +6,8 @@ pub type FxIndexSet = indexmap::IndexSet; pub type FxIndexMap = indexmap::IndexMap; // pub type FxDashMap = dashmap::DashMap; // pub type FxLinkedHashSet = hashlink::LinkedHashSet; -pub type FxHashSet = std::collections::HashSet; +pub type FxHashSet = hashbrown::HashSet; +pub type FxHashMap = hashbrown::HashMap; pub fn hash(t: &T) -> u64 { FxHasher::default().hash_one(t) } diff --git a/codegen-sdk-cst-generator/src/generator.rs b/codegen-sdk-cst-generator/src/generator.rs index c9ffdf29..48c6c66d 100644 --- a/codegen-sdk-cst-generator/src/generator.rs +++ b/codegen-sdk-cst-generator/src/generator.rs @@ -53,6 +53,7 @@ fn get_parser(language: &Language) -> TokenStream { #[tracked] #[return_ref] #[no_clone] + #[no_eq] pub tree: Arc>>, pub program: indextree::NodeId, } diff --git a/codegen-sdk-resolution/src/database.rs b/codegen-sdk-resolution/src/database.rs index fe07b272..b8f29f4e 100644 --- a/codegen-sdk-resolution/src/database.rs +++ b/codegen-sdk-resolution/src/database.rs @@ -8,9 +8,14 @@ pub trait Db: salsa::Database + Send { fn get_file(&self, path: PathBuf) -> Option; fn multi_progress(&self) -> &MultiProgress; fn watch_dir(&mut self, path: PathBuf) -> anyhow::Result<()>; - fn files(&self) -> codegen_sdk_common::hash::FxHashSet; + fn files(&self) -> codegen_sdk_common::hash::FxHashSet>; + fn get_file_for_id(&self, id: codegen_sdk_common::FileNodeId<'_>) -> Option { + self.get_file(id.path(self)) + } } #[salsa::tracked] -pub fn files(db: &dyn Db) -> codegen_sdk_common::hash::FxHashSet { +pub fn files( + db: &dyn Db, +) -> codegen_sdk_common::hash::FxHashSet> { db.files() } diff --git a/codegen-sdk-resolution/src/lib.rs b/codegen-sdk-resolution/src/lib.rs index 3bb34662..9d79bec7 100644 --- a/codegen-sdk-resolution/src/lib.rs +++ b/codegen-sdk-resolution/src/lib.rs @@ -20,5 +20,5 @@ pub use name::{FullyQualifiedName, HasId}; pub trait HasFile<'db> { type File<'db1>; fn file(&self, db: &'db dyn Db) -> &'db Self::File<'db>; - fn root_path(&self, db: &'db dyn salsa::Database) -> PathBuf; + fn root_path(&self, db: &'db dyn Db) -> PathBuf; } diff --git a/codegen-sdk-resolution/src/name.rs b/codegen-sdk-resolution/src/name.rs index 23557b2a..f56a7b42 100644 --- a/codegen-sdk-resolution/src/name.rs +++ b/codegen-sdk-resolution/src/name.rs @@ -4,11 +4,11 @@ use crate::Db; #[salsa::interned] pub struct FullyQualifiedName<'db> { #[id] - path: FileNodeId<'db>, + pub path: FileNodeId<'db>, #[return_ref] - name: String, + pub name: String, } pub trait HasId<'db> { - fn fully_qualified_name(&self, db: &'db dyn Db) -> FullyQualifiedName<'db>; + fn fully_qualified_name(&self, db: &'db dyn salsa::Database) -> FullyQualifiedName<'db>; } diff --git a/codegen-sdk-resolution/src/parse.rs b/codegen-sdk-resolution/src/parse.rs index d9614c4a..ca3ffb3b 100644 --- a/codegen-sdk-resolution/src/parse.rs +++ b/codegen-sdk-resolution/src/parse.rs @@ -1,5 +1,5 @@ -use salsa::Database; +use crate::Db; pub trait Parse<'db> { - fn parse(db: &'db dyn Database, input: codegen_sdk_cst::File) -> &'db Self; + fn parse(db: &'db dyn Db, input: codegen_sdk_common::FileNodeId<'db>) -> &'db Self; } diff --git a/codegen-sdk-resolution/src/references.rs b/codegen-sdk-resolution/src/references.rs index 4437df41..c8e54779 100644 --- a/codegen-sdk-resolution/src/references.rs +++ b/codegen-sdk-resolution/src/references.rs @@ -10,25 +10,25 @@ pub trait References< Clone + 'db, >: Eq + PartialEq + Hash + HasFile<'db, File<'db> = Scope> + HasId<'db> + Sized + 'db where Self:'db { - fn references(&self, db: &'db dyn Db) -> Vec + fn references(self, db: &'db dyn Db) -> Vec where Self: Sized, - Scope: Parse<'db>, - { - let files = files(db); - log::info!(target: "resolution", "Finding references across {:?} files", files.len()); - let mut results = Vec::new(); - for input in files { - // if !self.filter(db, &input) { - // continue; - // } - let file = Scope::parse(db, input.clone()); - let dependencies = file.clone().compute_dependencies_query(db); - if let Some(references) = dependencies.get(db, &self.fully_qualified_name(db)) { - results.extend(references.iter().cloned()); - } - } - results - } + Scope: Parse<'db>; + // { + // // let files = files(db); + // // log::info!(target: "resolution", "Finding references across {:?} files", files.len()); + // // let mut results = Vec::new(); + // // for input in files { + // // // if !self.filter(db, &input) { + // // // continue; + // // // } + // // let file = Scope::parse(db, input.clone()); + // // let dependencies = file.clone().compute_dependencies_query(db); + // // if let Some(references) = dependencies.get(db, &self.fully_qualified_name(db)) { + // // results.extend(references.iter().cloned()); + // // } + // // } + // results + // } fn filter(&self, db: &'db dyn Db, input: &codegen_sdk_cst::File) -> bool; } diff --git a/codegen-sdk-resolution/src/scope.rs b/codegen-sdk-resolution/src/scope.rs index aed6899e..84c2727d 100644 --- a/codegen-sdk-resolution/src/scope.rs +++ b/codegen-sdk-resolution/src/scope.rs @@ -20,17 +20,17 @@ pub trait Scope<'db>: Sized { fn compute_dependencies( self, db: &'db dyn Db, - ) -> codegen_sdk_common::hash::FxIndexMap< + ) -> codegen_sdk_common::hash::FxHashMap< FullyQualifiedName<'db>, codegen_sdk_common::hash::FxIndexSet, > where Self: 'db, { - let mut dependencies: codegen_sdk_common::hash::FxIndexMap< + let mut dependencies: codegen_sdk_common::hash::FxHashMap< FullyQualifiedName<'db>, codegen_sdk_common::hash::FxIndexSet, - > = codegen_sdk_common::hash::FxIndexMap::default(); + > = codegen_sdk_common::hash::FxHashMap::default(); for reference in self.resolvables(db) { let resolved = reference.clone().resolve_type(db); for resolved in resolved { diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index 27717653..6b3945af 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -11,12 +11,18 @@ pub mod ast { #[salsa::tracked] impl<'db> Import<'db> { #[salsa::tracked] - fn resolve_import(self, db: &'db dyn codegen_sdk_resolution::Db) -> Option { + fn resolve_import( + self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> Option> { let root_path = self.root_path(db); let module = self.module(db).source().replace(".", "/"); let target_path = root_path.join(module).with_extension("py"); log::info!(target: "resolution", "Resolving import to path: {:?}", target_path); - target_path.canonicalize().ok() + target_path + .canonicalize() + .ok() + .map(|path| codegen_sdk_common::FileNodeId::new(db, path)) } } #[salsa::tracked] @@ -25,7 +31,7 @@ pub mod ast { id: codegen_sdk_common::FileNodeId<'db>, #[return_ref] #[tracked] - dependencies: codegen_sdk_common::hash::FxIndexMap< + pub dependencies: codegen_sdk_common::hash::FxHashMap< codegen_sdk_resolution::FullyQualifiedName<'db>, codegen_sdk_common::hash::FxIndexSet>, >, @@ -46,13 +52,45 @@ pub mod ast { } } #[salsa::tracked(return_ref)] - pub fn dependencies( - db: &dyn codegen_sdk_resolution::Db, - input: codegen_sdk_cst::File, - ) -> PythonDependencies<'_> { + pub fn dependencies<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + input: codegen_sdk_common::FileNodeId<'db>, + ) -> PythonDependencies<'db> { let file = parse(db, input); PythonDependencies::new(db, file.id(db), file.compute_dependencies(db)) } + #[salsa::tracked(return_ref, no_eq)] + pub fn dependency_keys<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + input: codegen_sdk_common::FileNodeId<'db>, + ) -> codegen_sdk_common::hash::FxHashSet> { + let dependencies = dependencies(db, input); + dependencies.dependencies(db).keys().cloned().collect() + } + #[salsa::tracked] + struct UsagesInput<'db> { + #[id] + input: codegen_sdk_common::FileNodeId<'db>, + name: codegen_sdk_resolution::FullyQualifiedName<'db>, + } + #[salsa::tracked(return_ref)] + pub fn usages<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + input: UsagesInput<'db>, + ) -> codegen_sdk_common::hash::FxIndexSet> { + let file = parse(db, input.input(db)); + let mut results = codegen_sdk_common::hash::FxIndexSet::default(); + for reference in file.resolvables(db) { + let resolved = reference.clone().resolve_type(db); + for resolved in resolved { + if resolved.fully_qualified_name(db) == input.name(db) { + results.insert(reference); + continue; + } + } + } + results + } #[salsa::tracked] impl<'db> Scope<'db> for PythonFile<'db> { type Type = crate::ast::Symbol<'db>; @@ -94,17 +132,17 @@ pub mod ast { fn compute_dependencies( self, db: &'db dyn codegen_sdk_resolution::Db, - ) -> codegen_sdk_common::hash::FxIndexMap< + ) -> codegen_sdk_common::hash::FxHashMap< codegen_sdk_resolution::FullyQualifiedName<'db>, codegen_sdk_common::hash::FxIndexSet, > where Self: 'db, { - let mut dependencies: codegen_sdk_common::hash::FxIndexMap< + let mut dependencies: codegen_sdk_common::hash::FxHashMap< codegen_sdk_resolution::FullyQualifiedName<'db>, codegen_sdk_common::hash::FxIndexSet, - > = codegen_sdk_common::hash::FxIndexMap::default(); + > = codegen_sdk_common::hash::FxHashMap::default(); for reference in self.resolvables(db) { let resolved = reference.clone().resolve_type(db); for resolved in resolved { @@ -132,8 +170,8 @@ pub mod ast { fn resolve_type(self, db: &'db dyn codegen_sdk_resolution::Db) -> Vec { let target_path = self.resolve_import(db); if let Some(target_path) = target_path { - if let Some(input) = db.get_file(target_path) { - return PythonFile::parse(db, input) + if let Some(_) = db.get_file_for_id(target_path) { + return PythonFile::parse(db, target_path) .resolve(db, self.name(db).source()) .to_vec(); } @@ -154,6 +192,67 @@ pub mod ast { } } use codegen_sdk_resolution::{Db, Dependencies, HasId}; + // #[salsa::tracked(return_ref)] + pub fn references_for_file<'db>( + db: &'db dyn Db, + file: codegen_sdk_common::FileNodeId<'db>, + ) -> usize { + let parsed = parse(db, file); + let definitions = parsed.definitions(db); + let functions = definitions.functions(db); + let mut total_references = 0; + let total_functions = functions.len(); + let functions = functions + .into_iter() + .map(|(_, functions)| functions) + .flatten() + .map(|function| function.fully_qualified_name(db)) + .collect::>(); + let files = codegen_sdk_resolution::files(db); + log::info!(target: "resolution", "Finding references across {:?} files", files.len()); + let mut results = 0; + for input in files.into_iter() { + let keys = dependency_keys(db, input.clone()); + if keys.is_disjoint(&functions) { + continue; + } + // if !self.filter(db, &input) { + // continue; + // } + // let input = UsagesInput::new(db, input.clone(), name.clone()); + // results.extend(usages(db, input)); + let dependencies = dependencies(db, input.clone()); + for function in functions.iter() { + if let Some(references) = dependencies.get(db, function) { + results += references.len(); + } + } + } + results + } + pub fn references_impl<'db>( + db: &'db dyn Db, + name: codegen_sdk_resolution::FullyQualifiedName<'db>, + ) -> Vec> { + let files = codegen_sdk_resolution::files(db); + log::info!(target: "resolution", "Finding references across {:?} files", files.len()); + let mut results = Vec::new(); + for input in files.into_iter() { + let keys = dependency_keys(db, input.clone()); + if keys.contains(&name) { + // if !self.filter(db, &input) { + // continue; + // } + // let input = UsagesInput::new(db, input.clone(), name.clone()); + // results.extend(usages(db, input)); + let dependencies = dependencies(db, input.clone()); + if let Some(references) = dependencies.get(db, &name) { + results.extend(references.iter().cloned()); + } + } + } + results + } #[salsa::tracked] impl<'db> codegen_sdk_resolution::References< @@ -163,21 +262,8 @@ pub mod ast { PythonFile<'db>, > for crate::ast::Symbol<'db> { - fn references(&self, db: &'db dyn Db) -> Vec> { - let files = codegen_sdk_resolution::files(db); - log::info!(target: "resolution", "Finding references across {:?} files", files.len()); - let mut results = Vec::new(); - let name = self.fully_qualified_name(db); - for input in files { - // if !self.filter(db, &input) { - // continue; - // } - let dependencies = dependencies(db, input.clone()); - if let Some(references) = dependencies.get(db, &name) { - results.extend(references.iter().cloned()); - } - } - results + fn references(self, db: &'db dyn Db) -> Vec> { + references_impl(db, self.fully_qualified_name(db)) } fn filter( &self, diff --git a/src/main.rs b/src/main.rs index 2bb2c58c..5db74948 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,50 +12,48 @@ use codegen_sdk_resolution::{CodebaseContext, References}; struct Args { input: String, } -fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize, usize, usize)> { - codebase.execute_op_with_progress("Getting Usages", |db, file| { - if let Some(parsed) = parse_file(db, file).file(db) { - #[cfg(feature = "typescript")] - if let ParsedFile::Typescript(file) = parsed { - let definitions = file.definitions(db); - if let Some(node) = file.node(db) { - let tree = node.tree(db); - return ( - definitions.classes(db, &tree).len(), - definitions.functions(db, &tree).len(), - definitions.interfaces(db, &tree).len(), - definitions.methods(db, &tree).len(), - definitions.modules(db, &tree).len(), - 0, - ); - } - } - #[cfg(feature = "python")] - if let ParsedFile::Python(file) = parsed { - let definitions = file.definitions(db); - let functions = definitions.functions(db); - let mut total_references = 0; - let total_functions = functions.len(); - for function in functions - .into_iter() - .map(|(_, functions)| functions) - .flatten() - .map(|function| codegen_sdk_python::ast::Symbol::Function(function.clone())) - { - total_references += function.references(db).len(); - } +// #[salsa::tracked] +fn get_definitions<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + file: codegen_sdk_common::FileNodeId<'db>, +) -> (usize, usize, usize, usize, usize, usize) { + if let Some(parsed) = parse_file(db, file).file(db) { + #[cfg(feature = "typescript")] + if let ParsedFile::Typescript(file) = parsed { + let definitions = file.definitions(db); + if let Some(node) = file.node(db) { + let tree = node.tree(db); return ( definitions.classes(db).len(), - total_functions, - 0, + definitions.functions(db).len(), + definitions.interfaces(db).len(), + definitions.methods(db).len(), + definitions.modules(db).len(), 0, - 0, - total_references, ); } } - (0, 0, 0, 0, 0, 0) - }) + #[cfg(feature = "python")] + if let ParsedFile::Python(file) = parsed { + let definitions = file.definitions(db); + let functions = definitions.functions(db); + let mut total_references = + codegen_sdk_python::ast::references_for_file(db, file.id(db)); + return ( + definitions.classes(db).len(), + functions.len(), + 0, + 0, + 0, + total_references, + ); + } + } + (0, 0, 0, 0, 0, 0) +} +fn get_total_definitions(codebase: &Codebase) -> Vec<(usize, usize, usize, usize, usize, usize)> { + log::info!("Getting total definitions"); + codebase.execute_op_with_progress("Getting Usages", true, |db, file| get_definitions(db, file)) } fn print_definitions(codebase: &Codebase) { let mut total_classes = 0; @@ -89,17 +87,17 @@ fn main() -> anyhow::Result<()> { let dir = args.input; let start = Instant::now(); let mut codebase = Codebase::new(PathBuf::from(&dir)); - let end = Instant::now(); - let duration: std::time::Duration = end.duration_since(start); - let memory = get_memory(); - log::info!( - "{} files parsed in {:?}.{} seconds with {} errors. Using {} MB of memory", - codebase.files().len(), - duration.as_secs(), - duration.subsec_millis(), - codebase.errors().len(), - memory / 1024 / 1024 - ); + // let end = Instant::now(); + // let duration: std::time::Duration = end.duration_since(start); + // let memory = get_memory(); + // log::info!( + // "{} files parsed in {:?}.{} seconds with {} errors. Using {} MB of memory", + // codebase.files().len(), + // duration.as_secs(), + // duration.subsec_millis(), + // codebase.errors().len(), + // memory / 1024 / 1024 + // ); loop { // Compile the code starting at the provided input, this will read other // needed files using the on-demand mechanism. From f3933ca49b32ec70d5a162a9eb45c0696d5b00d2 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Mon, 10 Mar 2025 09:38:22 -0700 Subject: [PATCH 12/16] misc fixes --- Cargo.lock | 7 ++++--- Cargo.toml | 2 +- codegen-sdk-resolution/src/database.rs | 6 +++--- languages/codegen-sdk-python/src/lib.rs | 3 ++- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ec378de8..8f3f5427 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2344,7 +2344,7 @@ checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "salsa" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#9d2a9786c45000f5fa396ad2872391e302a2836a" +source = "git+https://github.com/salsa-rs/salsa?branch=master#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" dependencies = [ "boxcar", "crossbeam-queue", @@ -2353,6 +2353,7 @@ dependencies = [ "hashlink", "indexmap", "parking_lot", + "portable-atomic", "rayon", "rustc-hash", "salsa-macro-rules", @@ -2364,12 +2365,12 @@ dependencies = [ [[package]] name = "salsa-macro-rules" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#9d2a9786c45000f5fa396ad2872391e302a2836a" +source = "git+https://github.com/salsa-rs/salsa?branch=master#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" [[package]] name = "salsa-macros" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#9d2a9786c45000f5fa396ad2872391e302a2836a" +source = "git+https://github.com/salsa-rs/salsa?branch=master#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" dependencies = [ "heck 0.5.0", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 4f0e5724..c96fe6dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,7 +117,7 @@ insta = "1.42.1" prettyplease = "0.2.29" syn = { version = "2.0.98", features = ["proc-macro", "full"] } derive_more = { version = "2.0.1", features = ["debug", "display"] } -salsa = {git = "https://github.com/salsa-rs/salsa", branch = "master"} +salsa = {git = "https://github.com/salsa-rs/salsa", rev ="dbb0e5f6ab2cd61e42b372f333ab694f24141cf1"} subenum = {git = "https://github.com/mrenow/subenum", branch = "main"} indicatif-log-bridge = "0.2.3" indicatif = { version = "0.17.11", features = ["rayon"] } diff --git a/codegen-sdk-resolution/src/database.rs b/codegen-sdk-resolution/src/database.rs index b8f29f4e..300af45f 100644 --- a/codegen-sdk-resolution/src/database.rs +++ b/codegen-sdk-resolution/src/database.rs @@ -14,8 +14,8 @@ pub trait Db: salsa::Database + Send { } } #[salsa::tracked] -pub fn files( - db: &dyn Db, -) -> codegen_sdk_common::hash::FxHashSet> { +pub fn files<'db>( + db: &'db dyn Db, +) -> codegen_sdk_common::hash::FxHashSet> { db.files() } diff --git a/languages/codegen-sdk-python/src/lib.rs b/languages/codegen-sdk-python/src/lib.rs index 6b3945af..d9a3d74f 100644 --- a/languages/codegen-sdk-python/src/lib.rs +++ b/languages/codegen-sdk-python/src/lib.rs @@ -31,6 +31,7 @@ pub mod ast { id: codegen_sdk_common::FileNodeId<'db>, #[return_ref] #[tracked] + #[no_eq] pub dependencies: codegen_sdk_common::hash::FxHashMap< codegen_sdk_resolution::FullyQualifiedName<'db>, codegen_sdk_common::hash::FxIndexSet>, @@ -51,7 +52,7 @@ pub mod ast { self.dependencies(db).get(key) } } - #[salsa::tracked(return_ref)] + #[salsa::tracked(return_ref, no_eq)] pub fn dependencies<'db>( db: &'db dyn codegen_sdk_resolution::Db, input: codegen_sdk_common::FileNodeId<'db>, From 19100ca92185b0b9fa8c819c80f913a6f2e15ff5 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Mon, 10 Mar 2025 09:47:08 -0700 Subject: [PATCH 13/16] Fix bugs --- Cargo.lock | 6 +-- codegen-sdk-analyzer/src/codebase.rs | 3 +- codegen-sdk-analyzer/src/codebase/parser.rs | 6 +-- codegen-sdk-analyzer/src/parser.rs | 2 - codegen-sdk-resolution/src/name.rs | 1 - codegen-sdk-resolution/src/references.rs | 2 +- .../codegen-sdk-python/tests/test_python.rs | 4 +- src/main.rs | 46 +++++++++---------- 8 files changed, 31 insertions(+), 39 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8f3f5427..e3e0cb09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2344,7 +2344,7 @@ checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "salsa" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" +source = "git+https://github.com/salsa-rs/salsa?rev=dbb0e5f6ab2cd61e42b372f333ab694f24141cf1#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" dependencies = [ "boxcar", "crossbeam-queue", @@ -2365,12 +2365,12 @@ dependencies = [ [[package]] name = "salsa-macro-rules" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" +source = "git+https://github.com/salsa-rs/salsa?rev=dbb0e5f6ab2cd61e42b372f333ab694f24141cf1#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" [[package]] name = "salsa-macros" version = "0.18.0" -source = "git+https://github.com/salsa-rs/salsa?branch=master#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" +source = "git+https://github.com/salsa-rs/salsa?rev=dbb0e5f6ab2cd61e42b372f333ab694f24141cf1#dbb0e5f6ab2cd61e42b372f333ab694f24141cf1" dependencies = [ "heck 0.5.0", "proc-macro2", diff --git a/codegen-sdk-analyzer/src/codebase.rs b/codegen-sdk-analyzer/src/codebase.rs index 540c20c9..5f8e7b3c 100644 --- a/codegen-sdk-analyzer/src/codebase.rs +++ b/codegen-sdk-analyzer/src/codebase.rs @@ -3,7 +3,6 @@ use std::path::PathBuf; use anyhow::Context; #[cfg(feature = "serialization")] use codegen_sdk_common::serialization::Cache; -use codegen_sdk_cst::File; use codegen_sdk_resolution::{CodebaseContext, Db}; use discovery::FilesToParse; use notify_debouncer_mini::DebounceEventResult; @@ -66,7 +65,7 @@ impl Codebase { pub fn errors(&self) -> Vec<()> { let mut errors = Vec::new(); - for file in self.discover().files(&self.db) { + for file in self.db.files() { if self.get_file(file.path(&self.db)).is_none() { errors.push(()); } diff --git a/codegen-sdk-analyzer/src/codebase/parser.rs b/codegen-sdk-analyzer/src/codebase/parser.rs index 0754902a..3f2c4900 100644 --- a/codegen-sdk-analyzer/src/codebase/parser.rs +++ b/codegen-sdk-analyzer/src/codebase/parser.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use codegen_sdk_ast::{Definitions, References}; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; -use codegen_sdk_resolution::{Db, Scope}; +use codegen_sdk_resolution::Db; use indicatif::{ProgressBar, ProgressStyle}; use super::discovery::{FilesToParse, log_languages}; @@ -105,12 +105,12 @@ fn compute_dependencies_par(db: &dyn Db, files: FilesToParse) { .iter() .map(|input| codegen_sdk_common::FileNodeId::new(db, input.path(db))) .collect::>(); - let targets: codegen_sdk_common::hash::FxHashSet<(PathBuf, String)> = + let _targets: codegen_sdk_common::hash::FxHashSet<(PathBuf, String)> = execute_op_with_progress(db, ids, "Computing Dependencies", true, |db, input| { let file = parse_file(db, input.clone()); if let Some(parsed) = file.file(db) { #[cfg(feature = "python")] - if let ParsedFile::Python(parsed) = parsed { + if let ParsedFile::Python(_parsed) = parsed { let deps = codegen_sdk_python::ast::dependency_keys(db, input); return deps .iter() diff --git a/codegen-sdk-analyzer/src/parser.rs b/codegen-sdk-analyzer/src/parser.rs index caf91a29..16f3edab 100644 --- a/codegen-sdk-analyzer/src/parser.rs +++ b/codegen-sdk-analyzer/src/parser.rs @@ -1,5 +1,3 @@ -use std::path::PathBuf; - use codegen_sdk_common::FileNodeId; use codegen_sdk_cst::CSTLanguage; use codegen_sdk_macros::{languages_ast, parse_language}; diff --git a/codegen-sdk-resolution/src/name.rs b/codegen-sdk-resolution/src/name.rs index f56a7b42..c81e0a40 100644 --- a/codegen-sdk-resolution/src/name.rs +++ b/codegen-sdk-resolution/src/name.rs @@ -1,6 +1,5 @@ use codegen_sdk_common::FileNodeId; -use crate::Db; #[salsa::interned] pub struct FullyQualifiedName<'db> { #[id] diff --git a/codegen-sdk-resolution/src/references.rs b/codegen-sdk-resolution/src/references.rs index c8e54779..db21c568 100644 --- a/codegen-sdk-resolution/src/references.rs +++ b/codegen-sdk-resolution/src/references.rs @@ -1,6 +1,6 @@ use std::hash::Hash; -use crate::{Db, Dependencies, FullyQualifiedName, HasFile, HasId, Parse, ResolveType, files}; +use crate::{Db, Dependencies, FullyQualifiedName, HasFile, HasId, Parse, ResolveType}; pub trait References< 'db, diff --git a/languages/codegen-sdk-python/tests/test_python.rs b/languages/codegen-sdk-python/tests/test_python.rs index 69fe0ed2..db51c5bd 100644 --- a/languages/codegen-sdk-python/tests/test_python.rs +++ b/languages/codegen-sdk-python/tests/test_python.rs @@ -48,13 +48,13 @@ class Test: #[test_log::test] fn test_python_ast_function() { let temp_dir = tempfile::tempdir().unwrap(); + let root_path = temp_dir.path().to_path_buf(); let content = " def test(): pass"; let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_cst::File::new(&db, file_path, content); + let input = codegen_sdk_cst::File::new(&db, file_path, content, root_path.clone()); let file = codegen_sdk_python::ast::parse_query(&db, input); assert_eq!(file.definitions(&db).functions(&db).len(), 1); } diff --git a/src/main.rs b/src/main.rs index 5db74948..28ef18d6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ use codegen_sdk_ast::Definitions; #[cfg(feature = "serialization")] use codegen_sdk_common::serialize::Cache; use codegen_sdk_core::system::get_memory; -use codegen_sdk_resolution::{CodebaseContext, References}; +use codegen_sdk_resolution::CodebaseContext; #[derive(Debug, Parser)] struct Args { input: String, @@ -21,24 +21,20 @@ fn get_definitions<'db>( #[cfg(feature = "typescript")] if let ParsedFile::Typescript(file) = parsed { let definitions = file.definitions(db); - if let Some(node) = file.node(db) { - let tree = node.tree(db); - return ( - definitions.classes(db).len(), - definitions.functions(db).len(), - definitions.interfaces(db).len(), - definitions.methods(db).len(), - definitions.modules(db).len(), - 0, - ); - } + return ( + definitions.classes(db).len(), + definitions.functions(db).len(), + definitions.interfaces(db).len(), + definitions.methods(db).len(), + definitions.modules(db).len(), + 0, + ); } #[cfg(feature = "python")] if let ParsedFile::Python(file) = parsed { let definitions = file.definitions(db); let functions = definitions.functions(db); - let mut total_references = - codegen_sdk_python::ast::references_for_file(db, file.id(db)); + let total_references = codegen_sdk_python::ast::references_for_file(db, file.id(db)); return ( definitions.classes(db).len(), functions.len(), @@ -87,17 +83,17 @@ fn main() -> anyhow::Result<()> { let dir = args.input; let start = Instant::now(); let mut codebase = Codebase::new(PathBuf::from(&dir)); - // let end = Instant::now(); - // let duration: std::time::Duration = end.duration_since(start); - // let memory = get_memory(); - // log::info!( - // "{} files parsed in {:?}.{} seconds with {} errors. Using {} MB of memory", - // codebase.files().len(), - // duration.as_secs(), - // duration.subsec_millis(), - // codebase.errors().len(), - // memory / 1024 / 1024 - // ); + let end = Instant::now(); + let duration: std::time::Duration = end.duration_since(start); + let memory = get_memory(); + log::info!( + "{} files parsed in {:?}.{} seconds with {} errors. Using {} MB of memory", + codebase.files().len(), + duration.as_secs(), + duration.subsec_millis(), + codebase.errors().len(), + memory / 1024 / 1024 + ); loop { // Compile the code starting at the provided input, this will read other // needed files using the on-demand mechanism. From 271001282655d22e078d1706d2f3df765910801a Mon Sep 17 00:00:00 2001 From: bagel897 Date: Mon, 10 Mar 2025 10:23:59 -0700 Subject: [PATCH 14/16] Fix python tests --- codegen-sdk-ast-generator/Cargo.toml | 1 + codegen-sdk-cst-generator/Cargo.toml | 2 +- codegen-sdk-resolution/src/database.rs | 31 ++++++++++ .../codegen-sdk-python/tests/test_python.rs | 61 ++++++++++--------- .../tests/test_typescript.rs | 5 +- 5 files changed, 67 insertions(+), 33 deletions(-) diff --git a/codegen-sdk-ast-generator/Cargo.toml b/codegen-sdk-ast-generator/Cargo.toml index ddf4a180..00fce7a0 100644 --- a/codegen-sdk-ast-generator/Cargo.toml +++ b/codegen-sdk-ast-generator/Cargo.toml @@ -1,6 +1,7 @@ [package] name = "codegen-sdk-ast-generator" version = "0.1.0" +description = "Generates the AST for the given language from the tree-sitter queries" edition = "2024" [dependencies] diff --git a/codegen-sdk-cst-generator/Cargo.toml b/codegen-sdk-cst-generator/Cargo.toml index aecadca1..2581f234 100644 --- a/codegen-sdk-cst-generator/Cargo.toml +++ b/codegen-sdk-cst-generator/Cargo.toml @@ -2,7 +2,7 @@ name = "codegen-sdk-cst-generator" version = "0.1.0" edition = "2024" - +description = "Generates the CST for the given language from the tree-sitter node-types.json" [dependencies] syn = { workspace = true } tree-sitter = { workspace = true } diff --git a/codegen-sdk-resolution/src/database.rs b/codegen-sdk-resolution/src/database.rs index 300af45f..69aa8c29 100644 --- a/codegen-sdk-resolution/src/database.rs +++ b/codegen-sdk-resolution/src/database.rs @@ -19,3 +19,34 @@ pub fn files<'db>( ) -> codegen_sdk_common::hash::FxHashSet> { db.files() } +#[salsa::db] +impl Db for codegen_sdk_cst::CSTDatabase { + fn input(&self, path: PathBuf) -> anyhow::Result { + let content = std::fs::read_to_string(&path)?; + let file = + codegen_sdk_cst::File::new(self, path.canonicalize().unwrap(), content, PathBuf::new()); + Ok(file) + } + fn multi_progress(&self) -> &MultiProgress { + unimplemented!() + } + fn watch_dir(&mut self, _path: PathBuf) -> anyhow::Result<()> { + unimplemented!() + } + fn files(&self) -> codegen_sdk_common::hash::FxHashSet> { + let path = PathBuf::from("."); + let files = std::fs::read_dir(path).unwrap(); + let mut set = codegen_sdk_common::hash::FxHashSet::default(); + for file in files { + let file = file.unwrap(); + let path = file.path().canonicalize().unwrap(); + if path.is_file() { + set.insert(codegen_sdk_common::FileNodeId::new(self, path)); + } + } + set + } + fn get_file(&self, path: PathBuf) -> Option { + self.input(path).ok() + } +} diff --git a/languages/codegen-sdk-python/tests/test_python.rs b/languages/codegen-sdk-python/tests/test_python.rs index db51c5bd..9e3768f2 100644 --- a/languages/codegen-sdk-python/tests/test_python.rs +++ b/languages/codegen-sdk-python/tests/test_python.rs @@ -1,11 +1,8 @@ #![recursion_limit = "512"] -use std::path::PathBuf; +use std::{env, path::PathBuf}; use codegen_sdk_ast::{Definitions, References}; use codegen_sdk_resolution::References as _; -fn write_to_temp_file(content: &str, temp_dir: &tempfile::TempDir) -> PathBuf { - write_to_temp_file_with_name(content, temp_dir, "test.py") -} fn write_to_temp_file_with_name( content: &str, temp_dir: &tempfile::TempDir, @@ -15,6 +12,18 @@ fn write_to_temp_file_with_name( std::fs::write(&file_path, content).unwrap(); file_path } +fn parse_file<'db>( + db: &'db dyn codegen_sdk_resolution::Db, + content: &str, + temp_dir: &tempfile::TempDir, + name: &str, +) -> &'db codegen_sdk_python::ast::PythonFile<'db> { + let file_path = write_to_temp_file_with_name(content, temp_dir, name); + db.input(file_path.clone()).unwrap(); + let file_node_id = codegen_sdk_common::FileNodeId::new(db, file_path); + let file = codegen_sdk_python::ast::parse(db, file_node_id); + file +} // TODO: Fix queries for classes and functions // #[test_log::test] // fn test_typescript_ast_class() { @@ -38,54 +47,52 @@ fn test_python_ast_class() { let content = " class Test: pass"; - let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let root_path = temp_dir.path().to_path_buf(); - let input = codegen_sdk_cst::File::new(&db, file_path, content, root_path); - let file = codegen_sdk_python::ast::parse(&db, input); + let file = parse_file(&db, content, &temp_dir, "filea.py"); assert_eq!(file.definitions(&db).classes(&db).len(), 1); } #[test_log::test] fn test_python_ast_function() { let temp_dir = tempfile::tempdir().unwrap(); - let root_path = temp_dir.path().to_path_buf(); let content = " def test(): pass"; - let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let input = codegen_sdk_cst::File::new(&db, file_path, content, root_path.clone()); - let file = codegen_sdk_python::ast::parse_query(&db, input); + let file = parse_file(&db, content, &temp_dir, "filea.py"); assert_eq!(file.definitions(&db).functions(&db).len(), 1); } +// +// for function in codebase.functions(): +// function.rename("test2") +// codebase.commit() + +// 3 bounds +// 1. Codebase updated, everything else is invalidated +// 2. Files + codebase updated, everything else is invalidated +// 3. Everything updated, nothing is invalidated + #[test_log::test] fn test_python_ast_function_usages() { let temp_dir = tempfile::tempdir().unwrap(); + assert!(env::set_current_dir(&temp_dir).is_ok()); let content = " def test(): pass test()"; - let file_path = write_to_temp_file(content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let content = codegen_sdk_cst::Input::new(&db, content.to_string()); - let input = codegen_sdk_cst::File::new(&db, file_path, content); - let file = codegen_sdk_python::ast::parse_query(&db, input); + let file = parse_file(&db, content, &temp_dir, "filea.py"); assert_eq!(file.references(&db).calls(&db).len(), 1); let definitions = file.definitions(&db); let functions = definitions.functions(&db); let function = functions.get("test").unwrap().first().unwrap(); let function = codegen_sdk_python::ast::Symbol::Function(function.clone().clone()); - assert_eq!( - function - .references(&db, temp_dir.path().to_path_buf(), vec![*file], &file) - .len(), - 1 - ); + assert_eq!(function.references(&db).len(), 1); } #[test_log::test] fn test_python_ast_function_usages_cross_file() { let temp_dir = tempfile::tempdir().unwrap(); + assert!(env::set_current_dir(&temp_dir).is_ok()); let content = " def test(): pass @@ -94,15 +101,9 @@ def test(): let usage_file_content = " from filea import test test()"; - let root_path = temp_dir.path().to_path_buf(); - let file_path = write_to_temp_file_with_name(content, &temp_dir, "filea.py"); - let usage_file_path = write_to_temp_file_with_name(usage_file_content, &temp_dir, "fileb.py"); let db = codegen_sdk_cst::CSTDatabase::default(); - let input = codegen_sdk_cst::File::new(&db, file_path, content, root_path.clone()); - let usage_input = - codegen_sdk_cst::File::new(&db, usage_file_path, usage_file_content, root_path.clone()); - let file = codegen_sdk_python::ast::parse_query(&db, input); - let usage_file = codegen_sdk_python::ast::parse_query(&db, usage_input); + let file = parse_file(&db, content, &temp_dir, "filea.py"); + let usage_file = parse_file(&db, usage_file_content, &temp_dir, "fileb.py"); assert_eq!(usage_file.references(&db).calls(&db).len(), 1); let definitions = file.definitions(&db); let functions = definitions.functions(&db); diff --git a/languages/codegen-sdk-typescript/tests/test_typescript.rs b/languages/codegen-sdk-typescript/tests/test_typescript.rs index 8d7a54e8..8bc77c8b 100644 --- a/languages/codegen-sdk-typescript/tests/test_typescript.rs +++ b/languages/codegen-sdk-typescript/tests/test_typescript.rs @@ -28,10 +28,11 @@ fn write_to_temp_file(content: &str, temp_dir: &tempfile::TempDir) -> PathBuf { fn test_typescript_ast_interface() { let temp_dir = tempfile::tempdir().unwrap(); let content = "interface Test { }".to_string(); - let file_path = write_to_temp_file(content, &temp_dir); + let file_path = write_to_temp_file(&content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); let root_path = temp_dir.path().to_path_buf(); - let input = codegen_sdk_cst::File::new(&db, file_path, content, root_path); + db.input(file_path); + let input = codegen_sdk_cst::FileNodeId::new(&db, file_path); let file = codegen_sdk_typescript::ast::parse(&db, input); assert_eq!(file.definitions(&db).interfaces(&db).len(), 1); } From e0d945e74ecbbf88fbc1197302ce1d970be6e321 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Mon, 10 Mar 2025 10:26:57 -0700 Subject: [PATCH 15/16] Fix various tests --- codegen-sdk-cst/src/database.rs | 2 +- languages/codegen-sdk-typescript/tests/test_typescript.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/codegen-sdk-cst/src/database.rs b/codegen-sdk-cst/src/database.rs index f62cb369..e4c3b178 100644 --- a/codegen-sdk-cst/src/database.rs +++ b/codegen-sdk-cst/src/database.rs @@ -5,7 +5,7 @@ use dashmap::{DashMap, mapref::entry::Entry}; use crate::File; #[salsa::db] #[derive(Default, Clone)] -// Basic Database implementation for Query generation. This is not used for anything else. +// Basic Database implementation for Query generation and testing. This is not used for anything else. pub struct CSTDatabase { storage: salsa::Storage, } diff --git a/languages/codegen-sdk-typescript/tests/test_typescript.rs b/languages/codegen-sdk-typescript/tests/test_typescript.rs index 8bc77c8b..2a2b11b5 100644 --- a/languages/codegen-sdk-typescript/tests/test_typescript.rs +++ b/languages/codegen-sdk-typescript/tests/test_typescript.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use codegen_sdk_ast::Definitions; +use codegen_sdk_resolution::Db; fn write_to_temp_file(content: &str, temp_dir: &tempfile::TempDir) -> PathBuf { let file_path = temp_dir.path().join("test.ts"); std::fs::write(&file_path, content).unwrap(); @@ -30,9 +31,8 @@ fn test_typescript_ast_interface() { let content = "interface Test { }".to_string(); let file_path = write_to_temp_file(&content, &temp_dir); let db = codegen_sdk_cst::CSTDatabase::default(); - let root_path = temp_dir.path().to_path_buf(); - db.input(file_path); - let input = codegen_sdk_cst::FileNodeId::new(&db, file_path); + db.input(file_path.clone()).unwrap(); + let input = codegen_sdk_common::FileNodeId::new(&db, file_path); let file = codegen_sdk_typescript::ast::parse(&db, input); assert_eq!(file.definitions(&db).interfaces(&db).len(), 1); } From d31cab60a70be2ff0aacc96bbfe82ca62e6fa692 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Mon, 10 Mar 2025 10:28:22 -0700 Subject: [PATCH 16/16] update snapshots --- ...ast_generator__visitor__tests__python.snap | 202 +++++++++++++- ...generator__visitor__tests__typescript.snap | 250 +++++++++++++++++- ...ode__tests__get_struct_tokens_complex.snap | 2 +- ...node__tests__get_struct_tokens_simple.snap | 2 +- ...ests__get_struct_tokens_with_children.snap | 2 +- ..._tests__get_struct_tokens_with_fields.snap | 2 +- ..._struct_tokens_with_single_child_type.snap | 2 +- ...ts__add_field_subenums_missing_node-2.snap | 6 +- ..._generator__state__tests__get_structs.snap | 2 +- ..._tests__test_subtypes__basic_subtypes.snap | 86 +++++- ...test_subtypes__deeply_nested_subtypes.snap | 146 ++++++++-- ...__test_subtypes__subtypes_with_fields.snap | 103 +++++++- ...ypes_children__subtypes_with_children.snap | 112 ++++++-- ...ple_inheritance__multiple_inheritance.snap | 82 +++++- ...ubtypes_recursive__recursive_subtypes.snap | 131 ++++++++- 15 files changed, 1015 insertions(+), 115 deletions(-) diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap index d776675b..bf961cf9 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__python.snap @@ -4,41 +4,159 @@ expression: "codegen_sdk_common::generator::format_code_string(&visitor.to_strin --- #[salsa::tracked] pub struct Class<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::ClassDefinition<'db>, pub name: crate::cst::Identifier<'db>, } +impl<'db> Class<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::ClassDefinition<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Class<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Class<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[salsa::tracked] pub struct Constant<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::Module<'db>, pub name: crate::cst::Identifier<'db>, } +impl<'db> Constant<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::Module<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Constant<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Constant<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[salsa::tracked] pub struct Function<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::FunctionDefinition<'db>, pub name: crate::cst::Identifier<'db>, } +impl<'db> Function<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::FunctionDefinition<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Function<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Function<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[salsa::tracked] pub struct Import<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::ImportFromStatement<'db>, pub module: crate::cst::DottedName<'db>, + #[tracked] + #[return_ref] pub name: crate::cst::DottedName<'db>, } +impl<'db> Import<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::ImportFromStatement<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Import<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Import<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] pub enum Symbol<'db> { Class(Class<'db>), @@ -46,6 +164,38 @@ pub enum Symbol<'db> { Function(Function<'db>), Import(Import<'db>), } +impl<'db> codegen_sdk_resolution::HasFile<'db> for Symbol<'db> { + type File<'db1> = PythonFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + match self { + Self::Class(symbol) => symbol.file(db), + Self::Constant(symbol) => symbol.file(db), + Self::Function(symbol) => symbol.file(db), + Self::Import(symbol) => symbol.file(db), + } + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + match self { + Self::Class(symbol) => symbol.root_path(db), + Self::Constant(symbol) => symbol.root_path(db), + Self::Function(symbol) => symbol.root_path(db), + Self::Import(symbol) => symbol.root_path(db), + } + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Symbol<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + match self { + Self::Class(symbol) => symbol.fully_qualified_name(db), + Self::Constant(symbol) => symbol.fully_qualified_name(db), + Self::Function(symbol) => symbol.fully_qualified_name(db), + Self::Import(symbol) => symbol.fully_qualified_name(db), + } + } +} #[salsa::tracked] pub struct Definitions<'db> { #[return_ref] @@ -73,14 +223,34 @@ impl<'db> Definitions<'db> { ///Code for query: (class_definition name: (identifier) @name) @definition.class ///Code for field: name: (identifier) @name let name = node.name(tree); - let symbol = Class::new(db, id, node.clone(), name.clone()); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Class::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); classes.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::FunctionDefinition(node) => { ///Code for query: (function_definition name: (identifier) @name) @definition.function ///Code for field: name: (identifier) @name let name = node.name(tree); - let symbol = Function::new(db, id, node.clone(), name.clone()); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Function::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); functions.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::ImportFromStatement(node) => { @@ -95,12 +265,17 @@ impl<'db> Definitions<'db> { if let crate::cst::ImportFromStatementModuleNameRef::DottedName( module_name, ) = module_name { + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); let symbol = Import::new( db, + fully_qualified_name, id, - node.clone(), - module_name.clone(), - name.clone(), + module_name.clone().into(), + name.clone().into(), ); imports.entry(name.source()).or_default().push(symbol); } @@ -122,11 +297,16 @@ impl<'db> Definitions<'db> { ///Code for field: left: (identifier) @name let left = child.left(tree); if let crate::cst::AssignmentLeftRef::Identifier(left) = left { + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + left.source(), + ); let symbol = Constant::new( db, + fully_qualified_name, id, - node.clone(), - left.clone(), + left.clone().into(), ); constants.entry(left.source()).or_default().push(symbol); } diff --git a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap index 87717345..a059f4bf 100644 --- a/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap +++ b/codegen-sdk-ast-generator/src/snapshots/codegen_sdk_ast_generator__visitor__tests__typescript.snap @@ -4,49 +4,194 @@ expression: "codegen_sdk_common::generator::format_code_string(&visitor.to_strin --- #[salsa::tracked] pub struct Class<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::AbstractClassDeclaration<'db>, pub name: crate::cst::TypeIdentifier<'db>, } +impl<'db> Class<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::AbstractClassDeclaration<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Class<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Class<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[salsa::tracked] pub struct Function<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::FunctionSignature<'db>, pub name: crate::cst::Identifier<'db>, } +impl<'db> Function<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::FunctionSignature<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Function<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Function<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[salsa::tracked] pub struct Interface<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::InterfaceDeclaration<'db>, pub name: crate::cst::TypeIdentifier<'db>, } +impl<'db> Interface<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::InterfaceDeclaration<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Interface<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Interface<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[salsa::tracked] pub struct Method<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::AbstractMethodSignature<'db>, pub name: crate::cst::PropertyIdentifier<'db>, } +impl<'db> Method<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::AbstractMethodSignature<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Method<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Method<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[salsa::tracked] pub struct Module<'db> { + #[id] + _fully_qualified_name: codegen_sdk_resolution::FullyQualifiedName<'db>, #[id] node_id: indextree::NodeId, #[tracked] #[return_ref] - pub node: crate::cst::Module<'db>, pub name: crate::cst::Identifier<'db>, } +impl<'db> Module<'db> { + pub fn node( + &self, + db: &'db dyn codegen_sdk_resolution::Db, + ) -> &'db crate::cst::Module<'db> { + let file = self.file(db); + let tree = file.tree(db); + tree.get(&self.node_id(db)).unwrap().as_ref().try_into().unwrap() + } +} +impl<'db> codegen_sdk_resolution::HasFile<'db> for Module<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + let path = self._fully_qualified_name(db).path(db); + parse(db, path) + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + self.node(db).id().root(db).path(db) + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Module<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + self._fully_qualified_name(db) + } +} #[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] pub enum Symbol<'db> { Class(Class<'db>), @@ -55,6 +200,41 @@ pub enum Symbol<'db> { Method(Method<'db>), Module(Module<'db>), } +impl<'db> codegen_sdk_resolution::HasFile<'db> for Symbol<'db> { + type File<'db1> = TypescriptFile<'db1>; + fn file(&self, db: &'db dyn codegen_sdk_resolution::Db) -> &'db Self::File<'db> { + match self { + Self::Class(symbol) => symbol.file(db), + Self::Function(symbol) => symbol.file(db), + Self::Interface(symbol) => symbol.file(db), + Self::Method(symbol) => symbol.file(db), + Self::Module(symbol) => symbol.file(db), + } + } + fn root_path(&self, db: &'db dyn codegen_sdk_resolution::Db) -> PathBuf { + match self { + Self::Class(symbol) => symbol.root_path(db), + Self::Function(symbol) => symbol.root_path(db), + Self::Interface(symbol) => symbol.root_path(db), + Self::Method(symbol) => symbol.root_path(db), + Self::Module(symbol) => symbol.root_path(db), + } + } +} +impl<'db> codegen_sdk_resolution::HasId<'db> for Symbol<'db> { + fn fully_qualified_name( + &self, + db: &'db dyn salsa::Database, + ) -> codegen_sdk_resolution::FullyQualifiedName<'db> { + match self { + Self::Class(symbol) => symbol.fully_qualified_name(db), + Self::Function(symbol) => symbol.fully_qualified_name(db), + Self::Interface(symbol) => symbol.fully_qualified_name(db), + Self::Method(symbol) => symbol.fully_qualified_name(db), + Self::Module(symbol) => symbol.fully_qualified_name(db), + } + } +} #[salsa::tracked] pub struct Definitions<'db> { #[return_ref] @@ -85,7 +265,17 @@ impl<'db> Definitions<'db> { ///Code for query: (abstract_class_declaration name: (type_identifier) @name) @definition.class ///Code for field: name: (type_identifier) @name let name = node.name(tree); - let symbol = Class::new(db, id, node.clone(), name.clone()); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Class::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); classes.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::AbstractMethodSignature(node) => { @@ -95,7 +285,17 @@ impl<'db> Definitions<'db> { if let crate::cst::AbstractMethodSignatureNameRef::PropertyIdentifier( name, ) = name { - let symbol = Method::new(db, id, node.clone(), name.clone()); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Method::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); methods.entry(name.source()).or_default().push(symbol); } } @@ -103,14 +303,34 @@ impl<'db> Definitions<'db> { ///Code for query: (function_signature name: (identifier) @name) @definition.function ///Code for field: name: (identifier) @name let name = node.name(tree); - let symbol = Function::new(db, id, node.clone(), name.clone()); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Function::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); functions.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::InterfaceDeclaration(node) => { ///Code for query: (interface_declaration name: (type_identifier) @name) @definition.interface ///Code for field: name: (type_identifier) @name let name = node.name(tree); - let symbol = Interface::new(db, id, node.clone(), name.clone()); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Interface::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); interfaces.entry(name.source()).or_default().push(symbol); } crate::cst::NodeTypes::Module(node) => { @@ -118,7 +338,17 @@ impl<'db> Definitions<'db> { ///Code for field: name: (identifier) @name let name = node.name(tree); if let crate::cst::ModuleNameRef::Identifier(name) = name { - let symbol = Module::new(db, id, node.clone(), name.clone()); + let fully_qualified_name = codegen_sdk_resolution::FullyQualifiedName::new( + db, + node.file_id(), + name.source(), + ); + let symbol = Module::new( + db, + fully_qualified_name, + id, + name.clone().into(), + ); modules.entry(name.source()).or_default().push(symbol); } } diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_complex.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_complex.snap index a5dfaca2..0c3eab2c 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_complex.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_complex.snap @@ -24,7 +24,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let multiple_field = get_multiple_children_by_field_name::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_simple.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_simple.snap index 3488e0ab..8dacc1e0 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_simple.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_simple.snap @@ -21,7 +21,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_children.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_children.snap index 71c16c87..5ca94b1b 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_children.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_children.snap @@ -22,7 +22,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let _children = named_children_without_field_names::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_fields.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_fields.snap index f2c05f62..7a52fb59 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_fields.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_fields.snap @@ -22,7 +22,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let test_field = get_child_by_field_name::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_single_child_type.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_single_child_type.snap index 71c16c87..5ca94b1b 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_single_child_type.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__node__tests__get_struct_tokens_with_single_child_type.snap @@ -22,7 +22,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for TestNode<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let _children = named_children_without_field_names::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__add_field_subenums_missing_node-2.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__add_field_subenums_missing_node-2.snap index a6c909e5..def0094e 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__add_field_subenums_missing_node-2.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__add_field_subenums_missing_node-2.snap @@ -21,7 +21,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for AnonymousNodeA<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -142,7 +142,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for NodeB<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -264,7 +264,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for NodeC<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let field = get_multiple_children_by_field_name::< NodeTypes<'db>, diff --git a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__get_structs.snap b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__get_structs.snap index 5c0d90f7..cc9210e4 100644 --- a/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__get_structs.snap +++ b/codegen-sdk-cst-generator/src/generator/snapshots/codegen_sdk_cst_generator__generator__state__tests__get_structs.snap @@ -21,7 +21,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for Test<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__basic_subtypes.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__basic_subtypes.snap index 487887b6..0abe8b50 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__basic_subtypes.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__basic_subtypes.snap @@ -93,6 +93,26 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::UnaryExpression(data) => Self::UnaryExpression((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::UnaryExpression(data) => Self::UnaryExpression((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -144,6 +164,30 @@ impl<'db3> From<&'db3 Expression<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Expression<'db3> { + fn from(node: ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::UnaryExpression(data) => { + Self::UnaryExpression((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 ExpressionRef<'db3>> for Expression<'db3> { + fn from(node: &'db3 ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::UnaryExpression(data) => { + Self::UnaryExpression((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: ExpressionRef<'db3>) -> Result { @@ -191,7 +235,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for BinaryExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -312,7 +356,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for UnaryExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -422,13 +466,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -439,7 +483,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -449,7 +498,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -464,9 +515,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -483,12 +534,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__deeply_nested_subtypes.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__deeply_nested_subtypes.snap index 55f053b2..4ca160f1 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__deeply_nested_subtypes.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__deeply_nested_subtypes.snap @@ -147,6 +147,36 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + NodeTypesRef::ExpressionStatement(data) => { + Self::ExpressionStatement((*data).clone()) + } + NodeTypesRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + NodeTypesRef::ExpressionStatement(data) => { + Self::ExpressionStatement((*data).clone()) + } + NodeTypesRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 ClassDeclaration<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -212,6 +242,30 @@ impl<'db3> From<&'db3 Declaration<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Declaration<'db3> { + fn from(node: DeclarationRef<'db3>) -> Self { + match node { + DeclarationRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + DeclarationRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 DeclarationRef<'db3>> for Declaration<'db3> { + fn from(node: &'db3 DeclarationRef<'db3>) -> Self { + match node { + DeclarationRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + DeclarationRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 ClassDeclaration<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: DeclarationRef<'db3>) -> Result { @@ -263,6 +317,24 @@ impl<'db3> From<&'db3 FunctionDeclaration<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for FunctionDeclaration<'db3> { + fn from(node: FunctionDeclarationRef<'db3>) -> Self { + match node { + FunctionDeclarationRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 FunctionDeclarationRef<'db3>> for FunctionDeclaration<'db3> { + fn from(node: &'db3 FunctionDeclarationRef<'db3>) -> Self { + match node { + FunctionDeclarationRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 MethodDeclaration<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: FunctionDeclarationRef<'db3>) -> Result { @@ -302,6 +374,36 @@ impl<'db3> From<&'db3 Statement<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Statement<'db3> { + fn from(node: StatementRef<'db3>) -> Self { + match node { + StatementRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + StatementRef::ExpressionStatement(data) => { + Self::ExpressionStatement((*data).clone()) + } + StatementRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 StatementRef<'db3>> for Statement<'db3> { + fn from(node: &'db3 StatementRef<'db3>) -> Self { + match node { + StatementRef::ClassDeclaration(data) => { + Self::ClassDeclaration((*data).clone()) + } + StatementRef::ExpressionStatement(data) => { + Self::ExpressionStatement((*data).clone()) + } + StatementRef::MethodDeclaration(data) => { + Self::MethodDeclaration((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 ClassDeclaration<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: StatementRef<'db3>) -> Result { @@ -363,7 +465,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for ClassDeclaration<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -484,7 +586,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for ExpressionStatement<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -605,7 +707,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for MethodDeclaration<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -715,13 +817,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -732,7 +834,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -742,7 +849,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -757,9 +866,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -776,12 +885,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__subtypes_with_fields.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__subtypes_with_fields.snap index 30c03a16..4f98e87a 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__subtypes_with_fields.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes__subtypes_with_fields.snap @@ -112,6 +112,26 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -165,6 +185,27 @@ impl<'db3> From<&'db3 BinaryExpressionChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for BinaryExpressionChildren<'db3> { + fn from(node: BinaryExpressionChildrenRef<'db3>) -> Self { + match node { + BinaryExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + BinaryExpressionChildrenRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} +impl<'db3> From<&'db3 BinaryExpressionChildrenRef<'db3>> +for BinaryExpressionChildren<'db3> { + fn from(node: &'db3 BinaryExpressionChildrenRef<'db3>) -> Self { + match node { + BinaryExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + BinaryExpressionChildrenRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: BinaryExpressionChildrenRef<'db3>) -> Result { @@ -216,6 +257,26 @@ impl<'db3> From<&'db3 Expression<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Expression<'db3> { + fn from(node: ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} +impl<'db3> From<&'db3 ExpressionRef<'db3>> for Expression<'db3> { + fn from(node: &'db3 ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::Literal(data) => Self::Literal((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: ExpressionRef<'db3>) -> Result { @@ -265,7 +326,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for BinaryExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let left = get_child_by_field_name::< NodeTypes<'db>, @@ -416,7 +477,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for Literal<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -526,13 +587,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -543,7 +604,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -553,7 +619,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -568,9 +636,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -587,12 +655,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_children__subtypes_with_children.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_children__subtypes_with_children.snap index f57a9804..1bd8f1b8 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_children__subtypes_with_children.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_children__subtypes_with_children.snap @@ -147,6 +147,24 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::Block(data) => Self::Block((*data).clone()), + NodeTypesRef::IfStatement(data) => Self::IfStatement((*data).clone()), + NodeTypesRef::ReturnStatement(data) => Self::ReturnStatement((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::Block(data) => Self::Block((*data).clone()), + NodeTypesRef::IfStatement(data) => Self::IfStatement((*data).clone()), + NodeTypesRef::ReturnStatement(data) => Self::ReturnStatement((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 Block<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -212,6 +230,26 @@ impl<'db3> From<&'db3 BlockChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for BlockChildren<'db3> { + fn from(node: BlockChildrenRef<'db3>) -> Self { + match node { + BlockChildrenRef::IfStatement(data) => Self::IfStatement((*data).clone()), + BlockChildrenRef::ReturnStatement(data) => { + Self::ReturnStatement((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 BlockChildrenRef<'db3>> for BlockChildren<'db3> { + fn from(node: &'db3 BlockChildrenRef<'db3>) -> Self { + match node { + BlockChildrenRef::IfStatement(data) => Self::IfStatement((*data).clone()), + BlockChildrenRef::ReturnStatement(data) => { + Self::ReturnStatement((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 IfStatement<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: BlockChildrenRef<'db3>) -> Result { @@ -261,6 +299,20 @@ impl<'db3> From<&'db3 IfStatementChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for IfStatementChildren<'db3> { + fn from(node: IfStatementChildrenRef<'db3>) -> Self { + match node { + IfStatementChildrenRef::Block(data) => Self::Block((*data).clone()), + } + } +} +impl<'db3> From<&'db3 IfStatementChildrenRef<'db3>> for IfStatementChildren<'db3> { + fn from(node: &'db3 IfStatementChildrenRef<'db3>) -> Self { + match node { + IfStatementChildrenRef::Block(data) => Self::Block((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 Block<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: IfStatementChildrenRef<'db3>) -> Result { @@ -298,6 +350,22 @@ impl<'db3> From<&'db3 Statement<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Statement<'db3> { + fn from(node: StatementRef<'db3>) -> Self { + match node { + StatementRef::IfStatement(data) => Self::IfStatement((*data).clone()), + StatementRef::ReturnStatement(data) => Self::ReturnStatement((*data).clone()), + } + } +} +impl<'db3> From<&'db3 StatementRef<'db3>> for Statement<'db3> { + fn from(node: &'db3 StatementRef<'db3>) -> Self { + match node { + StatementRef::IfStatement(data) => Self::IfStatement((*data).clone()), + StatementRef::ReturnStatement(data) => Self::ReturnStatement((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 IfStatement<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: StatementRef<'db3>) -> Result { @@ -346,7 +414,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for Block<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let _children = named_children_without_field_names::< NodeTypes<'db>, @@ -482,7 +550,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for IfStatement<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let _children = named_children_without_field_names::< NodeTypes<'db>, @@ -617,7 +685,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for ReturnStatement<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -727,13 +795,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -744,7 +812,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -754,7 +827,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -769,9 +844,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -788,12 +863,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_multiple_inheritance__multiple_inheritance.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_multiple_inheritance__multiple_inheritance.snap index e7e00f03..e3c2c6a7 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_multiple_inheritance__multiple_inheritance.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_multiple_inheritance__multiple_inheritance.snap @@ -96,6 +96,20 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 ClassMethod<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -131,6 +145,20 @@ impl<'db3> From<&'db3 ClassMember<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for ClassMember<'db3> { + fn from(node: ClassMemberRef<'db3>) -> Self { + match node { + ClassMemberRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} +impl<'db3> From<&'db3 ClassMemberRef<'db3>> for ClassMember<'db3> { + fn from(node: &'db3 ClassMemberRef<'db3>) -> Self { + match node { + ClassMemberRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 ClassMethod<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: ClassMemberRef<'db3>) -> Result { @@ -166,6 +194,20 @@ impl<'db3> From<&'db3 Declaration<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Declaration<'db3> { + fn from(node: DeclarationRef<'db3>) -> Self { + match node { + DeclarationRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} +impl<'db3> From<&'db3 DeclarationRef<'db3>> for Declaration<'db3> { + fn from(node: &'db3 DeclarationRef<'db3>) -> Self { + match node { + DeclarationRef::ClassMethod(data) => Self::ClassMethod((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 ClassMethod<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: DeclarationRef<'db3>) -> Result { @@ -199,7 +241,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for ClassMethod<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); Ok(( Self { @@ -309,13 +351,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -326,7 +368,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -336,7 +383,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -351,9 +400,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -370,12 +419,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } } diff --git a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_recursive__recursive_subtypes.snap b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_recursive__recursive_subtypes.snap index 5a68711b..a2bb6f9c 100644 --- a/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_recursive__recursive_subtypes.snap +++ b/codegen-sdk-cst-generator/src/tests/snapshots/codegen_sdk_cst_generator__tests__test_subtypes_recursive__recursive_subtypes.snap @@ -131,6 +131,26 @@ impl<'db3> From<&'db3 NodeTypes<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for NodeTypes<'db3> { + fn from(node: NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::CallExpression(data) => Self::CallExpression((*data).clone()), + } + } +} +impl<'db3> From<&'db3 NodeTypesRef<'db3>> for NodeTypes<'db3> { + fn from(node: &'db3 NodeTypesRef<'db3>) -> Self { + match node { + NodeTypesRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + NodeTypesRef::CallExpression(data) => Self::CallExpression((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: NodeTypesRef<'db3>) -> Result { @@ -186,6 +206,31 @@ impl<'db3> From<&'db3 BinaryExpressionChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for BinaryExpressionChildren<'db3> { + fn from(node: BinaryExpressionChildrenRef<'db3>) -> Self { + match node { + BinaryExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + BinaryExpressionChildrenRef::CallExpression(data) => { + Self::CallExpression((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 BinaryExpressionChildrenRef<'db3>> +for BinaryExpressionChildren<'db3> { + fn from(node: &'db3 BinaryExpressionChildrenRef<'db3>) -> Self { + match node { + BinaryExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + BinaryExpressionChildrenRef::CallExpression(data) => { + Self::CallExpression((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: BinaryExpressionChildrenRef<'db3>) -> Result { @@ -239,6 +284,30 @@ impl<'db3> From<&'db3 CallExpressionChildren<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for CallExpressionChildren<'db3> { + fn from(node: CallExpressionChildrenRef<'db3>) -> Self { + match node { + CallExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + CallExpressionChildrenRef::CallExpression(data) => { + Self::CallExpression((*data).clone()) + } + } + } +} +impl<'db3> From<&'db3 CallExpressionChildrenRef<'db3>> for CallExpressionChildren<'db3> { + fn from(node: &'db3 CallExpressionChildrenRef<'db3>) -> Self { + match node { + CallExpressionChildrenRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + CallExpressionChildrenRef::CallExpression(data) => { + Self::CallExpression((*data).clone()) + } + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: CallExpressionChildrenRef<'db3>) -> Result { @@ -290,6 +359,26 @@ impl<'db3> From<&'db3 Expression<'db3>> for NodeTypesRef<'db3> { node.as_ref().into() } } +impl<'db3> From> for Expression<'db3> { + fn from(node: ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::CallExpression(data) => Self::CallExpression((*data).clone()), + } + } +} +impl<'db3> From<&'db3 ExpressionRef<'db3>> for Expression<'db3> { + fn from(node: &'db3 ExpressionRef<'db3>) -> Self { + match node { + ExpressionRef::BinaryExpression(data) => { + Self::BinaryExpression((*data).clone()) + } + ExpressionRef::CallExpression(data) => Self::CallExpression((*data).clone()), + } + } +} impl<'db3> TryFrom> for &'db3 BinaryExpression<'db3> { type Error = codegen_sdk_cst::ConversionError; fn try_from(node: ExpressionRef<'db3>) -> Result { @@ -339,7 +428,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for BinaryExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let left = get_child_by_field_name::< NodeTypes<'db>, @@ -492,7 +581,7 @@ impl<'db> FromNode<'db, NodeTypes<'db>> for CallExpression<'db> { ) -> Result<(Self, Vec), ParseError> { let start_position = Point::from(context.db, node.start_position()); let end_position = Point::from(context.db, node.end_position()); - let id = CSTNodeId::new(context.db, context.file_id, node.id()); + let id = CSTNodeId::new(context.db, context.file_id, node.id(), context.root); let mut ids = Vec::new(); let callee = get_child_by_field_name::< NodeTypes<'db>, @@ -632,13 +721,13 @@ pub struct Parsed<'db> { #[tracked] #[return_ref] #[no_clone] - pub tree: Tree>, + #[no_eq] + pub tree: Arc>>, pub program: indextree::NodeId, } pub fn parse_program_raw<'db>( db: &'db dyn salsa::Database, - input: codegen_sdk_cst::Input, - path: PathBuf, + input: codegen_sdk_cst::File, ) -> Option> { let buffer = Bytes::from(input.content(db).as_bytes().to_vec()); let tree = codegen_sdk_common::language::language::Language @@ -649,7 +738,12 @@ pub fn parse_program_raw<'db>( ParseError::SyntaxError.report(db); None } else { - let mut context = ParseContext::new(db, path, buffer); + let mut context = ParseContext::new( + db, + input.path(db), + input.root(db), + buffer, + ); let root_id = Program::orphaned(&mut context, tree.root_node()) .map_or_else( |e| { @@ -659,7 +753,9 @@ pub fn parse_program_raw<'db>( |program| { Some(program) }, ); if let Some(program) = root_id { - Some(Parsed::new(db, context.file_id, context.tree, program)) + Some( + Parsed::new(db, context.file_id, Arc::new(context.tree), program), + ) } else { None } @@ -674,9 +770,9 @@ pub fn parse_program_raw<'db>( #[salsa::tracked(return_ref)] pub fn parse_program( db: &dyn salsa::Database, - input: codegen_sdk_cst::Input, + input: codegen_sdk_cst::File, ) -> Parsed<'_> { - let raw = parse_program_raw(db, input, std::path::PathBuf::new()); + let raw = parse_program_raw(db, input); if let Some(parsed) = raw { parsed } else { @@ -693,12 +789,19 @@ impl CSTLanguage for Language { fn parse<'db>( db: &'db dyn salsa::Database, content: std::string::String, - ) -> Option<(&'db Self::Program<'db>, &'db Tree>)> { - let input = codegen_sdk_cst::Input::new(db, content); + ) -> Option< + (&'db Self::Program<'db>, &'db Tree>, indextree::NodeId), + > { + let input = codegen_sdk_cst::File::new( + db, + std::path::PathBuf::new(), + content, + std::path::PathBuf::new(), + ); let parsed = parse_program(db, input); - let program = parsed.program(db); + let program_id = parsed.program(db); let tree = parsed.tree(db); - let program = tree.get(&program).unwrap().as_ref(); - Some((program.try_into().unwrap(), tree)) + let program = tree.get(&program_id).unwrap().as_ref(); + Some((program.try_into().unwrap(), tree, program_id)) } }