Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions ai/gen-ai-agents/custom_rag_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@

# embeddings
EMBED_MODEL_ID = "cohere.embed-multilingual-v3.0"
# EMBED_MODEL_ID = "cohere.embed-multilingual-image-v3.0"

# LLM
# this is the default model
LLM_MODEL_ID = "meta.llama-3.3-70b-instruct"
TEMPERATURE = 0.1
MAX_TOKENS = 1024
MAX_TOKENS = 2048

# for the UI
LANGUAGE_LIST = ["same as the question", "en", "fr", "it", "es"]
MODEL_LIST = ["meta.llama-3.3-70b-instruct", "cohere.command-r-plus-08-2024"]
# replaced command-r with command-a
MODEL_LIST = ["meta.llama-3.3-70b-instruct", "cohere.command-a-03-2025"]

ENABLE_USER_FEEDBACK = True

# semantic search
TOP_K = 6
COLLECTION_LIST = ["BOOKS", "CNAF"]
COLLECTION_LIST = ["DEV_COACHING", "BOOKS", "CNAF"]

# OCI general
COMPARTMENT_ID = "ocid1.compartment.oc1..aaaaaaaaushuwb2evpuf7rcpl4r7ugmqoe7ekmaiik3ra3m7gec3d234eknq"
Expand All @@ -69,3 +71,8 @@
# for loading
CHUNK_SIZE = 2000
CHUNK_OVERLAP = 100

# for MCP server
TRANSPORT = "streamable-http"
HOST = "0.0.0.0"
PORT = 9000
122 changes: 122 additions & 0 deletions ai/gen-ai-agents/custom_rag_agent/mcp_semantic_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Semantic Search exposed as an MCP tool

Author: L. Saetta
License: MIT
"""

from typing import Annotated
from pydantic import Field
import oracledb
from fastmcp import FastMCP
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.embeddings import OCIGenAIEmbeddings
from langchain_community.vectorstores.oraclevs import OracleVS
from utils import get_console_logger

from config import DEBUG
from config import AUTH, EMBED_MODEL_ID, SERVICE_ENDPOINT, COMPARTMENT_ID
from config import TRANSPORT, HOST, PORT
from config_private import CONNECT_ARGS

logger = get_console_logger()

mcp = FastMCP("Demo Semantic Search as MCP server")


#
# Helper functions
#
def get_connection():
"""
get a connection to the DB
"""
return oracledb.connect(**CONNECT_ARGS)


def get_embedding_model():
"""
Create the Embedding Model
"""
embed_model = OCIGenAIEmbeddings(
auth_type=AUTH,
model_id=EMBED_MODEL_ID,
service_endpoint=SERVICE_ENDPOINT,
compartment_id=COMPARTMENT_ID,
)
return embed_model


@mcp.tool
def semantic_search(
query: Annotated[
str, Field(description="The search query to find relevant documents.")
],
top_k: Annotated[int, Field(description="TOP_K parameter for search")] = 5,
collection_name: Annotated[
str, Field(description="The name of DB table")
] = "BOOKS",
) -> dict:
"""
Perform a semantic search based on the provided query.
Args:
query (str): The search query.
top_k (int): The number of top results to return.
Returns:
dict: a dictionary containing the relevant documents.
"""
try:
# must be the same embedding model used during load in the Vector Store
embed_model = get_embedding_model()

# get a connection to the DB and init VS
with get_connection() as conn:
v_store = OracleVS(
client=conn,
table_name=collection_name,
distance_strategy=DistanceStrategy.COSINE,
embedding_function=embed_model,
)

relevant_docs = v_store.similarity_search(query=query, k=top_k)

if DEBUG:
logger.info("Result from similarity search:")
logger.info(relevant_docs)

except Exception as e:
logger.error("Error in vector_store.invoke: %s", e)
error = str(e)
return {"error": error}

result = {"relevant_docs": relevant_docs}

return result

@mcp.tool
def get_collections() -> list:
"""
Get the list of collections (DB tables) available in the Oracle Vector Store.
Returns:
list: A list of collection names.
"""
with get_connection() as conn:
cursor = conn.cursor()

cursor.execute(
"""SELECT DISTINCT utc.table_name
FROM user_tab_columns utc
WHERE utc.data_type = 'VECTOR'
ORDER BY 1 ASC"""
)
collections = [row[0] for row in cursor.fetchall()]
return collections

if __name__ == "__main__":
mcp.run(
transport=TRANSPORT,
# Bind to all interfaces
host=HOST,
port=PORT,
log_level="INFO",
)