Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ai/gen-ai-agents/custom_rag_agent/agent_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class State(TypedDict):
standalone_question: str = ""

# similarity_search
# 30/06: modified, now they're a dict with
# page_content and metadata
# populated with docs_serializable (utils.py)
retriever_docs: Optional[list] = []
# reranker
reranker_docs: Optional[list] = []
Expand Down
11 changes: 5 additions & 6 deletions ai/gen-ai-agents/custom_rag_agent/answer_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
File name: answer_generator.py
Author: Luigi Saetta
Date last modified: 2025-03-31
Date last modified: 2025-04-02
Python Version: 3.11

Description:
Expand Down Expand Up @@ -67,10 +67,8 @@ def build_context_for_llm(self, docs: list):

docs: list[Documents]
"""
_context = ""

for doc in docs:
_context += doc.page_content + "\n\n"
# more Pythonic
_context = "\n\n".join(doc["page_content"] for doc in docs)

return _context

Expand All @@ -79,7 +77,7 @@ def invoke(self, input: State, config=None, **kwargs):
"""
Generate the final answer
"""
# get the config
# get the model_id from config
model_id = config["configurable"]["model_id"]

if config["configurable"]["main_language"] in self.dict_languages:
Expand All @@ -102,6 +100,7 @@ def invoke(self, input: State, config=None, **kwargs):
try:
llm = get_llm(model_id=model_id)

# docs are returned from the reranker
_context = self.build_context_for_llm(input["reranker_docs"])

system_prompt = PromptTemplate(
Expand Down
22 changes: 12 additions & 10 deletions ai/gen-ai-agents/custom_rag_agent/assistant_ui_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
File name: assistant_ui.py
Author: Luigi Saetta
Date created: 2024-12-04
Date last modified: 2025-03-31
Date last modified: 2025-07-01
Python Version: 3.11

Description:
Expand All @@ -15,7 +15,7 @@
This code is released under the MIT License.

Notes:
This is part of a demo fro a RAG solution implemented
This is part of a demo for a RAG solution implemented
using LangGraph

Warnings:
Expand All @@ -38,7 +38,7 @@
from transport import http_transport
from utils import get_console_logger

# changed to better manage ENABLE_TRACING
# changed to better manage ENABLE_TRACING (can be enabled from UI)
import config

# Constant
Expand Down Expand Up @@ -142,13 +142,14 @@ def register_feedback():

st.sidebar.header("Options")

st.sidebar.text_input(label="Region", value=config.REGION, disabled=True)

# the collection used for semantic search
st.session_state.collection_name = st.sidebar.selectbox(
"Collection name",
config.COLLECTION_LIST,
)

# add the choice of LLM (not used for now)
st.session_state.main_language = st.sidebar.selectbox(
"Select the language for the answer",
config.LANGUAGE_LIST,
Expand Down Expand Up @@ -203,11 +204,11 @@ def register_feedback():
encoding=Encoding.V2_JSON,
sample_rate=100,
) as span:
# loop to manage streaming
# set the agent config
agent_config = {
"configurable": {
"model_id": st.session_state.model_id,
"embed_model_type": config.EMBED_MODEL_TYPE,
"enable_reranker": st.session_state.enable_reranker,
"enable_tracing": config.ENABLE_TRACING,
"main_language": st.session_state.main_language,
Expand All @@ -219,6 +220,7 @@ def register_feedback():
if config.DEBUG:
logger.info("Agent config: %s", agent_config)

# loop to manage streaming
for event in st.session_state.workflow.stream(
input_state,
config=agent_config,
Expand Down Expand Up @@ -248,13 +250,13 @@ def register_feedback():
# Stream
with st.chat_message(ASSISTANT):
response_container = st.empty()
full_response = ""
FULL_RESPONSE = ""

for chunk in answer_generator:
full_response += chunk.content
response_container.markdown(full_response + "▌")
FULL_RESPONSE += chunk.content
response_container.markdown(FULL_RESPONSE + "▌")

response_container.markdown(full_response)
response_container.markdown(FULL_RESPONSE)

elapsed_time = round((time.time() - time_start), 1)
logger.info("Elapsed time: %s sec.", elapsed_time)
Expand All @@ -268,7 +270,7 @@ def register_feedback():

# Add user/assistant message to chat history
add_to_chat_history(HumanMessage(content=question))
add_to_chat_history(AIMessage(content=full_response))
add_to_chat_history(AIMessage(content=FULL_RESPONSE))

# get the feedback
if st.session_state.get_feedback:
Expand Down
52 changes: 36 additions & 16 deletions ai/gen-ai-agents/custom_rag_agent/bm25_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,22 @@ def fetch_text_data(self):
cursor.execute(query)

while True:
rows = cursor.fetchmany(self.batch_size) # Fetch records in batches
# Fetch records in batches
rows = cursor.fetchmany(self.batch_size)
if not rows:
break # Exit loop when no more data
# Exit loop when no more data
break

for row in rows:
lob_data = row[0] # This is a CLOB object
# This is a CLOB object
lob_data = row[0]

if isinstance(lob_data, oracledb.LOB):
_results.append(lob_data.read()) # Read LOB content
# Read LOB content
_results.append(lob_data.read())
else:
_results.append(str(lob_data)) # Fallback for non-LOB data
# Fallback for non-LOB data
_results.append(str(lob_data))

return _results

Expand Down Expand Up @@ -116,18 +121,33 @@ def search(self, query, top_n=5):

# Example Usage:
# credential are packed in CONNECT_ARGS
table_name = "BOOKS"
text_column = "TEXT"

# create the index
bm25_search = BM25OracleSearch(table_name, text_column)

questions = ["Chi è Luigi Saetta?", "What are the main innovation produced by GPT-4?"]
def run_test():
"""
To run a quick test.
"""
table_name = "BOOKS"
text_column = "TEXT"

# create the index
bm25_search = BM25OracleSearch(table_name, text_column)

questions = [
"Chi è Luigi Saetta?",
"What are the main innovation produced by GPT-4?",
]

for _question in questions:
results = bm25_search.search(_question, top_n=2)

# Print search results
for text, score in results:
print(f"Score: {score:.2f} - Text: {text}")
print("")

for _question in questions:
results = bm25_search.search(_question, top_n=2)

# Print search results
for text, score in results:
print(f"Score: {score:.2f} - Text: {text}")
print("")
#
# Main
#
run_test()
48 changes: 35 additions & 13 deletions ai/gen-ai-agents/custom_rag_agent/chunk_index_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@
logger = get_console_logger()


def get_chunk_header(file_path):
"""
Generate an header for the chunk.
"""
doc_name = remove_path_from_ref(file_path)
# split to remove the extension
doc_title = doc_name.split(".")[0]

return f"# Doc. title: {doc_title}\n", doc_name


def get_recursive_text_splitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP):
"""
return a recursive text splitter
Expand All @@ -39,7 +50,15 @@ def get_recursive_text_splitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERL

def load_and_split_pdf(book_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP):
"""
load a single book
Loads and splits a PDF document into chunks using a recursive character text splitter.

Args:
book_path (str): The file path of the PDF document.
chunk_size (int): Size of each text chunk.
chunk_overlap (int): Overlap between chunks.

Returns:
List[Document]: A list of LangChain Document objects with metadata.
"""
text_splitter = get_recursive_text_splitter(chunk_size, chunk_overlap)

Expand All @@ -50,28 +69,33 @@ def load_and_split_pdf(book_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVE
chunk_header = ""

if len(docs) > 0:
doc_name = remove_path_from_ref(book_path)
# split to remove the extension
doc_title = doc_name.split(".")[0]
chunk_header = f"# Doc. title: {doc_title}\n"
chunk_header, _ = get_chunk_header(book_path)

# remove path from source and reduce the metadata (16/03/2025)
for doc in docs:
# add more context to the chunk
doc.page_content = chunk_header + doc.page_content
doc.metadata = {
"source": remove_path_from_ref(book_path),
"page_label": doc.metadata["page_label"],
"page_label": doc.metadata.get("page_label", ""),
}

logger.info("Loaded %s chunks...", len(docs))
logger.info("Successfully loaded and split %d chunks from %s", len(docs), book_path)

return docs


def load_and_split_docx(file_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP):
"""
To load docx files
Loads and splits a docx document into chunks using a recursive character text splitter.

Args:
file_path (str): The file path of the document.
chunk_size (int): Size of each text chunk.
chunk_overlap (int): Overlap between chunks.

Returns:
List[Document]: A list of LangChain Document objects with metadata.
"""
loader = UnstructuredLoader(file_path)
docs = loader.load()
Expand All @@ -80,12 +104,10 @@ def load_and_split_docx(file_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OV
grouped_text = defaultdict(list)

chunk_header = ""
doc_name = ""

if len(docs) > 0:
doc_name = remove_path_from_ref(file_path)
# split to remove the extension
doc_title = doc_name.split(".")[0]
chunk_header = f"# Doc. title: {doc_title}\n"
chunk_header, doc_name = get_chunk_header(file_path)

for doc in docs:
# fallback to 0 if not available
Expand Down Expand Up @@ -115,6 +137,6 @@ def load_and_split_docx(file_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OV
)
)

logger.info("Loaded %s chunks...", len(final_chunks))
logger.info("Successfully loaded and split %d chunks from %s", len(docs), file_path)

return final_chunks
Loading