Skip to content

Commit 20cc6a0

Browse files
committed
Extract regexp based chunk logic into a trait - for TreeSitter reuse.
1 parent 6953931 commit 20cc6a0

File tree

2 files changed

+136
-52
lines changed

2 files changed

+136
-52
lines changed

src/base/value.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ impl RangeValue {
2020
pub fn new(start: usize, end: usize) -> Self {
2121
RangeValue { start, end }
2222
}
23+
24+
pub fn len(&self) -> usize {
25+
self.end - self.start
26+
}
27+
28+
pub fn extract_str<'s>(&self, s: &'s (impl AsRef<str> + ?Sized)) -> &'s str {
29+
let s = s.as_ref();
30+
&s[self.start..self.end]
31+
}
2332
}
2433

2534
impl Serialize for RangeValue {

src/ops/functions/split_recursively.rs

Lines changed: 127 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use regex::Regex;
1+
use regex::{Matches, Regex};
22
use std::sync::LazyLock;
33
use std::{collections::HashMap, sync::Arc};
44

@@ -91,24 +91,100 @@ static SEPARATORS_BY_LANG: LazyLock<HashMap<&'static str, Vec<Regex>>> = LazyLoc
9191
.collect()
9292
});
9393

94-
struct SplitTask {
94+
trait NestedChunk: Sized {
95+
fn range(&self) -> &RangeValue;
96+
97+
fn sub_chunks(&self) -> Option<impl Iterator<Item = Self>>;
98+
}
99+
100+
struct SplitTarget<'s> {
95101
separators: &'static [Regex],
102+
text: &'s str,
103+
}
104+
105+
struct Chunk<'s> {
106+
target: &'s SplitTarget<'s>,
107+
range: RangeValue,
108+
next_sep_id: usize,
109+
}
110+
111+
struct SubChunksIter<'a, 's: 'a> {
112+
parent: &'a Chunk<'s>,
113+
matches_iter: Matches<'static, 's>,
114+
next_start_pos: Option<usize>,
115+
}
116+
117+
impl<'a, 's: 'a> SubChunksIter<'a, 's> {
118+
fn new(parent: &'a Chunk<'s>, matches_iter: Matches<'static, 's>) -> Self {
119+
Self {
120+
parent,
121+
matches_iter,
122+
next_start_pos: Some(parent.range.start),
123+
}
124+
}
125+
}
126+
127+
impl<'a, 's: 'a> Iterator for SubChunksIter<'a, 's> {
128+
type Item = Chunk<'s>;
129+
130+
fn next(&mut self) -> Option<Self::Item> {
131+
if let Some(start_pos) = self.next_start_pos {
132+
let end_pos = match self.matches_iter.next() {
133+
Some(grp) => {
134+
self.next_start_pos = Some(self.parent.range.start + grp.end());
135+
self.parent.range.start + grp.start()
136+
}
137+
None => {
138+
self.next_start_pos = None;
139+
self.parent.range.end
140+
}
141+
};
142+
Some(Chunk {
143+
target: self.parent.target,
144+
range: RangeValue::new(start_pos, end_pos),
145+
next_sep_id: self.parent.next_sep_id + 1,
146+
})
147+
} else {
148+
None
149+
}
150+
}
151+
}
152+
153+
impl<'s> NestedChunk for Chunk<'s> {
154+
fn range(&self) -> &RangeValue {
155+
&self.range
156+
}
157+
158+
fn sub_chunks(&self) -> Option<impl Iterator<Item = Self>> {
159+
if self.next_sep_id >= self.target.separators.len() {
160+
None
161+
} else {
162+
let sub_text = self.range.extract_str(&self.target.text);
163+
Some(SubChunksIter::new(
164+
self,
165+
self.target.separators[self.next_sep_id].find_iter(sub_text),
166+
))
167+
}
168+
}
169+
}
170+
171+
struct RecursiveChunker<'s> {
172+
text: &'s str,
96173
chunk_size: usize,
97174
chunk_overlap: usize,
98175
}
99176

100-
impl SplitTask {
101-
fn split_substring<'s>(
102-
&self,
103-
s: &'s str,
104-
base_pos: usize,
105-
next_sep_id: usize,
106-
output: &mut Vec<(RangeValue, &'s str)>,
107-
) {
108-
if next_sep_id >= self.separators.len() {
109-
self.add_output(base_pos, s, output);
177+
impl<'s> RecursiveChunker<'s> {
178+
fn split_substring<Chk>(&self, chunk: Chk, output: &mut Vec<(RangeValue, &'s str)>)
179+
where
180+
Chk: NestedChunk,
181+
{
182+
let sub_chunks_iter = if let Some(sub_chunks_iter) = chunk.sub_chunks() {
183+
sub_chunks_iter
184+
} else {
185+
self.add_output(*chunk.range(), output);
110186
return;
111-
}
187+
};
112188

113189
let flush_small_chunks =
114190
|chunks: &[RangeValue], output: &mut Vec<(RangeValue, &'s str)>| {
@@ -119,7 +195,7 @@ impl SplitTask {
119195
for i in 1..chunks.len() - 1 {
120196
let chunk = &chunks[i];
121197
if chunk.end - start_pos > self.chunk_size {
122-
self.add_output(base_pos + start_pos, &s[start_pos..chunk.end], output);
198+
self.add_output(RangeValue::new(start_pos, chunk.end), output);
123199

124200
// Find the new start position, allowing overlap within the threshold.
125201
let mut new_start_idx = i + 1;
@@ -139,37 +215,27 @@ impl SplitTask {
139215
}
140216

141217
let last_chunk = &chunks[chunks.len() - 1];
142-
self.add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output);
218+
self.add_output(RangeValue::new(start_pos, last_chunk.end), output);
143219
};
144220

145221
let mut small_chunks = Vec::new();
146-
let mut process_chunk =
147-
|start: usize, end: usize, output: &mut Vec<(RangeValue, &'s str)>| {
148-
let chunk = &s[start..end];
149-
if chunk.len() <= self.chunk_size {
150-
small_chunks.push(RangeValue::new(start, start + chunk.len()));
151-
} else {
152-
flush_small_chunks(&small_chunks, output);
153-
small_chunks.clear();
154-
self.split_substring(chunk, base_pos + start, next_sep_id + 1, output);
155-
}
156-
};
157-
158-
let mut next_start_pos = 0;
159-
for cap in self.separators[next_sep_id].find_iter(s) {
160-
process_chunk(next_start_pos, cap.start(), output);
161-
next_start_pos = cap.end();
162-
}
163-
if next_start_pos < s.len() {
164-
process_chunk(next_start_pos, s.len(), output);
222+
for sub_chunk in sub_chunks_iter {
223+
let sub_range = sub_chunk.range();
224+
if sub_range.len() <= self.chunk_size {
225+
small_chunks.push(*sub_chunk.range());
226+
} else {
227+
flush_small_chunks(&small_chunks, output);
228+
small_chunks.clear();
229+
self.split_substring(sub_chunk, output);
230+
}
165231
}
166-
167232
flush_small_chunks(&small_chunks, output);
168233
}
169234

170-
fn add_output<'s>(&self, pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) {
235+
fn add_output(&self, range: RangeValue, output: &mut Vec<(RangeValue, &'s str)>) {
236+
let text = range.extract_str(self.text);
171237
if !text.trim().is_empty() {
172-
output.push((RangeValue::new(pos, pos + text.len()), text));
238+
output.push((range, text));
173239
}
174240
}
175241
}
@@ -217,19 +283,9 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
217283
#[async_trait]
218284
impl SimpleFunctionExecutor for Executor {
219285
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
220-
let task = SplitTask {
221-
separators: self
222-
.args
223-
.language
224-
.value(&input)?
225-
.map(|v| v.as_str())
226-
.transpose()?
227-
.and_then(|lang| {
228-
SEPARATORS_BY_LANG
229-
.get(lang.to_lowercase().as_str())
230-
.map(|v| v.as_slice())
231-
})
232-
.unwrap_or(DEFAULT_SEPARATORS.as_slice()),
286+
let text = self.args.text.value(&input)?.as_str()?;
287+
let recursive_chunker = RecursiveChunker {
288+
text,
233289
chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize,
234290
chunk_overlap: self
235291
.args
@@ -240,9 +296,28 @@ impl SimpleFunctionExecutor for Executor {
240296
.unwrap_or(0) as usize,
241297
};
242298

243-
let text = self.args.text.value(&input)?.as_str()?;
299+
let separators = self
300+
.args
301+
.language
302+
.value(&input)?
303+
.map(|v| v.as_str())
304+
.transpose()?
305+
.and_then(|lang| {
306+
SEPARATORS_BY_LANG
307+
.get(lang.to_lowercase().as_str())
308+
.map(|v| v.as_slice())
309+
})
310+
.unwrap_or(DEFAULT_SEPARATORS.as_slice());
311+
244312
let mut output = Vec::new();
245-
task.split_substring(text, 0, 0, &mut output);
313+
recursive_chunker.split_substring(
314+
Chunk {
315+
target: &SplitTarget { separators, text },
316+
range: RangeValue::new(0, text.len()),
317+
next_sep_id: 0,
318+
},
319+
&mut output,
320+
);
246321

247322
translate_bytes_to_chars(
248323
text,

0 commit comments

Comments
 (0)