diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index dc4eac178..b1a9fb7f0 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -1,11 +1,12 @@ -use anyhow::anyhow; +use anyhow::{Context, anyhow}; use log::{error, trace}; use regex::{Matches, Regex}; use std::collections::HashSet; use std::sync::LazyLock; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, ops::Range, sync::Arc}; use unicase::UniCase; +use crate::ops::sdk::RangeValue; use crate::ops::shared::split::{Position, set_output_positions}; use crate::{fields_value, ops::sdk::*}; @@ -23,8 +24,6 @@ struct Spec { custom_languages: Vec, } -const TREESITTER_MAX_RECURSION_DEPTH: usize = 128; - const SYNTAX_LEVEL_GAP_COST: usize = 512; const MISSING_OVERLAP_COST: usize = 512; const PER_LINE_BREAK_LEVEL_GAP_COST: usize = 64; @@ -60,12 +59,12 @@ struct TreesitterLanguageConfig { terminal_node_kind_ids: HashSet, } -fn add_treesitter_language<'a>( - output: &'a mut HashMap, Arc>, +fn add_treesitter_language( + output: &mut HashMap, Arc>, name: &'static str, aliases: impl IntoIterator, lang_fn: impl Into, - terminal_node_kinds: impl IntoIterator, + terminal_node_kinds: impl IntoIterator, ) { let tree_sitter_lang: tree_sitter::Language = lang_fn.into(); let terminal_node_kind_ids = terminal_node_kinds @@ -103,7 +102,7 @@ static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock< &mut map, "C++", [".cpp", ".cc", ".cxx", ".h", ".hpp", "cpp"], - tree_sitter_c::LANGUAGE, + tree_sitter_cpp::LANGUAGE, [], ); add_treesitter_language( @@ -260,15 +259,10 @@ struct Chunk<'t, 's: 't> { kind: ChunkKind<'t>, } -impl<'t, 's: 't> Chunk<'t, 's> { - fn text(&self) -> &'s str { - self.range.extract_str(self.full_text) - } -} - struct TextChunksIter<'t, 's: 't> { lang_config: &'t SimpleLanguageConfig, - parent: &'t Chunk<'t, 's>, + full_text: &'s str, + range: RangeValue, matches_iter: Matches<'t, 's>, regexp_sep_id: usize, next_start_pos: Option, @@ -277,15 +271,19 @@ struct TextChunksIter<'t, 's: 't> { impl<'t, 's: 't> TextChunksIter<'t, 's> { fn new( lang_config: &'t SimpleLanguageConfig, - parent: &'t Chunk<'t, 's>, + full_text: &'s str, + range: RangeValue, regexp_sep_id: usize, ) -> Self { + let std_range = range.start..range.end; Self { lang_config, - parent, - matches_iter: lang_config.separator_regex[regexp_sep_id].find_iter(parent.text()), + full_text, + range, + matches_iter: lang_config.separator_regex[regexp_sep_id] + .find_iter(&full_text[std_range.clone()]), regexp_sep_id, - next_start_pos: Some(parent.range.start), + next_start_pos: Some(std_range.start), } } } @@ -297,19 +295,19 @@ impl<'t, 's: 't> Iterator for TextChunksIter<'t, 's> { let start_pos = self.next_start_pos?; let end_pos = match self.matches_iter.next() { Some(grp) => { - self.next_start_pos = Some(self.parent.range.start + grp.end()); - self.parent.range.start + grp.start() + self.next_start_pos = Some(self.range.start + grp.end()); + self.range.start + grp.start() } None => { self.next_start_pos = None; - if start_pos >= self.parent.range.end { + if start_pos >= self.range.end { return None; } - self.parent.range.end + self.range.end } }; Some(Chunk { - full_text: self.parent.full_text, + full_text: self.full_text, range: RangeValue::new(start_pos, end_pos), kind: ChunkKind::RegexpSepChunk { lang_config: self.lang_config, @@ -380,6 +378,24 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> { } } +enum ChunkIterator<'t, 's: 't> { + TreeSitter(TreeSitterNodeIter<'t, 's>), + Text(TextChunksIter<'t, 's>), + Once(std::iter::Once>), +} + +impl<'t, 's: 't> Iterator for ChunkIterator<'t, 's> { + type Item = Chunk<'t, 's>; + + fn next(&mut self) -> Option { + match self { + ChunkIterator::TreeSitter(iter) => iter.next(), + ChunkIterator::Text(iter) => iter.next(), + ChunkIterator::Once(iter) => iter.next(), + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] enum LineBreakLevel { Inline, @@ -422,18 +438,17 @@ const INLINE_SPACE_CHARS: [char; 2] = [' ', '\t']; struct AtomChunk { range: RangeValue, boundary_syntax_level: usize, - internal_lb_level: LineBreakLevel, boundary_lb_level: LineBreakLevel, } struct AtomChunksCollector<'s> { full_text: &'s str, - curr_level: usize, min_level: usize, atom_chunks: Vec, } + impl<'s> AtomChunksCollector<'s> { fn collect(&mut self, range: RangeValue) { // Trim trailing whitespaces. @@ -492,78 +507,75 @@ struct RecursiveChunker<'s> { } impl<'t, 's: 't> RecursiveChunker<'s> { - fn collect_atom_chunks_from_iter( - &self, - sub_chunks_iter: impl Iterator>, - atom_collector: &mut AtomChunksCollector<'s>, - ) -> Result<()> { - atom_collector.curr_level += 1; - for sub_chunk in sub_chunks_iter { - let range = sub_chunk.range; - if range.len() <= self.min_chunk_size { - atom_collector.collect(range); - } else { - self.collect_atom_chunks(sub_chunk, atom_collector)?; - } - } - atom_collector.curr_level -= 1; - if atom_collector.curr_level < atom_collector.min_level { - atom_collector.min_level = atom_collector.curr_level; - } - Ok(()) - } - fn collect_atom_chunks( &self, chunk: Chunk<'t, 's>, atom_collector: &mut AtomChunksCollector<'s>, ) -> Result<()> { - match chunk.kind { - ChunkKind::TreeSitterNode { lang_config, node } => { - if !lang_config.terminal_node_kind_ids.contains(&node.kind_id()) - && atom_collector.curr_level < TREESITTER_MAX_RECURSION_DEPTH - { - let mut cursor = node.walk(); - if cursor.goto_first_child() { - return self.collect_atom_chunks_from_iter( - TreeSitterNodeIter { - lang_config, + let mut iter_stack: Vec> = + vec![ChunkIterator::Once(std::iter::once(chunk))]; + + while !iter_stack.is_empty() { + atom_collector.curr_level = iter_stack.len(); + + if let Some(current_chunk) = iter_stack.last_mut().unwrap().next() { + if current_chunk.range.len() <= self.min_chunk_size { + atom_collector.collect(current_chunk.range); + } else { + match current_chunk.kind { + ChunkKind::TreeSitterNode { lang_config, node } => { + if !lang_config.terminal_node_kind_ids.contains(&node.kind_id()) { + let mut cursor = node.walk(); + if cursor.goto_first_child() { + iter_stack.push(ChunkIterator::TreeSitter( + TreeSitterNodeIter { + lang_config, + full_text: self.full_text, + cursor: Some(cursor), + next_start_pos: node.start_byte(), + end_pos: node.end_byte(), + }, + )); + continue; + } + } + iter_stack.push(ChunkIterator::Once(std::iter::once(Chunk { full_text: self.full_text, - cursor: Some(cursor), - next_start_pos: node.start_byte(), - end_pos: node.end_byte(), - }, - atom_collector, - ); + range: current_chunk.range, + kind: ChunkKind::RegexpSepChunk { + lang_config: &DEFAULT_LANGUAGE_CONFIG, + next_regexp_sep_id: 0, + }, + }))); + } + ChunkKind::RegexpSepChunk { + lang_config, + next_regexp_sep_id, + } => { + if next_regexp_sep_id >= lang_config.separator_regex.len() { + atom_collector.collect(current_chunk.range); + } else { + iter_stack.push(ChunkIterator::Text(TextChunksIter::new( + lang_config, + current_chunk.full_text, + current_chunk.range, + next_regexp_sep_id, + ))); + } + } } } - self.collect_atom_chunks( - Chunk { - full_text: self.full_text, - range: chunk.range, - kind: ChunkKind::RegexpSepChunk { - lang_config: &DEFAULT_LANGUAGE_CONFIG, - next_regexp_sep_id: 0, - }, - }, - atom_collector, - ) - } - ChunkKind::RegexpSepChunk { - lang_config, - next_regexp_sep_id, - } => { - if next_regexp_sep_id >= lang_config.separator_regex.len() { - atom_collector.collect(chunk.range); - Ok(()) - } else { - self.collect_atom_chunks_from_iter( - TextChunksIter::new(lang_config, &chunk, next_regexp_sep_id), - atom_collector, - ) + } else { + iter_stack.pop(); + let level_after_pop = iter_stack.len(); + atom_collector.curr_level = level_after_pop; + if level_after_pop < atom_collector.min_level { + atom_collector.min_level = level_after_pop; } } } + atom_collector.curr_level = 0; + Ok(()) } fn get_overlap_cost_base(&self, offset: usize) -> usize { @@ -940,10 +952,9 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> { mod tests { use super::*; use crate::ops::functions::test_utils::test_flow_function; - use crate::ops::sdk::{BasicValueType, KeyValue, RangeValue, make_output_type}; + use crate::ops::sdk::{BasicValueType, KeyPart, KeyValue, make_output_type}; use crate::ops::shared::split::OutputPosition; - // Helper function to build the standard input argument schemas for split_recursively tests fn build_split_recursively_arg_schemas() -> Vec<(Option<&'static str>, EnrichedValueType)> { vec![ ( @@ -1014,7 +1025,7 @@ mod tests { scope_value_ref.0.fields[0].as_str().unwrap_or_else(|_| { panic!("Chunk text not a string for key {key:?}") }); - assert_eq!(**chunk_text, *expected_text); + assert_eq!(*chunk_text, expected_text.into()); } None => panic!("Expected row value for key {key:?}, not found"), } @@ -1159,11 +1170,6 @@ mod tests { ], ) .await; - assert!( - result.is_ok(), - "test_flow_function failed: {:?}", - result.err() - ); let value = result.unwrap(); match value { Value::KTable(table) => { @@ -1177,11 +1183,8 @@ mod tests { let key = KeyValue::from_single_part(range); match table.get(&key) { Some(scope_value_ref) => { - let chunk_text = - scope_value_ref.0.fields[0].as_str().unwrap_or_else(|_| { - panic!("Chunk text not a string for key {key:?}") - }); - assert_eq!(**chunk_text, *expected_text); + let chunk_text = scope_value_ref.0.fields[0].as_str().unwrap(); + assert_eq!(*chunk_text, expected_text.into()); } None => panic!("Expected row value for key {key:?}, not found"), } @@ -1207,25 +1210,16 @@ mod tests { ], ) .await; - assert!( - result.is_ok(), - "test_flow_function failed: {:?}", - result.err() - ); let value = result.unwrap(); match value { Value::KTable(table) => { - // Expect multiple chunks, likely split by spaces due to chunk_size. assert!(table.len() > 1); let key = KeyValue::from_single_part(RangeValue::new(0, 16)); match table.get(&key) { Some(scope_value_ref) => { - let chunk_text = - scope_value_ref.0.fields[0].as_str().unwrap_or_else(|_| { - panic!("Chunk text not a string for key {key:?}") - }); - assert_eq!(&**chunk_text, "A very very long"); + let chunk_text = scope_value_ref.0.fields[0].as_str().unwrap(); + assert_eq!(*chunk_text, "A very very long".into()); assert!(chunk_text.len() <= 20); } None => panic!("Expected row value for key {key:?}, not found"), @@ -1259,26 +1253,21 @@ mod tests { ], ) .await; - assert!( - result.is_ok(), - "test_flow_function failed: {:?}", - result.err() - ); let value = result.unwrap(); match value { Value::KTable(table) => { assert!(table.len() > 1); - // Check first chunk length if table.len() >= 2 { let first_key = table.keys().next().unwrap(); match table.get(first_key) { Some(scope_value_ref) => { - let chunk_text = - scope_value_ref.0.fields[0].as_str().unwrap_or_else(|_| { - panic!("Chunk text not a string for key {first_key:?}") - }); - assert!(chunk_text.len() <= 25); + let chunk_text = scope_value_ref.0.fields[0].as_str().unwrap(); + assert!( + chunk_text.len() <= 25, + "Chunk was too long: '{}'", + chunk_text + ); } None => panic!("Expected row value for first key, not found"), }