diff --git a/Cargo.toml b/Cargo.toml index dba36cc0e..5faf94f25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,9 @@ console-subscriber = "0.4.1" env_logger = "0.11.7" reqwest = { version = "0.12.13", features = ["json"] } async-openai = "0.28.0" + tree-sitter = "0.25.3" +tree-sitter-language = "0.1.5" tree-sitter-python = "0.23.6" tree-sitter-javascript = "0.23.1" tree-sitter-typescript = "0.23.2" diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 1646a6293..ca305508e 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -1,5 +1,7 @@ use anyhow::anyhow; +use log::{error, trace}; use regex::{Matches, Regex}; +use std::collections::HashSet; use std::sync::LazyLock; use std::{collections::HashMap, sync::Arc}; @@ -22,20 +24,85 @@ static TEXT_SEPARATOR: LazyLock> = LazyLock::new(|| { .collect() }); -static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock> = - LazyLock::new(|| { - [ - ("python", tree_sitter_python::LANGUAGE.into()), - ("javascript", tree_sitter_javascript::LANGUAGE.into()), - ( - "typescript", - tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(), - ), - ("tsx", tree_sitter_typescript::LANGUAGE_TSX.into()), - ("markdown", tree_sitter_md::LANGUAGE.into()), - ] +struct LanguageConfig { + name: &'static str, + tree_sitter_lang: tree_sitter::Language, + terminal_node_kind_ids: HashSet, +} + +fn add_language<'a>( + output: &'a mut HashMap<&'static str, Arc>, + name: &'static str, + aliases: impl IntoIterator, + lang_fn: tree_sitter_language::LanguageFn, + terminal_node_kinds: impl IntoIterator, +) { + let tree_sitter_lang: tree_sitter::Language = lang_fn.into(); + let terminal_node_kind_ids = terminal_node_kinds .into_iter() - .collect() + .filter_map(|kind| { + let id = tree_sitter_lang.id_for_node_kind(kind, true); + if id != 0 { + trace!("Got id for node kind: `{kind}` -> {id}"); + Some(id) + } else { + error!("Failed in getting id for node kind: `{kind}`"); + None + } + }) + .collect(); + + let config = Arc::new(LanguageConfig { + name, + tree_sitter_lang, + terminal_node_kind_ids, + }); + for name in std::iter::once(name).chain(aliases.into_iter()) { + if output.insert(name, config.clone()).is_some() { + panic!("Language `{name}` already exists"); + } + } +} + +static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock>> = + LazyLock::new(|| { + let mut map = HashMap::new(); + add_language( + &mut map, + "Python", + ["py", "python"], + tree_sitter_python::LANGUAGE, + [], + ); + add_language( + &mut map, + "JavaScript", + ["JS", "js", "Javascript", "javascript"], + tree_sitter_javascript::LANGUAGE, + [], + ); + add_language( + &mut map, + "TypeScript", + ["TS", "ts", "Typescript", "typescript"], + tree_sitter_typescript::LANGUAGE_TYPESCRIPT, + [], + ); + add_language( + &mut map, + "TSX", + ["tsx"], + tree_sitter_typescript::LANGUAGE_TSX, + [], + ); + add_language( + &mut map, + "Markdown", + ["md", "markdown"], + tree_sitter_md::LANGUAGE.into(), + ["inline"], + ); + map }); enum ChunkKind<'t> { @@ -162,6 +229,7 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> { struct RecursiveChunker<'s> { full_text: &'s str, + lang_config: Option<&'s LanguageConfig>, chunk_size: usize, chunk_overlap: usize, } @@ -225,20 +293,27 @@ impl<'t, 's: 't> RecursiveChunker<'s> { ) -> Result<()> { match chunk.kind { ChunkKind::TreeSitterNode { node } => { - let mut cursor = node.walk(); - if cursor.goto_first_child() { - self.process_sub_chunks( - TreeSitterNodeIter { - full_text: self.full_text, - cursor: Some(cursor), - next_start_pos: node.start_byte(), - end_pos: node.end_byte(), - }, - output, - )?; - } else { - self.add_output(chunk.range, output); + if !self + .lang_config + .ok_or_else(|| anyhow!("Language not set."))? + .terminal_node_kind_ids + .contains(&node.kind_id()) + { + let mut cursor = node.walk(); + if cursor.goto_first_child() { + self.process_sub_chunks( + TreeSitterNodeIter { + full_text: self.full_text, + cursor: Some(cursor), + next_start_pos: node.start_byte(), + end_pos: node.end_byte(), + }, + output, + )?; + return Ok(()); + } } + self.add_output(chunk.range, output); } ChunkKind::RegexpSepChunk { next_regexp_sep_id } => { if next_regexp_sep_id >= TEXT_SEPARATOR.len() { @@ -335,8 +410,17 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator) -> Result { let full_text = self.args.text.value(&input)?.as_str()?; + let lang_config = { + let language = self.args.language.value(&input)?; + language + .map(|v| anyhow::Ok(v.as_str()?.as_ref())) + .transpose()? + .and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(lang)) + }; + let recursive_chunker = RecursiveChunker { full_text, + lang_config: lang_config.map(|c| c.as_ref()), chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize, chunk_overlap: self .args @@ -347,20 +431,11 @@ impl SimpleFunctionExecutor for Executor { .unwrap_or(0) as usize, }; - let language = self.args.language.value(&input)?; - let language = language - .map(|v| anyhow::Ok(v.as_str()?.as_ref())) - .transpose()?; - let tree_sitter_language = language.and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(lang)); - - let mut output = if let Some(tree_sitter_lang) = tree_sitter_language { + let mut output = if let Some(lang_config) = lang_config { let mut parser = tree_sitter::Parser::new(); - parser.set_language(tree_sitter_lang)?; + parser.set_language(&lang_config.tree_sitter_lang)?; let tree = parser.parse(full_text.as_ref(), None).ok_or_else(|| { - anyhow!( - "failed in parsing text in language: {}", - language.unwrap_or_default() - ) + anyhow!("failed in parsing text in language: {}", lang_config.name) })?; recursive_chunker.split_root_chunk(ChunkKind::TreeSitterNode { node: tree.root_node(),