Skip to content

Commit 82e07de

Browse files
committed
Add per-language information and terminate at "inline" for Markdown.
1 parent 08bce3a commit 82e07de

File tree

2 files changed

+115
-38
lines changed

2 files changed

+115
-38
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ console-subscriber = "0.4.1"
4949
env_logger = "0.11.7"
5050
reqwest = { version = "0.12.13", features = ["json"] }
5151
async-openai = "0.28.0"
52+
5253
tree-sitter = "0.25.3"
54+
tree-sitter-language = "0.1.5"
5355
tree-sitter-python = "0.23.6"
5456
tree-sitter-javascript = "0.23.1"
5557
tree-sitter-typescript = "0.23.2"

src/ops/functions/split_recursively.rs

Lines changed: 113 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use anyhow::anyhow;
2+
use log::{error, trace};
23
use regex::{Matches, Regex};
4+
use std::collections::HashSet;
35
use std::sync::LazyLock;
46
use std::{collections::HashMap, sync::Arc};
57

@@ -22,20 +24,85 @@ static TEXT_SEPARATOR: LazyLock<Vec<Regex>> = LazyLock::new(|| {
2224
.collect()
2325
});
2426

25-
static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock<HashMap<&'static str, tree_sitter::Language>> =
26-
LazyLock::new(|| {
27-
[
28-
("python", tree_sitter_python::LANGUAGE.into()),
29-
("javascript", tree_sitter_javascript::LANGUAGE.into()),
30-
(
31-
"typescript",
32-
tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
33-
),
34-
("tsx", tree_sitter_typescript::LANGUAGE_TSX.into()),
35-
("markdown", tree_sitter_md::LANGUAGE.into()),
36-
]
27+
struct LanguageConfig {
28+
name: &'static str,
29+
tree_sitter_lang: tree_sitter::Language,
30+
terminal_node_kind_ids: HashSet<u16>,
31+
}
32+
33+
fn add_language<'a>(
34+
output: &'a mut HashMap<&'static str, Arc<LanguageConfig>>,
35+
name: &'static str,
36+
aliases: impl IntoIterator<Item = &'static str>,
37+
lang_fn: tree_sitter_language::LanguageFn,
38+
terminal_node_kinds: impl IntoIterator<Item = &'a str>,
39+
) {
40+
let tree_sitter_lang: tree_sitter::Language = lang_fn.into();
41+
let terminal_node_kind_ids = terminal_node_kinds
3742
.into_iter()
38-
.collect()
43+
.filter_map(|kind| {
44+
let id = tree_sitter_lang.id_for_node_kind(kind, true);
45+
if id != 0 {
46+
trace!("Got id for node kind: `{kind}` -> {id}");
47+
Some(id)
48+
} else {
49+
error!("Failed in getting id for node kind: `{kind}`");
50+
None
51+
}
52+
})
53+
.collect();
54+
55+
let config = Arc::new(LanguageConfig {
56+
name,
57+
tree_sitter_lang,
58+
terminal_node_kind_ids,
59+
});
60+
for name in std::iter::once(name).chain(aliases.into_iter()) {
61+
if output.insert(name, config.clone()).is_some() {
62+
panic!("Language `{name}` already exists");
63+
}
64+
}
65+
}
66+
67+
static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock<HashMap<&'static str, Arc<LanguageConfig>>> =
68+
LazyLock::new(|| {
69+
let mut map = HashMap::new();
70+
add_language(
71+
&mut map,
72+
"Python",
73+
["py", "python"],
74+
tree_sitter_python::LANGUAGE,
75+
[],
76+
);
77+
add_language(
78+
&mut map,
79+
"JavaScript",
80+
["JS", "js", "Javascript", "javascript"],
81+
tree_sitter_javascript::LANGUAGE,
82+
[],
83+
);
84+
add_language(
85+
&mut map,
86+
"TypeScript",
87+
["TS", "ts", "Typescript", "typescript"],
88+
tree_sitter_typescript::LANGUAGE_TYPESCRIPT,
89+
[],
90+
);
91+
add_language(
92+
&mut map,
93+
"TSX",
94+
["tsx"],
95+
tree_sitter_typescript::LANGUAGE_TSX,
96+
[],
97+
);
98+
add_language(
99+
&mut map,
100+
"Markdown",
101+
["md", "markdown"],
102+
tree_sitter_md::LANGUAGE.into(),
103+
["inline"],
104+
);
105+
map
39106
});
40107

41108
enum ChunkKind<'t> {
@@ -162,6 +229,7 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> {
162229

163230
struct RecursiveChunker<'s> {
164231
full_text: &'s str,
232+
lang_config: Option<&'s LanguageConfig>,
165233
chunk_size: usize,
166234
chunk_overlap: usize,
167235
}
@@ -225,20 +293,27 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
225293
) -> Result<()> {
226294
match chunk.kind {
227295
ChunkKind::TreeSitterNode { node } => {
228-
let mut cursor = node.walk();
229-
if cursor.goto_first_child() {
230-
self.process_sub_chunks(
231-
TreeSitterNodeIter {
232-
full_text: self.full_text,
233-
cursor: Some(cursor),
234-
next_start_pos: node.start_byte(),
235-
end_pos: node.end_byte(),
236-
},
237-
output,
238-
)?;
239-
} else {
240-
self.add_output(chunk.range, output);
296+
if !self
297+
.lang_config
298+
.ok_or_else(|| anyhow!("Language not set."))?
299+
.terminal_node_kind_ids
300+
.contains(&node.kind_id())
301+
{
302+
let mut cursor = node.walk();
303+
if cursor.goto_first_child() {
304+
self.process_sub_chunks(
305+
TreeSitterNodeIter {
306+
full_text: self.full_text,
307+
cursor: Some(cursor),
308+
next_start_pos: node.start_byte(),
309+
end_pos: node.end_byte(),
310+
},
311+
output,
312+
)?;
313+
return Ok(());
314+
}
241315
}
316+
self.add_output(chunk.range, output);
242317
}
243318
ChunkKind::RegexpSepChunk { next_regexp_sep_id } => {
244319
if next_regexp_sep_id >= TEXT_SEPARATOR.len() {
@@ -335,8 +410,17 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
335410
impl SimpleFunctionExecutor for Executor {
336411
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
337412
let full_text = self.args.text.value(&input)?.as_str()?;
413+
let lang_config = {
414+
let language = self.args.language.value(&input)?;
415+
language
416+
.map(|v| anyhow::Ok(v.as_str()?.as_ref()))
417+
.transpose()?
418+
.and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(lang))
419+
};
420+
338421
let recursive_chunker = RecursiveChunker {
339422
full_text,
423+
lang_config: lang_config.map(|c| c.as_ref()),
340424
chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize,
341425
chunk_overlap: self
342426
.args
@@ -347,20 +431,11 @@ impl SimpleFunctionExecutor for Executor {
347431
.unwrap_or(0) as usize,
348432
};
349433

350-
let language = self.args.language.value(&input)?;
351-
let language = language
352-
.map(|v| anyhow::Ok(v.as_str()?.as_ref()))
353-
.transpose()?;
354-
let tree_sitter_language = language.and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(lang));
355-
356-
let mut output = if let Some(tree_sitter_lang) = tree_sitter_language {
434+
let mut output = if let Some(lang_config) = lang_config {
357435
let mut parser = tree_sitter::Parser::new();
358-
parser.set_language(tree_sitter_lang)?;
436+
parser.set_language(&lang_config.tree_sitter_lang)?;
359437
let tree = parser.parse(full_text.as_ref(), None).ok_or_else(|| {
360-
anyhow!(
361-
"failed in parsing text in language: {}",
362-
language.unwrap_or_default()
363-
)
438+
anyhow!("failed in parsing text in language: {}", lang_config.name)
364439
})?;
365440
recursive_chunker.split_root_chunk(ChunkKind::TreeSitterNode {
366441
node: tree.root_node(),

0 commit comments

Comments
 (0)