diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index 27a501717..7ae3f1a44 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -1,6 +1,7 @@ """All builtin functions.""" from typing import Annotated, Any, TYPE_CHECKING +import dataclasses from .typing import Float32, Vector, TypeAttr from . import op, llm @@ -14,9 +15,20 @@ class ParseJson(op.FunctionSpec): """Parse a text into a JSON object.""" +@dataclasses.dataclass +class CustomLanguageSpec: + """Custom language specification.""" + + language_name: str + separators_regex: list[str] + aliases: list[str] = dataclasses.field(default_factory=list) + + class SplitRecursively(op.FunctionSpec): """Split a document (in string) recursively.""" + custom_languages: list[CustomLanguageSpec] = dataclasses.field(default_factory=list) + class ExtractByLlm(op.FunctionSpec): """Extract information from a text using a LLM.""" diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index b10d111ca..e04ef003e 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -8,9 +8,22 @@ use std::{collections::HashMap, sync::Arc}; use unicase::UniCase; use crate::base::field_attrs; +use crate::ops::registry::ExecutorFactoryRegistry; use crate::{fields_value, ops::sdk::*}; -type Spec = EmptySpec; +#[derive(Deserialize)] +struct CustomLanguageSpec { + language_name: String, + #[serde(default)] + aliases: Vec, + separators_regex: Vec, +} + +#[derive(Deserialize)] +struct Spec { + #[serde(default)] + custom_languages: Vec, +} const SYNTAX_LEVEL_GAP_COST: usize = 512; const MISSING_OVERLAP_COST: usize = 512; @@ -25,21 +38,30 @@ pub struct Args { language: Option, } -static TEXT_SEPARATOR: LazyLock> = LazyLock::new(|| { - [r"\n\n+", r"\n", r"\s+"] - .into_iter() - .map(|s| Regex::new(s).unwrap()) - .collect() -}); +struct SimpleLanguageConfig { + name: String, + aliases: Vec, + separator_regex: Vec, +} -struct LanguageConfig { - name: &'static str, +static DEFAULT_LANGUAGE_CONFIG: LazyLock = + LazyLock::new(|| SimpleLanguageConfig { + name: "_DEFAULT".to_string(), + aliases: vec![], + separator_regex: [r"\n\n+", r"\n", r"\s+"] + .into_iter() + .map(|s| Regex::new(s).unwrap()) + .collect(), + }); + +struct TreesitterLanguageConfig { + name: String, tree_sitter_lang: tree_sitter::Language, terminal_node_kind_ids: HashSet, } -fn add_language<'a>( - output: &'a mut HashMap, Arc>, +fn add_treesitter_language<'a>( + output: &'a mut HashMap, Arc>, name: &'static str, aliases: impl IntoIterator, lang_fn: impl Into, @@ -60,8 +82,8 @@ fn add_language<'a>( }) .collect(); - let config = Arc::new(LanguageConfig { - name, + let config = Arc::new(TreesitterLanguageConfig { + name: name.to_string(), tree_sitter_lang, terminal_node_kind_ids, }); @@ -72,144 +94,150 @@ fn add_language<'a>( } } -static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock, Arc>> = - LazyLock::new(|| { - let mut map = HashMap::new(); - add_language(&mut map, "C", [".c"], tree_sitter_c::LANGUAGE, []); - add_language( - &mut map, - "C++", - [".cpp", ".cc", ".cxx", ".h", ".hpp", "cpp"], - tree_sitter_c::LANGUAGE, - [], - ); - add_language( - &mut map, - "C#", - [".cs", "cs"], - tree_sitter_c_sharp::LANGUAGE, - [], - ); - add_language( - &mut map, - "CSS", - [".css", ".scss"], - tree_sitter_css::LANGUAGE, - [], - ); - add_language( - &mut map, - "Fortran", - [".f", ".f90", ".f95", ".f03", "f", "f90", "f95", "f03"], - tree_sitter_fortran::LANGUAGE, - [], - ); - add_language( - &mut map, - "Go", - [".go", "golang"], - tree_sitter_go::LANGUAGE, - [], - ); - add_language( - &mut map, - "HTML", - [".html", ".htm"], - tree_sitter_html::LANGUAGE, - [], - ); - add_language(&mut map, "Java", [".java"], tree_sitter_java::LANGUAGE, []); - add_language( - &mut map, - "JavaScript", - [".js", "js"], - tree_sitter_javascript::LANGUAGE, - [], - ); - add_language(&mut map, "JSON", [".json"], tree_sitter_json::LANGUAGE, []); - add_language( - &mut map, - "Markdown", - [".md", ".mdx", "md"], - tree_sitter_md::LANGUAGE, - ["inline"], - ); - add_language( - &mut map, - "Pascal", - [".pas", "pas", ".dpr", "dpr", "Delphi"], - tree_sitter_pascal::LANGUAGE, - [], - ); - add_language(&mut map, "PHP", [".php"], tree_sitter_php::LANGUAGE_PHP, []); - add_language( - &mut map, - "Python", - [".py"], - tree_sitter_python::LANGUAGE, - [], - ); - add_language(&mut map, "R", [".r"], tree_sitter_r::LANGUAGE, []); - add_language(&mut map, "Ruby", [".rb"], tree_sitter_ruby::LANGUAGE, []); - add_language( - &mut map, - "Rust", - [".rs", "rs"], - tree_sitter_rust::LANGUAGE, - [], - ); - add_language( - &mut map, - "Scala", - [".scala"], - tree_sitter_scala::LANGUAGE, - [], - ); - add_language(&mut map, "SQL", [".sql"], tree_sitter_sequel::LANGUAGE, []); - add_language( - &mut map, - "Swift", - [".swift"], - tree_sitter_swift::LANGUAGE, - [], - ); - add_language( - &mut map, - "TOML", - [".toml"], - tree_sitter_toml_ng::LANGUAGE, - [], - ); - add_language( - &mut map, - "TSX", - [".tsx"], - tree_sitter_typescript::LANGUAGE_TSX, - [], - ); - add_language( - &mut map, - "TypeScript", - [".ts", "ts"], - tree_sitter_typescript::LANGUAGE_TYPESCRIPT, - [], - ); - add_language(&mut map, "XML", [".xml"], tree_sitter_xml::LANGUAGE_XML, []); - add_language(&mut map, "DTD", [".dtd"], tree_sitter_xml::LANGUAGE_DTD, []); - add_language( - &mut map, - "YAML", - [".yaml", ".yml"], - tree_sitter_yaml::LANGUAGE, - [], - ); - map - }); +static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock< + HashMap, Arc>, +> = LazyLock::new(|| { + let mut map = HashMap::new(); + add_treesitter_language(&mut map, "C", [".c"], tree_sitter_c::LANGUAGE, []); + add_treesitter_language( + &mut map, + "C++", + [".cpp", ".cc", ".cxx", ".h", ".hpp", "cpp"], + tree_sitter_c::LANGUAGE, + [], + ); + add_treesitter_language( + &mut map, + "C#", + [".cs", "cs"], + tree_sitter_c_sharp::LANGUAGE, + [], + ); + add_treesitter_language( + &mut map, + "CSS", + [".css", ".scss"], + tree_sitter_css::LANGUAGE, + [], + ); + add_treesitter_language( + &mut map, + "Fortran", + [".f", ".f90", ".f95", ".f03", "f", "f90", "f95", "f03"], + tree_sitter_fortran::LANGUAGE, + [], + ); + add_treesitter_language( + &mut map, + "Go", + [".go", "golang"], + tree_sitter_go::LANGUAGE, + [], + ); + add_treesitter_language( + &mut map, + "HTML", + [".html", ".htm"], + tree_sitter_html::LANGUAGE, + [], + ); + add_treesitter_language(&mut map, "Java", [".java"], tree_sitter_java::LANGUAGE, []); + add_treesitter_language( + &mut map, + "JavaScript", + [".js", "js"], + tree_sitter_javascript::LANGUAGE, + [], + ); + add_treesitter_language(&mut map, "JSON", [".json"], tree_sitter_json::LANGUAGE, []); + add_treesitter_language( + &mut map, + "Markdown", + [".md", ".mdx", "md"], + tree_sitter_md::LANGUAGE, + ["inline"], + ); + add_treesitter_language( + &mut map, + "Pascal", + [".pas", "pas", ".dpr", "dpr", "Delphi"], + tree_sitter_pascal::LANGUAGE, + [], + ); + add_treesitter_language(&mut map, "PHP", [".php"], tree_sitter_php::LANGUAGE_PHP, []); + add_treesitter_language( + &mut map, + "Python", + [".py"], + tree_sitter_python::LANGUAGE, + [], + ); + add_treesitter_language(&mut map, "R", [".r"], tree_sitter_r::LANGUAGE, []); + add_treesitter_language(&mut map, "Ruby", [".rb"], tree_sitter_ruby::LANGUAGE, []); + add_treesitter_language( + &mut map, + "Rust", + [".rs", "rs"], + tree_sitter_rust::LANGUAGE, + [], + ); + add_treesitter_language( + &mut map, + "Scala", + [".scala"], + tree_sitter_scala::LANGUAGE, + [], + ); + add_treesitter_language(&mut map, "SQL", [".sql"], tree_sitter_sequel::LANGUAGE, []); + add_treesitter_language( + &mut map, + "Swift", + [".swift"], + tree_sitter_swift::LANGUAGE, + [], + ); + add_treesitter_language( + &mut map, + "TOML", + [".toml"], + tree_sitter_toml_ng::LANGUAGE, + [], + ); + add_treesitter_language( + &mut map, + "TSX", + [".tsx"], + tree_sitter_typescript::LANGUAGE_TSX, + [], + ); + add_treesitter_language( + &mut map, + "TypeScript", + [".ts", "ts"], + tree_sitter_typescript::LANGUAGE_TYPESCRIPT, + [], + ); + add_treesitter_language(&mut map, "XML", [".xml"], tree_sitter_xml::LANGUAGE_XML, []); + add_treesitter_language(&mut map, "DTD", [".dtd"], tree_sitter_xml::LANGUAGE_DTD, []); + add_treesitter_language( + &mut map, + "YAML", + [".yaml", ".yml"], + tree_sitter_yaml::LANGUAGE, + [], + ); + map +}); enum ChunkKind<'t> { - TreeSitterNode { node: tree_sitter::Node<'t> }, - RegexpSepChunk { next_regexp_sep_id: usize }, - LeafText, + TreeSitterNode { + lang_config: &'t TreesitterLanguageConfig, + node: tree_sitter::Node<'t>, + }, + RegexpSepChunk { + lang_config: &'t SimpleLanguageConfig, + next_regexp_sep_id: usize, + }, } struct Chunk<'t, 's: 't> { @@ -225,17 +253,23 @@ impl<'t, 's: 't> Chunk<'t, 's> { } struct TextChunksIter<'t, 's: 't> { + lang_config: &'t SimpleLanguageConfig, parent: &'t Chunk<'t, 's>, - matches_iter: Matches<'static, 's>, + matches_iter: Matches<'t, 's>, regexp_sep_id: usize, next_start_pos: Option, } impl<'t, 's: 't> TextChunksIter<'t, 's> { - fn new(parent: &'t Chunk<'t, 's>, regexp_sep_id: usize) -> Self { + fn new( + lang_config: &'t SimpleLanguageConfig, + parent: &'t Chunk<'t, 's>, + regexp_sep_id: usize, + ) -> Self { Self { + lang_config, parent, - matches_iter: TEXT_SEPARATOR[regexp_sep_id].find_iter(parent.text()), + matches_iter: lang_config.separator_regex[regexp_sep_id].find_iter(parent.text()), regexp_sep_id, next_start_pos: Some(parent.range.start), } @@ -264,6 +298,7 @@ impl<'t, 's: 't> Iterator for TextChunksIter<'t, 's> { full_text: self.parent.full_text, range: RangeValue::new(start_pos, end_pos), kind: ChunkKind::RegexpSepChunk { + lang_config: self.lang_config, next_regexp_sep_id: self.regexp_sep_id + 1, }, }) @@ -271,6 +306,7 @@ impl<'t, 's: 't> Iterator for TextChunksIter<'t, 's> { } struct TreeSitterNodeIter<'t, 's: 't> { + lang_config: &'t TreesitterLanguageConfig, full_text: &'s str, cursor: Option>, next_start_pos: usize, @@ -289,7 +325,10 @@ impl<'t, 's: 't> TreeSitterNodeIter<'t, 's> { Some(Chunk { full_text, range: RangeValue::new(start_pos, gap_end_pos), - kind: ChunkKind::LeafText, + kind: ChunkKind::RegexpSepChunk { + lang_config: &DEFAULT_LANGUAGE_CONFIG, + next_regexp_sep_id: 0, + }, }) } else { None @@ -319,7 +358,10 @@ impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> { Some(Chunk { full_text: self.full_text, range: RangeValue::new(node.start_byte(), node.end_byte()), - kind: ChunkKind::TreeSitterNode { node }, + kind: ChunkKind::TreeSitterNode { + lang_config: self.lang_config, + node, + }, }) } } @@ -424,7 +466,6 @@ impl<'s> AtomChunksCollector<'s> { struct RecursiveChunker<'s> { full_text: &'s str, - lang_config: Option<&'s LanguageConfig>, chunk_size: usize, chunk_overlap: usize, min_chunk_size: usize, @@ -458,56 +499,48 @@ impl<'t, 's: 't> RecursiveChunker<'s> { atom_collector: &mut AtomChunksCollector<'s>, ) -> Result<()> { match chunk.kind { - ChunkKind::TreeSitterNode { node } => { - if !self - .lang_config - .ok_or_else(|| anyhow!("Language not set."))? - .terminal_node_kind_ids - .contains(&node.kind_id()) - { + ChunkKind::TreeSitterNode { lang_config, node } => { + if !lang_config.terminal_node_kind_ids.contains(&node.kind_id()) { let mut cursor = node.walk(); if cursor.goto_first_child() { - self.collect_atom_chunks_from_iter( + return self.collect_atom_chunks_from_iter( TreeSitterNodeIter { + lang_config, full_text: self.full_text, cursor: Some(cursor), next_start_pos: node.start_byte(), end_pos: node.end_byte(), }, atom_collector, - )?; - return Ok(()); + ); } } + self.collect_atom_chunks( + Chunk { + full_text: self.full_text, + range: chunk.range, + kind: ChunkKind::RegexpSepChunk { + lang_config: &DEFAULT_LANGUAGE_CONFIG, + next_regexp_sep_id: 0, + }, + }, + atom_collector, + ) } - ChunkKind::RegexpSepChunk { next_regexp_sep_id } => { - if next_regexp_sep_id >= TEXT_SEPARATOR.len() { - atom_collector.collect(chunk.range); + ChunkKind::RegexpSepChunk { + lang_config, + next_regexp_sep_id, + } => { + if next_regexp_sep_id >= lang_config.separator_regex.len() { + Ok(atom_collector.collect(chunk.range)) } else { self.collect_atom_chunks_from_iter( - TextChunksIter::new(&chunk, next_regexp_sep_id), + TextChunksIter::new(lang_config, &chunk, next_regexp_sep_id), atom_collector, - )?; + ) } - return Ok(()); } - ChunkKind::LeafText => {} - } - if chunk.range.len() > self.chunk_size { - self.collect_atom_chunks( - Chunk { - full_text: self.full_text, - range: chunk.range, - kind: ChunkKind::RegexpSepChunk { - next_regexp_sep_id: 0, - }, - }, - atom_collector, - )?; - } else { - atom_collector.collect(chunk.range); } - Ok(()) } fn get_overlap_cost_base(&self, offset: usize) -> usize { @@ -685,11 +718,54 @@ impl<'t, 's: 't> RecursiveChunker<'s> { struct Executor { args: Args, + custom_languages: HashMap, Arc>, } impl Executor { - fn new(args: Args) -> Result { - Ok(Self { args }) + fn new(args: Args, spec: Spec) -> Result { + let mut custom_languages = HashMap::new(); + for lang in spec.custom_languages { + let separator_regex = lang + .separators_regex + .iter() + .map(|s| Regex::new(s)) + .collect::>() + .with_context(|| { + format!( + "failed in parsing regexp for language `{}`", + lang.language_name + ) + })?; + let language_config = Arc::new(SimpleLanguageConfig { + name: lang.language_name, + aliases: lang.aliases, + separator_regex, + }); + if custom_languages + .insert( + UniCase::new(language_config.name.clone()), + language_config.clone(), + ) + .is_some() + { + api_bail!( + "duplicate language name / alias: `{}`", + language_config.name + ); + } + for alias in &language_config.aliases { + if custom_languages + .insert(UniCase::new(alias.clone()), language_config.clone()) + .is_some() + { + api_bail!("duplicate language name / alias: `{}`", alias); + } + } + } + Ok(Self { + args, + custom_languages, + }) } } @@ -728,18 +804,9 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator) -> Result { let full_text = self.args.text.value(&input)?.as_str()?; - let lang_config = { - let language = self.args.language.value(&input)?; - language - .optional() - .map(|v| anyhow::Ok(v.as_str()?.as_ref())) - .transpose()? - .and_then(|lang| TREE_SITTER_LANGUAGE_BY_LANG.get(&UniCase::new(lang))) - }; let chunk_size = self.args.chunk_size.value(&input)?.as_int64()?; let recursive_chunker = RecursiveChunker { full_text, - lang_config: lang_config.map(|c| c.as_ref()), chunk_size: chunk_size as usize, chunk_overlap: (self.args.chunk_overlap.value(&input)?) .optional() @@ -753,17 +820,32 @@ impl SimpleFunctionExecutor for Executor { .unwrap_or(chunk_size / 2) as usize, }; - let mut output = if let Some(lang_config) = lang_config { + let language = UniCase::new( + (if let Some(language) = self.args.language.value(&input)?.optional() { + language.as_str()? + } else { + "" + }) + .to_string(), + ); + let mut output = if let Some(lang_config) = self.custom_languages.get(&language) { + recursive_chunker.split_root_chunk(ChunkKind::RegexpSepChunk { + lang_config, + next_regexp_sep_id: 0, + })? + } else if let Some(lang_config) = TREE_SITTER_LANGUAGE_BY_LANG.get(&language) { let mut parser = tree_sitter::Parser::new(); parser.set_language(&lang_config.tree_sitter_lang)?; let tree = parser.parse(full_text.as_ref(), None).ok_or_else(|| { anyhow!("failed in parsing text in language: {}", lang_config.name) })?; recursive_chunker.split_root_chunk(ChunkKind::TreeSitterNode { + lang_config, node: tree.root_node(), })? } else { recursive_chunker.split_root_chunk(ChunkKind::RegexpSepChunk { + lang_config: &DEFAULT_LANGUAGE_CONFIG, next_regexp_sep_id: 0, })? }; @@ -784,7 +866,7 @@ impl SimpleFunctionExecutor for Executor { } } -pub struct Factory; +struct Factory; #[async_trait] impl SimpleFunctionFactoryBase for Factory { @@ -839,14 +921,18 @@ impl SimpleFunctionFactoryBase for Factory { async fn build_executor( self: Arc, - _spec: Spec, + spec: Spec, args: Args, _context: Arc, ) -> Result> { - Ok(Box::new(Executor::new(args)?)) + Ok(Box::new(Executor::new(args, spec)?)) } } +pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> { + Factory.register(registry) +} + #[cfg(test)] mod tests { use super::*; @@ -883,7 +969,6 @@ mod tests { ) -> RecursiveChunker { RecursiveChunker { full_text: text, - lang_config: None, chunk_size, chunk_overlap, min_chunk_size, @@ -928,6 +1013,7 @@ mod tests { let chunker = create_test_chunker(text, 15, 5, 0); let result = chunker.split_root_chunk(ChunkKind::RegexpSepChunk { + lang_config: &DEFAULT_LANGUAGE_CONFIG, next_regexp_sep_id: 0, }); @@ -943,6 +1029,7 @@ mod tests { let text2 = "A very very long text that needs to be split."; let chunker2 = create_test_chunker(text2, 20, 12, 0); let result2 = chunker2.split_root_chunk(ChunkKind::RegexpSepChunk { + lang_config: &DEFAULT_LANGUAGE_CONFIG, next_regexp_sep_id: 0, }); @@ -960,6 +1047,7 @@ mod tests { let chunker = create_test_chunker(text, 20, 10, 5); let result = chunker.split_root_chunk(ChunkKind::RegexpSepChunk { + lang_config: &DEFAULT_LANGUAGE_CONFIG, next_regexp_sep_id: 0, }); @@ -981,6 +1069,7 @@ mod tests { let chunker = create_test_chunker(text, 30, 10, 0); let result = chunker.split_root_chunk(ChunkKind::RegexpSepChunk { + lang_config: &DEFAULT_LANGUAGE_CONFIG, next_regexp_sep_id: 0, }); diff --git a/src/ops/registration.rs b/src/ops/registration.rs index c7ac01f92..00a805698 100644 --- a/src/ops/registration.rs +++ b/src/ops/registration.rs @@ -13,7 +13,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result sources::amazon_s3::Factory.register(registry)?; functions::parse_json::Factory.register(registry)?; - functions::split_recursively::Factory.register(registry)?; + functions::split_recursively::register(registry)?; functions::extract_by_llm::Factory.register(registry)?; storages::postgres::Factory::default().register(registry)?;