diff --git a/Cargo.toml b/Cargo.toml index b21738d03..dba36cc0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,3 +49,8 @@ console-subscriber = "0.4.1" env_logger = "0.11.7" reqwest = { version = "0.12.13", features = ["json"] } async-openai = "0.28.0" +tree-sitter = "0.25.3" +tree-sitter-python = "0.23.6" +tree-sitter-javascript = "0.23.1" +tree-sitter-typescript = "0.23.2" +tree-sitter-md = "0.3.2" diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index ef803d0d3..cc5fc6d1f 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -1,3 +1,4 @@ +use anyhow::anyhow; use regex::{Matches, Regex}; use std::sync::LazyLock; use std::{collections::HashMap, sync::Arc}; @@ -14,118 +15,67 @@ pub struct Args { language: Option, } -static DEFAULT_SEPARATORS: LazyLock> = LazyLock::new(|| { +static TEXT_SEPARATOR: LazyLock> = LazyLock::new(|| { [r"\n\n+", r"\n", r"\s+"] .into_iter() .map(|s| Regex::new(s).unwrap()) .collect() }); -static SEPARATORS_BY_LANG: LazyLock>> = LazyLock::new(|| { - [ - ( - "markdown", - vec![ - r"(^|\n)\n*# ", - r"(^|\n)\n*## ", - r"(^|\n)\n*### ", - r"(^|\n)\n*#### ", - r"(^|\n)\n*##### ", - r"(^|\n)\n*###### ", - // Code block - r"(^|\n\n)\n*```\S*\n|\n\s*```\n*(\n\n|$)", - // Horizontal lines - r"(^|\n\n)\n*(\*\*\*+|---+|___+)\n*(\n\n|$)", - r"\n\n+", - r"(\.|!|\?)\s*(\s|$)", - r":\s*(\s|$)", - r";\s*(\s|$)", - r"\n", - r"\s+", - ], - ), - ( - "python", - vec![ - // First, try to split along class definitions - r"\nclass ", - r"\n def ", - r"\n def ", - r"\n def ", - r"\n def ", - // Now split by the normal type of lines - r"\n\n", - r"\n", - r"\s+", - ], - ), - ( - "javascript", - vec![ - // Split along function definitions - r"\nfunction ", - r"\nconst ", - r"\nlet ", - r"\nvar ", - r"\nclass ", - // Split along control flow statements - r"\n\s*if ", - r"\n\s*for ", - r"\n\s*while ", - r"\n\s*switch ", - r"\n\s*case ", - r"\n\s*default ", - // Split by the normal type of lines - r"\n\n", - r"\n", - ], - ), - ] - .into_iter() - .map(|(lang, separators)| { - let regexs = separators - .into_iter() - .map(|s| Regex::new(s).unwrap()) - .collect(); - (lang, regexs) - }) - .collect() -}); -trait NestedChunk { - fn range(&self) -> &RangeValue; +static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock> = + LazyLock::new(|| { + [ + ("python", tree_sitter_python::LANGUAGE.into()), + ("javascript", tree_sitter_javascript::LANGUAGE.into()), + ( + "typescript", + tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(), + ), + ("tsx", tree_sitter_typescript::LANGUAGE_TSX.into()), + ("markdown", tree_sitter_md::LANGUAGE.into()), + ] + .into_iter() + .collect() + }); - fn sub_chunks(&self) -> Option>; +enum ChunkKind<'t> { + TreeSitterNode { node: tree_sitter::Node<'t> }, + RegexpSepChunk { next_regexp_sep_id: usize }, + LeafText, } -struct SplitTarget<'s> { - separators: &'static [Regex], - text: &'s str, +struct Chunk<'t, 's: 't> { + full_text: &'s str, + range: RangeValue, + kind: ChunkKind<'t>, } -struct Chunk<'s> { - target: &'s SplitTarget<'s>, - range: RangeValue, - next_sep_id: usize, +impl<'t, 's: 't> Chunk<'t, 's> { + fn text(&self) -> &'s str { + self.range.extract_str(self.full_text) + } } -struct SubChunksIter<'a, 's: 'a> { - parent: &'a Chunk<'s>, +struct TextChunksIter<'t, 's: 't> { + parent: &'t Chunk<'t, 's>, matches_iter: Matches<'static, 's>, + regexp_sep_id: usize, next_start_pos: Option, } -impl<'a, 's: 'a> SubChunksIter<'a, 's> { - fn new(parent: &'a Chunk<'s>, matches_iter: Matches<'static, 's>) -> Self { +impl<'t, 's: 't> TextChunksIter<'t, 's> { + fn new(parent: &'t Chunk<'t, 's>, regexp_sep_id: usize) -> Self { Self { parent, - matches_iter, + matches_iter: TEXT_SEPARATOR[regexp_sep_id].find_iter(parent.text()), + regexp_sep_id, next_start_pos: Some(parent.range.start), } } } -impl<'a, 's: 'a> Iterator for SubChunksIter<'a, 's> { - type Item = Chunk<'s>; +impl<'t, 's: 't> Iterator for TextChunksIter<'t, 's> { + type Item = Chunk<'t, 's>; fn next(&mut self) -> Option { let start_pos = if let Some(start_pos) = self.next_start_pos { @@ -147,99 +97,187 @@ impl<'a, 's: 'a> Iterator for SubChunksIter<'a, 's> { } }; Some(Chunk { - target: self.parent.target, + full_text: self.parent.full_text, range: RangeValue::new(start_pos, end_pos), - next_sep_id: self.parent.next_sep_id + 1, + kind: ChunkKind::RegexpSepChunk { + next_regexp_sep_id: self.regexp_sep_id + 1, + }, }) } } -impl<'s> NestedChunk for Chunk<'s> { - fn range(&self) -> &RangeValue { - &self.range - } +struct TreeSitterNodeIter<'t, 's: 't> { + full_text: &'s str, + cursor: Option>, + next_start_pos: usize, + end_pos: usize, +} - fn sub_chunks(&self) -> Option> { - if self.next_sep_id >= self.target.separators.len() { +impl<'t, 's: 't> TreeSitterNodeIter<'t, 's> { + fn fill_gap( + next_start_pos: &mut usize, + gap_end_pos: usize, + full_text: &'s str, + ) -> Option> { + let start_pos = *next_start_pos; + if start_pos < gap_end_pos { + *next_start_pos = gap_end_pos; + Some(Chunk { + full_text, + range: RangeValue::new(start_pos, gap_end_pos), + kind: ChunkKind::LeafText, + }) + } else { None + } + } +} + +impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> { + type Item = Chunk<'t, 's>; + + fn next(&mut self) -> Option { + let cursor = if let Some(cursor) = &mut self.cursor { + cursor } else { - let sub_text = self.range.extract_str(&self.target.text); - Some(SubChunksIter::new( - self, - self.target.separators[self.next_sep_id].find_iter(sub_text), - )) + return Self::fill_gap(&mut self.next_start_pos, self.end_pos, self.full_text); + }; + let node = cursor.node(); + if let Some(gap) = + Self::fill_gap(&mut self.next_start_pos, node.start_byte(), self.full_text) + { + return Some(gap); } + if !cursor.goto_next_sibling() { + self.cursor = None; + } + self.next_start_pos = node.end_byte(); + Some(Chunk { + full_text: self.full_text, + range: RangeValue::new(node.start_byte(), node.end_byte()), + kind: ChunkKind::TreeSitterNode { node }, + }) } } struct RecursiveChunker<'s> { - text: &'s str, + full_text: &'s str, chunk_size: usize, chunk_overlap: usize, } -impl<'s> RecursiveChunker<'s> { - fn split_substring(&self, chunk: Chk, output: &mut Vec<(RangeValue, &'s str)>) - where - Chk: NestedChunk + Sized, - { - let sub_chunks_iter = if let Some(sub_chunks_iter) = chunk.sub_chunks() { - sub_chunks_iter - } else { - self.add_output(*chunk.range(), output); +impl<'t, 's: 't> RecursiveChunker<'s> { + fn flush_small_chunks(&self, chunks: &[RangeValue], output: &mut Vec<(RangeValue, &'s str)>) { + if chunks.is_empty() { return; - }; - - let flush_small_chunks = - |chunks: &[RangeValue], output: &mut Vec<(RangeValue, &'s str)>| { - if chunks.is_empty() { - return; - } - let mut start_pos = chunks[0].start; - for i in 1..chunks.len() - 1 { - let chunk = &chunks[i]; - if chunk.end - start_pos > self.chunk_size { - self.add_output(RangeValue::new(start_pos, chunk.end), output); - - // Find the new start position, allowing overlap within the threshold. - let mut new_start_idx = i + 1; - let next_chunk = &chunks[i + 1]; - while new_start_idx > 0 { - let prev_pos = chunks[new_start_idx - 1].start; - if prev_pos <= start_pos - || chunk.end - prev_pos > self.chunk_overlap - || next_chunk.end - prev_pos > self.chunk_size - { - break; - } - new_start_idx -= 1; - } - start_pos = chunks[new_start_idx].start; + } + let mut start_pos = chunks[0].start; + for i in 1..chunks.len() - 1 { + let chunk = &chunks[i]; + if chunk.end - start_pos > self.chunk_size { + self.add_output(RangeValue::new(start_pos, chunk.end), output); + + // Find the new start position, allowing overlap within the threshold. + let mut new_start_idx = i + 1; + let next_chunk = &chunks[i + 1]; + while new_start_idx > 0 { + let prev_pos = chunks[new_start_idx - 1].start; + if prev_pos <= start_pos + || chunk.end - prev_pos > self.chunk_overlap + || next_chunk.end - prev_pos > self.chunk_size + { + break; } + new_start_idx -= 1; } + start_pos = chunks[new_start_idx].start; + } + } - let last_chunk = &chunks[chunks.len() - 1]; - self.add_output(RangeValue::new(start_pos, last_chunk.end), output); - }; + let last_chunk = &chunks[chunks.len() - 1]; + self.add_output(RangeValue::new(start_pos, last_chunk.end), output); + } + fn process_sub_chunks( + &self, + sub_chunks_iter: impl Iterator>, + output: &mut Vec<(RangeValue, &'s str)>, + ) -> Result<()> { let mut small_chunks = Vec::new(); for sub_chunk in sub_chunks_iter { - let sub_range = sub_chunk.range(); + let sub_range = sub_chunk.range; if sub_range.len() <= self.chunk_size { - small_chunks.push(*sub_chunk.range()); + small_chunks.push(sub_range); } else { - flush_small_chunks(&small_chunks, output); + self.flush_small_chunks(&small_chunks, output); small_chunks.clear(); - self.split_substring(sub_chunk, output); + self.split_substring(sub_chunk, output)?; + } + } + self.flush_small_chunks(&small_chunks, output); + Ok(()) + } + + fn split_substring( + &self, + chunk: Chunk<'t, 's>, + output: &mut Vec<(RangeValue, &'s str)>, + ) -> Result<()> { + match chunk.kind { + ChunkKind::TreeSitterNode { node } => { + let mut cursor = node.walk(); + if cursor.goto_first_child() { + self.process_sub_chunks( + TreeSitterNodeIter { + full_text: self.full_text, + cursor: Some(cursor), + next_start_pos: node.start_byte(), + end_pos: node.end_byte(), + }, + output, + )?; + } else { + self.add_output(chunk.range, output); + } + } + ChunkKind::RegexpSepChunk { next_regexp_sep_id } => { + if next_regexp_sep_id >= TEXT_SEPARATOR.len() { + self.add_output(chunk.range, output); + } else { + self.process_sub_chunks( + TextChunksIter::new(&chunk, next_regexp_sep_id), + output, + )?; + } + } + ChunkKind::LeafText => { + self.add_output(chunk.range, output); } } - flush_small_chunks(&small_chunks, output); + Ok(()) + } + + fn split_root_chunk(&self, kind: ChunkKind<'t>) -> Result> { + let mut output = Vec::new(); + self.split_substring( + Chunk { + full_text: self.full_text, + range: RangeValue::new(0, self.full_text.len()), + kind, + }, + &mut output, + )?; + Ok(output) } fn add_output(&self, range: RangeValue, output: &mut Vec<(RangeValue, &'s str)>) { - let text = range.extract_str(self.text); - if !text.trim().is_empty() { - output.push((range, text)); + let text = range.extract_str(self.full_text); + let trimmed_text = text.trim_end(); + if !trimmed_text.is_empty() { + output.push(( + RangeValue::new(range.start, range.start + trimmed_text.len()), + trimmed_text, + )); } } } @@ -287,9 +325,9 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator) -> Result { - let text = self.args.text.value(&input)?.as_str()?; + let full_text = self.args.text.value(&input)?.as_str()?; let recursive_chunker = RecursiveChunker { - text, + full_text, chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize, chunk_overlap: self .args @@ -300,31 +338,32 @@ impl SimpleFunctionExecutor for Executor { .unwrap_or(0) as usize, }; - let separators = self - .args - .language - .value(&input)? - .map(|v| v.as_str()) - .transpose()? - .and_then(|lang| { - SEPARATORS_BY_LANG - .get(lang.to_lowercase().as_str()) - .map(|v| v.as_slice()) - }) - .unwrap_or(DEFAULT_SEPARATORS.as_slice()); - - let mut output = Vec::new(); - recursive_chunker.split_substring( - Chunk { - target: &SplitTarget { separators, text }, - range: RangeValue::new(0, text.len()), - next_sep_id: 0, - }, - &mut output, - ); + let language = self.args.language.value(&input)?; + let language = language + .map(|v| anyhow::Ok(v.as_str()?.as_ref())) + .transpose()?; + let tree_sitter_language = language.and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(lang)); + + let mut output = if let Some(tree_sitter_lang) = tree_sitter_language { + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_lang)?; + let tree = parser.parse(full_text.as_ref(), None).ok_or_else(|| { + anyhow!( + "failed in parsing text in language: {}", + language.unwrap_or_default() + ) + })?; + recursive_chunker.split_root_chunk(ChunkKind::TreeSitterNode { + node: tree.root_node(), + })? + } else { + recursive_chunker.split_root_chunk(ChunkKind::RegexpSepChunk { + next_regexp_sep_id: 0, + })? + }; translate_bytes_to_chars( - text, + full_text, output .iter_mut() .map(|(range, _)| [&mut range.start, &mut range.end].into_iter())