@@ -261,8 +261,7 @@ struct Chunk<'t, 's: 't> {
261261
262262struct TextChunksIter < ' t , ' s : ' t > {
263263 lang_config : & ' t SimpleLanguageConfig ,
264- full_text : & ' s str ,
265- range : RangeValue ,
264+ parent : & ' t Chunk < ' t , ' s > ,
266265 matches_iter : Matches < ' t , ' s > ,
267266 regexp_sep_id : usize ,
268267 next_start_pos : Option < usize > ,
@@ -271,19 +270,16 @@ struct TextChunksIter<'t, 's: 't> {
271270impl < ' t , ' s : ' t > TextChunksIter < ' t , ' s > {
272271 fn new (
273272 lang_config : & ' t SimpleLanguageConfig ,
274- full_text : & ' s str ,
275- range : RangeValue ,
273+ parent : & ' t Chunk < ' t , ' s > ,
276274 regexp_sep_id : usize ,
277275 ) -> Self {
278- let std_range = range. start ..range. end ;
279276 Self {
280277 lang_config,
281- full_text,
282- range,
278+ parent,
283279 matches_iter : lang_config. separator_regex [ regexp_sep_id]
284- . find_iter ( & full_text[ std_range . clone ( ) ] ) ,
280+ . find_iter ( & parent . full_text [ parent . range . start ..parent . range . end ] ) ,
285281 regexp_sep_id,
286- next_start_pos : Some ( std_range . start ) ,
282+ next_start_pos : Some ( parent . range . start ) ,
287283 }
288284 }
289285}
@@ -295,19 +291,19 @@ impl<'t, 's: 't> Iterator for TextChunksIter<'t, 's> {
295291 let start_pos = self . next_start_pos ?;
296292 let end_pos = match self . matches_iter . next ( ) {
297293 Some ( grp) => {
298- self . next_start_pos = Some ( self . range . start + grp. end ( ) ) ;
299- self . range . start + grp. start ( )
294+ self . next_start_pos = Some ( self . parent . range . start + grp. end ( ) ) ;
295+ self . parent . range . start + grp. start ( )
300296 }
301297 None => {
302298 self . next_start_pos = None ;
303- if start_pos >= self . range . end {
299+ if start_pos >= self . parent . range . end {
304300 return None ;
305301 }
306- self . range . end
302+ self . parent . range . end
307303 }
308304 } ;
309305 Some ( Chunk {
310- full_text : self . full_text ,
306+ full_text : self . parent . full_text ,
311307 range : RangeValue :: new ( start_pos, end_pos) ,
312308 kind : ChunkKind :: RegexpSepChunk {
313309 lang_config : self . lang_config ,
@@ -378,6 +374,24 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> {
378374 }
379375}
380376
377+ enum ChunkIterator < ' t , ' s : ' t > {
378+ TreeSitter ( TreeSitterNodeIter < ' t , ' s > ) ,
379+ Text ( TextChunksIter < ' t , ' s > ) ,
380+ Once ( std:: iter:: Once < Chunk < ' t , ' s > > ) ,
381+ }
382+
383+ impl < ' t , ' s : ' t > Iterator for ChunkIterator < ' t , ' s > {
384+ type Item = Chunk < ' t , ' s > ;
385+
386+ fn next ( & mut self ) -> Option < Self :: Item > {
387+ match self {
388+ ChunkIterator :: TreeSitter ( iter) => iter. next ( ) ,
389+ ChunkIterator :: Text ( iter) => iter. next ( ) ,
390+ ChunkIterator :: Once ( iter) => iter. next ( ) ,
391+ }
392+ }
393+ }
394+
381395#[ derive( Debug , Clone , Copy , PartialEq , Eq , PartialOrd , Ord ) ]
382396enum LineBreakLevel {
383397 Inline ,
@@ -434,8 +448,7 @@ struct AtomChunksCollector<'s> {
434448impl < ' s > AtomChunksCollector < ' s > {
435449 fn collect ( & mut self , range : RangeValue ) {
436450 // Trim trailing whitespaces.
437- let std_range = range. start ..range. end ;
438- let end_trimmed_text = & self . full_text [ std_range] . trim_end ( ) ;
451+ let end_trimmed_text = & self . full_text [ range. start ..range. end ] . trim_end ( ) ;
439452 if end_trimmed_text. is_empty ( ) {
440453 return ;
441454 }
@@ -449,7 +462,7 @@ impl<'s> AtomChunksCollector<'s> {
449462 let prev_end = self . atom_chunks . last ( ) . map_or ( 0 , |chunk| chunk. range . end ) ;
450463 let gap = & self . full_text [ prev_end..new_start] ;
451464 let boundary_lb_level = line_break_level ( gap) ;
452- let range: RangeValue = if boundary_lb_level != LineBreakLevel :: Inline {
465+ let range = if boundary_lb_level != LineBreakLevel :: Inline {
453466 let trimmed_gap = gap. trim_end_matches ( INLINE_SPACE_CHARS ) ;
454467 RangeValue :: new ( prev_end + trimmed_gap. len ( ) , new_end)
455468 } else {
@@ -495,8 +508,8 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
495508 chunk : Chunk < ' t , ' s > ,
496509 atom_collector : & mut AtomChunksCollector < ' s > ,
497510 ) -> Result < ( ) > {
498- let mut iter_stack: Vec < Box < dyn Iterator < Item = Chunk < ' t , ' s > > > > =
499- vec ! [ Box :: new ( std:: iter:: once( chunk) ) ] ;
511+ let mut iter_stack: Vec < ChunkIterator < ' t , ' s > > =
512+ vec ! [ ChunkIterator :: Once ( std:: iter:: once( chunk) ) ] ;
500513
501514 while !iter_stack. is_empty ( ) {
502515 atom_collector. curr_level = iter_stack. len ( ) ;
@@ -510,17 +523,19 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
510523 if !lang_config. terminal_node_kind_ids . contains ( & node. kind_id ( ) ) {
511524 let mut cursor = node. walk ( ) ;
512525 if cursor. goto_first_child ( ) {
513- iter_stack. push ( Box :: new ( TreeSitterNodeIter {
514- lang_config,
515- full_text : self . full_text ,
516- cursor : Some ( cursor) ,
517- next_start_pos : node. start_byte ( ) ,
518- end_pos : node. end_byte ( ) ,
519- } ) ) ;
526+ iter_stack. push ( ChunkIterator :: TreeSitter (
527+ TreeSitterNodeIter {
528+ lang_config,
529+ full_text : self . full_text ,
530+ cursor : Some ( cursor) ,
531+ next_start_pos : node. start_byte ( ) ,
532+ end_pos : node. end_byte ( ) ,
533+ } ,
534+ ) ) ;
520535 continue ;
521536 }
522537 }
523- iter_stack. push ( Box :: new ( std:: iter:: once ( Chunk {
538+ iter_stack. push ( ChunkIterator :: Once ( std:: iter:: once ( Chunk {
524539 full_text : self . full_text ,
525540 range : current_chunk. range ,
526541 kind : ChunkKind :: RegexpSepChunk {
@@ -536,10 +551,9 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
536551 if next_regexp_sep_id >= lang_config. separator_regex . len ( ) {
537552 atom_collector. collect ( current_chunk. range ) ;
538553 } else {
539- iter_stack. push ( Box :: new ( TextChunksIter :: new (
554+ iter_stack. push ( ChunkIterator :: Text ( TextChunksIter :: new (
540555 lang_config,
541- current_chunk. full_text ,
542- current_chunk. range ,
556+ & current_chunk, // Pass reference to chunk
543557 next_regexp_sep_id,
544558 ) ) ) ;
545559 }
@@ -559,12 +573,22 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
559573 Ok ( ( ) )
560574 }
561575
576+ fn get_overlap_cost_base ( & self , offset : usize ) -> usize {
577+ if self . chunk_overlap == 0 {
578+ 0
579+ } else {
580+ ( self . full_text . len ( ) - offset) * MISSING_OVERLAP_COST / self . chunk_overlap
581+ }
582+ }
583+
562584 fn merge_atom_chunks ( & self , atom_chunks : Vec < AtomChunk > ) -> Vec < ChunkOutput < ' s > > {
563585 struct AtomRoutingPlan {
564586 start_idx : usize ,
565587 prev_plan_idx : usize ,
566588 cost : usize ,
589+ overlap_cost_base : usize ,
567590 }
591+ type PrevPlanCandidate = ( std:: cmp:: Reverse < usize > , usize ) ; // (cost, start_idx)
568592
569593 if atom_chunks. is_empty ( ) || atom_chunks. len ( ) == 1 {
570594 return Vec :: new ( ) ;
@@ -575,7 +599,9 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
575599 start_idx : 0 ,
576600 prev_plan_idx : 0 ,
577601 cost : 0 ,
602+ overlap_cost_base : self . get_overlap_cost_base ( 0 ) ,
578603 } ) ;
604+ let mut prev_plan_candidates = std:: collections:: BinaryHeap :: < PrevPlanCandidate > :: new ( ) ;
579605
580606 let mut gap_cost_cache = vec ! [ 0 ] ;
581607 let mut syntax_level_gap_cost = |boundary : usize , internal : usize | -> usize {
@@ -590,7 +616,7 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
590616 }
591617 } ;
592618
593- for i in 0 ..atom_chunks. len ( ) - 1 {
619+ for ( i , chunk ) in atom_chunks [ 0 ..atom_chunks. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
594620 let mut min_cost = usize:: MAX ;
595621 let mut arg_min_start_idx: usize = 0 ;
596622 let mut arg_min_prev_plan_idx: usize = 0 ;
@@ -609,11 +635,9 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
609635 0
610636 }
611637 }
612-
613638 loop {
614639 let start_chunk = & atom_chunks[ start_idx] ;
615- let current_chunk_end = atom_chunks[ i] . range . end ;
616- let chunk_size = current_chunk_end - start_chunk. range . start ;
640+ let chunk_size = chunk. range . end - start_chunk. range . start ;
617641
618642 let mut cost = 0 ;
619643 cost +=
@@ -635,42 +659,41 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
635659 break ;
636660 }
637661
638- let mut best_prev_plan_idx = start_idx;
639- if self . chunk_overlap > 0 {
640- let mut min_prev_plan_cost = plans[ start_idx] . cost ;
641- for k_idx in ( 0 ..start_idx) . rev ( ) {
642- let end_of_prev_chunk = if k_idx > 0 {
643- atom_chunks[ k_idx - 1 ] . range . end
644- } else {
645- 0
646- } ;
647- let overlap = end_of_prev_chunk. saturating_sub ( start_chunk. range . start ) ;
648- if overlap > self . chunk_overlap {
662+ let prev_plan_idx = if self . chunk_overlap > 0 {
663+ while let Some ( top_prev_plan) = prev_plan_candidates. peek ( ) {
664+ let overlap_size =
665+ atom_chunks[ top_prev_plan. 1 ] . range . end - start_chunk. range . start ;
666+ if overlap_size <= self . chunk_overlap {
649667 break ;
650668 }
651- if plans[ k_idx] . cost < min_prev_plan_cost {
652- min_prev_plan_cost = plans[ k_idx] . cost ;
653- best_prev_plan_idx = k_idx;
654- }
669+ prev_plan_candidates. pop ( ) ;
655670 }
656- }
657-
658- let prev_plan = & plans[ best_prev_plan_idx ] ;
659- cost += prev_plan . cost ;
660-
661- let end_of_prev_chunk = if best_prev_plan_idx > 0 {
662- atom_chunks [ best_prev_plan_idx - 1 ] . range . end
671+ prev_plan_candidates . push ( (
672+ std :: cmp :: Reverse (
673+ plans[ start_idx ] . cost + plans [ start_idx ] . overlap_cost_base ,
674+ ) ,
675+ start_idx ,
676+ ) ) ;
677+ prev_plan_candidates . peek ( ) . unwrap ( ) . 1
663678 } else {
664- 0
679+ start_idx
665680 } ;
666- if end_of_prev_chunk < start_chunk. range . start {
667- cost += MISSING_OVERLAP_COST ;
681+ let prev_plan = & plans[ prev_plan_idx] ;
682+ cost += prev_plan. cost ;
683+ if self . chunk_overlap == 0 {
684+ cost += MISSING_OVERLAP_COST / 2 ;
685+ } else {
686+ let start_cost_base = self . get_overlap_cost_base ( start_chunk. range . start ) ;
687+ cost += if prev_plan. overlap_cost_base < start_cost_base {
688+ MISSING_OVERLAP_COST + prev_plan. overlap_cost_base - start_cost_base
689+ } else {
690+ MISSING_OVERLAP_COST
691+ } ;
668692 }
669-
670693 if cost < min_cost {
671694 min_cost = cost;
672695 arg_min_start_idx = start_idx;
673- arg_min_prev_plan_idx = best_prev_plan_idx ;
696+ arg_min_prev_plan_idx = prev_plan_idx ;
674697 }
675698
676699 if start_idx == 0 {
@@ -679,14 +702,17 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
679702
680703 start_idx -= 1 ;
681704 internal_syntax_level =
682- internal_syntax_level. min ( atom_chunks[ start_idx] . boundary_syntax_level ) ;
683- internal_lb_level = internal_lb_level. max ( atom_chunks[ start_idx] . internal_lb_level ) ;
705+ internal_syntax_level. min ( atom_chunks[ start_idx + 1 ] . boundary_syntax_level ) ;
706+ internal_lb_level =
707+ internal_lb_level. max ( atom_chunks[ start_idx + 1 ] . internal_lb_level ) ;
684708 }
685709 plans. push ( AtomRoutingPlan {
686710 start_idx : arg_min_start_idx,
687711 prev_plan_idx : arg_min_prev_plan_idx,
688712 cost : min_cost,
713+ overlap_cost_base : self . get_overlap_cost_base ( chunk. range . end ) ,
689714 } ) ;
715+ prev_plan_candidates. clear ( ) ;
690716 }
691717
692718 let mut output = Vec :: new ( ) ;
@@ -695,11 +721,10 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
695721 let plan = & plans[ plan_idx] ;
696722 let start_chunk = & atom_chunks[ plan. start_idx ] ;
697723 let end_chunk = & atom_chunks[ plan_idx - 1 ] ;
698- let std_range: Range < usize > = start_chunk. range . start ..end_chunk. range . end ;
699724 output. push ( ChunkOutput {
700725 start_pos : Position :: new ( start_chunk. range . start ) ,
701726 end_pos : Position :: new ( end_chunk. range . end ) ,
702- text : & self . full_text [ std_range ] ,
727+ text : & self . full_text [ start_chunk . range . start ..end_chunk . range . end ] ,
703728 } ) ;
704729 plan_idx = plan. prev_plan_idx ;
705730 }
@@ -710,7 +735,7 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
710735 fn split_root_chunk ( & self , kind : ChunkKind < ' t > ) -> Result < Vec < ChunkOutput < ' s > > > {
711736 let mut atom_collector = AtomChunksCollector {
712737 full_text : self . full_text ,
713- min_level : usize :: MAX ,
738+ min_level : 0 ,
714739 curr_level : 0 ,
715740 atom_chunks : Vec :: new ( ) ,
716741 } ;
0 commit comments