diff --git a/examples/code_embedding/main.py b/examples/code_embedding/main.py index e06c089e9..d136bebd5 100644 --- a/examples/code_embedding/main.py +++ b/examples/code_embedding/main.py @@ -1,6 +1,19 @@ from dotenv import load_dotenv import cocoindex +import os + +class ExtractExtension(cocoindex.op.FunctionSpec): + """Summarize a Python module.""" + +@cocoindex.op.executor_class() +class ExtractExtensionExecutor: + """Executor for ExtractExtension.""" + + spec: ExtractExtension + + def __call__(self, filename: str) -> str: + return os.path.splitext(filename)[1] def code_to_embedding(text: cocoindex.DataSlice) -> cocoindex.DataSlice: """ @@ -17,14 +30,15 @@ def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind """ data_scope["files"] = flow_builder.add_source( cocoindex.sources.LocalFile(path="../..", - included_patterns=["*.py"], - excluded_patterns=[".*"])) + included_patterns=["*.py", "*.rs", "*.toml", "*.md", "*.mdx"], + excluded_patterns=[".*", "target", "**/node_modules"])) code_embeddings = data_scope.add_collector() with data_scope["files"].row() as file: + file["extension"] = file["filename"].transform(ExtractExtension()) file["chunks"] = file["content"].transform( cocoindex.functions.SplitRecursively(), - language="python", chunk_size=1000, chunk_overlap=300) + language=file["extension"], chunk_size=1000, chunk_overlap=300) with file["chunks"].row() as chunk: chunk["embedding"] = chunk["text"].call(code_to_embedding) code_embeddings.collect(filename=file["filename"], location=chunk["location"], diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 150109be9..1fa3bf4b4 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -117,7 +117,7 @@ static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock, Arc add_language( &mut map, "Markdown", - [".md", "md"], + [".md", ".mdx", "md"], tree_sitter_md::LANGUAGE, ["inline"], );