diff --git a/ai/gen-ai-agents/custom_rag_agent/LICENSE b/ai/gen-ai-agents/custom_rag_agent/LICENSE new file mode 100644 index 000000000..ba2a03c8e --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Oracle + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/ai/gen-ai-agents/custom_rag_agent/README.md b/ai/gen-ai-agents/custom_rag_agent/README.md new file mode 100644 index 000000000..362e6c264 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/README.md @@ -0,0 +1,37 @@ +![UI](images/ui_image.png) + +# Custom RAG agent +This repository contains all the code for the development of a **custom RAG Agent**, based on OCI Generative AI, Oracle 23AI DB and **LangGraph** + +## Design and implementation +* The agent is implemented using **LangGraph** +* Vector Search is implemented, using Langchain, on top of Oracle 23AI +* A **reranker** can be used to refine the search + +Design decisions: +* For every node of the graph there is a dedicated Python class (a **Runnable**, as QueryRewriter...) +* Reranker is implemented using a LLM. As other option, it is easy to plug-in, for example, Cohere reranker +* The agent is integrated with **OCI APM**, for **Observability**; Integration using **py-zipkin** +* UI implemented using **Streamlit** + +Streaming: +* Support for streaming events from the agent: as soon as a step is completed (Vector Search, Reranking, ...) the UI is updated. +For example, links to the documentation' chunks are displayed before the final answer is ready. +* Streaming of the final answer. + +## Status +It is **wip**. + +## References +* [Integration with OCI APM](https://luigi-saetta.medium.com/enhancing-observability-in-rag-solutions-with-oracle-cloud-6f93b2675f40) + +## Advantages of the Agentic approach +One of the primary advantages of the agentic approach is its modularity. +Customer requirements often surpass the simplicity of typical Retrieval-Augmented Generation (RAG) demonstrations. Implementing a framework like **LangGraph** necessitates organizing code into a modular sequence of steps, facilitating the seamless integration of additional features at appropriate places.​ + +For example, to ensure that final responses do not disclose Personally Identifiable Information (PII) present in the knowledge base, one can simply append a node at the end of the graph. This node would process the generated answer, detect any PII, and anonymize it accordingly. + +## Configuration +* use Python 3.11 +* use the requirements.txt +* create your config_private.py using the template provided diff --git a/ai/gen-ai-agents/custom_rag_agent/agent_state.py b/ai/gen-ai-agents/custom_rag_agent/agent_state.py new file mode 100644 index 000000000..b8018a836 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/agent_state.py @@ -0,0 +1,52 @@ +""" +File name: agent_state.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module defines the class that handles the agent's State + + +Usage: + Import this module into other scripts to use its functions. + Example: + from agent_state import State + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +from typing_extensions import TypedDict, Optional + + +class State(TypedDict): + """ + The state of the graph + """ + + # the original user request + user_request: str + chat_history: list = [] + + # the question reformulated using chat_history + standalone_question: str = "" + + # similarity_search + retriever_docs: Optional[list] = [] + # reranker + reranker_docs: Optional[list] = [] + # Answer + final_answer: str + # Citations + citations: list = [] + + # if any step encounter an error + error: Optional[str] = None diff --git a/ai/gen-ai-agents/custom_rag_agent/answer_generator.py b/ai/gen-ai-agents/custom_rag_agent/answer_generator.py new file mode 100644 index 000000000..f28c3e4a7 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/answer_generator.py @@ -0,0 +1,135 @@ +""" +File name: answer_generator.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module implements the last step in the workflow: generation + of the answer form the LLM: + + +Usage: + Import this module into other scripts to use its functions. + Example: + from answer_generator import AnswerGenerator + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +from langchain_core.runnables import Runnable +from langchain_core.messages import SystemMessage, HumanMessage +from langchain.prompts import PromptTemplate + +# integration with APM +from py_zipkin.zipkin import zipkin_span + +from agent_state import State +from oci_models import get_llm +from prompts import ( + ANSWER_PROMPT_TEMPLATE, +) + +from utils import get_console_logger +from config import AGENT_NAME, DEBUG + +logger = get_console_logger() + + +class AnswerGenerator(Runnable): + """ + Takes the user request and the chat history and rewrite the user query + in a standalone question that is used for the semantic search + """ + + def __init__(self): + """ + Init + """ + self.dict_languages = { + "en": "English", + "fr": "French", + "it": "Italian", + "es": "Spanish", + } + + def build_context_for_llm(self, docs: list): + """ + Build the context for the final answer from LLM + + docs: list[Documents] + """ + _context = "" + + for doc in docs: + _context += doc.page_content + "\n\n" + + return _context + + @zipkin_span(service_name=AGENT_NAME, span_name="answer_generation") + def invoke(self, input: State, config=None, **kwargs): + """ + Generate the final answer + """ + # get the config + model_id = config["configurable"]["model_id"] + + if config["configurable"]["main_language"] in self.dict_languages: + # want to change language + main_language = self.dict_languages.get( + config["configurable"]["main_language"] + ) + else: + # "same as the question" (default) + # answer will be in the same language as the question + main_language = None + + if DEBUG: + logger.info("AnswerGenerator, model_id: %s", model_id) + logger.info("AnswerGenerator, main_language: %s", main_language) + + final_answer = "" + error = None + + try: + llm = get_llm(model_id=model_id) + + _context = self.build_context_for_llm(input["reranker_docs"]) + + system_prompt = PromptTemplate( + input_variables=["context"], + template=ANSWER_PROMPT_TEMPLATE, + ).format(context=_context) + + messages = [ + SystemMessage(content=system_prompt), + ] + # add the chat history + for msg in input["chat_history"]: + messages.append(msg) + + # to force the answer in the selected language + if main_language is not None: + the_question = f"{input['user_request']}. Answer in {main_language}." + else: + # no cross language + the_question = input["user_request"] + + messages.append(HumanMessage(content=the_question)) + + # here we invoke the LLM and we return the generator + final_answer = llm.stream(messages) + + except Exception as e: + logger.error("Error in generate_answer: %s", e) + error = str(e) + + return {"final_answer": final_answer, "error": error} diff --git a/ai/gen-ai-agents/custom_rag_agent/assistant_ui_langgraph.py b/ai/gen-ai-agents/custom_rag_agent/assistant_ui_langgraph.py new file mode 100644 index 000000000..5f85e3109 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/assistant_ui_langgraph.py @@ -0,0 +1,285 @@ +""" +File name: assistant_ui.py +Author: Luigi Saetta +Date created: 2024-12-04 +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module provides the UI for the RAG demo + +Usage: + streamlit run assistant_ui_langgraph.py + +License: + This code is released under the MIT License. + +Notes: + This is part of a demo fro a RAG solution implemented + using LangGraph + +Warnings: + This module is in development, may change in future versions. +""" + +import uuid +from typing import List, Union +import time +import streamlit as st + +from langchain_core.messages import HumanMessage, AIMessage + +# for APM integration +from py_zipkin.zipkin import zipkin_span +from py_zipkin import Encoding + +from rag_agent import State, create_workflow +from rag_feedback import RagFeedback +from transport import http_transport +from utils import get_console_logger + +# changed to better manage ENABLE_TRACING +import config + +# Constant + +# name for the roles +USER = "user" +ASSISTANT = "assistant" + +logger = get_console_logger() + + +# Initialize session state +if "chat_history" not in st.session_state: + st.session_state.chat_history = [] +if "workflow" not in st.session_state: + # the agent instance + st.session_state.workflow = create_workflow() +if "thread_id" not in st.session_state: + # generate a new thread_Id + st.session_state.thread_id = str(uuid.uuid4()) +if "model_id" not in st.session_state: + st.session_state.model_id = "meta.llama3.3-70B" +if "main_language" not in st.session_state: + st.session_state.main_language = "en" +if "enable_reranker" not in st.session_state: + st.session_state.enable_reranker = True +if "collection_name" not in st.session_state: + st.session_state.collection_name = config.COLLECTION_LIST[0] + +# to manage feedback +if "get_feedback" not in st.session_state: + st.session_state.get_feedback = False + + +# +# supporting functions +# +def display_msg_on_rerun(chat_hist: List[Union[HumanMessage, AIMessage]]) -> None: + """Display all messages on rerun.""" + for msg in chat_hist: + role = USER if isinstance(msg, HumanMessage) else ASSISTANT + with st.chat_message(role): + st.markdown(msg.content) + + +# when push the button reset the chat_history +def reset_conversation(): + """Reset the chat history.""" + st.session_state.chat_history = [] + + # change thread_id + st.session_state.thread_id = str(uuid.uuid4()) + + +def add_to_chat_history(msg): + """ + add the msg to chat history + """ + st.session_state.chat_history.append(msg) + + +def get_chat_history(): + """return the chat history from the session""" + return ( + st.session_state.chat_history[-config.MAX_MSGS_IN_HISTORY :] + if config.MAX_MSGS_IN_HISTORY > 0 + else st.session_state.chat_history + ) + + +def register_feedback(): + """ + Register the feedback. + """ + # number of stars, start at 0 + n_stars = st.session_state.feedback + 1 + logger.info("Feedback: %d %s", n_stars, "stars") + logger.info("") + + # register the feedback in DB + rag_feedback = RagFeedback() + + rag_feedback.insert_feedback( + question=st.session_state.chat_history[-2].content, + answer=st.session_state.chat_history[-1].content, + feedback=n_stars, + ) + + st.session_state.get_feedback = False + + +# +# Main +# +st.title("OCI Custom RAG Agent") + +# Reset button +if st.sidebar.button("Clear Chat History"): + reset_conversation() + + +st.sidebar.header("Options") + +# the collection used for semantic search +st.session_state.collection_name = st.sidebar.selectbox( + "Collection name", + config.COLLECTION_LIST, +) + +# add the choice of LLM (not used for now) +st.session_state.main_language = st.sidebar.selectbox( + "Select the language for the answer", + config.LANGUAGE_LIST, +) +st.session_state.model_id = st.sidebar.selectbox( + "Select the Chat Model", + config.MODEL_LIST, +) +st.session_state.enable_reranker = st.sidebar.checkbox( + "Enable Reranker", value=True, disabled=False +) +config.ENABLE_TRACING = st.sidebar.checkbox( + "Enable tracing", value=False, disabled=False +) + + +# +# Here the code where react to user input +# + +# Display chat messages from history on app rerun +display_msg_on_rerun(get_chat_history()) + +if question := st.chat_input("Hello, how can I help you?"): + # Display user message in chat message container + st.chat_message(USER).markdown(question) + + try: + with st.spinner("Calling AI..."): + time_start = time.time() + + # get the chat history to give as input to LLM + _chat_history = get_chat_history() + + # modified to be more responsive, show result asap + try: + input_state = State( + user_request=question, + chat_history=_chat_history, + error=None, + ) + + # collect the results of all steps + results = [] + ERROR = None + + # integration with tracing, start the trace + with zipkin_span( + service_name=config.AGENT_NAME, + span_name="stream", + transport_handler=http_transport, + encoding=Encoding.V2_JSON, + sample_rate=100, + ) as span: + # loop to manage streaming + # set the agent config + agent_config = { + "configurable": { + "model_id": st.session_state.model_id, + "enable_reranker": st.session_state.enable_reranker, + "enable_tracing": config.ENABLE_TRACING, + "main_language": st.session_state.main_language, + "collection_name": st.session_state.collection_name, + "thread_id": st.session_state.thread_id, + } + } + + if config.DEBUG: + logger.info("Agent config: %s", agent_config) + + for event in st.session_state.workflow.stream( + input_state, + config=agent_config, + ): + for key, value in event.items(): + MSG = f"Completed: {key}!" + logger.info(MSG) + st.toast(MSG) + results.append(value) + + # to see if there has been an error + ERROR = value["error"] + + # update UI asap + if key == "QueryRewrite": + st.sidebar.header("Standalone question:") + st.sidebar.write(value["standalone_question"]) + if key == "Rerank": + st.sidebar.header("References:") + st.sidebar.write(value["citations"]) + + # process final result from agent + if ERROR is None: + # visualize the output + answer_generator = results[-1]["final_answer"] + + # Stream + with st.chat_message(ASSISTANT): + response_container = st.empty() + full_response = "" + + for chunk in answer_generator: + full_response += chunk.content + response_container.markdown(full_response + "▌") + + response_container.markdown(full_response) + + elapsed_time = round((time.time() - time_start), 1) + logger.info("Elapsed time: %s sec.", elapsed_time) + logger.info("") + + if config.ENABLE_USER_FEEDBACK: + st.session_state.get_feedback = True + + else: + st.error(ERROR) + + # Add user/assistant message to chat history + add_to_chat_history(HumanMessage(content=question)) + add_to_chat_history(AIMessage(content=full_response)) + + # get the feedback + if st.session_state.get_feedback: + st.feedback("stars", key="feedback", on_change=register_feedback) + + except Exception as e: + ERR_MSG = f"Error in assistant_ui, generate_and_exec {e}" + logger.error(ERR_MSG) + st.error(ERR_MSG) + + except Exception as e: + ERR_MSG = "An error occurred: " + str(e) + logger.error(ERR_MSG) + st.error(ERR_MSG) diff --git a/ai/gen-ai-agents/custom_rag_agent/bm25_search.py b/ai/gen-ai-agents/custom_rag_agent/bm25_search.py new file mode 100644 index 000000000..75cb634ee --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/bm25_search.py @@ -0,0 +1,133 @@ +""" +BM25 Search Engine with Oracle Database +""" + +import re +import oracledb +import numpy as np +from rank_bm25 import BM25Okapi +from utils import get_console_logger +from config_private import CONNECT_ARGS + +logger = get_console_logger() + + +class BM25OracleSearch: + """ + Implments BM25 + """ + + def __init__(self, table_name, text_column, batch_size=40): + """ + Initializes the BM25 search engine with data from an Oracle database. + + :param table_name: Name of the table containing the text + :param text_column: Column name that contains the text to index + """ + self.table_name = table_name + self.text_column = text_column + self.batch_size = batch_size + self.texts = [] + self.tokenized_texts = [] + self.bm25 = None + self.index_data() + + def connect_db(self): + """ + Establishes a connection to the Oracle database. + """ + try: + connection = oracledb.connect(**CONNECT_ARGS) + return connection + except oracledb.DatabaseError as e: + logger.info("Database connection error: %s", e) + return None + + def fetch_text_data(self): + """ + Fetches text data from the specified table and column. + Used to initialize the index. + """ + + _results = [] + + with self.connect_db() as conn: + with conn.cursor() as cursor: + query = f"SELECT {self.text_column} FROM {self.table_name}" + cursor.execute(query) + + while True: + rows = cursor.fetchmany(self.batch_size) # Fetch records in batches + if not rows: + break # Exit loop when no more data + + for row in rows: + lob_data = row[0] # This is a CLOB object + + if isinstance(lob_data, oracledb.LOB): + _results.append(lob_data.read()) # Read LOB content + else: + _results.append(str(lob_data)) # Fallback for non-LOB data + + return _results + + def simple_tokenize(self, text): + """ + Tokenizes a string by extracting words (alphanumeric sequences). + + :param text: Input text string + :return: List of lowercase tokens + """ + return re.findall(r"\w+", text.lower()) + + def index_data(self): + """ + Reads text from the database and prepares BM25 index. + """ + logger.info("Creating BM25 index...") + + self.texts = self.fetch_text_data() + self.tokenized_texts = [self.simple_tokenize(text) for text in self.texts] + self.bm25 = BM25Okapi(self.tokenized_texts) + + logger.info("BM25 index created successfully!") + logger.info("") + + def search(self, query, top_n=5): + """ + Performs a BM25 search on the indexed documents. + + :param query: Search query string + :param top_n: Number of top results to return + :return: List of tuples (text, score) + """ + if not self.bm25: + print("BM25 index not initialized. Please check data indexing.") + return [] + + query_tokens = self.simple_tokenize(query) + scores = self.bm25.get_scores(query_tokens) + ranked_indices = np.argsort(scores)[::-1][:top_n] # Get top_n results + + _results = [(self.texts[i], scores[i]) for i in ranked_indices] + + return _results + + +# Example Usage: +# credential are packed in CONNECT_ARGS +table_name = "BOOKS" +text_column = "TEXT" + +# create the index +bm25_search = BM25OracleSearch(table_name, text_column) + +questions = ["Chi è Luigi Saetta?", "What are the main innovation produced by GPT-4?"] + +for _question in questions: + results = bm25_search.search(_question, top_n=2) + + # Print search results + for text, score in results: + print(f"Score: {score:.2f} - Text: {text}") + print("") diff --git a/ai/gen-ai-agents/custom_rag_agent/check_code.sh b/ai/gen-ai-agents/custom_rag_agent/check_code.sh new file mode 100755 index 000000000..b5ca27212 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/check_code.sh @@ -0,0 +1,3 @@ +black *.py +pylint *.py + diff --git a/ai/gen-ai-agents/custom_rag_agent/chunk_index_utils.py b/ai/gen-ai-agents/custom_rag_agent/chunk_index_utils.py new file mode 100644 index 000000000..bc860e179 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/chunk_index_utils.py @@ -0,0 +1,120 @@ +""" +Author: Luigi Saetta +Date created: 2024-04-27 +Date last modified: 2024-04-30 +Python Version: 3.11 + +Usage: contains the functions to split in chunks and create the index +""" + +from collections import defaultdict +from langchain.schema import Document +from langchain_community.document_loaders import PyPDFLoader +from langchain_unstructured import UnstructuredLoader + +from langchain_text_splitters import RecursiveCharacterTextSplitter + + +from utils import get_console_logger, remove_path_from_ref +from config import ( + CHUNK_SIZE, + CHUNK_OVERLAP, +) + +logger = get_console_logger() + + +def get_recursive_text_splitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP): + """ + return a recursive text splitter + """ + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + length_function=len, + is_separator_regex=False, + ) + return text_splitter + + +def load_and_split_pdf(book_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP): + """ + load a single book + """ + text_splitter = get_recursive_text_splitter(chunk_size, chunk_overlap) + + loader = PyPDFLoader(file_path=book_path) + + docs = loader.load_and_split(text_splitter=text_splitter) + + chunk_header = "" + + if len(docs) > 0: + doc_name = remove_path_from_ref(book_path) + # split to remove the extension + doc_title = doc_name.split(".")[0] + chunk_header = f"# Doc. title: {doc_title}\n" + + # remove path from source and reduce the metadata (16/03/2025) + for doc in docs: + # add more context to the chunk + doc.page_content = chunk_header + doc.page_content + doc.metadata = { + "source": remove_path_from_ref(book_path), + "page_label": doc.metadata["page_label"], + } + + logger.info("Loaded %s chunks...", len(docs)) + + return docs + + +def load_and_split_docx(file_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP): + """ + To load docx files + """ + loader = UnstructuredLoader(file_path) + docs = loader.load() + + # Raggruppa per numero di pagina (o altro metadato) + grouped_text = defaultdict(list) + + chunk_header = "" + + if len(docs) > 0: + doc_name = remove_path_from_ref(file_path) + # split to remove the extension + doc_title = doc_name.split(".")[0] + chunk_header = f"# Doc. title: {doc_title}\n" + + for doc in docs: + # fallback to 0 if not available + page = doc.metadata.get("page_number", 0) + grouped_text[page].append(doc.page_content) + + splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) + + final_chunks = [] + + # Per ogni pagina (o gruppo), unisci il testo e splitta + for page, texts in grouped_text.items(): + full_text = "\n".join(texts) + splits = splitter.split_text(full_text) + + for chunk in splits: + final_chunks.append( + Document( + # add more context + page_content=chunk_header + chunk, + metadata={ + "source": doc_name, + "page_label": str(page), + }, + ) + ) + + logger.info("Loaded %s chunks...", len(final_chunks)) + + return final_chunks diff --git a/ai/gen-ai-agents/custom_rag_agent/config.py b/ai/gen-ai-agents/custom_rag_agent/config.py new file mode 100644 index 000000000..81f2a4470 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/config.py @@ -0,0 +1,71 @@ +""" +File name: config.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module provides general configurations + + +Usage: + Import this module into other scripts to use its functions. + Example: + import config + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +DEBUG = False + +# type of auth +AUTH = "API_KEY" + +# embeddings +EMBED_MODEL_ID = "cohere.embed-multilingual-v3.0" + +# LLM +# this is the default model +LLM_MODEL_ID = "meta.llama-3.3-70b-instruct" +TEMPERATURE = 0.1 +MAX_TOKENS = 1024 + +# for the UI +LANGUAGE_LIST = ["same as the question", "en", "fr", "it", "es"] +MODEL_LIST = ["meta.llama-3.3-70b-instruct", "cohere.command-r-plus-08-2024"] + +ENABLE_USER_FEEDBACK = True + +# semantic search +TOP_K = 6 +COLLECTION_LIST = ["BOOKS", "CNAF"] + +# OCI general +COMPARTMENT_ID = "ocid1.compartment.oc1..aaaaaaaaushuwb2evpuf7rcpl4r7ugmqoe7ekmaiik3ra3m7gec3d234eknq" +SERVICE_ENDPOINT = "https://inference.generativeai.eu-frankfurt-1.oci.oraclecloud.com" + +# history management (put -1 if you want to disable trimming) +# consider that we have pair (human, ai) so use an even (ex: 6) value +MAX_MSGS_IN_HISTORY = 6 + +# reranking enabled or disabled from UI + +# Integration with APM +ENABLE_TRACING = False +AGENT_NAME = "OCI_CUSTOM_RAG_AGENT" + +# lsaetta-apm compartment +APM_BASE_URL = "https://aaaadec2jjn3maaaaaaaaach4e.apm-agt.eu-frankfurt-1.oci.oraclecloud.com/20200101" +APM_CONTENT_TYPE = "application/json" + +# for loading +CHUNK_SIZE = 2000 +CHUNK_OVERLAP = 100 diff --git a/ai/gen-ai-agents/custom_rag_agent/config_private_template.py b/ai/gen-ai-agents/custom_rag_agent/config_private_template.py new file mode 100644 index 000000000..cad157ce8 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/config_private_template.py @@ -0,0 +1,45 @@ +""" +File name: config_private.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + All the security and private configs here. + +Usage: + Import this module into other scripts to use its functions. + Example: + from config_private import ... + + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +# Oracle Vector Store +VECTOR_DB_USER = "XXXXX" +VECTOR_DB_PWD = "YYYY" +VECTOR_WALLET_PWD = "welcome1" +VECTOR_DSN = "aidb_medium" +VECTOR_WALLET_DIR = "/Users/xxxx/Progetti/yyyyy/WALLET_VECTOR" + + +CONNECT_ARGS = { + "user": VECTOR_DB_USER, + "password": VECTOR_DB_PWD, + "dsn": VECTOR_DSN, + "config_dir": VECTOR_WALLET_DIR, + "wallet_location": VECTOR_WALLET_DIR, + "wallet_password": VECTOR_WALLET_PWD, +} + +# integration with APM +APM_PUBLIC_KEY = "123456789PM" diff --git a/ai/gen-ai-agents/custom_rag_agent/images/ui_image.png b/ai/gen-ai-agents/custom_rag_agent/images/ui_image.png new file mode 100644 index 000000000..b8e0dd78a Binary files /dev/null and b/ai/gen-ai-agents/custom_rag_agent/images/ui_image.png differ diff --git a/ai/gen-ai-agents/custom_rag_agent/oci_models.py b/ai/gen-ai-agents/custom_rag_agent/oci_models.py new file mode 100644 index 000000000..ad011705b --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/oci_models.py @@ -0,0 +1,61 @@ +""" +File name: oci_models.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module enables easy access to OCI GenAI LLM. + + +Usage: + Import this module into other scripts to use its functions. + Example: + from oci_models import get_llm + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +from langchain_community.chat_models import ChatOCIGenAI + +from utils import get_console_logger +from config import ( + AUTH, + SERVICE_ENDPOINT, + LLM_MODEL_ID, + COMPARTMENT_ID, + TEMPERATURE, + MAX_TOKENS, +) + + +logger = get_console_logger() + + +def get_llm(model_id=LLM_MODEL_ID, temperature=TEMPERATURE, max_tokens=MAX_TOKENS): + """ + Initialize and return an instance of ChatOCIGenAI with the specified configuration. + + Returns: + ChatOCIGenAI: An instance of the OCI GenAI language model. + """ + llm = ChatOCIGenAI( + auth_type=AUTH, + model_id=model_id, + service_endpoint=SERVICE_ENDPOINT, + compartment_id=COMPARTMENT_ID, + is_stream=True, + model_kwargs={ + "temperature": temperature, + "max_tokens": max_tokens, + }, + ) + return llm diff --git a/ai/gen-ai-agents/custom_rag_agent/pages/loader_ui.py b/ai/gen-ai-agents/custom_rag_agent/pages/loader_ui.py new file mode 100644 index 000000000..19bd0a002 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/pages/loader_ui.py @@ -0,0 +1,141 @@ +""" +UI for file loading +""" + +import sys +import os +import tempfile +import pandas as pd +import streamlit as st + +# add parent dir +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from config import DEBUG, COLLECTION_LIST +from vector_search import SemanticSearch +from chunk_index_utils import load_and_split_pdf, load_and_split_docx +from utils import get_console_logger + +# init session +if COLLECTION_LIST: + if "collection_name" not in st.session_state: + st.session_state.collection_name = COLLECTION_LIST[0] +else: + st.error("No collections available.") + +if "show_documents" not in st.session_state: + st.session_state.show_documents = False + + +header_area = st.container() +table_area = st.container() + +with header_area: + st.header("Loading Utility") + +logger = get_console_logger() + + +# +# Supporting functions +# +def list_books(_collection_name): + """ + return the list of books in the given collection + """ + _search = SemanticSearch() + + _books_list = _search.list_books_in_collection(collection_name=_collection_name) + + # reorder + return sorted(_books_list) + + +def show_documents_in_collection(_collection_name): + """ + show the documents in the given collection + """ + if st.session_state.show_documents: + with st.spinner(): + _books_list = list_books(_collection_name) + + books_names = [item[0] for item in _books_list] + books_chunks = [item[1] for item in _books_list] + + # convert in a Pandas DataFrame for Visualization + df_list = pd.DataFrame( + {"Document": books_names, + "Num. chunks": books_chunks} + ) + # index starting by 1 + df_list.index = range(1, len(df_list) + 1) + # visualize + with table_area: + st.table(df_list) + + +def on_selection_change(): + """ + React to the selection of the collection + """ + selected = st.session_state["name_selected"] + + logger.info("Collection list selected: %s", selected) + + show_documents_in_collection(selected) + + +st.session_state.collection_name = st.sidebar.selectbox( + "Collection name", + COLLECTION_LIST, + key="name_selected", + on_change=on_selection_change, +) + +st.session_state.show_documents = st.sidebar.checkbox("Show documents") +uploaded_file = st.sidebar.file_uploader("Upload a file", type=["pdf", "docx"]) + +# added a button for loading +load_file = st.sidebar.button("Load file") + +if uploaded_file is not None and load_file: + # identify file type + only_name = os.path.basename(uploaded_file.name) + file_ext = uploaded_file.name.split(".")[-1] + + if DEBUG: + logger.info(file_ext) + + # save as a temporary file + path_file_temp = os.path.join(tempfile.gettempdir(), only_name) + + # write the temp file + with open(path_file_temp, "wb") as tmp_file: + tmp_file.write(uploaded_file.read()) + + # check that the file is not already in the collection + books_list = list_books(st.session_state.collection_name) + + if DEBUG: + logger.info(books_list) + + if only_name not in books_list: + logger.info("Loading %s ...", only_name) + + docs = [] + + if file_ext == "pdf": + docs = load_and_split_pdf(path_file_temp) + + elif file_ext == "docx": + docs = load_and_split_docx(path_file_temp) + + if len(docs) > 0: + SemanticSearch().add_documents( + docs, collection_name=st.session_state.collection_name + ) + + st.success("Document loaded") + + else: + st.error(f"{only_name} already in collection") diff --git a/ai/gen-ai-agents/custom_rag_agent/prompts.py b/ai/gen-ai-agents/custom_rag_agent/prompts.py new file mode 100644 index 000000000..23e7f9223 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/prompts.py @@ -0,0 +1,95 @@ +""" +File name: prompts.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + All the prompts are defined here. + +Usage: + Import this module into other scripts to use its functions. + Example: + from prompts import ... + + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +REFORMULATE_PROMPT_TEMPLATE = """ +You're an AI assistant. Given a user request and a chat history reformulate the user question +in a standalone question using the chat history. + +Constraints: +- return only the standalone question, do not add any other text. + +User request: {user_request} +Chat history: {chat_history} + +Standalone question: +""" + +ANSWER_PROMPT_TEMPLATE = """ +You're an AI assistant. Your task is to answer to user questions based on the provided context +and the history of previous messages. +Use always a friendly but polite tone. + +## Constraints: +- Answer based only on the provided context. If you don't know the answer, say simply **I don't know the answer.** +- Return the answer using always properly formatted markdown. + +Context: {context} + +""" + +RERANKER_TEMPLATE = """ +You are an intelligent ranking assistant. Your task is to rank and filter text chunks +based on their relevance to a given user query. You will receive: + +1. A user query. +2. A list of text chunks. + +Your goal is to: +- Rank the text chunks in order of relevance to the user query. +- Remove any text chunks that are completely irrelevant to the query. + +### Instructions: +- Assign a **relevance score** to each chunk based on how well it answers or relates to the query. +- Return only the **top-ranked** chunks, filtering out those that are completely irrelevant. +- The output should be a **sorted list** of relevant chunks, from most to least relevant. +- Return only the JSON, don't add other text. +- Don't return the text of the chunk, only the index and the score. + +### Input Format: +User Query: +{query} + +Text Chunks (list indexed from 0): +{chunks} + +### **Output Format:** +Return a **JSON object** with the following format: +```json +{{ + "ranked_chunks": [ + {{"index": 0, "score": X.X}}, + {{"index": 2, "score": Y.Y}}, + ... + ] +}} +``` +Where: +- "index" is the original position of the chunk in the input list. Index starts from 0. +- "score" is the relevance score (higher is better). + +Ensure that only relevant chunks are included in the output. If no chunk is relevant, return an empty list. + +""" diff --git a/ai/gen-ai-agents/custom_rag_agent/query_rewriter.py b/ai/gen-ai-agents/custom_rag_agent/query_rewriter.py new file mode 100644 index 000000000..1fb94d8fd --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/query_rewriter.py @@ -0,0 +1,93 @@ +""" +File name: query_rewriter.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module implements the first step in the agent's workflow: + query rewriting. + +Usage: + Import this module into other scripts to use its functions. + Example: + from query_rewriter import QueryRewriter + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +from langchain_core.runnables import Runnable +from langchain_core.messages import HumanMessage +from langchain.prompts import PromptTemplate + +# integration with APM +from py_zipkin.zipkin import zipkin_span + +from agent_state import State +from prompts import REFORMULATE_PROMPT_TEMPLATE +from oci_models import get_llm +from utils import get_console_logger +from config import DEBUG, AGENT_NAME + +logger = get_console_logger() + + +class QueryRewriter(Runnable): + """ + Takes the user request and the chat history and rewrite the user query + in a standalone question that is used for the semantic search + """ + + def __init__(self): + """ + Init + """ + + @zipkin_span(service_name=AGENT_NAME, span_name="query_rewriting") + def invoke(self, input: State, config=None, **kwargs): + """ + Rewrite the query + + Reformulate the question in a standalone question, using the chat_history + """ + user_request = input["user_request"] + error = None + + if len(input["chat_history"]) > 0: + logger.info("Reformulating the question...") + + try: + llm = get_llm(temperature=0) + + _prompt_template = PromptTemplate( + input_variables=["user_request", "chat_history"], + template=REFORMULATE_PROMPT_TEMPLATE, + ) + + prompt = _prompt_template.format( + user_request=user_request, chat_history=input["chat_history"] + ) + + messages = [ + HumanMessage(content=prompt), + ] + + standalone_question = llm.invoke(messages).content + + if DEBUG: + logger.info("Standalone question: %s", standalone_question) + except Exception as e: + logger.error("Error in query_rewriting: %s", e) + error = str(e) + else: + standalone_question = user_request + + return {"standalone_question": standalone_question, "error": error} diff --git a/ai/gen-ai-agents/custom_rag_agent/rag_agent.py b/ai/gen-ai-agents/custom_rag_agent/rag_agent.py new file mode 100644 index 000000000..6717ad9c4 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/rag_agent.py @@ -0,0 +1,68 @@ +""" +File name: rag_agent.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module implements the orchestration in the Agent. + Based on LanGraph. + +Usage: + Import this module into other scripts to use its functions. + Example: + + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +from langgraph.graph import StateGraph, START, END + +from agent_state import State + +from query_rewriter import QueryRewriter +from vector_search import SemanticSearch +from reranker import Reranker +from answer_generator import AnswerGenerator +from utils import get_console_logger + +logger = get_console_logger() + + +def create_workflow(): + """ + Create the entire workflow + """ + workflow = StateGraph(State) + + # create nodes + query_rewriter = QueryRewriter() + semantic_search = SemanticSearch() + reranker = Reranker() + answer_generator = AnswerGenerator() + + # Add nodes + workflow.add_node("QueryRewrite", query_rewriter) + workflow.add_node("Search", semantic_search) + workflow.add_node("Rerank", reranker) + workflow.add_node("Answer", answer_generator) + + # define edges + workflow.add_edge(START, "QueryRewrite") + workflow.add_edge("QueryRewrite", "Search") + workflow.add_edge("Search", "Rerank") + workflow.add_edge("Rerank", "Answer") + workflow.add_edge("Answer", END) + + # create workflow executor + workflow_app = workflow.compile() + + return workflow_app diff --git a/ai/gen-ai-agents/custom_rag_agent/rag_feedback.py b/ai/gen-ai-agents/custom_rag_agent/rag_feedback.py new file mode 100644 index 000000000..0823c8f32 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/rag_feedback.py @@ -0,0 +1,70 @@ +""" +File name: rag_feedback.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module implements handling of user's feedback. + Based on LangGraph. + +Usage: + Import this module into other scripts to use its functions. + Example: + + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +from datetime import datetime +import oracledb + +from config_private import CONNECT_ARGS + + +class RagFeedback: + """ + To register user feedback in the RAG_FEEDBACK table. + """ + + def __init__(self): + """ + Init + """ + + def get_connection(self): + """Establish the Oracle DB connection.""" + return oracledb.connect(**CONNECT_ARGS) + + def insert_feedback(self, question: str, answer: str, feedback: int): + """Insert a new feedback record into RAG_FEEDBACK table.""" + if feedback < 1 or feedback > 5: + raise ValueError("Feedback must be a number between 1 and 5.") + + sql = """ + INSERT INTO RAG_FEEDBACK (CREATED_AT, QUESTION, ANSWER, FEEDBACK) + VALUES (:created_at, :question, :answer, :feedback) + """ + + with self.get_connection() as conn: + cursor = conn.cursor() + + cursor.execute( + sql, + { + "created_at": datetime.now(), + "question": question, + "answer": answer, + "feedback": feedback, + }, + ) + conn.commit() + cursor.close() diff --git a/ai/gen-ai-agents/custom_rag_agent/requirements.txt b/ai/gen-ai-agents/custom_rag_agent/requirements.txt new file mode 100644 index 000000000..215e76e4a --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/requirements.txt @@ -0,0 +1,179 @@ +aiohappyeyeballs==2.4.8 +aiohttp==3.11.13 +aiosignal==1.3.2 +altair==5.5.0 +annotated-types==0.7.0 +anyio==4.8.0 +appnope==0.1.4 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +astroid==3.3.8 +asttokens==3.0.0 +async-lru==2.0.4 +attrs==25.1.0 +babel==2.17.0 +beautifulsoup4==4.13.3 +black==25.1.0 +bleach==6.2.0 +blinker==1.9.0 +cachetools==5.5.2 +certifi==2025.1.31 +cffi==1.17.1 +charset-normalizer==3.4.1 +circuitbreaker==2.0.0 +click==8.1.8 +comm==0.2.2 +contourpy==1.3.1 +cryptography==44.0.2 +cycler==0.12.1 +dataclasses-json==0.6.7 +debugpy==1.8.13 +decorator==5.2.1 +defusedxml==0.7.1 +dill==0.3.9 +executing==2.2.0 +fastjsonschema==2.21.1 +fonttools==4.56.0 +fqdn==1.5.1 +frozenlist==1.5.0 +gitdb==4.0.12 +GitPython==3.1.44 +h11==0.14.0 +httpcore==1.0.7 +httpx==0.28.1 +httpx-sse==0.4.0 +idna==3.10 +ipykernel==6.29.5 +ipython==9.0.1 +ipython_pygments_lexers==1.1.1 +ipywidgets==8.1.5 +isoduration==20.11.0 +isort==6.0.1 +jedi==0.19.2 +Jinja2==3.1.5 +json5==0.10.0 +jsonpatch==1.33 +jsonpointer==3.0.0 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +jupyter==1.1.1 +jupyter-console==6.6.3 +jupyter-events==0.12.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.3 +jupyter_core==5.7.2 +jupyter_server==2.15.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.3.5 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +jupyterlab_widgets==3.0.13 +kiwisolver==1.4.8 +lab==8.4 +langchain==0.3.20 +langchain-community==0.3.19 +langchain-core==0.3.41 +langchain-text-splitters==0.3.6 +langgraph==0.3.5 +langgraph-checkpoint==2.0.17 +langgraph-prebuilt==0.1.2 +langgraph-sdk==0.1.55 +langsmith==0.3.11 +MarkupSafe==3.0.2 +marshmallow==3.26.1 +matplotlib==3.10.1 +matplotlib-inline==0.1.7 +mccabe==0.7.0 +mistune==3.1.2 +msgpack==1.1.0 +multidict==6.1.0 +mypy-extensions==1.0.0 +narwhals==1.29.0 +nbclient==0.10.2 +nbconvert==7.16.6 +nbformat==5.10.4 +nest-asyncio==1.6.0 +notebook==7.3.2 +notebook_shim==0.2.4 +numpy==2.2.3 +oci==2.148.0 +oracledb==3.0.0 +orjson==3.10.15 +overrides==7.7.0 +packaging==24.2 +pandas==2.2.3 +pandocfilters==1.5.1 +parso==0.8.4 +pathspec==0.12.1 +pdfminer.six==20231228 +pdfplumber==0.11.5 +pexpect==4.9.0 +pillow==11.1.0 +platformdirs==4.3.6 +prometheus_client==0.21.1 +prompt_toolkit==3.0.50 +propcache==0.3.0 +protobuf==5.29.3 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +py-zipkin==1.2.8 +pyarrow==19.0.1 +pycparser==2.22 +pydantic==2.10.6 +pydantic-settings==2.8.1 +pydantic_core==2.27.2 +pydeck==0.9.1 +Pygments==2.19.1 +pylint==3.3.4 +PyMuPDF==1.25.4 +pyOpenSSL==24.3.0 +pyparsing==3.2.1 +pypdfium2==4.30.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +python-json-logger==3.3.0 +pytz==2025.1 +PyYAML==6.0.2 +pyzmq==26.2.1 +rank-bm25==0.2.2 +referencing==0.36.2 +requests==2.32.3 +requests-toolbelt==1.0.0 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rpds-py==0.23.1 +Send2Trash==1.8.3 +simplejson==3.20.1 +six==1.17.0 +smmap==5.0.2 +sniffio==1.3.1 +soupsieve==2.6 +SQLAlchemy==2.0.38 +stack-data==0.6.3 +streamlit==1.43.0 +tabulate==0.9.0 +tenacity==9.0.0 +terminado==0.18.1 +tinycss2==1.4.0 +tokenize_rt==6.1.0 +toml==0.10.2 +tomlkit==0.13.2 +tornado==6.4.2 +traitlets==5.14.3 +txt2tags==3.9 +types-python-dateutil==2.9.0.20241206 +typing-inspect==0.9.0 +typing_extensions==4.12.2 +tzdata==2025.1 +uri-template==1.3.0 +urllib3==2.3.0 +watchdog==6.0.0 +wcwidth==0.2.13 +webcolors==24.11.1 +webencodings==0.5.1 +websocket-client==1.8.0 +widgetsnbextension==4.0.13 +yarl==1.18.3 +zstandard==0.23.0 diff --git a/ai/gen-ai-agents/custom_rag_agent/reranker.py b/ai/gen-ai-agents/custom_rag_agent/reranker.py new file mode 100644 index 000000000..d5d8aaa2b --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/reranker.py @@ -0,0 +1,145 @@ +""" +File name: reranker.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module implements filtering and reranking documents + returned by Similarity Search, using a LLM + +Usage: + Import this module into other scripts to use its functions. + Example: + import config + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +# Import traceback for better error logging +import traceback +from langchain_core.runnables import Runnable +from langchain_core.messages import HumanMessage +from langchain.prompts import PromptTemplate + +# integration with APM +from py_zipkin.zipkin import zipkin_span + +from agent_state import State +from prompts import ( + RERANKER_TEMPLATE, +) +from oci_models import get_llm +from utils import get_console_logger, extract_json_from_text +from config import DEBUG, AGENT_NAME, TOP_K + +logger = get_console_logger() + + +class Reranker(Runnable): + """ + Implements a reranker using a LLM + """ + + def __init__(self): + """ + Init + """ + + def generate_refs(self, docs: list): + """ + Returns a list of reference dictionaries used in the reranker. + """ + return [ + {"source": doc.metadata["source"], "page": doc.metadata["page_label"]} + for doc in docs + ] + + @staticmethod + def get_reranked_docs(llm, query, retriever_docs): + """ + Rerank documents using LLM based on user request. + + query: the search query (can be reformulated) + retriever_docs: list of Langchain Documents + """ + # Prepare chunk texts + chunks = [doc.page_content for doc in retriever_docs] + + _prompt = PromptTemplate( + input_variables=["query", "chunks"], + template=RERANKER_TEMPLATE, + ).format(query=query, chunks=chunks) + + messages = [HumanMessage(content=_prompt)] + + reranker_output = llm.invoke(messages).content + + # Extract ranking order + json_dict = extract_json_from_text(reranker_output) + + if DEBUG: + logger.info(json_dict.get("ranked_chunks", "No ranked chunks found.")) + + # Get indexes and sort documents + # added < TOP_K (hallucinations ?) + indexes = [ + chunk["index"] + for chunk in json_dict.get("ranked_chunks", []) + if chunk["index"] < TOP_K + ] + + return [retriever_docs[i] for i in indexes] + + @zipkin_span(service_name=AGENT_NAME, span_name="reranking") + def invoke(self, input: State, config=None, **kwargs): + """ + Implements reranking logic. + + input: The agent state. + """ + enable_reranker = config["configurable"]["enable_reranker"] + + user_request = input.get("standalone_question", "") + retriever_docs = input.get("retriever_docs", []) + error = None + + if DEBUG: + logger.info("Reranker input state: %s", input) + + try: + if retriever_docs: + # there is something to rerank! + if enable_reranker: + # do reranking + llm = get_llm(temperature=0.0) + + reranked_docs = self.get_reranked_docs( + llm, user_request, retriever_docs + ) + + else: + reranked_docs = retriever_docs + else: + reranked_docs = [] + + except Exception as e: + logger.error("Error in reranker: %s", e) + # Log the full stack trace for debugging + logger.debug(traceback.format_exc()) + error = str(e) + # Fallback to original documents + reranked_docs = retriever_docs + + # Get reference citations + citations = self.generate_refs(reranked_docs) + + return {"reranker_docs": reranked_docs, "citations": citations, "error": error} diff --git a/ai/gen-ai-agents/custom_rag_agent/run_assistant.sh b/ai/gen-ai-agents/custom_rag_agent/run_assistant.sh new file mode 100755 index 000000000..96113e173 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/run_assistant.sh @@ -0,0 +1,3 @@ +streamlit run assistant_ui_langgraph.py + + diff --git a/ai/gen-ai-agents/custom_rag_agent/transport.py b/ai/gen-ai-agents/custom_rag_agent/transport.py new file mode 100644 index 000000000..5ce876408 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/transport.py @@ -0,0 +1,81 @@ +""" +File name: trasnport.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This code provide the http transport support for integration with OCI APM. + +Usage: + Import this module into other scripts to use its functions. + Example: + ... + + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +import requests +from utils import get_console_logger + +# changed to handle ENABLE_TRACING from UI +import config +from config_private import APM_PUBLIC_KEY + + +logger = get_console_logger() + + +def http_transport(encoded_span): + """ + Sends encoded tracing data to OCI APM using py-zipkin. + + Args: + encoded_span (bytes): The encoded span data to send. + + Returns: + requests.Response or None: The response from the APM service or None if tracing is disabled. + """ + try: + # Load config inside the function to avoid global dependency issues + base_url = config.APM_BASE_URL + content_type = config.APM_CONTENT_TYPE + + # Validate configuration + if not base_url: + raise ValueError("APM base URL is not configured") + if not APM_PUBLIC_KEY: + raise ValueError("APM public key is missing") + + # If tracing is disabled, do nothing + if not config.ENABLE_TRACING: + logger.info("Tracing is disabled. No data sent to APM.") + return None + + # Construct endpoint dynamically + apm_url = f"{base_url}/observations/public-span?dataFormat=zipkin&dataFormatVersion=2&dataKey={APM_PUBLIC_KEY}" + + response = requests.post( + apm_url, + data=encoded_span, + headers={"Content-Type": content_type}, + timeout=30, + ) + response.raise_for_status() # Raise exception for HTTP errors + + return response + except requests.RequestException as e: + logger.error("Failed to send span to APM: %s", str(e)) + return None + except Exception as e: + logger.error("Unexpected error in http_transport: %s", str(e)) + return None diff --git a/ai/gen-ai-agents/custom_rag_agent/utils.py b/ai/gen-ai-agents/custom_rag_agent/utils.py new file mode 100644 index 000000000..1825574df --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/utils.py @@ -0,0 +1,111 @@ +""" +File name: utils.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + Utility functions here. + +Usage: + Import this module into other scripts to use its functions. + Example: + from utils import ... + + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +import os +import logging +import re +import json + + +def get_console_logger(): + """ + To get a logger to print on console + """ + logger = logging.getLogger("ConsoleLogger") + + # to avoid duplication of logging + if not logger.handlers: + logger.setLevel(logging.INFO) + + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + + formatter = logging.Formatter("%(asctime)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.propagate = False + + return logger + + +def extract_text_triple_backticks(_text): + """ + Extracts all text enclosed between triple backticks (```) from a string. + + :param text: The input string to analyze. + :return: A list containing the texts found between triple backticks. + """ + logger = get_console_logger() + + pattern = r"```(.*?)```" # Uses (.*?) to capture text between backticks in a non-greedy way + # re.DOTALL allows capturing multiline content + + try: + _result = [block.strip() for block in re.findall(pattern, _text, re.DOTALL)][0] + except Exception as e: + logger.info("no triple backtickes in extract_text_triple_backticks: %s", e) + + # try to be resilient, return the entire text + _result = _text + + return _result + + +def extract_json_from_text(text): + """ + Extracts JSON content from a given text and returns it as a Python dictionary. + + Args: + text (str): The input text containing JSON content. + + Returns: + dict: Parsed JSON data. + """ + try: + # Use regex to extract JSON content (contained between {}) + json_match = re.search(r"\{.*\}", text, re.DOTALL) + if json_match: + json_content = json_match.group(0) + return json.loads(json_content) + + # If no JSON content is found, raise an error + raise ValueError("No JSON content found in the text.") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format: {e}") + + +# for the loading utility +def remove_path_from_ref(ref_pathname): + """ + remove the path from source (ref) + """ + ref = ref_pathname + # check if / or \ is contained + if len(ref_pathname.split(os.sep)) > 0: + ref = ref_pathname.split(os.sep)[-1] + + return ref diff --git a/ai/gen-ai-agents/custom_rag_agent/vector_search.py b/ai/gen-ai-agents/custom_rag_agent/vector_search.py new file mode 100644 index 000000000..ce9510a53 --- /dev/null +++ b/ai/gen-ai-agents/custom_rag_agent/vector_search.py @@ -0,0 +1,178 @@ +""" +File name: vector_search.py +Author: Luigi Saetta +Date last modified: 2025-03-31 +Python Version: 3.11 + +Description: + This module implements the Semantic Sesarch in the agent + using 23Ai Vector Search + + +Usage: + Import this module into other scripts to use its functions. + Example: + import config + +License: + This code is released under the MIT License. + +Notes: + This is a part of a demo showing how to implement an advanced + RAG solution as a LangGraph agent. + +Warnings: + This module is in development, may change in future versions. +""" + +import oracledb +from langchain_core.runnables import Runnable +from langchain_community.vectorstores.utils import DistanceStrategy +from langchain_community.embeddings import OCIGenAIEmbeddings +from langchain_community.vectorstores.oraclevs import OracleVS + +# integration with APM +from py_zipkin.zipkin import zipkin_span + +from agent_state import State +from utils import get_console_logger + +from config import ( + AGENT_NAME, + DEBUG, + AUTH, + EMBED_MODEL_ID, + SERVICE_ENDPOINT, + COMPARTMENT_ID, + TOP_K, +) + +from config_private import CONNECT_ARGS + +logger = get_console_logger() + + +class SemanticSearch(Runnable): + """ + Implements Semantic Search for the Agent + """ + + def __init__(self): + """ + Init + """ + + def get_connection(self): + """ + get a connection to the DB + """ + return oracledb.connect(**CONNECT_ARGS) + + def get_embedding_model(self): + """ + Create the Embedding Model + """ + embed_model = OCIGenAIEmbeddings( + auth_type=AUTH, + model_id=EMBED_MODEL_ID, + service_endpoint=SERVICE_ENDPOINT, + compartment_id=COMPARTMENT_ID, + ) + + return embed_model + + @zipkin_span(service_name=AGENT_NAME, span_name="similarity_search") + def invoke(self, input: State, config=None, **kwargs): + """ + This method invokes the vector search + + input: the agent state + """ + collection_name = config["configurable"]["collection_name"] + + relevant_docs = [] + error = None + + standalone_question = input["standalone_question"] + + if DEBUG: + logger.info("Search question: %s", standalone_question) + + try: + embed_model = self.get_embedding_model() + + # get a connection to the DB and init VS + with self.get_connection() as conn: + v_store = OracleVS( + client=conn, + table_name=collection_name, + distance_strategy=DistanceStrategy.COSINE, + embedding_function=embed_model, + ) + + relevant_docs = v_store.similarity_search( + query=standalone_question, k=TOP_K + ) + + if DEBUG: + logger.info("Result from similarity search:") + logger.info(relevant_docs) + + except Exception as e: + logger.error("Error in vector_store.invoke: %s", e) + error = str(e) + + return {"retriever_docs": relevant_docs, "error": error} + + # + # Helper functions + # + def list_books_in_collection(self, collection_name: str) -> list: + """ + get the list of books/documents names in the collection + taken from metadata + expect metadata contains the field source + + modified to return also the numb. of chunks + """ + query = f""" + SELECT DISTINCT json_value(METADATA, '$.source') AS books, + count(*) as n_chunks + FROM {collection_name} + group by books + ORDER by books ASC + """ + with self.get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query) + + rows = cursor.fetchall() + + list_books = [] + for row in rows: + list_books.append((row[0], row[1])) + + return list_books + + def add_documents(self, docs: list, collection_name: str): + """ + Add the chunks to an existing collection + + docs is a list of Langchain documents + """ + try: + embed_model = self.get_embedding_model() + + with self.get_connection() as conn: + v_store = OracleVS( + client=conn, + table_name=collection_name, + distance_strategy=DistanceStrategy.COSINE, + embedding_function=embed_model, + ) + v_store.add_documents(docs) + logger.info("Added docs to collection %s", collection_name) + + except Exception as e: + logger.error("Error in vector_store.add_documents: %s", e) + raise e