diff --git a/docs/docs/ops/functions.md b/docs/docs/ops/functions.md index 769017d3f..54e1f3386 100644 --- a/docs/docs/ops/functions.md +++ b/docs/docs/ops/functions.md @@ -26,6 +26,17 @@ Input data: * `text` (type: `str`, required): The text to split. * `chunk_size` (type: `int`, required): The maximum size of each chunk, in bytes. +* `min_chunk_size` (type: `int`, optional): The minimum size of each chunk, in bytes. If not provided, default to `chunk_size / 2`. + + :::note + + `SplitRecursively` will do its best to make the output chunks sized between `min_chunk_size` and `chunk_size`. + However, it's possible that some chunks are smaller than `min_chunk_size` or larger than `chunk_size` in rare cases, e.g. too short input text, or non-splittable large text. + + Please avoid setting `min_chunk_size` to a value too close to `chunk_size`, to leave more rooms for the function to plan the optimal chunking. + + ::: + * `chunk_overlap` (type: `int`, optional): The maximum overlap size between adjacent chunks, in bytes. * `language` (type: `str`, optional): The language of the document. Can be a langauge name (e.g. `Python`, `Javascript`, `Markdown`) or a file extension (e.g. `.py`, `.js`, `.md`). diff --git a/examples/code_embedding/main.py b/examples/code_embedding/main.py index 43d2e9f62..577bc8e66 100644 --- a/examples/code_embedding/main.py +++ b/examples/code_embedding/main.py @@ -27,7 +27,7 @@ def code_to_embedding( @cocoindex.flow_def(name="CodeEmbedding") def code_embedding_flow( flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope -): +) -> None: """ Define an example flow that embeds files into a vector database. """ @@ -46,6 +46,7 @@ def code_embedding_flow( cocoindex.functions.SplitRecursively(), language=file["extension"], chunk_size=1000, + min_chunk_size=300, chunk_overlap=300, ) with file["chunks"].row() as chunk: diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index f90ebbfb3..b10d111ca 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -3,6 +3,7 @@ use log::{error, trace}; use regex::{Matches, Regex}; use std::collections::HashSet; use std::sync::LazyLock; +use std::usize; use std::{collections::HashMap, sync::Arc}; use unicase::UniCase; @@ -11,9 +12,15 @@ use crate::{fields_value, ops::sdk::*}; type Spec = EmptySpec; +const SYNTAX_LEVEL_GAP_COST: usize = 512; +const MISSING_OVERLAP_COST: usize = 512; +const PER_LINE_BREAK_LEVEL_GAP_COST: usize = 64; +const TOO_SMALL_CHUNK_COST: usize = 1048576; + pub struct Args { text: ResolvedOpArg, chunk_size: ResolvedOpArg, + min_chunk_size: Option, chunk_overlap: Option, language: Option, } @@ -317,69 +324,138 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> { } } -struct RecursiveChunker<'s> { - full_text: &'s str, - lang_config: Option<&'s LanguageConfig>, - chunk_size: usize, - chunk_overlap: usize, +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum LineBreakLevel { + Inline, + Newline, + DoubleNewline, } -impl<'t, 's: 't> RecursiveChunker<'s> { - fn flush_small_chunks(&self, chunks: &[RangeValue], output: &mut Vec<(RangeValue, &'s str)>) { - if chunks.is_empty() { - return; +impl LineBreakLevel { + fn ord(self) -> usize { + match self { + LineBreakLevel::Inline => 0, + LineBreakLevel::Newline => 1, + LineBreakLevel::DoubleNewline => 2, } - let mut start_pos = chunks[0].start; - for i in 0..chunks.len() - 1 { - let next_chunk = &chunks[i + 1]; - if next_chunk.end - start_pos > self.chunk_size { - let chunk = &chunks[i]; - self.add_output(RangeValue::new(start_pos, chunk.end), output); - - // Find the new start position, allowing overlap within the threshold. - let mut new_start_idx = i + 1; - while new_start_idx > 0 { - let prev_pos = chunks[new_start_idx - 1].start; - if prev_pos <= start_pos - || chunk.end - prev_pos > self.chunk_overlap - || next_chunk.end - prev_pos > self.chunk_size - { - break; + } +} + +fn line_break_level(c: &str) -> LineBreakLevel { + let mut lb_level = LineBreakLevel::Inline; + let mut iter = c.chars(); + while let Some(c) = iter.next() { + if c == '\n' || c == '\r' { + lb_level = LineBreakLevel::Newline; + while let Some(c2) = iter.next() { + if c2 == '\n' || c2 == '\r' { + if c == c2 { + return LineBreakLevel::DoubleNewline; } - new_start_idx -= 1; + } else { + break; } - start_pos = chunks[new_start_idx].start; } } + } + lb_level +} + +const INLINE_SPACE_CHARS: [char; 2] = [' ', '\t']; + +struct AtomChunk { + range: RangeValue, + boundary_syntax_level: usize, + + internal_lb_level: LineBreakLevel, + boundary_lb_level: LineBreakLevel, +} + +struct AtomChunksCollector<'s> { + full_text: &'s str, + + curr_level: usize, + min_level: usize, + atom_chunks: Vec, +} +impl<'s> AtomChunksCollector<'s> { + fn collect(&mut self, range: RangeValue) { + // Trim trailing whitespaces. + let end_trimmed_text = &self.full_text[range.start..range.end].trim_end(); + if end_trimmed_text.is_empty() { + return; + } + + // Trim leading whitespaces. + let trimmed_text = end_trimmed_text.trim_start(); + let new_start = range.start + (end_trimmed_text.len() - trimmed_text.len()); + let new_end = new_start + trimmed_text.len(); + + // Align to beginning of the line if possible. + let prev_end = self.atom_chunks.last().map_or(0, |chunk| chunk.range.end); + let gap = &self.full_text[prev_end..new_start]; + let boundary_lb_level = line_break_level(gap); + let range = if boundary_lb_level != LineBreakLevel::Inline { + let trimmed_gap = gap.trim_end_matches(INLINE_SPACE_CHARS); + RangeValue::new(prev_end + trimmed_gap.len(), new_end) + } else { + RangeValue::new(new_start, new_end) + }; - let last_chunk = &chunks[chunks.len() - 1]; - self.add_output(RangeValue::new(start_pos, last_chunk.end), output); + self.atom_chunks.push(AtomChunk { + range, + boundary_syntax_level: self.min_level, + internal_lb_level: line_break_level(trimmed_text), + boundary_lb_level, + }); + self.min_level = self.curr_level; } - fn process_sub_chunks( + fn into_atom_chunks(mut self) -> Vec { + self.atom_chunks.push(AtomChunk { + range: RangeValue::new(self.full_text.len(), self.full_text.len()), + boundary_syntax_level: self.min_level, + internal_lb_level: LineBreakLevel::Inline, + boundary_lb_level: LineBreakLevel::DoubleNewline, + }); + self.atom_chunks + } +} + +struct RecursiveChunker<'s> { + full_text: &'s str, + lang_config: Option<&'s LanguageConfig>, + chunk_size: usize, + chunk_overlap: usize, + min_chunk_size: usize, +} + +impl<'t, 's: 't> RecursiveChunker<'s> { + fn collect_atom_chunks_from_iter( &self, sub_chunks_iter: impl Iterator>, - output: &mut Vec<(RangeValue, &'s str)>, + atom_collector: &mut AtomChunksCollector<'s>, ) -> Result<()> { - let mut small_chunks = Vec::new(); + atom_collector.curr_level += 1; for sub_chunk in sub_chunks_iter { - let sub_range = sub_chunk.range; - if sub_range.len() <= self.chunk_size { - small_chunks.push(sub_range); + let range = sub_chunk.range; + if range.len() <= self.min_chunk_size { + atom_collector.collect(range); } else { - self.flush_small_chunks(&small_chunks, output); - small_chunks.clear(); - self.split_substring(sub_chunk, output)?; + self.collect_atom_chunks(sub_chunk, atom_collector)?; } } - self.flush_small_chunks(&small_chunks, output); + atom_collector.curr_level -= 1; + if atom_collector.curr_level < atom_collector.min_level { + atom_collector.min_level = atom_collector.curr_level; + } Ok(()) } - fn split_substring( + fn collect_atom_chunks( &self, chunk: Chunk<'t, 's>, - output: &mut Vec<(RangeValue, &'s str)>, + atom_collector: &mut AtomChunksCollector<'s>, ) -> Result<()> { match chunk.kind { ChunkKind::TreeSitterNode { node } => { @@ -391,68 +467,220 @@ impl<'t, 's: 't> RecursiveChunker<'s> { { let mut cursor = node.walk(); if cursor.goto_first_child() { - self.process_sub_chunks( + self.collect_atom_chunks_from_iter( TreeSitterNodeIter { full_text: self.full_text, cursor: Some(cursor), next_start_pos: node.start_byte(), end_pos: node.end_byte(), }, - output, + atom_collector, )?; return Ok(()); } } - self.add_output(chunk.range, output); } ChunkKind::RegexpSepChunk { next_regexp_sep_id } => { if next_regexp_sep_id >= TEXT_SEPARATOR.len() { - self.add_output(chunk.range, output); + atom_collector.collect(chunk.range); } else { - self.process_sub_chunks( + self.collect_atom_chunks_from_iter( TextChunksIter::new(&chunk, next_regexp_sep_id), - output, + atom_collector, )?; } + return Ok(()); } - ChunkKind::LeafText => { - self.add_output(chunk.range, output); - } + ChunkKind::LeafText => {} + } + if chunk.range.len() > self.chunk_size { + self.collect_atom_chunks( + Chunk { + full_text: self.full_text, + range: chunk.range, + kind: ChunkKind::RegexpSepChunk { + next_regexp_sep_id: 0, + }, + }, + atom_collector, + )?; + } else { + atom_collector.collect(chunk.range); } Ok(()) } - fn split_root_chunk(&self, kind: ChunkKind<'t>) -> Result> { + fn get_overlap_cost_base(&self, offset: usize) -> usize { + if self.chunk_overlap == 0 { + 0 + } else { + (self.full_text.len() - offset) * MISSING_OVERLAP_COST / self.chunk_overlap + } + } + + fn merge_atom_chunks(&self, atom_chunks: Vec) -> Vec<(RangeValue, &'s str)> { + struct AtomRoutingPlan { + start_idx: usize, // index of `atom_chunks` for the start chunk + prev_plan_idx: usize, // index of `plans` for the previous plan + cost: usize, + overlap_cost_base: usize, + } + type PrevPlanCandidate = (std::cmp::Reverse, usize); // (cost, start_idx) + + let mut plans = Vec::with_capacity(atom_chunks.len()); + // Janitor + plans.push(AtomRoutingPlan { + start_idx: 0, + prev_plan_idx: 0, + cost: 0, + overlap_cost_base: self.get_overlap_cost_base(0), + }); + let mut prev_plan_candidates = std::collections::BinaryHeap::::new(); + + let mut gap_cost_cache = vec![0]; + let mut syntax_level_gap_cost = |boundary: usize, internal: usize| -> usize { + if boundary > internal { + let gap = boundary - internal; + for i in gap_cost_cache.len()..=gap { + gap_cost_cache.push(gap_cost_cache[i - 1] + SYNTAX_LEVEL_GAP_COST / i); + } + gap_cost_cache[gap] + } else { + 0 + } + }; + + for (i, chunk) in (&atom_chunks[0..atom_chunks.len() - 1]).iter().enumerate() { + let mut min_cost = usize::MAX; + let mut arg_min_start_idx: usize = 0; + let mut arg_min_prev_plan_idx: usize = 0; + let mut start_idx = i; + + let end_syntax_level = atom_chunks[i + 1].boundary_syntax_level; + let end_lb_level = atom_chunks[i + 1].boundary_lb_level; + + let mut internal_syntax_level = usize::MAX; + let mut internal_lb_level = LineBreakLevel::Inline; + + fn lb_level_gap(boundary: LineBreakLevel, internal: LineBreakLevel) -> usize { + if boundary.ord() < internal.ord() { + internal.ord() - boundary.ord() + } else { + 0 + } + } + loop { + let start_chunk = &atom_chunks[start_idx]; + let chunk_size = chunk.range.end - start_chunk.range.start; + + let mut cost = 0; + cost += + syntax_level_gap_cost(start_chunk.boundary_syntax_level, internal_syntax_level); + cost += syntax_level_gap_cost(end_syntax_level, internal_syntax_level); + cost += (lb_level_gap(start_chunk.boundary_lb_level, internal_lb_level) + + lb_level_gap(end_lb_level, internal_lb_level)) + * PER_LINE_BREAK_LEVEL_GAP_COST; + if chunk_size < self.min_chunk_size { + cost += TOO_SMALL_CHUNK_COST; + } + + if chunk_size > self.chunk_size { + if min_cost == usize::MAX { + min_cost = cost + plans[start_idx].cost; + arg_min_start_idx = start_idx; + arg_min_prev_plan_idx = start_idx; + } + break; + } + + let prev_plan_idx = if self.chunk_overlap > 0 { + while let Some(top_prev_plan) = prev_plan_candidates.peek() { + let overlap_size = + atom_chunks[top_prev_plan.1].range.end - start_chunk.range.start; + if overlap_size <= self.chunk_overlap { + break; + } + prev_plan_candidates.pop(); + } + prev_plan_candidates.push(( + std::cmp::Reverse( + plans[start_idx].cost + plans[start_idx].overlap_cost_base, + ), + start_idx, + )); + prev_plan_candidates.peek().unwrap().1 + } else { + start_idx + }; + let prev_plan = &plans[prev_plan_idx]; + cost += prev_plan.cost; + if self.chunk_overlap == 0 { + cost += MISSING_OVERLAP_COST / 2; + } else { + let start_cost_base = self.get_overlap_cost_base(start_chunk.range.start); + cost += if prev_plan.overlap_cost_base < start_cost_base { + MISSING_OVERLAP_COST + prev_plan.overlap_cost_base - start_cost_base + } else { + MISSING_OVERLAP_COST + }; + } + if cost < min_cost { + min_cost = cost; + arg_min_start_idx = start_idx; + arg_min_prev_plan_idx = prev_plan_idx; + } + + if start_idx == 0 { + break; + } + + start_idx -= 1; + internal_syntax_level = + internal_syntax_level.min(start_chunk.boundary_syntax_level); + internal_lb_level = internal_lb_level.max(start_chunk.internal_lb_level); + } + plans.push(AtomRoutingPlan { + start_idx: arg_min_start_idx, + prev_plan_idx: arg_min_prev_plan_idx, + cost: min_cost, + overlap_cost_base: self.get_overlap_cost_base(chunk.range.end), + }); + prev_plan_candidates.clear(); + } + let mut output = Vec::new(); - self.split_substring( + let mut plan_idx = plans.len() - 1; + while plan_idx > 0 { + let plan = &plans[plan_idx]; + let start_chunk = &atom_chunks[plan.start_idx]; + let end_chunk = &atom_chunks[plan_idx - 1]; + let range = RangeValue::new(start_chunk.range.start, end_chunk.range.end); + output.push((range, &self.full_text[range.start..range.end])); + plan_idx = plan.prev_plan_idx; + } + output.reverse(); + output + } + + fn split_root_chunk(&self, kind: ChunkKind<'t>) -> Result> { + let mut atom_collector = AtomChunksCollector { + full_text: self.full_text, + min_level: 0, + curr_level: 0, + atom_chunks: Vec::new(), + }; + self.collect_atom_chunks( Chunk { full_text: self.full_text, range: RangeValue::new(0, self.full_text.len()), kind, }, - &mut output, + &mut atom_collector, )?; + let atom_chunks = atom_collector.into_atom_chunks(); + let output = self.merge_atom_chunks(atom_chunks); Ok(output) } - - fn add_output(&self, range: RangeValue, output: &mut Vec<(RangeValue, &'s str)>) { - let text = range.extract_str(self.full_text); - - // Trim leading new lines. - let trimmed_text = text.trim_start_matches(['\n', '\r']); - let adjusted_start = range.start + (text.len() - trimmed_text.len()); - - // Trim trailing whitespaces - let trimmed_text = trimmed_text.trim_end(); - - // Only record chunks with alphanumeric characters. - if trimmed_text.chars().any(|ch| ch.is_alphanumeric()) { - output.push(( - RangeValue::new(adjusted_start, adjusted_start + trimmed_text.len()), - trimmed_text, - )); - } - } } struct Executor { @@ -508,19 +736,21 @@ impl SimpleFunctionExecutor for Executor { .transpose()? .and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(&UniCase::new(lang))) }; - + let chunk_size = self.args.chunk_size.value(&input)?.as_int64()?; let recursive_chunker = RecursiveChunker { full_text, lang_config: lang_config.map(|c| c.as_ref()), - chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize, - chunk_overlap: self - .args - .chunk_overlap - .value(&input)? + chunk_size: chunk_size as usize, + chunk_overlap: (self.args.chunk_overlap.value(&input)?) .optional() .map(|v| v.as_int64()) .transpose()? .unwrap_or(0) as usize, + min_chunk_size: (self.args.min_chunk_size.value(&input)?) + .optional() + .map(|v| v.as_int64()) + .transpose()? + .unwrap_or(chunk_size / 2) as usize, }; let mut output = if let Some(lang_config) = lang_config { @@ -578,6 +808,9 @@ impl SimpleFunctionFactoryBase for Factory { chunk_size: args_resolver .next_arg("chunk_size")? .expect_type(&ValueType::Basic(BasicValueType::Int64))?, + min_chunk_size: args_resolver + .next_optional_arg("min_chunk_size")? + .expect_type(&ValueType::Basic(BasicValueType::Int64))?, chunk_overlap: args_resolver .next_optional_arg("chunk_overlap")? .expect_type(&ValueType::Basic(BasicValueType::Int64))?, @@ -645,6 +878,7 @@ mod tests { fn create_test_chunker( text: &str, chunk_size: usize, + min_chunk_size: usize, chunk_overlap: usize, ) -> RecursiveChunker { RecursiveChunker { @@ -652,6 +886,7 @@ mod tests { lang_config: None, chunk_size, chunk_overlap, + min_chunk_size, } } @@ -690,7 +925,7 @@ mod tests { #[test] fn test_basic_split_no_overlap() { let text = "Linea 1.\nLinea 2.\n\nLinea 3."; - let chunker = create_test_chunker(text, 15, 0); + let chunker = create_test_chunker(text, 15, 5, 0); let result = chunker.split_root_chunk(ChunkKind::RegexpSepChunk { next_regexp_sep_id: 0, @@ -706,7 +941,7 @@ mod tests { // Test splitting when chunk_size forces breaks within segments. let text2 = "A very very long text that needs to be split."; - let chunker2 = create_test_chunker(text2, 20, 0); + let chunker2 = create_test_chunker(text2, 20, 12, 0); let result2 = chunker2.split_root_chunk(ChunkKind::RegexpSepChunk { next_regexp_sep_id: 0, }); @@ -722,7 +957,7 @@ mod tests { #[test] fn test_basic_split_with_overlap() { let text = "This is a test text that is a bit longer to see how the overlap works."; - let chunker = create_test_chunker(text, 20, 5); + let chunker = create_test_chunker(text, 20, 10, 5); let result = chunker.split_root_chunk(ChunkKind::RegexpSepChunk { next_regexp_sep_id: 0, @@ -743,7 +978,7 @@ mod tests { #[test] fn test_split_trims_whitespace() { let text = " \n First chunk. \n\n Second chunk with spaces at the end. \n"; - let chunker = create_test_chunker(text, 30, 0); + let chunker = create_test_chunker(text, 30, 10, 0); let result = chunker.split_root_chunk(ChunkKind::RegexpSepChunk { next_regexp_sep_id: 0, @@ -757,34 +992,15 @@ mod tests { assert_chunk_text_consistency( text, &chunks[0], - " \n First chunk.", + " First chunk.", "Whitespace Test, Chunk 0", ); assert_chunk_text_consistency( text, &chunks[1], - " Second chunk with spaces at", + " Second chunk with spaces", "Whitespace Test, Chunk 1", ); - assert_chunk_text_consistency(text, &chunks[2], "the end.", "Whitespace Test, Chunk 2"); - } - #[test] - fn test_split_discards_empty_chunks() { - let text = "Chunk 1.\n\n \n\nChunk 2.\n\n------\n\nChunk 3."; - let chunker = create_test_chunker(text, 10, 0); - - let result = chunker.split_root_chunk(ChunkKind::RegexpSepChunk { - next_regexp_sep_id: 0, - }); - - assert!(result.is_ok()); - let chunks = result.unwrap(); - - assert_eq!(chunks.len(), 3); - - // Expect only the chunks with actual alphanumeric content. - assert_chunk_text_consistency(text, &chunks[0], "Chunk 1.", "Discard Test, Chunk 0"); - assert_chunk_text_consistency(text, &chunks[1], "Chunk 2.", "Discard Test, Chunk 1"); - assert_chunk_text_consistency(text, &chunks[2], "Chunk 3.", "Discard Test, Chunk 2"); + assert_chunk_text_consistency(text, &chunks[2], "at the end.", "Whitespace Test, Chunk 2"); } }