Skip to content

Commit 95df05c

Browse files
Refactor: Apply review suggestions for SplitRecursively
1 parent e390cbc commit 95df05c

File tree

1 file changed

+91
-66
lines changed

1 file changed

+91
-66
lines changed

src/ops/functions/split_recursively.rs

Lines changed: 91 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,7 @@ struct Chunk<'t, 's: 't> {
261261

262262
struct TextChunksIter<'t, 's: 't> {
263263
lang_config: &'t SimpleLanguageConfig,
264-
full_text: &'s str,
265-
range: RangeValue,
264+
parent: &'t Chunk<'t, 's>,
266265
matches_iter: Matches<'t, 's>,
267266
regexp_sep_id: usize,
268267
next_start_pos: Option<usize>,
@@ -271,19 +270,16 @@ struct TextChunksIter<'t, 's: 't> {
271270
impl<'t, 's: 't> TextChunksIter<'t, 's> {
272271
fn new(
273272
lang_config: &'t SimpleLanguageConfig,
274-
full_text: &'s str,
275-
range: RangeValue,
273+
parent: &'t Chunk<'t, 's>,
276274
regexp_sep_id: usize,
277275
) -> Self {
278-
let std_range = range.start..range.end;
279276
Self {
280277
lang_config,
281-
full_text,
282-
range,
278+
parent,
283279
matches_iter: lang_config.separator_regex[regexp_sep_id]
284-
.find_iter(&full_text[std_range.clone()]),
280+
.find_iter(&parent.full_text[parent.range.start..parent.range.end]),
285281
regexp_sep_id,
286-
next_start_pos: Some(std_range.start),
282+
next_start_pos: Some(parent.range.start),
287283
}
288284
}
289285
}
@@ -295,19 +291,19 @@ impl<'t, 's: 't> Iterator for TextChunksIter<'t, 's> {
295291
let start_pos = self.next_start_pos?;
296292
let end_pos = match self.matches_iter.next() {
297293
Some(grp) => {
298-
self.next_start_pos = Some(self.range.start + grp.end());
299-
self.range.start + grp.start()
294+
self.next_start_pos = Some(self.parent.range.start + grp.end());
295+
self.parent.range.start + grp.start()
300296
}
301297
None => {
302298
self.next_start_pos = None;
303-
if start_pos >= self.range.end {
299+
if start_pos >= self.parent.range.end {
304300
return None;
305301
}
306-
self.range.end
302+
self.parent.range.end
307303
}
308304
};
309305
Some(Chunk {
310-
full_text: self.full_text,
306+
full_text: self.parent.full_text,
311307
range: RangeValue::new(start_pos, end_pos),
312308
kind: ChunkKind::RegexpSepChunk {
313309
lang_config: self.lang_config,
@@ -378,6 +374,24 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> {
378374
}
379375
}
380376

377+
enum ChunkIterator<'t, 's: 't> {
378+
TreeSitter(TreeSitterNodeIter<'t, 's>),
379+
Text(TextChunksIter<'t, 's>),
380+
Once(std::iter::Once<Chunk<'t, 's>>),
381+
}
382+
383+
impl<'t, 's: 't> Iterator for ChunkIterator<'t, 's> {
384+
type Item = Chunk<'t, 's>;
385+
386+
fn next(&mut self) -> Option<Self::Item> {
387+
match self {
388+
ChunkIterator::TreeSitter(iter) => iter.next(),
389+
ChunkIterator::Text(iter) => iter.next(),
390+
ChunkIterator::Once(iter) => iter.next(),
391+
}
392+
}
393+
}
394+
381395
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
382396
enum LineBreakLevel {
383397
Inline,
@@ -434,8 +448,7 @@ struct AtomChunksCollector<'s> {
434448
impl<'s> AtomChunksCollector<'s> {
435449
fn collect(&mut self, range: RangeValue) {
436450
// Trim trailing whitespaces.
437-
let std_range = range.start..range.end;
438-
let end_trimmed_text = &self.full_text[std_range].trim_end();
451+
let end_trimmed_text = &self.full_text[range.start..range.end].trim_end();
439452
if end_trimmed_text.is_empty() {
440453
return;
441454
}
@@ -449,7 +462,7 @@ impl<'s> AtomChunksCollector<'s> {
449462
let prev_end = self.atom_chunks.last().map_or(0, |chunk| chunk.range.end);
450463
let gap = &self.full_text[prev_end..new_start];
451464
let boundary_lb_level = line_break_level(gap);
452-
let range: RangeValue = if boundary_lb_level != LineBreakLevel::Inline {
465+
let range = if boundary_lb_level != LineBreakLevel::Inline {
453466
let trimmed_gap = gap.trim_end_matches(INLINE_SPACE_CHARS);
454467
RangeValue::new(prev_end + trimmed_gap.len(), new_end)
455468
} else {
@@ -495,8 +508,8 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
495508
chunk: Chunk<'t, 's>,
496509
atom_collector: &mut AtomChunksCollector<'s>,
497510
) -> Result<()> {
498-
let mut iter_stack: Vec<Box<dyn Iterator<Item = Chunk<'t, 's>>>> =
499-
vec![Box::new(std::iter::once(chunk))];
511+
let mut iter_stack: Vec<ChunkIterator<'t, 's>> =
512+
vec![ChunkIterator::Once(std::iter::once(chunk))];
500513

501514
while !iter_stack.is_empty() {
502515
atom_collector.curr_level = iter_stack.len();
@@ -510,17 +523,19 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
510523
if !lang_config.terminal_node_kind_ids.contains(&node.kind_id()) {
511524
let mut cursor = node.walk();
512525
if cursor.goto_first_child() {
513-
iter_stack.push(Box::new(TreeSitterNodeIter {
514-
lang_config,
515-
full_text: self.full_text,
516-
cursor: Some(cursor),
517-
next_start_pos: node.start_byte(),
518-
end_pos: node.end_byte(),
519-
}));
526+
iter_stack.push(ChunkIterator::TreeSitter(
527+
TreeSitterNodeIter {
528+
lang_config,
529+
full_text: self.full_text,
530+
cursor: Some(cursor),
531+
next_start_pos: node.start_byte(),
532+
end_pos: node.end_byte(),
533+
},
534+
));
520535
continue;
521536
}
522537
}
523-
iter_stack.push(Box::new(std::iter::once(Chunk {
538+
iter_stack.push(ChunkIterator::Once(std::iter::once(Chunk {
524539
full_text: self.full_text,
525540
range: current_chunk.range,
526541
kind: ChunkKind::RegexpSepChunk {
@@ -536,10 +551,9 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
536551
if next_regexp_sep_id >= lang_config.separator_regex.len() {
537552
atom_collector.collect(current_chunk.range);
538553
} else {
539-
iter_stack.push(Box::new(TextChunksIter::new(
554+
iter_stack.push(ChunkIterator::Text(TextChunksIter::new(
540555
lang_config,
541-
current_chunk.full_text,
542-
current_chunk.range,
556+
&current_chunk, // Pass reference to chunk
543557
next_regexp_sep_id,
544558
)));
545559
}
@@ -559,12 +573,22 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
559573
Ok(())
560574
}
561575

576+
fn get_overlap_cost_base(&self, offset: usize) -> usize {
577+
if self.chunk_overlap == 0 {
578+
0
579+
} else {
580+
(self.full_text.len() - offset) * MISSING_OVERLAP_COST / self.chunk_overlap
581+
}
582+
}
583+
562584
fn merge_atom_chunks(&self, atom_chunks: Vec<AtomChunk>) -> Vec<ChunkOutput<'s>> {
563585
struct AtomRoutingPlan {
564586
start_idx: usize,
565587
prev_plan_idx: usize,
566588
cost: usize,
589+
overlap_cost_base: usize,
567590
}
591+
type PrevPlanCandidate = (std::cmp::Reverse<usize>, usize); // (cost, start_idx)
568592

569593
if atom_chunks.is_empty() || atom_chunks.len() == 1 {
570594
return Vec::new();
@@ -575,7 +599,9 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
575599
start_idx: 0,
576600
prev_plan_idx: 0,
577601
cost: 0,
602+
overlap_cost_base: self.get_overlap_cost_base(0),
578603
});
604+
let mut prev_plan_candidates = std::collections::BinaryHeap::<PrevPlanCandidate>::new();
579605

580606
let mut gap_cost_cache = vec![0];
581607
let mut syntax_level_gap_cost = |boundary: usize, internal: usize| -> usize {
@@ -590,7 +616,7 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
590616
}
591617
};
592618

593-
for i in 0..atom_chunks.len() - 1 {
619+
for (i, chunk) in atom_chunks[0..atom_chunks.len() - 1].iter().enumerate() {
594620
let mut min_cost = usize::MAX;
595621
let mut arg_min_start_idx: usize = 0;
596622
let mut arg_min_prev_plan_idx: usize = 0;
@@ -609,11 +635,9 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
609635
0
610636
}
611637
}
612-
613638
loop {
614639
let start_chunk = &atom_chunks[start_idx];
615-
let current_chunk_end = atom_chunks[i].range.end;
616-
let chunk_size = current_chunk_end - start_chunk.range.start;
640+
let chunk_size = chunk.range.end - start_chunk.range.start;
617641

618642
let mut cost = 0;
619643
cost +=
@@ -635,42 +659,41 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
635659
break;
636660
}
637661

638-
let mut best_prev_plan_idx = start_idx;
639-
if self.chunk_overlap > 0 {
640-
let mut min_prev_plan_cost = plans[start_idx].cost;
641-
for k_idx in (0..start_idx).rev() {
642-
let end_of_prev_chunk = if k_idx > 0 {
643-
atom_chunks[k_idx - 1].range.end
644-
} else {
645-
0
646-
};
647-
let overlap = end_of_prev_chunk.saturating_sub(start_chunk.range.start);
648-
if overlap > self.chunk_overlap {
662+
let prev_plan_idx = if self.chunk_overlap > 0 {
663+
while let Some(top_prev_plan) = prev_plan_candidates.peek() {
664+
let overlap_size =
665+
atom_chunks[top_prev_plan.1].range.end - start_chunk.range.start;
666+
if overlap_size <= self.chunk_overlap {
649667
break;
650668
}
651-
if plans[k_idx].cost < min_prev_plan_cost {
652-
min_prev_plan_cost = plans[k_idx].cost;
653-
best_prev_plan_idx = k_idx;
654-
}
669+
prev_plan_candidates.pop();
655670
}
656-
}
657-
658-
let prev_plan = &plans[best_prev_plan_idx];
659-
cost += prev_plan.cost;
660-
661-
let end_of_prev_chunk = if best_prev_plan_idx > 0 {
662-
atom_chunks[best_prev_plan_idx - 1].range.end
671+
prev_plan_candidates.push((
672+
std::cmp::Reverse(
673+
plans[start_idx].cost + plans[start_idx].overlap_cost_base,
674+
),
675+
start_idx,
676+
));
677+
prev_plan_candidates.peek().unwrap().1
663678
} else {
664-
0
679+
start_idx
665680
};
666-
if end_of_prev_chunk < start_chunk.range.start {
667-
cost += MISSING_OVERLAP_COST;
681+
let prev_plan = &plans[prev_plan_idx];
682+
cost += prev_plan.cost;
683+
if self.chunk_overlap == 0 {
684+
cost += MISSING_OVERLAP_COST / 2;
685+
} else {
686+
let start_cost_base = self.get_overlap_cost_base(start_chunk.range.start);
687+
cost += if prev_plan.overlap_cost_base < start_cost_base {
688+
MISSING_OVERLAP_COST + prev_plan.overlap_cost_base - start_cost_base
689+
} else {
690+
MISSING_OVERLAP_COST
691+
};
668692
}
669-
670693
if cost < min_cost {
671694
min_cost = cost;
672695
arg_min_start_idx = start_idx;
673-
arg_min_prev_plan_idx = best_prev_plan_idx;
696+
arg_min_prev_plan_idx = prev_plan_idx;
674697
}
675698

676699
if start_idx == 0 {
@@ -679,14 +702,17 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
679702

680703
start_idx -= 1;
681704
internal_syntax_level =
682-
internal_syntax_level.min(atom_chunks[start_idx].boundary_syntax_level);
683-
internal_lb_level = internal_lb_level.max(atom_chunks[start_idx].internal_lb_level);
705+
internal_syntax_level.min(atom_chunks[start_idx + 1].boundary_syntax_level);
706+
internal_lb_level =
707+
internal_lb_level.max(atom_chunks[start_idx + 1].internal_lb_level);
684708
}
685709
plans.push(AtomRoutingPlan {
686710
start_idx: arg_min_start_idx,
687711
prev_plan_idx: arg_min_prev_plan_idx,
688712
cost: min_cost,
713+
overlap_cost_base: self.get_overlap_cost_base(chunk.range.end),
689714
});
715+
prev_plan_candidates.clear();
690716
}
691717

692718
let mut output = Vec::new();
@@ -695,11 +721,10 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
695721
let plan = &plans[plan_idx];
696722
let start_chunk = &atom_chunks[plan.start_idx];
697723
let end_chunk = &atom_chunks[plan_idx - 1];
698-
let std_range: Range<usize> = start_chunk.range.start..end_chunk.range.end;
699724
output.push(ChunkOutput {
700725
start_pos: Position::new(start_chunk.range.start),
701726
end_pos: Position::new(end_chunk.range.end),
702-
text: &self.full_text[std_range],
727+
text: &self.full_text[start_chunk.range.start..end_chunk.range.end],
703728
});
704729
plan_idx = plan.prev_plan_idx;
705730
}
@@ -710,7 +735,7 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
710735
fn split_root_chunk(&self, kind: ChunkKind<'t>) -> Result<Vec<ChunkOutput<'s>>> {
711736
let mut atom_collector = AtomChunksCollector {
712737
full_text: self.full_text,
713-
min_level: usize::MAX,
738+
min_level: 0,
714739
curr_level: 0,
715740
atom_chunks: Vec::new(),
716741
};

0 commit comments

Comments
 (0)