diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 64b2a96..6038954 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,7 +43,7 @@ jobs: - name: Run Flake8 run: | source venv/bin/activate - flake8 + flake8 src/ actionlint: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 9b69bfd..b21e5a1 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,5 @@ venv.bak/ ehthumbs.db Thumbs.db CLAUDE.md +db/ +vtk-examples.json diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..b121779 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "rag-components"] + path = rag-components + url = git@gitlab.kitware.com:christos.tsolakis/rag-components.git diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 40e7bd5..0000000 --- a/Dockerfile +++ /dev/null @@ -1,28 +0,0 @@ -LABEL org.opencontainers.image.title="VTK MCP Server" -LABEL org.opencontainers.image.description="Model Context Protocol server for VTK class documentation" -LABEL org.opencontainers.image.source="https://github.com/kitware/vtk-mcp" -LABEL org.opencontainers.image.authors="Vicente Adolfo Bolea Sanchez " -LABEL org.opencontainers.image.licenses="MIT" -LABEL org.opencontainers.image.documentation="https://github.com/kitware/vtk-mcp/blob/main/README.md" - -FROM python:3.12-slim - -ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \ - PIP_NO_CACHE_DIR=1 \ - PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 - -# Install system dependencies for VTK -RUN apt update && \ - apt install --no-install-recommends --no-install-suggests -y \ - libgl1-mesa-dev \ - libxrender-dev/stable - -WORKDIR /app -COPY . . -RUN pip install --upgrade pip && \ - pip install --verbose . - -EXPOSE 8000 - -CMD ["vtk-mcp-server", "--transport", "http", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md index 0232fc8..520b9b2 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,50 @@ vtk-mcp-client --host localhost --port 8000 info-cpp vtkActor ## MCP Tools -The server provides three MCP tools: +The server provides four MCP tools: - `get_vtk_class_info_cpp(class_name)` - Get detailed C++ documentation for a VTK class from online documentation - `get_vtk_class_info_python(class_name)` - Get Python API documentation using help() function - `search_vtk_classes(search_term)` - Search for VTK classes containing a term +- `vector_search_vtk_examples(query)` - Search VTK examples using vector similarity (requires embeddings database) + +## Vector Search with RAG + +The server supports semantic search over VTK Python examples using vector embeddings. This requires the embeddings database. + +### Downloading the Embeddings Database + +The pre-built embeddings database is available as a container image on GitHub Container Registry: + +```bash +# Using Docker +docker create --name vtk-embeddings ghcr.io/kitware/vtk-mcp/embeddings-database:latest +docker cp vtk-embeddings:/vtk-examples-embeddings.tar.gz . +docker rm vtk-embeddings + +# Using Podman +podman create --name vtk-embeddings ghcr.io/kitware/vtk-mcp/embeddings-database:latest +podman cp vtk-embeddings:/vtk-examples-embeddings.tar.gz . +podman rm vtk-embeddings + +# Extract the database +tar -xzf vtk-examples-embeddings.tar.gz +``` + +### Using Vector Search + +After downloading and extracting the database, start the server with the database path: + +```bash +# Install RAG dependencies +pip install -r rag-components/requirements.txt + +# Start server with vector search enabled +vtk-mcp-server --transport http --database-path ./db/vtk-examples + +# Use vector search with the client +vtk-mcp-client vector-search "render a sphere" +vtk-mcp-client vector-search "read DICOM files" --top-k 10 +``` ## Docker diff --git a/deploy.Dockerfile b/deploy.Dockerfile new file mode 100644 index 0000000..50fe1b3 --- /dev/null +++ b/deploy.Dockerfile @@ -0,0 +1,48 @@ +LABEL org.opencontainers.image.title="VTK MCP Server with Embeddings" +LABEL org.opencontainers.image.description="Model Context Protocol server for VTK with vector search embeddings" +LABEL org.opencontainers.image.source="https://github.com/kitware/vtk-mcp" +LABEL org.opencontainers.image.authors="Vicente Adolfo Bolea Sanchez " +LABEL org.opencontainers.image.licenses="MIT" +LABEL org.opencontainers.image.documentation="https://github.com/kitware/vtk-mcp/blob/main/README.md" + +FROM python:3.12-slim AS embeddings + +# Download embeddings database from GHCR +COPY --from=ghcr.io/kitware/vtk-mcp/embeddings-database:latest /vtk-examples-embeddings.tar.gz /tmp/ + +# Extract the database +RUN mkdir -p /app/db && \ + tar -xzf /tmp/vtk-examples-embeddings.tar.gz -C /app/db && \ + rm /tmp/vtk-examples-embeddings.tar.gz + +FROM python:3.12-slim + +ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \ + PIP_NO_CACHE_DIR=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +# Install system dependencies for VTK +RUN apt update && \ + apt install --no-install-recommends --no-install-suggests -y \ + libgl1-mesa-dev \ + libxrender-dev/stable && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy application code +COPY . . + +# Copy embeddings database from first stage +COPY --from=embeddings /app/db /app/db + +# Install Python dependencies (including RAG dependencies) +RUN pip install --upgrade pip && \ + pip install --verbose . && \ + pip install -r rag-components/requirements.txt + +EXPOSE 8000 + +# Start server with database path configured +CMD ["vtk-mcp-server", "--transport", "http", "--host", "0.0.0.0", "--port", "8000", "--database-path", "/app/db/vtk-examples"] diff --git a/pyproject.toml b/pyproject.toml index b08009b..0e1726d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ markers = [ "integration: Integration tests that require server/client interaction", "http: HTTP transport integration tests", "stdio: Stdio transport integration tests", + "vector_search: Vector search integration tests (requires podman and embeddings database)", "slow: Tests that take longer to run", ] filterwarnings = [ diff --git a/rag-components/.gitignore b/rag-components/.gitignore new file mode 100644 index 0000000..52693c6 --- /dev/null +++ b/rag-components/.gitignore @@ -0,0 +1,13 @@ +# OS files +.DS_Store + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Environments +.env +.venv +env/ +venv/ diff --git a/rag-components/LICENSE b/rag-components/LICENSE new file mode 100644 index 0000000..7e38aea --- /dev/null +++ b/rag-components/LICENSE @@ -0,0 +1,13 @@ +Copyright 2025 Kitware Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/rag-components/README.md b/rag-components/README.md new file mode 100644 index 0000000..e54ac66 --- /dev/null +++ b/rag-components/README.md @@ -0,0 +1,67 @@ +# A simple RAG for VTK + +This project creates a database out of the existing Python examples of VTK and allows to ask questions related to VTK. + +## Set up +1. By default it uses the OpenAI API. Make sure you get an API key and set +your environmental variable appropriately. To use other an other model see +[below](#supported-llm-models). + +2. Get the code of the vtk-examples. We will use this to generate our database. + +```bash +git clone https://gitlab.kitware.com/vtk/vtk-examples +``` + +3. Create a virtual environment and install the dependencies. + +```bash +git clone https://gitlab.kitware.com/vtk/vtk-examples +python -m venv env +source env/bin/activate +pip install -r requirements.txt +``` + +4. Populate the database. This is required only once or if you want to experiment with a different embedding function. +It will take some time depending on the hardware you are using. + +```bash +python populate_db.py --dir ./vtk-examples/src/Python +``` + +5. Now ask your question ! + +```bash +$ python chat.py --database ./db/codesage-codesage-large-v2 +User: How to read a vti file + To read a VTK image data file (.vti), you can use the `vtkXMLImageDataReader` class. Here is a basic example: + +import vtk + +# Create a reader for your vti file +reader = vtk.vtkXMLImageDataReader() +reader.SetFileName('your_file.vti') +reader.Update() + +# The output of reader.GetOutput() is your vtkImageData object +image_data = reader.GetOutput() + +In this code, replace `'your_file.vti'` with the path to your .vti file. The +`reader.Update()` call is necessary to actually perform the reading operation. +After this, you can use `reader.GetOutput()` to get the `vtkImageData` object +that was read from the file. + +References: +https://examples.vtk.org/site/Python/Medical/GenerateModelsFromLabels +https://examples.vtk.org/site/Python/ImageData/WriteReadVtkImageData +... +``` + +### Supported LLM models +`chat.py` uses by default "gpt-4" model to switch to a different one pass the name of the model via the `--model=` parameter. +Currently supported models: +- OpenAI models. See exact model names [here](https://platform.openai.com/docs/models#current-model-aliases). To use them you need an OpenAI API [key](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key). +- Anthropic models. See exact names [here](https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-names). To use them you need an Anthropic API [key](https://docs.anthropic.com/en/api/getting-started#accessing-the-api). +- Models supported by the Ollama framework. To use these models make sure you have [ollama](https://github.com/ollama/ollama) installed and that it + is running in another terminal (via `ollama serve`) and the you have already + pulled the model you want to use (via `ollama pull `). You can find available models [here](https://ollama.com/). diff --git a/rag-components/chat.py b/rag-components/chat.py new file mode 100644 index 0000000..f2d7f7e --- /dev/null +++ b/rag-components/chat.py @@ -0,0 +1,170 @@ +from llama_index.core.llms import ChatMessage +from llama_index.llms.openai import OpenAI +from llama_index.llms.anthropic import Anthropic +from llama_index.llms.ollama import Ollama +import argparse +import query_db +import json +from string import Template +from pathlib import Path +from typing import List +import os + +PROMPT = Template( + """ +You are an AI assistant specializing in VTK (Visualization Toolkit) +documentation. Your primary task is to provide accurate, concise, and helpful +responses to user queries about VTK, including relevant code snippets + +Here is the context information you should use to answer queries: + +$CONTEXT + + +Here's the user's query: + + +$QUERY + + +When responding to a user query, follow these guidelines: + +1. Relevance Check: + + - If the query is not relevant to VTK, respond with "This question is not relevant to VTK." + +2. Answer Formulation: + + - If you don't know the answer, clearly state that. + - If uncertain, ask the user for clarification. + - Respond in the same language as the user's query. + - Be concise while providing complete information. + - If the answer isn't in the context but you have the knowledge, explain this to the user and provide the answer based on your understanding. +""" +) +# History of a chat +HISTORY = [ + ChatMessage( + role="system", content="You are a helpful assistant" + ), # TODO what else to add ? +] + +llm = None +client = None + + +def init(model: str, database: str) -> None: + global llm, client + try: + if any( + x in model for x in ["gpt", "o1", "o3"] + ): # TODO this is too simplistic and error-prone + llm = OpenAI(model=model) + elif "claude" in model: + llm = Anthropic(model=model) + else: # assumming an ollama model + llm = Ollama( + model=model, request_timeout=3000.0, address="http://localhost:11434" + ) + except: + raise RuntimeError(f"Usupported Model {model}") + + client = query_db.initialize_db(database_path=database) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def ask(query: str, collection_name: str, top_k: int, streaming: bool = False): + # corpus = json.load(open(collection_name+".json")) + HISTORY.append(ChatMessage(role="user", content=query)) + + # Query the index to retrieve relevant documents + results = query_db.query_db(query, collection_name, top_k, client) + relevant_examples = [item["original_id"] for item in results["code_metadata"]] + [ + item["code"] for item in results["text_metadata"] + ] + snippets = [item for item in results["code_documents"]] + relevant_examples = list(set(relevant_examples)) + + # Combine the retrieved documents into a single text + retrieved_text = "\n\n## Next example:\n\n".join(snippets) + content = PROMPT.substitute(CONTEXT=retrieved_text, QUERY=query.rstrip()) + # print(content) + + # Add the retrieved text as a new message + HISTORY.append(ChatMessage(role="assistant", content=content.rstrip())) + + # Generate a response using the LLM + if streaming: + response = llm.stream_chat(HISTORY) + else: + response = llm.chat(HISTORY) + + return {"response": response, "references": relevant_examples} + + +def _generate_urls_from_references(references: List[str]): + urls = [] + for ref in references: + ref = Path(ref) + # this transformation alters + # vtk-examples.git/src/Python/PolyData/CurvaturesAdjustEdges.py + # to + # https://examples.vtk.org/site/Python/PolyData/CurvaturesAdjustEdges + url = "https://examples.vtk.org/site/{}".format( + (ref.relative_to(ref.parents[-3])).with_suffix("") + ) + urls.append(url) + return urls + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Query database for code snippets", + ) + parser.add_argument( + "--database", + type=str, + help="Path to the database.", + default="./db/codesage-codesage-large-v2", # this corresponds to the default path of the default model in populate_db.py + ) + parser.add_argument( + "--collection-name", + type=str, + help="Name for the collection in the database. Used to distinguish different families of documents during retrieval", + default="python-examples", + ) + parser.add_argument( + "--top-k", + type=int, + help="Retrieve the top `k` examples from the database when composing a context", + default=15, + ) + parser.add_argument( + "--model", + type=str, + help="LLM model to use", + default="gpt-4", + ) + args = parser.parse_args() + init(args.model, args.database) + print( + "Welcome to VTK's assistant! What would you like to know ?\ntype 'exit' to quit" + ) + while True: + user_input = input("User: ") + if len(user_input) == 0: + continue + full_reply = "" + if user_input.lower() == "exit": + print("Bye!") + break + reply = ask(user_input, args.collection_name, args.top_k, streaming=True) + for item in reply["response"]: + print(item.delta, end="", flush=True) + full_reply += item.delta + print( + "\n Here are some relevant references:\n", + "\n".join(_generate_urls_from_references(reply["references"])), + ) + # add reply to the chat history + HISTORY.append(ChatMessage(role="assistant", content=full_reply.rstrip())) diff --git a/rag-components/populate_db.py b/rag-components/populate_db.py new file mode 100644 index 0000000..01b96a7 --- /dev/null +++ b/rag-components/populate_db.py @@ -0,0 +1,191 @@ +import chromadb +from chromadb.api import ClientAPI +from chromadb.utils import embedding_functions +from pathlib import Path +from llama_index.core.node_parser import CodeSplitter, SentenceSplitter, TextSplitter +from sentence_transformers import SentenceTransformer +import hashlib + +from tqdm import tqdm +import json + +import argparse +from typing import Optional, List + +EMBEDDING_MODEL_TEXT = "all-MiniLM-L6-v2" # "nomic-ai/nomic-embed-text-v2-moe" + +LANGUAGE_SUFFIX = { + "markdown": "md", + "python": "py", +} + + +def _create_collection( + client: ClientAPI, name: str, embedding_model: str, splitter: TextSplitter +): + collection = client.create_collection( + name=name, + embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=embedding_model, trust_remote_code=True + ), + metadata={ + "embedding_model": embedding_model, # chroma db does not store the embedding model but it is required when performing queries, so we store it as metadata + # not all model return normalized embeddings so the metric matters. Let's stick with cosine similarity + # https://wietsevenema.eu/is-cosine-similarity-always-the-best-choice-for-text-embedding-search + "hnsw:space": "cosine", + "splitter": splitter.json(), + }, + ) + return collection + + +def md5_for_file(path: Path): + with open(path, "rb") as f: + data = ( + f.read() + ) # TODO if we have at some point larger files we will need to read increamentally + return hashlib.md5(data).hexdigest() + + +def _get_unique(files: List[Path]): + unique = {} + for file in files: + if not file.exists(): + continue + md5 = md5_for_file(file) + if md5 not in unique: + unique[md5] = file + return list(unique.values()) + + +def fill_database( + files: List[Path], + database_path: str, + embedding_model: str, + language: Optional[str] = None, + collection_name: str = "", + corpus_name: Optional[str] = None, +): + if corpus_name is None: + corpus_name = collection_name + ".json" + + code_splitter = None + if language is not None: + # larger chunk_size means larger + # context. Remember however that LLMs have limmited context anyways so, you + # need to trade-off between few big examples and many small. + code_splitter = CodeSplitter( + language=language, + chunk_lines=20, + chunk_lines_overlap=5, # this does not seem to have any effect + max_chars=700, + ) + text_splitter = SentenceSplitter() + + print("Setting up database ...") + client = chromadb.PersistentClient(path=database_path) + + # Create a collection with the appropriate embedding function + collection_code = _create_collection( + client, collection_name, embedding_model, code_splitter + ) + # create a collection with descriptions of the examples. + collection_text = _create_collection( + client, collection_name + "_text", EMBEDDING_MODEL_TEXT, text_splitter + ) + + code_files = [file for file in files if file.suffix == ".py"] + markdown_files = [file.parent / (file.stem + ".md") for file in code_files] + # keep only unique files, otherwise we may cause overfit. Also, it produces warnings in chromadb + code_files = _get_unique(code_files) + markdown_files = _get_unique(markdown_files) + + # TODO how to use the following which processing snippets in parallel while keeping track of the association between chunk,path ? + # text_embeddings = text_model.encode( text, batch_size=4,show_progress_bar=True,) + + # corpus is the dictionary of filesnames->content . We save is as json and + # use it in the retrieval pahse to get the original documents + corpus = {} + + for fcode, ftext in tqdm(zip(code_files, markdown_files)): + print(fcode) + with open(fcode, "r") as code: + try: + code = code.read() + code_chunks = code_splitter.split_text(code) + collection_code.add( + documents=code_chunks, + ids=[f"snippet_{fcode.name}_{i}" for i in range(len(code_chunks))], + metadatas=[ + { + "original_id": str(fcode), + "text": ftext.name if ftext.exists() else "null", + } # chromadb does not accept None as metadata value + for _ in range(len(code_chunks)) + ], + ) + corpus[str(fcode)] = code + if ftext.exists(): + text = open(ftext, "r") + text = text.read() + text_chunks = text_splitter.split_text(text) + collection_text.add( + documents=text_chunks, + ids=[ + f"snippet_{ftext.name}_{i}" for i in range(len(text_chunks)) + ], + metadatas=[ + {"original_id": str(ftext), "code": str(fcode)} + for _ in range(len(text_chunks)) + ], + ) + corpus[str(ftext)] = text + except ValueError as error: + print(f"Skipping {fcode}") + print(error) + + with open(corpus_name, "w") as file: + json.dump(corpus, file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Populate database with embeddings of documents based on a model", + ) + parser.add_argument( + "--language", type=str, help="Language of the documents", default="python" + ) + parser.add_argument( + "--dir", + required=True, + type=str, + help="Base Directory of documents. The script will collect all documents recursively", + ) + parser.add_argument( + "--database", + type=str, + help="Name of the database. Default is ./db/model-name", + default=None, + ) + parser.add_argument( + "--model", + default="codesage/codesage-large-v2", + help="Model name for embedding. For now, we use sentence transformer models. See https://www.sbert.net/", + ) + parser.add_argument( + "--collection-name", + type=str, + help="Name for the collection in the database. Used to distinguish different families of documents during retrieval", + default="python-examples", + ) + args = parser.parse_args() + + files = [file for file in Path(args.dir).rglob("*") if not file.is_dir()] + + language = args.language.lower() + database_path = args.database + if database_path is None: + model_prefix = args.model.replace("/", "-") + database_path = str(Path("./db") / Path(model_prefix)) + + fill_database(files, database_path, args.model, language, args.collection_name) diff --git a/rag-components/query_db.py b/rag-components/query_db.py new file mode 100644 index 0000000..ceb65b2 --- /dev/null +++ b/rag-components/query_db.py @@ -0,0 +1,146 @@ +import chromadb +from chromadb.api import ClientAPI +from sentence_transformers import CrossEncoder +from chromadb.utils import embedding_functions +import argparse + + +def _get_collection(client, collection_name: str): + collection = client.get_collection(name=collection_name) + embedding_model = collection.metadata["embedding_model"] + + collection = client.get_collection( + name=collection_name, + embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=embedding_model, trust_remote_code=True + ), + ) + return collection + + +def initialize_db(database_path: str) -> ClientAPI: + return chromadb.PersistentClient(path=database_path) + + +def query_db_interactive( + query: str, + database_path: str, + collection_name: str, + top_k: int = 5, +): + client = initialize_db(database_path) + return query_db(query, collection_name, top_k, client) + + +def query_db( + query: str, + collection_name: str, + top_k: int = 5, + client: ClientAPI = None, +): + # how many results to get from the database ? This is the first selection. Once we get them we rerank them and select the top K. + num_code_query_results = 50 + num_text_query_results = 25 + + # Get collection by name and the appropriate embedding function + code_collection = _get_collection(client, collection_name) + text_collection = _get_collection(client, collection_name + "_text") + + # Perform search + code_results = code_collection.query( + query_texts=[query], n_results=num_code_query_results + ) + text_results = text_collection.query( + query_texts=[query], n_results=num_text_query_results + ) + code_snippets = code_results["documents"][0] + text_snippets = text_results["documents"][0] + + # rerank results and keep top k, these can be the contex of the LLM + rerank_model = CrossEncoder( + "jinaai/jina-reranker-v1-turbo-en", trust_remote_code=True + ) + code_reranked_results = rerank_model.rank(query, code_snippets, top_k=top_k) + text_reranked_results = rerank_model.rank(query, text_snippets, top_k=top_k) + + code_documents = [] + code_metadata = [] + code_scores = [] + text_documents = [] + text_metadata = [] + text_scores = [] + for citem, titem in zip(code_reranked_results, text_reranked_results): + code_corpus_id = citem["corpus_id"] + code_scores.append(citem["score"]) + code_documents.append(code_snippets[code_corpus_id]) + code_metadata.append(code_results["metadatas"][0][code_corpus_id]) + + text_corpus_id = titem["corpus_id"] + text_scores.append(titem["score"]) + text_documents.append(text_snippets[text_corpus_id]) + text_metadata.append(text_results["metadatas"][0][text_corpus_id]) + + return { + "code_documents": code_documents, + "code_metadata": code_metadata, + "code_scores": code_scores, + "text_documents": text_documents, + "text_metadata": text_metadata, + "text_scores": text_scores, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Query dabase for code snippets", + ) + parser.add_argument( + "query", + type=str, + help="Query for the search", + default=None, + ) + parser.add_argument( + "--database", + type=str, + help="Path to the database.", + default=None, + ) + parser.add_argument( + "--collection-name", + type=str, + help="Name for the collection in the database. Used to distinguish different families of documents during retrieval", + default="python-examples", + ) + parser.add_argument( + "--top-k", + type=int, + help="Return the top k matches tot he query", + default=5, + ) + args = parser.parse_args() + + print("***********************************************************************") + reranked_results = query_db_interactive( + args.query, args.database, args.collection_name, args.top_k + ) + print(f"Query: {args.query}\n") + print(f"Top {args.top_k} most similar code chunks:") + for document, metadata, score in zip( + reranked_results["code_documents"], + reranked_results["code_metadata"], + reranked_results["code_scores"], + ): + print(f"Score: {score}") + print(f"snippet {document}:") + print(f"from {metadata}:") + print("-------------------------------------------------------") + for document, metadata, score in zip( + reranked_results["text_documents"], + reranked_results["text_metadata"], + reranked_results["text_scores"], + ): + print(f"Score: {score}") + print(f"snippet {document}:") + print(f"from {metadata}:") + print("-------------------------------------------------------") diff --git a/rag-components/requirements.txt b/rag-components/requirements.txt new file mode 100644 index 0000000..165b00e --- /dev/null +++ b/rag-components/requirements.txt @@ -0,0 +1,11 @@ +accelerate==1.3.0 +chromadb==0.6.3 +llama_index-core==0.12.17 +llama-index-llms-anthropic==0.6.7 +llama-index-llms-ollama==0.5.2 +llama-index-llms-openai==0.3.20 +llama-index-readers-file==0.4.5 +sentence-transformers==3.4.1 +tree_sitter_languages==1.9.1 +tree_sitter<0.22 +tqdm==4.67.1 diff --git a/deploy/README.md b/share/README.md similarity index 100% rename from deploy/README.md rename to share/README.md diff --git a/deploy/vtk-mcp-server-compose.yml b/share/vtk-mcp-server-compose.yml similarity index 100% rename from deploy/vtk-mcp-server-compose.yml rename to share/vtk-mcp-server-compose.yml diff --git a/deploy/vtk-mcp-server-deploy.yml b/share/vtk-mcp-server-deploy.yml similarity index 90% rename from deploy/vtk-mcp-server-deploy.yml rename to share/vtk-mcp-server-deploy.yml index 19fb06b..6a5351a 100644 --- a/deploy/vtk-mcp-server-deploy.yml +++ b/share/vtk-mcp-server-deploy.yml @@ -19,7 +19,7 @@ src: "{{ item[0]}}" dest: "{{ item[1] }}" loop: - - ['./vtk-mcp-server-compose.yml', '/var/opt/vtk-mcp-server/'] + - ['./vtk-mcp-server-compose.yml', '/var/opt/vtk-mcp-server/'] - ['./vtk-mcp-server.service', '/etc/systemd/system/'] - name: restart vtk-mcp-server services diff --git a/deploy/vtk-mcp-server.service b/share/vtk-mcp-server.service similarity index 100% rename from deploy/vtk-mcp-server.service rename to share/vtk-mcp-server.service diff --git a/src/vtk_mcp_server/server.py b/src/vtk_mcp_server/server.py index b4affca..53a9393 100644 --- a/src/vtk_mcp_server/server.py +++ b/src/vtk_mcp_server/server.py @@ -3,7 +3,9 @@ import asyncio import click import io +import sys from contextlib import redirect_stdout +from pathlib import Path from fastmcp import FastMCP from .vtk_scraper import VTKClassScraper @@ -11,6 +13,9 @@ mcp = FastMCP("VTK MCP Server") scraper = VTKClassScraper() +# Global database path (set via CLI) +_database_path = None + @mcp.tool() def get_vtk_class_info_cpp(class_name: str) -> str: @@ -90,6 +95,145 @@ def get_vtk_class_info_python(class_name: str) -> str: return f"Error getting Python help for '{class_name}': {str(e)}" +@mcp.tool() +def vector_search_vtk_examples( + query: str, + collection_name: str = "vtk-examples", + top_k: int = 5, +) -> str: + """Search VTK examples using vector similarity search with RAG. + + This function performs semantic search over VTK code examples using embeddings + and returns the most relevant code snippets and documentation. + + Args: + query: The search query describing what you're looking for + collection_name: Name of the collection in the database + (default: vtk-examples) + top_k: Number of top results to return (default: 5) + + Returns: + Formatted string with code snippets, documentation, and relevance scores + """ + if not query: + return "Error: query is required" + + if not _database_path: + return ( + "Error: Database path not configured. " + "Start the server with --database-path option." + ) + + try: + # Check dependencies + import importlib.util + + required_modules = ["chromadb", "sentence_transformers"] + missing_modules = [ + module + for module in required_modules + if importlib.util.find_spec(module) is None + ] + + if missing_modules: + return ( + f"Error: Missing required dependencies: {', '.join(missing_modules)}\n" + "Install with: pip install chromadb sentence-transformers\n\n" + "Or ensure the rag-components submodule is initialized:\n" + "git submodule update --init --recursive" + ) + + # Setup path and import + script_dir = Path(__file__).resolve().parent + project_root = script_dir.parent.parent + rag_path = str(project_root / "rag-components") + + if rag_path not in sys.path: + sys.path.insert(0, rag_path) + + from query_db import query_db, initialize_db + + # Initialize database client with configured path + client = initialize_db(_database_path) + + # Perform vector search with reranking + results = query_db( + query=query, collection_name=collection_name, top_k=top_k, client=client + ) + + # Format results + return _format_vector_search_results(query, results, top_k) + + except Exception as e: + return f"Error performing vector search: {str(e)}" + + +def _format_vector_search_results(query, results, top_k): + """Format vector search results into readable markdown.""" + lines = [ + f"# Vector Search Results for: {query}", + "", + f"Found {top_k} most relevant code examples and documentation snippets.", + "", + ] + + # Format code results + if results["code_documents"]: + lines.append("## Code Examples") + lines.append("") + + for i, (document, metadata, score) in enumerate( + zip( + results["code_documents"], + results["code_metadata"], + results["code_scores"], + ), + 1, + ): + lines.append(f"### Result {i} (Relevance: {score:.4f})") + lines.append("") + + # Add metadata if available + if metadata: + source = metadata.get("source", "Unknown") + lines.append(f"**Source:** {source}") + lines.append("") + + # Add code snippet + lines.append("```python") + lines.append(document.strip()) + lines.append("```") + lines.append("") + + # Format text/documentation results + if results["text_documents"]: + lines.append("## Documentation Snippets") + lines.append("") + + for i, (document, metadata, score) in enumerate( + zip( + results["text_documents"], + results["text_metadata"], + results["text_scores"], + ), + 1, + ): + lines.append(f"### Snippet {i} (Relevance: {score:.4f})") + lines.append("") + + # Add metadata if available + if metadata: + source = metadata.get("source", "Unknown") + lines.append(f"**Source:** {source}") + lines.append("") + + # Add text content + lines.append(document.strip()) + lines.append("") + + return "\n".join(lines) + + def _format_class_info(info): """Format class info into readable markdown.""" lines = [f"# {info['class_name']}", ""] @@ -130,8 +274,20 @@ def _format_class_info(info): ) @click.option("--host", default="127.0.0.1", help="Host (HTTP only)") @click.option("--port", default=8000, type=int, help="Port (HTTP only)") -def main(transport, host, port): +@click.option( + "--database-path", + type=click.Path(exists=True), + help="Path to the RAG database (required for vector search)", +) +def main(transport, host, port, database_path): """Run the VTK MCP Server""" + global _database_path + + # Set global database path if provided + if database_path: + _database_path = database_path + click.echo(f"Database path configured: {database_path}") + if transport == "http": click.echo(f"Starting VTK MCP Server on http://{host}:{port}") asyncio.run(mcp.run_http_async(host=host, port=port)) diff --git a/src/vtk_mcp_server/simple_client.py b/src/vtk_mcp_server/simple_client.py index d0a8519..bd443b2 100644 --- a/src/vtk_mcp_server/simple_client.py +++ b/src/vtk_mcp_server/simple_client.py @@ -133,6 +133,28 @@ def get_class_info_python(self, class_name): result = self._parse_response(response) self._handle_tool_response(result) + def vector_search_vtk_examples( + self, query, collection_name="vtk-examples", top_k=5 + ): + """Search VTK examples using vector search""" + payload = { + "jsonrpc": "2.0", + "id": "4", + "method": "tools/call", + "params": { + "name": "vector_search_vtk_examples", + "arguments": { + "query": query, + "collection_name": collection_name, + "top_k": top_k, + }, + }, + } + response = self._make_request(payload) + if response: + result = self._parse_response(response) + self._handle_tool_response(result) + def list_tools(self): """List available MCP tools""" payload = {"jsonrpc": "2.0", "id": "3", "method": "tools/list"} @@ -208,6 +230,19 @@ def info_python(ctx, class_name): client.get_class_info_python(class_name) +@cli.command() +@click.argument("query") +@click.option("--collection", default="vtk-examples", help="Collection name") +@click.option("--top-k", default=5, type=int, help="Number of results to return") +@click.pass_context +def vector_search(ctx, query, collection, top_k): + """Search for VTK examples using vector search""" + client = ctx.obj["client"] + click.echo(f"Searching for VTK examples: '{query}'...") + click.echo() + client.vector_search_vtk_examples(query, collection, top_k) + + @cli.command() @click.pass_context def list_tools(ctx): diff --git a/tests/test_vector_search_integration.py b/tests/test_vector_search_integration.py new file mode 100644 index 0000000..9bacb3f --- /dev/null +++ b/tests/test_vector_search_integration.py @@ -0,0 +1,240 @@ +"""Integration tests for vector search with RAG database.""" + +import pytest +import asyncio +import time +import threading +import socket +import subprocess +import tarfile +import shutil +from pathlib import Path +from vtk_mcp_server.simple_client import SimpleVTKClient + + +pytestmark = [pytest.mark.integration, pytest.mark.vector_search] + + +@pytest.fixture(scope="module") +def embeddings_database(tmp_path_factory): + """Download and extract embeddings database from container image or use local.""" + # First check if we have a local database + local_db = Path(__file__).parent.parent / "db" / "vtk-examples" + if local_db.exists() and (local_db / "chroma.sqlite3").exists(): + yield str(local_db) + return + + tmpdir = tmp_path_factory.mktemp("embeddings") + db_path = tmpdir / "db" / "vtk-examples" + tarball_path = tmpdir / "vtk-examples-embeddings.tar.gz" + + try: + # Check if podman is available + subprocess.run( + ["podman", "--version"], + check=True, + capture_output=True, + ) + except (subprocess.CalledProcessError, FileNotFoundError): + pytest.skip("Podman not available and no local database found") + + try: + # Create temporary container to extract database + container_name = "vtk-embeddings-test" + image_name = "ghcr.io/kitware/vtk-mcp/embeddings-database:latest" + + # Pull the image + subprocess.run( + ["podman", "pull", image_name], + check=True, + capture_output=True, + ) + + # Create container + subprocess.run( + ["podman", "create", "--name", container_name, image_name], + check=True, + capture_output=True, + ) + + # Copy tarball from container + subprocess.run( + [ + "podman", + "cp", + f"{container_name}:/vtk-examples-embeddings.tar.gz", + str(tarball_path), + ], + check=True, + capture_output=True, + ) + + # Remove container + subprocess.run( + ["podman", "rm", container_name], + check=True, + capture_output=True, + ) + + # Extract tarball + tmpdir.joinpath("db").mkdir(parents=True, exist_ok=True) + with tarfile.open(tarball_path, "r:gz") as tar: + tar.extractall(path=tmpdir / "db") + + # Verify database exists + if not db_path.exists(): + pytest.skip(f"Database not found at {db_path}") + + yield str(db_path) + + except subprocess.CalledProcessError as e: + pytest.skip(f"Failed to extract embeddings database: {e}") + finally: + # Cleanup + if tmpdir.exists(): + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.fixture +def vector_search_server(embeddings_database): + """Start HTTP server with vector search enabled.""" + # Import here to set database path before server starts + from vtk_mcp_server import server + + # Set the database path globally + server._database_path = embeddings_database + + # Find a free port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + server_thread = None + server_started = threading.Event() + + def run_server(): + try: + # Run server with database path configured + asyncio.run(server.mcp.run_http_async(host="127.0.0.1", port=port)) + except Exception: + pass # Server might be stopped + finally: + server_started.set() + + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + # Wait for server to start + time.sleep(1.0) + + yield f"http://127.0.0.1:{port}" + + # Cleanup - server thread will die when test ends due to daemon=True + + +class TestVectorSearchIntegration: + """Test vector search functionality end-to-end.""" + + def test_vector_search_red_sphere(self, vector_search_server): + """Test vector search for 'a red vtkSphere'.""" + client = SimpleVTKClient(base_url=vector_search_server) + + # Call vector search + payload = { + "jsonrpc": "2.0", + "id": "vector-search-test", + "method": "tools/call", + "params": { + "name": "vector_search_vtk_examples", + "arguments": { + "query": "a red vtkSphere", + "collection_name": "vtk-examples", + "top_k": 5, + }, + }, + } + + response = client._make_request(payload) + assert response is not None, "Vector search request failed" + assert ( + response.status_code == 200 + ), f"Unexpected status code: {response.status_code}" + + result = client._parse_response(response) + assert result is not None, "Failed to parse response" + assert "result" in result, "No result in response" + assert "content" in result["result"], "No content in result" + + content = result["result"]["content"][0]["text"] + assert content is not None, "Empty response content" + assert "Vector Search Results" in content, "Unexpected response format" + assert "vtk" in content.lower(), "Results don't mention VTK" + + # Verify we got actual results + assert ( + "Code Examples" in content or "Documentation Snippets" in content + ), "No code examples or documentation found" + + def test_vector_search_with_different_top_k(self, vector_search_server): + """Test vector search with different top_k values.""" + client = SimpleVTKClient(base_url=vector_search_server) + + payload = { + "jsonrpc": "2.0", + "id": "vector-search-topk", + "method": "tools/call", + "params": { + "name": "vector_search_vtk_examples", + "arguments": { + "query": "render a sphere", + "collection_name": "vtk-examples", + "top_k": 3, + }, + }, + } + + response = client._make_request(payload) + assert response is not None + assert response.status_code == 200 + + result = client._parse_response(response) + assert result is not None + assert "result" in result + + content = result["result"]["content"][0]["text"] + assert "Found 3 most relevant" in content, "Expected 3 results" + + def test_vector_search_query_variations(self, vector_search_server): + """Test vector search with various queries.""" + client = SimpleVTKClient(base_url=vector_search_server) + + test_queries = [ + "read DICOM files", + "create a cylinder", + "texture mapping", + ] + + for query in test_queries: + payload = { + "jsonrpc": "2.0", + "id": f"query-{query}", + "method": "tools/call", + "params": { + "name": "vector_search_vtk_examples", + "arguments": { + "query": query, + "top_k": 2, + }, + }, + } + + response = client._make_request(payload) + assert response is not None, f"Query '{query}' failed" + assert response.status_code == 200 + + result = client._parse_response(response) + assert result is not None + assert "result" in result + + content = result["result"]["content"][0]["text"] + assert "Vector Search Results" in content diff --git a/vtk-examples-db.dockerfile b/vtk-examples-db.dockerfile new file mode 100644 index 0000000..aad763c --- /dev/null +++ b/vtk-examples-db.dockerfile @@ -0,0 +1,10 @@ +LABEL org.opencontainers.image.title="VTK MCP Embeddings Database" +LABEL org.opencontainers.image.description="Vector search embeddings database for VTK examples" +LABEL org.opencontainers.image.source="https://github.com/kitware/vtk-mcp" +LABEL org.opencontainers.image.authors="Vicente Adolfo Bolea Sanchez " +LABEL org.opencontainers.image.licenses="MIT" +LABEL org.opencontainers.image.documentation="https://github.com/kitware/vtk-mcp/blob/main/README.md" + +FROM scratch + +COPY vtk-examples-embeddings.tar.gz /