Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ console-subscriber = "0.4.1"
env_logger = "0.11.7"
reqwest = { version = "0.12.13", features = ["json"] }
async-openai = "0.28.0"

tree-sitter = "0.25.3"
tree-sitter-language = "0.1.5"
tree-sitter-python = "0.23.6"
tree-sitter-javascript = "0.23.1"
tree-sitter-typescript = "0.23.2"
Expand Down
151 changes: 113 additions & 38 deletions src/ops/functions/split_recursively.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use anyhow::anyhow;
use log::{error, trace};
use regex::{Matches, Regex};
use std::collections::HashSet;
use std::sync::LazyLock;
use std::{collections::HashMap, sync::Arc};

Expand All @@ -22,20 +24,85 @@ static TEXT_SEPARATOR: LazyLock<Vec<Regex>> = LazyLock::new(|| {
.collect()
});

static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock<HashMap<&'static str, tree_sitter::Language>> =
LazyLock::new(|| {
[
("python", tree_sitter_python::LANGUAGE.into()),
("javascript", tree_sitter_javascript::LANGUAGE.into()),
(
"typescript",
tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
),
("tsx", tree_sitter_typescript::LANGUAGE_TSX.into()),
("markdown", tree_sitter_md::LANGUAGE.into()),
]
struct LanguageConfig {
name: &'static str,
tree_sitter_lang: tree_sitter::Language,
terminal_node_kind_ids: HashSet<u16>,
}

fn add_language<'a>(
output: &'a mut HashMap<&'static str, Arc<LanguageConfig>>,
name: &'static str,
aliases: impl IntoIterator<Item = &'static str>,
lang_fn: tree_sitter_language::LanguageFn,
terminal_node_kinds: impl IntoIterator<Item = &'a str>,
) {
let tree_sitter_lang: tree_sitter::Language = lang_fn.into();
let terminal_node_kind_ids = terminal_node_kinds
.into_iter()
.collect()
.filter_map(|kind| {
let id = tree_sitter_lang.id_for_node_kind(kind, true);
if id != 0 {
trace!("Got id for node kind: `{kind}` -> {id}");
Some(id)
} else {
error!("Failed in getting id for node kind: `{kind}`");
None
}
})
.collect();

let config = Arc::new(LanguageConfig {
name,
tree_sitter_lang,
terminal_node_kind_ids,
});
for name in std::iter::once(name).chain(aliases.into_iter()) {
if output.insert(name, config.clone()).is_some() {
panic!("Language `{name}` already exists");
}
}
}

static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock<HashMap<&'static str, Arc<LanguageConfig>>> =
LazyLock::new(|| {
let mut map = HashMap::new();
add_language(
&mut map,
"Python",
["py", "python"],
tree_sitter_python::LANGUAGE,
[],
);
add_language(
&mut map,
"JavaScript",
["JS", "js", "Javascript", "javascript"],
tree_sitter_javascript::LANGUAGE,
[],
);
add_language(
&mut map,
"TypeScript",
["TS", "ts", "Typescript", "typescript"],
tree_sitter_typescript::LANGUAGE_TYPESCRIPT,
[],
);
add_language(
&mut map,
"TSX",
["tsx"],
tree_sitter_typescript::LANGUAGE_TSX,
[],
);
add_language(
&mut map,
"Markdown",
["md", "markdown"],
tree_sitter_md::LANGUAGE.into(),
["inline"],
);
map
});

enum ChunkKind<'t> {
Expand Down Expand Up @@ -162,6 +229,7 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> {

struct RecursiveChunker<'s> {
full_text: &'s str,
lang_config: Option<&'s LanguageConfig>,
chunk_size: usize,
chunk_overlap: usize,
}
Expand Down Expand Up @@ -225,20 +293,27 @@ impl<'t, 's: 't> RecursiveChunker<'s> {
) -> Result<()> {
match chunk.kind {
ChunkKind::TreeSitterNode { node } => {
let mut cursor = node.walk();
if cursor.goto_first_child() {
self.process_sub_chunks(
TreeSitterNodeIter {
full_text: self.full_text,
cursor: Some(cursor),
next_start_pos: node.start_byte(),
end_pos: node.end_byte(),
},
output,
)?;
} else {
self.add_output(chunk.range, output);
if !self
.lang_config
.ok_or_else(|| anyhow!("Language not set."))?
.terminal_node_kind_ids
.contains(&node.kind_id())
{
let mut cursor = node.walk();
if cursor.goto_first_child() {
self.process_sub_chunks(
TreeSitterNodeIter {
full_text: self.full_text,
cursor: Some(cursor),
next_start_pos: node.start_byte(),
end_pos: node.end_byte(),
},
output,
)?;
return Ok(());
}
}
self.add_output(chunk.range, output);
}
ChunkKind::RegexpSepChunk { next_regexp_sep_id } => {
if next_regexp_sep_id >= TEXT_SEPARATOR.len() {
Expand Down Expand Up @@ -335,8 +410,17 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
impl SimpleFunctionExecutor for Executor {
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let full_text = self.args.text.value(&input)?.as_str()?;
let lang_config = {
let language = self.args.language.value(&input)?;
language
.map(|v| anyhow::Ok(v.as_str()?.as_ref()))
.transpose()?
.and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(lang))
};

let recursive_chunker = RecursiveChunker {
full_text,
lang_config: lang_config.map(|c| c.as_ref()),
chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize,
chunk_overlap: self
.args
Expand All @@ -347,20 +431,11 @@ impl SimpleFunctionExecutor for Executor {
.unwrap_or(0) as usize,
};

let language = self.args.language.value(&input)?;
let language = language
.map(|v| anyhow::Ok(v.as_str()?.as_ref()))
.transpose()?;
let tree_sitter_language = language.and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(lang));

let mut output = if let Some(tree_sitter_lang) = tree_sitter_language {
let mut output = if let Some(lang_config) = lang_config {
let mut parser = tree_sitter::Parser::new();
parser.set_language(tree_sitter_lang)?;
parser.set_language(&lang_config.tree_sitter_lang)?;
let tree = parser.parse(full_text.as_ref(), None).ok_or_else(|| {
anyhow!(
"failed in parsing text in language: {}",
language.unwrap_or_default()
)
anyhow!("failed in parsing text in language: {}", lang_config.name)
})?;
recursive_chunker.split_root_chunk(ChunkKind::TreeSitterNode {
node: tree.root_node(),
Expand Down