11use anyhow:: anyhow;
2+ use log:: { error, trace} ;
23use regex:: { Matches , Regex } ;
4+ use std:: collections:: HashSet ;
35use std:: sync:: LazyLock ;
46use std:: { collections:: HashMap , sync:: Arc } ;
57
@@ -22,20 +24,85 @@ static TEXT_SEPARATOR: LazyLock<Vec<Regex>> = LazyLock::new(|| {
2224 . collect ( )
2325} ) ;
2426
25- static TREE_SITTER_LANGUAGE_BY_LANG : LazyLock < HashMap < & ' static str , tree_sitter:: Language > > =
26- LazyLock :: new ( || {
27- [
28- ( "python" , tree_sitter_python:: LANGUAGE . into ( ) ) ,
29- ( "javascript" , tree_sitter_javascript:: LANGUAGE . into ( ) ) ,
30- (
31- "typescript" ,
32- tree_sitter_typescript:: LANGUAGE_TYPESCRIPT . into ( ) ,
33- ) ,
34- ( "tsx" , tree_sitter_typescript:: LANGUAGE_TSX . into ( ) ) ,
35- ( "markdown" , tree_sitter_md:: LANGUAGE . into ( ) ) ,
36- ]
27+ struct LanguageConfig {
28+ name : & ' static str ,
29+ tree_sitter_lang : tree_sitter:: Language ,
30+ terminal_node_kind_ids : HashSet < u16 > ,
31+ }
32+
33+ fn add_language < ' a > (
34+ output : & ' a mut HashMap < & ' static str , Arc < LanguageConfig > > ,
35+ name : & ' static str ,
36+ aliases : impl IntoIterator < Item = & ' static str > ,
37+ lang_fn : tree_sitter_language:: LanguageFn ,
38+ terminal_node_kinds : impl IntoIterator < Item = & ' a str > ,
39+ ) {
40+ let tree_sitter_lang: tree_sitter:: Language = lang_fn. into ( ) ;
41+ let terminal_node_kind_ids = terminal_node_kinds
3742 . into_iter ( )
38- . collect ( )
43+ . filter_map ( |kind| {
44+ let id = tree_sitter_lang. id_for_node_kind ( kind, true ) ;
45+ if id != 0 {
46+ trace ! ( "Got id for node kind: `{kind}` -> {id}" ) ;
47+ Some ( id)
48+ } else {
49+ error ! ( "Failed in getting id for node kind: `{kind}`" ) ;
50+ None
51+ }
52+ } )
53+ . collect ( ) ;
54+
55+ let config = Arc :: new ( LanguageConfig {
56+ name,
57+ tree_sitter_lang,
58+ terminal_node_kind_ids,
59+ } ) ;
60+ for name in std:: iter:: once ( name) . chain ( aliases. into_iter ( ) ) {
61+ if output. insert ( name, config. clone ( ) ) . is_some ( ) {
62+ panic ! ( "Language `{name}` already exists" ) ;
63+ }
64+ }
65+ }
66+
67+ static TREE_SITTER_LANGUAGE_BY_LANG : LazyLock < HashMap < & ' static str , Arc < LanguageConfig > > > =
68+ LazyLock :: new ( || {
69+ let mut map = HashMap :: new ( ) ;
70+ add_language (
71+ & mut map,
72+ "Python" ,
73+ [ "py" , "python" ] ,
74+ tree_sitter_python:: LANGUAGE ,
75+ [ ] ,
76+ ) ;
77+ add_language (
78+ & mut map,
79+ "JavaScript" ,
80+ [ "JS" , "js" , "Javascript" , "javascript" ] ,
81+ tree_sitter_javascript:: LANGUAGE ,
82+ [ ] ,
83+ ) ;
84+ add_language (
85+ & mut map,
86+ "TypeScript" ,
87+ [ "TS" , "ts" , "Typescript" , "typescript" ] ,
88+ tree_sitter_typescript:: LANGUAGE_TYPESCRIPT ,
89+ [ ] ,
90+ ) ;
91+ add_language (
92+ & mut map,
93+ "TSX" ,
94+ [ "tsx" ] ,
95+ tree_sitter_typescript:: LANGUAGE_TSX ,
96+ [ ] ,
97+ ) ;
98+ add_language (
99+ & mut map,
100+ "Markdown" ,
101+ [ "md" , "markdown" ] ,
102+ tree_sitter_md:: LANGUAGE . into ( ) ,
103+ [ "inline" ] ,
104+ ) ;
105+ map
39106 } ) ;
40107
41108enum ChunkKind < ' t > {
@@ -162,6 +229,7 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> {
162229
163230struct RecursiveChunker < ' s > {
164231 full_text : & ' s str ,
232+ lang_config : Option < & ' s LanguageConfig > ,
165233 chunk_size : usize ,
166234 chunk_overlap : usize ,
167235}
@@ -225,20 +293,27 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
225293 ) -> Result < ( ) > {
226294 match chunk. kind {
227295 ChunkKind :: TreeSitterNode { node } => {
228- let mut cursor = node. walk ( ) ;
229- if cursor. goto_first_child ( ) {
230- self . process_sub_chunks (
231- TreeSitterNodeIter {
232- full_text : self . full_text ,
233- cursor : Some ( cursor) ,
234- next_start_pos : node. start_byte ( ) ,
235- end_pos : node. end_byte ( ) ,
236- } ,
237- output,
238- ) ?;
239- } else {
240- self . add_output ( chunk. range , output) ;
296+ if !self
297+ . lang_config
298+ . ok_or_else ( || anyhow ! ( "Language not set." ) ) ?
299+ . terminal_node_kind_ids
300+ . contains ( & node. kind_id ( ) )
301+ {
302+ let mut cursor = node. walk ( ) ;
303+ if cursor. goto_first_child ( ) {
304+ self . process_sub_chunks (
305+ TreeSitterNodeIter {
306+ full_text : self . full_text ,
307+ cursor : Some ( cursor) ,
308+ next_start_pos : node. start_byte ( ) ,
309+ end_pos : node. end_byte ( ) ,
310+ } ,
311+ output,
312+ ) ?;
313+ return Ok ( ( ) ) ;
314+ }
241315 }
316+ self . add_output ( chunk. range , output) ;
242317 }
243318 ChunkKind :: RegexpSepChunk { next_regexp_sep_id } => {
244319 if next_regexp_sep_id >= TEXT_SEPARATOR . len ( ) {
@@ -335,8 +410,17 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
335410impl SimpleFunctionExecutor for Executor {
336411 async fn evaluate ( & self , input : Vec < Value > ) -> Result < Value > {
337412 let full_text = self . args . text . value ( & input) ?. as_str ( ) ?;
413+ let lang_config = {
414+ let language = self . args . language . value ( & input) ?;
415+ language
416+ . map ( |v| anyhow:: Ok ( v. as_str ( ) ?. as_ref ( ) ) )
417+ . transpose ( ) ?
418+ . and_then ( |lang| TREE_SITTER_LANGUAGE_BY_LANG . get ( lang) )
419+ } ;
420+
338421 let recursive_chunker = RecursiveChunker {
339422 full_text,
423+ lang_config : lang_config. map ( |c| c. as_ref ( ) ) ,
340424 chunk_size : self . args . chunk_size . value ( & input) ?. as_int64 ( ) ? as usize ,
341425 chunk_overlap : self
342426 . args
@@ -347,20 +431,11 @@ impl SimpleFunctionExecutor for Executor {
347431 . unwrap_or ( 0 ) as usize ,
348432 } ;
349433
350- let language = self . args . language . value ( & input) ?;
351- let language = language
352- . map ( |v| anyhow:: Ok ( v. as_str ( ) ?. as_ref ( ) ) )
353- . transpose ( ) ?;
354- let tree_sitter_language = language. and_then ( |lang| TREE_SITTER_LANGUAGE_BY_LANG . get ( lang) ) ;
355-
356- let mut output = if let Some ( tree_sitter_lang) = tree_sitter_language {
434+ let mut output = if let Some ( lang_config) = lang_config {
357435 let mut parser = tree_sitter:: Parser :: new ( ) ;
358- parser. set_language ( tree_sitter_lang) ?;
436+ parser. set_language ( & lang_config . tree_sitter_lang ) ?;
359437 let tree = parser. parse ( full_text. as_ref ( ) , None ) . ok_or_else ( || {
360- anyhow ! (
361- "failed in parsing text in language: {}" ,
362- language. unwrap_or_default( )
363- )
438+ anyhow ! ( "failed in parsing text in language: {}" , lang_config. name)
364439 } ) ?;
365440 recursive_chunker. split_root_chunk ( ChunkKind :: TreeSitterNode {
366441 node : tree. root_node ( ) ,
0 commit comments