Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/base/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ impl RangeValue {
pub fn new(start: usize, end: usize) -> Self {
RangeValue { start, end }
}

pub fn len(&self) -> usize {
self.end - self.start
}

pub fn extract_str<'s>(&self, s: &'s (impl AsRef<str> + ?Sized)) -> &'s str {
let s = s.as_ref();
&s[self.start..self.end]
}
}

impl Serialize for RangeValue {
Expand Down
179 changes: 127 additions & 52 deletions src/ops/functions/split_recursively.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use regex::Regex;
use regex::{Matches, Regex};
use std::sync::LazyLock;
use std::{collections::HashMap, sync::Arc};

Expand Down Expand Up @@ -91,24 +91,100 @@ static SEPARATORS_BY_LANG: LazyLock<HashMap<&'static str, Vec<Regex>>> = LazyLoc
.collect()
});

struct SplitTask {
trait NestedChunk: Sized {
fn range(&self) -> &RangeValue;

fn sub_chunks(&self) -> Option<impl Iterator<Item = Self>>;
}

struct SplitTarget<'s> {
separators: &'static [Regex],
text: &'s str,
}

struct Chunk<'s> {
target: &'s SplitTarget<'s>,
range: RangeValue,
next_sep_id: usize,
}

struct SubChunksIter<'a, 's: 'a> {
parent: &'a Chunk<'s>,
matches_iter: Matches<'static, 's>,
next_start_pos: Option<usize>,
}

impl<'a, 's: 'a> SubChunksIter<'a, 's> {
fn new(parent: &'a Chunk<'s>, matches_iter: Matches<'static, 's>) -> Self {
Self {
parent,
matches_iter,
next_start_pos: Some(parent.range.start),
}
}
}

impl<'a, 's: 'a> Iterator for SubChunksIter<'a, 's> {
type Item = Chunk<'s>;

fn next(&mut self) -> Option<Self::Item> {
if let Some(start_pos) = self.next_start_pos {
let end_pos = match self.matches_iter.next() {
Some(grp) => {
self.next_start_pos = Some(self.parent.range.start + grp.end());
self.parent.range.start + grp.start()
}
None => {
self.next_start_pos = None;
self.parent.range.end
}
};
Some(Chunk {
target: self.parent.target,
range: RangeValue::new(start_pos, end_pos),
next_sep_id: self.parent.next_sep_id + 1,
})
} else {
None
}
}
}

impl<'s> NestedChunk for Chunk<'s> {
fn range(&self) -> &RangeValue {
&self.range
}

fn sub_chunks(&self) -> Option<impl Iterator<Item = Self>> {
if self.next_sep_id >= self.target.separators.len() {
None
} else {
let sub_text = self.range.extract_str(&self.target.text);
Some(SubChunksIter::new(
self,
self.target.separators[self.next_sep_id].find_iter(sub_text),
))
}
}
}

struct RecursiveChunker<'s> {
text: &'s str,
chunk_size: usize,
chunk_overlap: usize,
}

impl SplitTask {
fn split_substring<'s>(
&self,
s: &'s str,
base_pos: usize,
next_sep_id: usize,
output: &mut Vec<(RangeValue, &'s str)>,
) {
if next_sep_id >= self.separators.len() {
self.add_output(base_pos, s, output);
impl<'s> RecursiveChunker<'s> {
fn split_substring<Chk>(&self, chunk: Chk, output: &mut Vec<(RangeValue, &'s str)>)
where
Chk: NestedChunk,
{
let sub_chunks_iter = if let Some(sub_chunks_iter) = chunk.sub_chunks() {
sub_chunks_iter
} else {
self.add_output(*chunk.range(), output);
return;
}
};

let flush_small_chunks =
|chunks: &[RangeValue], output: &mut Vec<(RangeValue, &'s str)>| {
Expand All @@ -119,7 +195,7 @@ impl SplitTask {
for i in 1..chunks.len() - 1 {
let chunk = &chunks[i];
if chunk.end - start_pos > self.chunk_size {
self.add_output(base_pos + start_pos, &s[start_pos..chunk.end], output);
self.add_output(RangeValue::new(start_pos, chunk.end), output);

// Find the new start position, allowing overlap within the threshold.
let mut new_start_idx = i + 1;
Expand All @@ -139,37 +215,27 @@ impl SplitTask {
}

let last_chunk = &chunks[chunks.len() - 1];
self.add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output);
self.add_output(RangeValue::new(start_pos, last_chunk.end), output);
};

let mut small_chunks = Vec::new();
let mut process_chunk =
|start: usize, end: usize, output: &mut Vec<(RangeValue, &'s str)>| {
let chunk = &s[start..end];
if chunk.len() <= self.chunk_size {
small_chunks.push(RangeValue::new(start, start + chunk.len()));
} else {
flush_small_chunks(&small_chunks, output);
small_chunks.clear();
self.split_substring(chunk, base_pos + start, next_sep_id + 1, output);
}
};

let mut next_start_pos = 0;
for cap in self.separators[next_sep_id].find_iter(s) {
process_chunk(next_start_pos, cap.start(), output);
next_start_pos = cap.end();
}
if next_start_pos < s.len() {
process_chunk(next_start_pos, s.len(), output);
for sub_chunk in sub_chunks_iter {
let sub_range = sub_chunk.range();
if sub_range.len() <= self.chunk_size {
small_chunks.push(*sub_chunk.range());
} else {
flush_small_chunks(&small_chunks, output);
small_chunks.clear();
self.split_substring(sub_chunk, output);
}
}

flush_small_chunks(&small_chunks, output);
}

fn add_output<'s>(&self, pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) {
fn add_output(&self, range: RangeValue, output: &mut Vec<(RangeValue, &'s str)>) {
let text = range.extract_str(self.text);
if !text.trim().is_empty() {
output.push((RangeValue::new(pos, pos + text.len()), text));
output.push((range, text));
}
}
}
Expand Down Expand Up @@ -217,19 +283,9 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
#[async_trait]
impl SimpleFunctionExecutor for Executor {
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let task = SplitTask {
separators: self
.args
.language
.value(&input)?
.map(|v| v.as_str())
.transpose()?
.and_then(|lang| {
SEPARATORS_BY_LANG
.get(lang.to_lowercase().as_str())
.map(|v| v.as_slice())
})
.unwrap_or(DEFAULT_SEPARATORS.as_slice()),
let text = self.args.text.value(&input)?.as_str()?;
let recursive_chunker = RecursiveChunker {
text,
chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize,
chunk_overlap: self
.args
Expand All @@ -240,9 +296,28 @@ impl SimpleFunctionExecutor for Executor {
.unwrap_or(0) as usize,
};

let text = self.args.text.value(&input)?.as_str()?;
let separators = self
.args
.language
.value(&input)?
.map(|v| v.as_str())
.transpose()?
.and_then(|lang| {
SEPARATORS_BY_LANG
.get(lang.to_lowercase().as_str())
.map(|v| v.as_slice())
})
.unwrap_or(DEFAULT_SEPARATORS.as_slice());

let mut output = Vec::new();
task.split_substring(text, 0, 0, &mut output);
recursive_chunker.split_substring(
Chunk {
target: &SplitTarget { separators, text },
range: RangeValue::new(0, text.len()),
next_sep_id: 0,
},
&mut output,
);

translate_bytes_to_chars(
text,
Expand Down