From 20cc6a0f76a045cbd5a0037c06e3c55614501a46 Mon Sep 17 00:00:00 2001 From: LJ Date: Fri, 14 Mar 2025 00:30:44 -0700 Subject: [PATCH] Extract regexp based chunk logic into a trait - for TreeSitter reuse. --- src/base/value.rs | 9 ++ src/ops/functions/split_recursively.rs | 179 ++++++++++++++++++------- 2 files changed, 136 insertions(+), 52 deletions(-) diff --git a/src/base/value.rs b/src/base/value.rs index c3df94498..21bf12a86 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -20,6 +20,15 @@ impl RangeValue { pub fn new(start: usize, end: usize) -> Self { RangeValue { start, end } } + + pub fn len(&self) -> usize { + self.end - self.start + } + + pub fn extract_str<'s>(&self, s: &'s (impl AsRef + ?Sized)) -> &'s str { + let s = s.as_ref(); + &s[self.start..self.end] + } } impl Serialize for RangeValue { diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 35a302bcf..a7be39c34 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -1,4 +1,4 @@ -use regex::Regex; +use regex::{Matches, Regex}; use std::sync::LazyLock; use std::{collections::HashMap, sync::Arc}; @@ -91,24 +91,100 @@ static SEPARATORS_BY_LANG: LazyLock>> = LazyLoc .collect() }); -struct SplitTask { +trait NestedChunk: Sized { + fn range(&self) -> &RangeValue; + + fn sub_chunks(&self) -> Option>; +} + +struct SplitTarget<'s> { separators: &'static [Regex], + text: &'s str, +} + +struct Chunk<'s> { + target: &'s SplitTarget<'s>, + range: RangeValue, + next_sep_id: usize, +} + +struct SubChunksIter<'a, 's: 'a> { + parent: &'a Chunk<'s>, + matches_iter: Matches<'static, 's>, + next_start_pos: Option, +} + +impl<'a, 's: 'a> SubChunksIter<'a, 's> { + fn new(parent: &'a Chunk<'s>, matches_iter: Matches<'static, 's>) -> Self { + Self { + parent, + matches_iter, + next_start_pos: Some(parent.range.start), + } + } +} + +impl<'a, 's: 'a> Iterator for SubChunksIter<'a, 's> { + type Item = Chunk<'s>; + + fn next(&mut self) -> Option { + if let Some(start_pos) = self.next_start_pos { + let end_pos = match self.matches_iter.next() { + Some(grp) => { + self.next_start_pos = Some(self.parent.range.start + grp.end()); + self.parent.range.start + grp.start() + } + None => { + self.next_start_pos = None; + self.parent.range.end + } + }; + Some(Chunk { + target: self.parent.target, + range: RangeValue::new(start_pos, end_pos), + next_sep_id: self.parent.next_sep_id + 1, + }) + } else { + None + } + } +} + +impl<'s> NestedChunk for Chunk<'s> { + fn range(&self) -> &RangeValue { + &self.range + } + + fn sub_chunks(&self) -> Option> { + if self.next_sep_id >= self.target.separators.len() { + None + } else { + let sub_text = self.range.extract_str(&self.target.text); + Some(SubChunksIter::new( + self, + self.target.separators[self.next_sep_id].find_iter(sub_text), + )) + } + } +} + +struct RecursiveChunker<'s> { + text: &'s str, chunk_size: usize, chunk_overlap: usize, } -impl SplitTask { - fn split_substring<'s>( - &self, - s: &'s str, - base_pos: usize, - next_sep_id: usize, - output: &mut Vec<(RangeValue, &'s str)>, - ) { - if next_sep_id >= self.separators.len() { - self.add_output(base_pos, s, output); +impl<'s> RecursiveChunker<'s> { + fn split_substring(&self, chunk: Chk, output: &mut Vec<(RangeValue, &'s str)>) + where + Chk: NestedChunk, + { + let sub_chunks_iter = if let Some(sub_chunks_iter) = chunk.sub_chunks() { + sub_chunks_iter + } else { + self.add_output(*chunk.range(), output); return; - } + }; let flush_small_chunks = |chunks: &[RangeValue], output: &mut Vec<(RangeValue, &'s str)>| { @@ -119,7 +195,7 @@ impl SplitTask { for i in 1..chunks.len() - 1 { let chunk = &chunks[i]; if chunk.end - start_pos > self.chunk_size { - self.add_output(base_pos + start_pos, &s[start_pos..chunk.end], output); + self.add_output(RangeValue::new(start_pos, chunk.end), output); // Find the new start position, allowing overlap within the threshold. let mut new_start_idx = i + 1; @@ -139,37 +215,27 @@ impl SplitTask { } let last_chunk = &chunks[chunks.len() - 1]; - self.add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output); + self.add_output(RangeValue::new(start_pos, last_chunk.end), output); }; let mut small_chunks = Vec::new(); - let mut process_chunk = - |start: usize, end: usize, output: &mut Vec<(RangeValue, &'s str)>| { - let chunk = &s[start..end]; - if chunk.len() <= self.chunk_size { - small_chunks.push(RangeValue::new(start, start + chunk.len())); - } else { - flush_small_chunks(&small_chunks, output); - small_chunks.clear(); - self.split_substring(chunk, base_pos + start, next_sep_id + 1, output); - } - }; - - let mut next_start_pos = 0; - for cap in self.separators[next_sep_id].find_iter(s) { - process_chunk(next_start_pos, cap.start(), output); - next_start_pos = cap.end(); - } - if next_start_pos < s.len() { - process_chunk(next_start_pos, s.len(), output); + for sub_chunk in sub_chunks_iter { + let sub_range = sub_chunk.range(); + if sub_range.len() <= self.chunk_size { + small_chunks.push(*sub_chunk.range()); + } else { + flush_small_chunks(&small_chunks, output); + small_chunks.clear(); + self.split_substring(sub_chunk, output); + } } - flush_small_chunks(&small_chunks, output); } - fn add_output<'s>(&self, pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) { + fn add_output(&self, range: RangeValue, output: &mut Vec<(RangeValue, &'s str)>) { + let text = range.extract_str(self.text); if !text.trim().is_empty() { - output.push((RangeValue::new(pos, pos + text.len()), text)); + output.push((range, text)); } } } @@ -217,19 +283,9 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator) -> Result { - let task = SplitTask { - separators: self - .args - .language - .value(&input)? - .map(|v| v.as_str()) - .transpose()? - .and_then(|lang| { - SEPARATORS_BY_LANG - .get(lang.to_lowercase().as_str()) - .map(|v| v.as_slice()) - }) - .unwrap_or(DEFAULT_SEPARATORS.as_slice()), + let text = self.args.text.value(&input)?.as_str()?; + let recursive_chunker = RecursiveChunker { + text, chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize, chunk_overlap: self .args @@ -240,9 +296,28 @@ impl SimpleFunctionExecutor for Executor { .unwrap_or(0) as usize, }; - let text = self.args.text.value(&input)?.as_str()?; + let separators = self + .args + .language + .value(&input)? + .map(|v| v.as_str()) + .transpose()? + .and_then(|lang| { + SEPARATORS_BY_LANG + .get(lang.to_lowercase().as_str()) + .map(|v| v.as_slice()) + }) + .unwrap_or(DEFAULT_SEPARATORS.as_slice()); + let mut output = Vec::new(); - task.split_substring(text, 0, 0, &mut output); + recursive_chunker.split_substring( + Chunk { + target: &SplitTarget { separators, text }, + range: RangeValue::new(0, text.len()), + next_sep_id: 0, + }, + &mut output, + ); translate_bytes_to_chars( text,