diff --git a/agentic_rag/README.md b/agentic_rag/README.md index 6367cd6..e421048 100644 --- a/agentic_rag/README.md +++ b/agentic_rag/README.md @@ -4,6 +4,8 @@ An intelligent RAG (Retrieval Augmented Generation) system that uses an LLM agent to make decisions about information retrieval and response generation. The system processes PDF documents and can intelligently decide which knowledge base to query based on the user's question. +CoT output + The system has the following features: - Intelligent query routing @@ -12,8 +14,19 @@ The system has the following features: - Smart context retrieval and response generation - FastAPI-based REST API for document upload and querying - Support for both OpenAI-based agents or local, transformer-based agents (`Mistral-7B` by default) +- Support for quantized models (4-bit/8-bit) and Ollama models for faster inference - Optional Chain of Thought (CoT) reasoning for more detailed and structured responses +Gradio Interface + +Gradio Interface + +Gradio Interface + +Here you can find a result of using Chain of Thought (CoT) reasoning: + +CoT output + ## 0. Prerequisites and setup ### Prerequisites @@ -29,6 +42,8 @@ The system has the following features: - Minimum 16GB RAM (recommended >24GBs) - GPU with 8GB VRAM recommended for better performance - Will run on CPU if GPU is not available, but will be significantly slower. + - For quantized models (4-bit/8-bit): Reduced VRAM requirements (4-6GB) with minimal performance impact + - For Ollama models: Requires Ollama to be installed and running, with significantly reduced memory requirements ### Setup @@ -36,11 +51,11 @@ The system has the following features: ```bash git clone https://github.com/oracle-devrel/devrel-labs.git - cd agentic-rag + cd devrel-labs/agentic_rag pip install -r requirements.txt ``` -2. Authenticate with HuggingFace: +2. Authenticate with HuggingFace (for Hugging Face models only): The system uses `Mistral-7B` by default, which requires authentication with HuggingFace: @@ -63,6 +78,30 @@ The system has the following features: If no API key is provided, the system will automatically download and use `Mistral-7B-Instruct-v0.2` for text generation when using the local model. No additional configuration is needed. +4. For quantized models, ensure bitsandbytes is installed: + + ```bash + pip install bitsandbytes>=0.41.0 + ``` + +5. For Ollama models, install Ollama: + + a. Download and install Ollama from [ollama.com/download](https://ollama.com/download) for Windows, or run the following command in Linux: + + ```bash + curl -fsSL https://ollama.com/install.sh | sh + ``` + + b. Start the Ollama service + + c. Pull the models you want to use beforehand: + + ```bash + ollama pull llama3 + ollama pull phi3 + ollama pull qwen2 + ``` + ## 1. Getting Started You can launch this solution in three ways: @@ -93,19 +132,29 @@ python gradio_app.py This will start the Gradio server and automatically open the interface in your default browser at `http://localhost:7860`. The interface has two main tabs: -1. **Document Processing**: +1. **Model Management**: + - Download models in advance to prepare them for use + - View model information including size and VRAM requirements + - Check download status and error messages + +2. **Document Processing**: - Upload PDFs using the file uploader - Process web content by entering URLs - View processing status and results -2. **Chat Interface**: - - Select between Local (Mistral) and OpenAI models +3. **Chat Interface**: + - Select between different model options: + - Local (Mistral) - Default Mistral-7B model (recommended) + - Local (Mistral) with 4-bit or 8-bit quantization for faster inference + - Ollama models (llama3, phi-3, qwen2) as alternative options + - OpenAI (if API key is configured) - Toggle Chain of Thought reasoning for more detailed responses - Chat with your documents using natural language - Clear chat history as needed Note: The interface will automatically detect available models based on your configuration: -- Local Mistral model requires HuggingFace token in `config.yaml` +- Local Mistral model requires HuggingFace token in `config.yaml` (default option) +- Ollama models require Ollama to be installed and running (alternative options) - OpenAI model requires API key in `.env` file ### 3. Using Individual Python Components via Command Line @@ -301,14 +350,19 @@ This endpoint processes a query through the agentic RAG pipeline and returns a r ## Annex: Architecture +Architecture + The system consists of several key components: -1. **PDF Processor**: we use Docling to extract and chunk text from PDF documents -2. **Vector Store**: Manages document embeddings and similarity search using ChromaDB -3. **RAG Agent**: Makes intelligent decisions about query routing and response generation +1. **PDF Processor**: we use `docling` to extract and chunk text from PDF documents +2. **Web Processor**: we use `trafilatura` to extract and chunk text from websites +3. **GitHub Repository Processor**: we use `gitingest` to extract and chunk text from repositories +4. **Vector Store**: Manages document embeddings and similarity search using `ChromaDB` +5. **RAG Agent**: Makes intelligent decisions about query routing and response generation - OpenAI Agent: Uses `gpt-4-turbo-preview` for high-quality responses, but requires an OpenAI API key - Local Agent: Uses `Mistral-7B` as an open-source alternative -4. **FastAPI Server**: Provides REST API endpoints for document upload and querying +6. **FastAPI Server**: Provides REST API endpoints for document upload and querying +7. **Gradio Interface**: Provides a user-friendly web interface for interacting with the RAG system The RAG Agent flow is the following: diff --git a/agentic_rag/agents/agent_factory.py b/agentic_rag/agents/agent_factory.py index 2f26055..75254e8 100644 --- a/agentic_rag/agents/agent_factory.py +++ b/agentic_rag/agents/agent_factory.py @@ -27,11 +27,30 @@ class Agent(BaseModel): def log_prompt(self, prompt: str, prefix: str = ""): """Log a prompt being sent to the LLM""" - logger.info(f"\n{'='*80}\n{prefix} Prompt:\n{'-'*40}\n{prompt}\n{'='*80}") + # Check if the prompt contains context + if "Context:" in prompt: + # Split the prompt at "Context:" and keep only the first part + parts = prompt.split("Context:") + # Keep the first part and add a note that context is omitted + truncated_prompt = parts[0] + "Context: [Context omitted for brevity]" + if len(parts) > 2 and "Key Findings:" in parts[1]: + # For researcher prompts, keep the "Key Findings:" part + key_findings_part = parts[1].split("Key Findings:") + if len(key_findings_part) > 1: + truncated_prompt += "\nKey Findings:" + key_findings_part[1] + logger.info(f"\n{'='*80}\n{prefix} Prompt:\n{'-'*40}\n{truncated_prompt}\n{'='*80}") + else: + # If no context, log the full prompt + logger.info(f"\n{'='*80}\n{prefix} Prompt:\n{'-'*40}\n{prompt}\n{'='*80}") def log_response(self, response: str, prefix: str = ""): """Log a response received from the LLM""" - logger.info(f"\n{'='*80}\n{prefix} Response:\n{'-'*40}\n{response}\n{'='*80}") + # Log the response but truncate if it's too long + if len(response) > 500: + truncated_response = response[:500] + "... [response truncated]" + logger.info(f"\n{'='*80}\n{prefix} Response:\n{'-'*40}\n{truncated_response}\n{'='*80}") + else: + logger.info(f"\n{'='*80}\n{prefix} Response:\n{'-'*40}\n{response}\n{'='*80}") class PlannerAgent(Agent): """Agent responsible for breaking down problems and planning steps""" @@ -108,6 +127,7 @@ def research(self, query: str, step: str) -> List[Dict[str, Any]]: Key Findings:""" + # Create context string but don't log it context_str = "\n\n".join([f"Source {i+1}:\n{item['content']}" for i, item in enumerate(all_results)]) prompt = ChatPromptTemplate.from_template(template) messages = prompt.format_messages(step=step, context=context_str) @@ -140,6 +160,7 @@ def reason(self, query: str, step: str, context: List[Dict[str, Any]]) -> str: Conclusion:""" + # Create context string but don't log it context_str = "\n\n".join([f"Context {i+1}:\n{item['content']}" for i, item in enumerate(context)]) prompt = ChatPromptTemplate.from_template(template) messages = prompt.format_messages(step=step, query=query, context=context_str) diff --git a/agentic_rag/articles/kubernetes_rag.md b/agentic_rag/articles/kubernetes_rag.md new file mode 100644 index 0000000..48cd03d --- /dev/null +++ b/agentic_rag/articles/kubernetes_rag.md @@ -0,0 +1,184 @@ +# Agentic RAG: Enterprise-Scale Multi-Agent AI System on Oracle Cloud Infrastructure + +## Introduction + + + +Agentic RAG is an advanced Retrieval-Augmented Generation system that employs a multi-agent architecture with Chain-of-Thought reasoning, designed for enterprise-scale deployment on Oracle Cloud Infrastructure (OCI). + +The system leverages specialized AI agents for complex document analysis and query processing, while taking advantage of OCI's managed Kubernetes service and security features for production-grade deployment. + +With this article, we want to show you how you can get started in a few steps to install and deploy this multi-agent RAG system using Oracle Kubernetes Engine (OKE) and OCI. + +## Features + +This Agentic RAG system is based on the following technologies: + +- Oracle Kubernetes Engine (OKE) +- Oracle Cloud Infrastructure (OCI) +- `ollama` as the inference server for most Large Language Models (LLMs) available in the solution (`llama3`, `phi3`, `qwen2`) +- `Mistral-7B` language model, with an optional multi-agent Chain of Thought reasoning +- `ChromaDB` as vector store and retrieval system +- `Trafilatura`, `docling` and `gitingest` to extract the content from PDFs and web pages, and have them ready to be used by the RAG system +- Multi-agent architecture with specialized agents: + - Planner Agent: Strategic decomposition of complex queries + - Research Agent: Intelligent information retrieval (from vector database) + - Reasoning Agent: Logical analysis and conclusion drawing + - Synthesis Agent: Comprehensive response generation +- Support for both cloud-based (OpenAI) and local (Mistral-7B) language models +- Step-by-step reasoning visualization +- `Gradio` web interface for easy interaction with the RAG system + +There are several benefits to using Containerized LLMs over running the LLMs directly on the cloud instances. For example: + +- **Scalability**: you can easily scale the LLM workloads across Kubernetes clusters. In our case, we're deploying the solution with 4 agents in the same cluster, but you could deploy each agent in a different cluster if you wanted to accelerate the Chain-of-Thought reasoning processing time (horizontal scaling). You could also use vertical scaling by adding more resources to the same agent. +- **Resource Optimization**: you can efficiently allocate GPU and memory resources for each agent +- **Isolation**: Each agent runs in its own container for better resource management +- **Version Control**: easily update and rollback LLM versions and configurations +- **Reproducibility**: have a consistent environment across development and production, which is crucial when you're working with complex LLM applications +- **Cost Efficiency**: you pay only for the resources you need, and when you're doen with your work, you can simply stop the Kubernetes cluster and you won't be charged for the resources anymore. +- **Integration**: you can easily integrate the RAG system with other programming languages or frameworks, as we also made available a REST-based API to interact with the system, apart from the standard web interface. + +In conclusion, it's really easy to scale your system up and down with Kubernetes, without having to worry about the underlying infrastructure, installation, configuration, etc. + +Note that the way we've planned the infrastructure is important because it allows us to: +1. Scale the `chromadb` vector store system independently +2. The LLM container can be shared across agents, meaning only deploying the LLM container once, and then using it across all the agents +3. The `Research Agent` can be scaled separately for parallel document processing, if needed +4. Memory and GPU resources can be optimized, since there's only one LLM instance running + +## Deployment in Kubernetes + +We have devised two different ways to deploy in Kubernetes: either through a local or distributed system, each offering its own advantages. + +### Local Deployment + +This method is the easiest way to implement and deploy. We call it local because every resource is deployed in the same pod. The advantages are the following: + +- **Simplicity**: All components run in a single pod, making deployment and management straightforward +- **Easier debugging**: Troubleshooting is simpler when all logs and components are in one place (we're looking to expand the standard logging mechanism that we have right now with `fluentd`) +- **Quick setup**: Ideal for testing, development, or smaller-scale deployments +- **Lower complexity**: No need to configure inter-service communication or network policies like port forwarding or such mechanisms. + +### Distributed System Deployment + +By decoupling the `ollama` LLM inference system to another pod, we could easily ready our system for **vertical scaling**: if we're ever running out of resources or we need to use a bigger model, we don't have to worry about the other solution components not having enough resources for processing and logging: we can simply scale up our inference pod and connect it via a FastAPI or similar system to allow the Gradio interface to make calls to the model, following a distributed system architecture. + +The advantages are: + +- **Independent Scaling**: Each component can be scaled according to its specific resource needs +- **Resource Optimization**: Dedicated resources for compute-intensive LLM inference separate from other components +- **High Availability**: System remains operational even if individual components fail, and we can have multiple pods running failover LLMs to help us with disaster recovery. +- **Flexible Model Deployment**: Easily swap or upgrade LLM models without affecting the rest of the system (also, with virtually zero downtime!) +- **Load Balancing**: Distribute inference requests across multiple LLM pods for better performance, thus allowing concurrent users in our Gradio interface. +- **Isolation**: Performance issues on the LLM side won't impact the interface +- **Cost Efficiency**: Allocate expensive GPU resources only where needed (inference) while using cheaper CPU resources for other components (e.g. we use GPU for Chain of Thought reasoning, while keeping a quantized CPU LLM for standard chatting). + +## Quick Start + +For this solution, we have currently implemented the local system deployment, which is what we'll cover in this section. + +First, we need to create a GPU OKE cluster with `zx` and Terraform. For this, you can follow the steps in [this repository](https://github.com/vmleon/oci-oke-gpu), or reuse your own Kubernetes cluster if you happen to already have one. + +Then, we can start setting up the solution in our cluster by following these steps. + +1. Clone the repository containing the Kubernetes manifests: + + ```bash + git clone https://github.com/oracle-devrel/devrel-labs.git + cd devrel-labs/agentic_rag/k8s + ``` + +2. Create a namespace: + + ```bash + kubectl create namespace agentic-rag + ``` + +3. Create a ConfigMap: + + This step will help our deployment for several reasons: + + 1. **Externalized Configuration**: It separates configuration from application code, following best practices for containerized applications + 2. **Environment-specific Settings**: Allows us to maintain different configurations for development, testing, and production environments + 3. **Credential Management**: Provides a way to inject API tokens (like Hugging Face) without hardcoding them in the image + 4. **Runtime Configuration**: Enables changing configuration without rebuilding or redeploying the application container + 5. **Consistency**: Ensures all pods use the same configuration when scaled horizontally + + In our specific case, the ConfigMap stores the Hugging Face Hub token for accessing (and downloading) the `mistral-7b` model (and CPU-quantized variants) + - Optionally, OpenAI API keys if using those models + - Any other environment-specific variables needed by the application, in case we want to make further development and increase the capabilities of the system with external API keys, authentication tokens... etc. + + Let's run the following command to create the config map: + + ```bash + # With a Hugging Face token + cat <`. + +## Resource Requirements + +The deployment of this solution requires the following minimum resources: + +- **CPU**: 4+ cores +- **Memory**: 16GB+ RAM +- **Storage**: 50GB+ +- **GPU**: recommended for faster inference. In theory, you can use `mistral-7b` CPU-quantized models, but it will be sub-optimal. + +## Conclusion + +You can check out the full AI solution and the deployment options we mention in this article in [the official GitHub repository](https://github.com/oracle-devrel/devrel-labs/tree/main/agentic_rag). \ No newline at end of file diff --git a/agentic_rag/gradio_app.py b/agentic_rag/gradio_app.py index e07cd17..6a8f98e 100644 --- a/agentic_rag/gradio_app.py +++ b/agentic_rag/gradio_app.py @@ -5,6 +5,8 @@ import tempfile from dotenv import load_dotenv import yaml +import torch +import time from pdf_processor import PDFProcessor from web_processor import WebProcessor @@ -77,31 +79,75 @@ def process_repo(repo_path: str) -> str: except Exception as e: return f"βœ— Error processing repository: {str(e)}" -def chat(message: str, history: List[List[str]], agent_type: str, use_cot: bool, language: str, collection: str) -> List[List[str]]: +def chat(message: str, history: List[List[str]], agent_type: str, use_cot: bool, collection: str) -> List[List[str]]: """Process chat message using selected agent and collection""" try: print("\n" + "="*50) print(f"New message received: {message}") - print(f"Agent: {agent_type}, CoT: {use_cot}, Language: {language}, Collection: {collection}") + print(f"Agent: {agent_type}, CoT: {use_cot}, Collection: {collection}") print("="*50 + "\n") + # Determine if we should skip analysis based on collection and interface type + # Skip analysis for General Knowledge or when using standard chat interface (not CoT) + skip_analysis = collection == "General Knowledge" or not use_cot + + # Parse agent type to determine model and quantization + quantization = None + model_name = None + + if "4-bit" in agent_type: + quantization = "4bit" + model_type = "Local (Mistral)" + elif "8-bit" in agent_type: + quantization = "8bit" + model_type = "Local (Mistral)" + elif "Ollama" in agent_type: + model_type = "Ollama" + # Extract model name from agent_type and use correct Ollama model names + if "llama3" in agent_type.lower(): + model_name = "ollama:llama3" + elif "phi-3" in agent_type.lower(): + model_name = "ollama:phi3" + elif "qwen2" in agent_type.lower(): + model_name = "ollama:qwen2" + else: + model_type = agent_type + # Select appropriate agent and reinitialize with correct settings - if agent_type == "Local (Mistral)": + if "Local" in model_type: + # For HF models, we need the token if not hf_token: response_text = "Local agent not available. Please check your HuggingFace token configuration." print(f"Error: {response_text}") return history + [[message, response_text]] - agent = LocalRAGAgent(vector_store, use_cot=use_cot) + agent = LocalRAGAgent(vector_store, use_cot=use_cot, collection=collection, + skip_analysis=skip_analysis, quantization=quantization) + elif model_type == "Ollama": + # For Ollama models + if model_name: + try: + agent = LocalRAGAgent(vector_store, model_name=model_name, use_cot=use_cot, + collection=collection, skip_analysis=skip_analysis) + except Exception as e: + response_text = f"Error initializing Ollama model: {str(e)}. Falling back to Local Mistral." + print(f"Error: {response_text}") + # Fall back to Mistral if Ollama fails + if hf_token: + agent = LocalRAGAgent(vector_store, use_cot=use_cot, collection=collection, + skip_analysis=skip_analysis) + else: + return history + [[message, "Local Mistral agent not available for fallback. Please check your HuggingFace token configuration."]] + else: + response_text = "Ollama model not specified correctly." + print(f"Error: {response_text}") + return history + [[message, response_text]] else: if not openai_key: response_text = "OpenAI agent not available. Please check your OpenAI API key configuration." print(f"Error: {response_text}") return history + [[message, response_text]] - agent = RAGAgent(vector_store, openai_api_key=openai_key, use_cot=use_cot) - - # Convert language selection to language code - lang_code = "es" if language == "Spanish" else "en" - agent.language = lang_code + agent = RAGAgent(vector_store, openai_api_key=openai_key, use_cot=use_cot, + collection=collection, skip_analysis=skip_analysis) # Process query and get response print("Processing query...") @@ -152,10 +198,31 @@ def chat(message: str, history: List[List[str]], agent_type: str, use_cot: bool, # Add final formatted response to history history.append([message, formatted_response]) else: + # For standard response (no CoT) formatted_response = response["answer"] print("\nStandard Response:") print("-" * 50) print(formatted_response) + + # Add sources if available + if response.get("context"): + print("\nSources Used:") + print("-" * 50) + sources_text = "\n\nπŸ“š Sources used:\n" + formatted_response += sources_text + print(sources_text) + + for ctx in response["context"]: + source = ctx["metadata"].get("source", "Unknown") + if "page_numbers" in ctx["metadata"]: + pages = ctx["metadata"].get("page_numbers", []) + source_line = f"- {source} (pages: {pages})\n" + else: + file_path = ctx["metadata"].get("file_path", "Unknown") + source_line = f"- {source} (file: {file_path})\n" + formatted_response += source_line + print(source_line) + history.append([message, formatted_response]) print("\n" + "="*50) @@ -179,8 +246,97 @@ def create_interface(): # πŸ€– Agentic RAG System Upload PDFs, process web content, repositories, and chat with your documents using local or OpenAI models. + + > **Note on Performance**: When using the Local (Mistral) model, initial loading can take 1-5 minutes, and each query may take 30-60 seconds to process depending on your hardware. OpenAI queries are typically much faster. """) + # Create model choices list for reuse + model_choices = [] + # HF models first if token is available + if hf_token: + model_choices.extend([ + "Local (Mistral)", + "Local (Mistral) - 4-bit Quantized", + "Local (Mistral) - 8-bit Quantized", + ]) + # Then Ollama models (don't require HF token) + model_choices.extend([ + "Ollama - llama3", + "Ollama - phi-3", + "Ollama - qwen2" + ]) + if openai_key: + model_choices.append("OpenAI") + + # Model Management Tab (First Tab) + with gr.Tab("Model Management"): + gr.Markdown(""" + ## Model Management + + Download models in advance to prepare them for use in the chat interface. + + ### Hugging Face Models (Default) + + The system uses Mistral-7B by default. For Hugging Face models (Mistral), you'll need a Hugging Face token in your config.yaml file. + + ### Ollama Models (Alternative) + + Ollama models are available as alternatives. For Ollama models, this will pull the model using the Ollama client. + Make sure Ollama is installed and running on your system. + You can download Ollama from [ollama.com/download](https://ollama.com/download) + """) + + with gr.Row(): + with gr.Column(): + model_dropdown = gr.Dropdown( + choices=model_choices, + value=model_choices[0] if model_choices else None, + label="Select Model to Download", + interactive=True + ) + download_button = gr.Button("Download Selected Model") + model_status = gr.Textbox( + label="Download Status", + placeholder="Select a model and click Download to begin...", + interactive=False + ) + + with gr.Column(): + gr.Markdown(""" + ### Model Information + + **Local (Mistral)**: The default Mistral-7B-Instruct-v0.2 model. + - Size: ~14GB + - VRAM Required: ~8GB + - Good balance of quality and speed + + **Local (Mistral) - 4-bit Quantized**: 4-bit quantized version of Mistral-7B. + - Size: ~4GB + - VRAM Required: ~4GB + - Faster inference with minimal quality loss + + **Local (Mistral) - 8-bit Quantized**: 8-bit quantized version of Mistral-7B. + - Size: ~7GB + - VRAM Required: ~6GB + - Balance between quality and memory usage + + **Ollama - llama3**: Meta's Llama 3 model via Ollama. + - Size: ~4GB + - Requires Ollama to be installed and running + - Excellent performance and quality + + **Ollama - phi-3**: Microsoft's Phi-3 model via Ollama. + - Size: ~4GB + - Requires Ollama to be installed and running + - Efficient small model with good performance + + **Ollama - qwen2**: Alibaba's Qwen2 model via Ollama. + - Size: ~4GB + - Requires Ollama to be installed and running + - High-quality model with good performance + """) + + # Document Processing Tab with gr.Tab("Document Processing"): with gr.Row(): with gr.Column(): @@ -200,24 +356,26 @@ def create_interface(): with gr.Tab("Standard Chat Interface"): with gr.Row(): - with gr.Column(): + with gr.Column(scale=1): + # Create model choices with quantization options standard_agent_dropdown = gr.Dropdown( - choices=["Local (Mistral)", "OpenAI"] if openai_key else ["Local (Mistral)"], - value="Local (Mistral)", + choices=model_choices, + value=model_choices[0] if model_choices else None, label="Select Agent" ) - with gr.Column(): - standard_language_dropdown = gr.Dropdown( - choices=["English", "Spanish"], - value="English", - label="Response Language" - ) - with gr.Column(): + with gr.Column(scale=1): standard_collection_dropdown = gr.Dropdown( choices=["PDF Collection", "Repository Collection", "General Knowledge"], value="PDF Collection", label="Knowledge Collection" ) + gr.Markdown(""" + > **Collection Selection**: + > - This interface ALWAYS uses the selected collection without performing query analysis. + > - "PDF Collection": Will ALWAYS search the PDF documents regardless of query type. + > - "Repository Collection": Will ALWAYS search the repository code regardless of query type. + > - "General Knowledge": Will ALWAYS use the model's built-in knowledge without searching collections. + """) standard_chatbot = gr.Chatbot(height=400) with gr.Row(): standard_msg = gr.Textbox(label="Your Message", scale=9) @@ -226,24 +384,27 @@ def create_interface(): with gr.Tab("Chain of Thought Chat Interface"): with gr.Row(): - with gr.Column(): + with gr.Column(scale=1): + # Create model choices with quantization options cot_agent_dropdown = gr.Dropdown( - choices=["Local (Mistral)", "OpenAI"] if openai_key else ["Local (Mistral)"], - value="Local (Mistral)", + choices=model_choices, + value=model_choices[0] if model_choices else None, label="Select Agent" ) - with gr.Column(): - cot_language_dropdown = gr.Dropdown( - choices=["English", "Spanish"], - value="English", - label="Response Language" - ) - with gr.Column(): + with gr.Column(scale=1): cot_collection_dropdown = gr.Dropdown( choices=["PDF Collection", "Repository Collection", "General Knowledge"], value="PDF Collection", label="Knowledge Collection" ) + gr.Markdown(""" + > **Collection Selection**: + > - When a specific collection is selected, the system will ALWAYS use that collection without analysis: + > - "PDF Collection": Will ALWAYS search the PDF documents. + > - "Repository Collection": Will ALWAYS search the repository code. + > - "General Knowledge": Will ALWAYS use the model's built-in knowledge. + > - This interface shows step-by-step reasoning and may perform query analysis when needed. + """) cot_chatbot = gr.Chatbot(height=400) with gr.Row(): cot_msg = gr.Textbox(label="Your Message", scale=9) @@ -255,6 +416,9 @@ def create_interface(): url_button.click(process_url, inputs=[url_input], outputs=[url_output]) repo_button.click(process_repo, inputs=[repo_input], outputs=[repo_output]) + # Model download event handler + download_button.click(download_model, inputs=[model_dropdown], outputs=[model_status]) + # Standard chat handlers standard_msg.submit( chat, @@ -263,7 +427,6 @@ def create_interface(): standard_chatbot, standard_agent_dropdown, gr.State(False), # use_cot=False - standard_language_dropdown, standard_collection_dropdown ], outputs=[standard_chatbot] @@ -275,7 +438,6 @@ def create_interface(): standard_chatbot, standard_agent_dropdown, gr.State(False), # use_cot=False - standard_language_dropdown, standard_collection_dropdown ], outputs=[standard_chatbot] @@ -290,7 +452,6 @@ def create_interface(): cot_chatbot, cot_agent_dropdown, gr.State(True), # use_cot=True - cot_language_dropdown, cot_collection_dropdown ], outputs=[cot_chatbot] @@ -302,7 +463,6 @@ def create_interface(): cot_chatbot, cot_agent_dropdown, gr.State(True), # use_cot=True - cot_language_dropdown, cot_collection_dropdown ], outputs=[cot_chatbot] @@ -322,14 +482,22 @@ def create_interface(): 2. **Standard Chat Interface**: - Quick responses without detailed reasoning steps - Select your preferred agent (Local Mistral or OpenAI) - - Choose your preferred response language - - Select which knowledge collection to query + - Select which knowledge collection to query: + - **PDF Collection**: Always searches PDF documents + - **Repository Collection**: Always searches code repositories + - **General Knowledge**: Uses the model's built-in knowledge without searching collections 3. **Chain of Thought Chat Interface**: - Detailed responses with step-by-step reasoning - See the planning, research, reasoning, and synthesis steps - Great for complex queries or when you want to understand the reasoning process - May take longer but provides more detailed and thorough answers + - Same collection selection options as the Standard Chat Interface + + 4. **Performance Expectations**: + - **Local (Mistral) model**: Initial loading takes 1-5 minutes, each query takes 30-60 seconds + - **OpenAI model**: Much faster responses, typically a few seconds per query + - Chain of Thought reasoning takes longer for both models Note: OpenAI agent requires an API key in `.env` file """) @@ -350,5 +518,141 @@ def main(): inbrowser=True ) +def download_model(model_type: str) -> str: + """Download a model and return status message""" + try: + print(f"Downloading model: {model_type}") + + # Parse model type to determine model and quantization + quantization = None + model_name = None + + if "4-bit" in model_type or "8-bit" in model_type: + # For HF models, we need the token + if not hf_token: + return "❌ Error: HuggingFace token not found in config.yaml. Please add your token first." + + model_name = "mistralai/Mistral-7B-Instruct-v0.2" # Default model + if "4-bit" in model_type: + quantization = "4bit" + elif "8-bit" in model_type: + quantization = "8bit" + + # Start download timer + start_time = time.time() + + try: + from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig + + # Download tokenizer first (smaller download to check access) + try: + tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) + except Exception as e: + if "401" in str(e): + return f"❌ Error: This model is gated. Please accept the terms on the Hugging Face website: https://huggingface.co/{model_name}" + else: + return f"❌ Error downloading tokenizer: {str(e)}" + + # Set up model loading parameters + model_kwargs = { + "token": hf_token, + "device_map": None, # Don't load on GPU for download only + } + + # Apply quantization if specified + if quantization == '4bit': + try: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" + ) + model_kwargs["quantization_config"] = quantization_config + except ImportError: + return "❌ Error: bitsandbytes not installed. Please install with: pip install bitsandbytes>=0.41.0" + elif quantization == '8bit': + try: + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + model_kwargs["quantization_config"] = quantization_config + except ImportError: + return "❌ Error: bitsandbytes not installed. Please install with: pip install bitsandbytes>=0.41.0" + + # Download model (but don't load it fully to save memory) + AutoModelForCausalLM.from_pretrained( + model_name, + **model_kwargs + ) + + # Calculate download time + download_time = time.time() - start_time + return f"βœ… Successfully downloaded {model_type} in {download_time:.1f} seconds." + + except Exception as e: + return f"❌ Error downloading model: {str(e)}" + + elif "Ollama" in model_type: + # Extract model name from model_type + if "llama3" in model_type.lower(): + model_name = "llama3" + elif "phi-3" in model_type.lower(): + model_name = "phi3" + elif "qwen2" in model_type.lower(): + model_name = "qwen2" + else: + return "❌ Error: Unknown Ollama model type" + + # Use Ollama to pull the model + try: + import ollama + + print(f"Pulling Ollama model: {model_name}") + start_time = time.time() + + # Check if model already exists + try: + models = ollama.list().models + available_models = [model.model for model in models] + + # Check for model with or without :latest suffix + if model_name in available_models or f"{model_name}:latest" in available_models: + return f"βœ… Model {model_name} is already available in Ollama." + except Exception: + # If we can't check, proceed with pull anyway + pass + + # Pull the model with progress tracking + progress_text = "" + for progress in ollama.pull(model_name, stream=True): + status = progress.get('status') + if status: + progress_text = f"Status: {status}" + print(progress_text) + + # Show download progress + if 'completed' in progress and 'total' in progress: + completed = progress['completed'] + total = progress['total'] + if total > 0: + percent = (completed / total) * 100 + progress_text = f"Downloading: {percent:.1f}% ({completed}/{total})" + print(progress_text) + + # Calculate download time + download_time = time.time() - start_time + return f"βœ… Successfully pulled Ollama model {model_name} in {download_time:.1f} seconds." + + except ImportError: + return "❌ Error: ollama not installed. Please install with: pip install ollama" + except ConnectionError: + return "❌ Error: Could not connect to Ollama. Please make sure Ollama is installed and running." + except Exception as e: + return f"❌ Error pulling Ollama model: {str(e)}" + else: + return "❌ Error: Unknown model type" + + except Exception as e: + return f"❌ Error: {str(e)}" + if __name__ == "__main__": main() \ No newline at end of file diff --git a/agentic_rag/img/architecture.png b/agentic_rag/img/architecture.png new file mode 100644 index 0000000..a7b1dcc Binary files /dev/null and b/agentic_rag/img/architecture.png differ diff --git a/agentic_rag/img/cot_final_answer.png b/agentic_rag/img/cot_final_answer.png new file mode 100644 index 0000000..47a69df Binary files /dev/null and b/agentic_rag/img/cot_final_answer.png differ diff --git a/agentic_rag/img/gradio_1.png b/agentic_rag/img/gradio_1.png new file mode 100644 index 0000000..422e2f8 Binary files /dev/null and b/agentic_rag/img/gradio_1.png differ diff --git a/agentic_rag/img/gradio_2.png b/agentic_rag/img/gradio_2.png new file mode 100644 index 0000000..dc31a06 Binary files /dev/null and b/agentic_rag/img/gradio_2.png differ diff --git a/agentic_rag/img/gradio_3.png b/agentic_rag/img/gradio_3.png new file mode 100644 index 0000000..7979ca1 Binary files /dev/null and b/agentic_rag/img/gradio_3.png differ diff --git a/agentic_rag/k8s/README.md b/agentic_rag/k8s/README.md new file mode 100644 index 0000000..9123071 --- /dev/null +++ b/agentic_rag/k8s/README.md @@ -0,0 +1,152 @@ +# Kubernetes Deployment for Agentic RAG + +This directory contains Kubernetes manifests and guides for deploying the Agentic RAG system on Kubernetes. + +## Deployment Options + +We currently provide a single deployment option with plans for a distributed deployment in the future: + +### Local Deployment + +This is a single-pod deployment where all components run in the same pod. It's simpler to deploy and manage, making it ideal for testing and development. + +**Features:** +- Includes both Hugging Face models and Ollama for inference +- Uses GPU acceleration for faster inference +- Simpler deployment and management +- Easier debugging (all logs in one place) +- Lower complexity (no inter-service communication) +- Quicker setup + +**Model Options:** +- **Hugging Face Models**: Uses `Mistral-7B` models from Hugging Face (requires a token) +- **Ollama Models**: Uses `ollama` for inference (llama3, phi3, qwen2) + +### Future: Distributed System Deployment + +A distributed system deployment that separates the LLM inference system into its own service is planned for future releases. This will allow for better resource allocation and scaling in production environments. + +**Advantages:** +- Independent scaling of components +- Better resource optimization +- Higher availability +- Flexible model deployment +- Load balancing capabilities + +## Deployment Guides + +We provide several guides for different environments: + +1. [**General Kubernetes Guide**](README_k8s.md): Basic instructions for any Kubernetes cluster +2. [**Oracle Kubernetes Engine (OKE) Guide**](OKE_DEPLOYMENT.md): Detailed instructions for deploying on OCI +3. [**Minikube Guide**](MINIKUBE.md): Quick start guide for local testing with Minikube + +## Directory Structure + +```bash +k8s/ +β”œβ”€β”€ README_MAIN.md # This file +β”œβ”€β”€ README.md # General Kubernetes guide +β”œβ”€β”€ OKE_DEPLOYMENT.md # Oracle Kubernetes Engine guide +β”œβ”€β”€ MINIKUBE.md # Minikube guide +β”œβ”€β”€ deploy.sh # Deployment script +└── local-deployment/ # Manifests for local deployment + β”œβ”€β”€ configmap.yaml + β”œβ”€β”€ deployment.yaml + └── service.yaml +``` + +## Quick Start + +For a quick start, use the deployment script. Just go into the script and replace your `HF_TOKEN` in line 17: + +```bash +# Make the script executable +chmod +x deploy.sh + +# Deploy with a Hugging Face token +./deploy.sh --hf-token "your-huggingface-token" --namespace agentic-rag + +# Or deploy without a Hugging Face token (Ollama models only) +./deploy.sh --namespace agentic-rag + +# Deploy without GPU support (not recommended for production) +./deploy.sh --cpu-only --namespace agentic-rag +``` + +## Resource Requirements + +The deployment requires the following minimum resources: + +- **CPU**: 4+ cores +- **Memory**: 16GB+ RAM +- **Storage**: 50GB+ +- **GPU**: 1 NVIDIA GPU (required for optimal performance) + +## Next Steps + +After deployment, you can: + +1. **Add Documents**: Upload PDFs, process web content, or add repositories to the knowledge base +2. **Configure Models**: Download and configure different models +3. **Customize**: Adjust the system to your specific needs +4. **Scale**: For production use, consider implementing the distributed deployment with persistent storage (coming soon) + +## Troubleshooting + +See the specific deployment guides for troubleshooting tips. Common issues include: + +- Insufficient resources +- Network connectivity problems +- Model download failures +- Configuration errors +- GPU driver issues + +### GPU-Related Issues + +If you encounter GPU-related issues: + +1. **Check GPU availability**: Ensure your Kubernetes cluster has GPU nodes available +2. **Verify NVIDIA drivers**: Make sure NVIDIA drivers are installed on the nodes +3. **Check NVIDIA device plugin**: Ensure the NVIDIA device plugin is installed in your cluster +4. **Inspect pod logs**: Check for GPU-related errors in the pod logs + +```bash +kubectl logs -f deployment/agentic-rag -n +``` + +## GPU Configuration Summary + +The deployment has been configured to use GPU acceleration by default for optimal performance: + +### Key GPU Configuration Changes + +1. **Resource Requests and Limits**: + - Each pod requests and is limited to 1 NVIDIA GPU + - Memory and CPU resources have been increased to better support GPU workloads + +2. **NVIDIA Container Support**: + - The deployment installs NVIDIA drivers and CUDA in the container + - Environment variables are set to enable GPU visibility and capabilities + +3. **Ollama GPU Configuration**: + - Ollama is configured to use GPU acceleration automatically + - Models like llama3, phi3, and qwen2 will benefit from GPU acceleration + +4. **Deployment Script Enhancements**: + - Added GPU availability detection + - Added `--cpu-only` flag for environments without GPUs + - Provides guidance for GPU monitoring and troubleshooting + +5. **Documentation Updates**: + - Added GPU-specific instructions for different Kubernetes environments + - Included troubleshooting steps for GPU-related issues + - Updated resource requirements to reflect GPU needs + +### CPU Fallback + +While the deployment is optimized for GPU usage, a CPU-only mode is available using the `--cpu-only` flag with the deployment script. However, this is not recommended for production use as inference performance will be significantly slower. + +## Contributing + +Contributions to improve the deployment manifests and guides are welcome. Please submit a pull request or open an issue. \ No newline at end of file diff --git a/agentic_rag/k8s/deploy.sh b/agentic_rag/k8s/deploy.sh new file mode 100644 index 0000000..bcc77d2 --- /dev/null +++ b/agentic_rag/k8s/deploy.sh @@ -0,0 +1,122 @@ +#!/bin/bash + +# Deployment script for Agentic RAG + +# Function to display usage +usage() { + echo "Usage: $0 [--hf-token TOKEN] [--namespace NAMESPACE] [--cpu-only]" + echo "" + echo "Options:" + echo " --hf-token TOKEN Hugging Face token (optional but recommended)" + echo " --namespace NAMESPACE Kubernetes namespace to deploy to (default: default)" + echo " --cpu-only Deploy without GPU support (not recommended for production)" + exit 1 +} + +# Default values +NAMESPACE="default" +HF_TOKEN="" +CPU_ONLY=false + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --hf-token) + HF_TOKEN="$2" + shift 2 + ;; + --namespace) + NAMESPACE="$2" + shift 2 + ;; + --cpu-only) + CPU_ONLY=true + shift + ;; + *) + usage + ;; + esac +done + +# Create namespace if it doesn't exist +kubectl get namespace $NAMESPACE > /dev/null 2>&1 || kubectl create namespace $NAMESPACE + +echo "Deploying Agentic RAG to namespace $NAMESPACE..." + +# Check for GPU availability if not in CPU-only mode +if [[ "$CPU_ONLY" == "false" ]]; then + echo "Checking for GPU availability..." + GPU_COUNT=$(kubectl get nodes "-o=custom-columns=GPU:.status.allocatable.nvidia\.com/gpu" --no-headers | grep -v "" | wc -l) + + if [[ "$GPU_COUNT" -eq 0 ]]; then + echo "WARNING: No GPUs detected in the cluster!" + echo "The deployment is configured to use GPUs, but none were found." + echo "Options:" + echo " 1. Install the NVIDIA device plugin: kubectl create -f https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.14.0/nvidia-device-plugin.yml" + echo " 2. Use --cpu-only flag to deploy without GPU support (not recommended for production)" + echo " 3. Ensure your nodes have GPUs and proper drivers installed" + + read -p "Continue with deployment anyway? (y/n): " CONTINUE + if [[ "$CONTINUE" != "y" && "$CONTINUE" != "Y" ]]; then + echo "Deployment aborted." + exit 1 + fi + + echo "Continuing with deployment despite no GPUs detected..." + else + echo "Found $GPU_COUNT nodes with GPUs available." + fi +fi + +# Create ConfigMap with Hugging Face token if provided +if [[ -n "$HF_TOKEN" ]]; then + echo "Using provided Hugging Face token..." + cat < local-deployment/deployment-cpu.yaml + kubectl apply -n $NAMESPACE -f local-deployment/deployment-cpu.yaml + rm local-deployment/deployment-cpu.yaml +else + kubectl apply -n $NAMESPACE -f local-deployment/deployment.yaml +fi + +kubectl apply -n $NAMESPACE -f local-deployment/service.yaml + +echo "Deployment started. Check status with: kubectl get pods -n $NAMESPACE" +echo "Access the application with: kubectl get service agentic-rag -n $NAMESPACE" +echo "Note: Initial startup may take some time as models are downloaded." + +# Provide additional guidance for monitoring GPU usage +if [[ "$CPU_ONLY" == "false" ]]; then + echo "" + echo "To monitor GPU usage:" + echo " 1. Check pod status: kubectl get pods -n $NAMESPACE" + echo " 2. View pod logs: kubectl logs -f deployment/agentic-rag -n $NAMESPACE" + echo " 3. Check GPU allocation: kubectl describe pod -l app=agentic-rag -n $NAMESPACE | grep -A5 'Allocated resources'" +fi \ No newline at end of file diff --git a/agentic_rag/k8s/local-deployment/configmap.yaml b/agentic_rag/k8s/local-deployment/configmap.yaml new file mode 100644 index 0000000..20cef3c --- /dev/null +++ b/agentic_rag/k8s/local-deployment/configmap.yaml @@ -0,0 +1,10 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: agentic-rag-config +data: + config.yaml: | + HUGGING_FACE_HUB_TOKEN: "your-huggingface-token" + # Optional OpenAI configuration + # .env: | + # OPENAI_API_KEY=your-openai-api-key \ No newline at end of file diff --git a/agentic_rag/k8s/local-deployment/deployment.yaml b/agentic_rag/k8s/local-deployment/deployment.yaml new file mode 100644 index 0000000..2129263 --- /dev/null +++ b/agentic_rag/k8s/local-deployment/deployment.yaml @@ -0,0 +1,116 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: agentic-rag + labels: + app: agentic-rag +spec: + replicas: 1 + selector: + matchLabels: + app: agentic-rag + template: + metadata: + labels: + app: agentic-rag + spec: + containers: + - name: agentic-rag + image: python:3.10-slim + resources: + requests: + memory: "8Gi" + cpu: "2" + ephemeral-storage: "50Gi" # Add this + limits: + memory: "16Gi" + cpu: "4" + ephemeral-storage: "100Gi" # Add this + ports: + - containerPort: 7860 + name: gradio + - containerPort: 11434 + name: ollama-api + volumeMounts: + - name: config-volume + mountPath: /app/config.yaml + subPath: config.yaml + - name: data-volume + mountPath: /app/embeddings + - name: chroma-volume + mountPath: /app/chroma_db + - name: ollama-models + mountPath: /root/.ollama + command: ["/bin/bash", "-c"] + args: + - | + apt-get update && apt-get install -y git curl gnupg + + # Install NVIDIA drivers and CUDA + echo "Installing NVIDIA drivers and CUDA..." + curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg + curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ + sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ + tee /etc/apt/sources.list.d/nvidia-container-toolkit.list + apt-get update && apt-get install -y nvidia-container-toolkit + + # Verify GPU is available + echo "Verifying GPU availability..." + nvidia-smi || echo "WARNING: nvidia-smi command failed. GPU might not be properly configured." + + # Install Ollama + echo "Installing Ollama..." + curl -fsSL https://ollama.com/install.sh | sh + + # Configure Ollama to use GPU + echo "Configuring Ollama for GPU usage..." + mkdir -p /root/.ollama + echo '{"gpu": {"enable": true}}' > /root/.ollama/config.json + + # Start Ollama in the background with GPU support + echo "Starting Ollama service with GPU support..." + ollama serve & + + # Wait for Ollama to be ready + echo "Waiting for Ollama to be ready..." + until curl -s http://localhost:11434/api/tags >/dev/null; do + sleep 5 + done + + # Verify models are using GPU + echo "Verifying models are using GPU..." + curl -s http://localhost:11434/api/tags | grep -q "llama3" && echo "llama3 model is available" + + # Clone and set up the application + cd /app + git clone https://github.com/oracle-devrel/devrel-labs.git + cd devrel-labs/agentic_rag + pip install -r requirements.txt + + # Start the Gradio app + echo "Starting Gradio application..." + python gradio_app.py + env: + - name: PYTHONUNBUFFERED + value: "1" + - name: OLLAMA_HOST + value: "http://localhost:11434" + - name: NVIDIA_VISIBLE_DEVICES + value: "all" + - name: NVIDIA_DRIVER_CAPABILITIES + value: "compute,utility" + - name: TORCH_CUDA_ARCH_LIST + value: "7.0;7.5;8.0;8.6" + volumes: + - name: config-volume + configMap: + name: agentic-rag-config + - name: data-volume + persistentVolumeClaim: + claimName: agentic-rag-data-pvc + - name: chroma-volume + persistentVolumeClaim: + claimName: agentic-rag-chroma-pvc + - name: ollama-models + persistentVolumeClaim: + claimName: ollama-models-pvc \ No newline at end of file diff --git a/agentic_rag/k8s/local-deployment/pvcs.yaml b/agentic_rag/k8s/local-deployment/pvcs.yaml new file mode 100644 index 0000000..6ab8ff3 --- /dev/null +++ b/agentic_rag/k8s/local-deployment/pvcs.yaml @@ -0,0 +1,35 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: agentic-rag-data-pvc + namespace: agentic-rag +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 50Gi +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: agentic-rag-chroma-pvc + namespace: agentic-rag +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 50Gi +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: ollama-models-pvc + namespace: agentic-rag +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 50Gi # Larger storage for model files \ No newline at end of file diff --git a/agentic_rag/k8s/local-deployment/service.yaml b/agentic_rag/k8s/local-deployment/service.yaml new file mode 100644 index 0000000..59f6772 --- /dev/null +++ b/agentic_rag/k8s/local-deployment/service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: agentic-rag + labels: + app: agentic-rag +spec: + type: LoadBalancer # Use NodePort if LoadBalancer is not available + ports: + - port: 80 + targetPort: 7860 + protocol: TCP + name: http + selector: + app: agentic-rag \ No newline at end of file diff --git a/agentic_rag/local_rag_agent.py b/agentic_rag/local_rag_agent.py index 308b13c..4742160 100644 --- a/agentic_rag/local_rag_agent.py +++ b/agentic_rag/local_rag_agent.py @@ -1,13 +1,15 @@ from typing import List, Dict, Any, Optional from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import torch -from pydantic import BaseModel, Field from store import VectorStore from agents.agent_factory import create_agents import argparse import yaml import os import logging +import time +import json +from pathlib import Path # Configure logging logging.basicConfig( @@ -17,18 +19,6 @@ ) logger = logging.getLogger(__name__) -class QueryAnalysis(BaseModel): - """Pydantic model for query analysis output""" - query_type: str = Field( - description="Type of query: 'pdf_documents', 'general_knowledge', or 'unsupported'" - ) - reasoning: str = Field( - description="Reasoning behind the query type selection" - ) - requires_context: bool = Field( - description="Whether the query requires additional context to answer" - ) - class LocalLLM: """Wrapper for local LLM to match LangChain's ChatOpenAI interface""" def __init__(self, pipeline): @@ -53,45 +43,189 @@ def __init__(self, content): return Response(result.strip()) -class LocalRAGAgent: - def __init__(self, vector_store: VectorStore, model_name: str = "mistralai/Mistral-7B-Instruct-v0.2", use_cot: bool = False, language: str = "en"): - """Initialize local RAG agent with vector store and local LLM""" - self.vector_store = vector_store - self.use_cot = use_cot - self.language = language +class OllamaModelHandler: + """Handler for Ollama models""" + def __init__(self, model_name: str): + """Initialize Ollama model handler - # Load HuggingFace token from config + Args: + model_name: Name of the Ollama model to use + """ + # Remove the 'ollama:' prefix if present + self.model_name = model_name.replace("ollama:", "") if model_name.startswith("ollama:") else model_name + self._check_ollama_running() + + def _check_ollama_running(self): + """Check if Ollama is running and the model is available""" try: - with open('config.yaml', 'r') as f: - config = yaml.safe_load(f) - token = config.get('HUGGING_FACE_HUB_TOKEN') - if not token: - raise ValueError("HUGGING_FACE_HUB_TOKEN not found in config.yaml") + import ollama + + # Check if Ollama is running + try: + models = ollama.list().models + available_models = [model.model for model in models] + print(f"Available Ollama models: {', '.join(available_models)}") + + # Check if the requested model is available + if self.model_name not in available_models: + # Try with :latest suffix + if f"{self.model_name}:latest" in available_models: + self.model_name = f"{self.model_name}:latest" + print(f"Using model with :latest suffix: {self.model_name}") + else: + print(f"Model '{self.model_name}' not found in Ollama. Available models: {', '.join(available_models)}") + print(f"You can pull it with: ollama pull {self.model_name}") + except Exception as e: + raise ConnectionError(f"Failed to connect to Ollama. Please make sure Ollama is running. Error: {str(e)}") + + except ImportError: + raise ImportError("Failed to import ollama. Please install with: pip install ollama") + + def __call__(self, prompt, max_new_tokens=512, temperature=0.1, top_p=0.95, **kwargs): + """Generate text using the Ollama model""" + try: + import ollama + + # Generate text + response = ollama.generate( + model=self.model_name, + prompt=prompt, + options={ + "num_predict": max_new_tokens, + "temperature": temperature, + "top_p": top_p + } + ) + + # Format result to match transformers pipeline output + formatted_result = [{ + "generated_text": response["response"] + }] + + return formatted_result + except Exception as e: - raise Exception(f"Failed to load HuggingFace token from config.yaml: {str(e)}") + raise Exception(f"Failed to generate text with Ollama: {str(e)}") + +class LocalRAGAgent: + def __init__(self, vector_store: VectorStore, model_name: str = "mistralai/Mistral-7B-Instruct-v0.2", + use_cot: bool = False, collection: str = None, skip_analysis: bool = False, + quantization: str = None): + """Initialize local RAG agent with vector store and local LLM - # Load model and tokenizer - print("\nLoading model and tokenizer...") - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16, - device_map="auto", - token=token - ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + Args: + vector_store: Vector store for retrieving context + model_name: HuggingFace model name/path or Ollama model name + use_cot: Whether to use Chain of Thought reasoning + collection: Collection to search in (PDF, Repository, or General Knowledge) + skip_analysis: Whether to skip query analysis (kept for backward compatibility) + quantization: Quantization method to use (None, '4bit', '8bit') + """ + self.vector_store = vector_store + self.use_cot = use_cot + self.collection = collection + self.quantization = quantization + self.model_name = model_name + # skip_analysis parameter kept for backward compatibility but no longer used - # Create text generation pipeline - self.pipeline = pipeline( - "text-generation", - model=self.model, - tokenizer=self.tokenizer, - max_new_tokens=512, - do_sample=True, - temperature=0.1, - top_p=0.95, - device_map="auto" - ) - print("βœ“ Model loaded successfully") + # Check if this is an Ollama model + self.is_ollama = model_name.startswith("ollama:") + + if self.is_ollama: + # Extract the actual model name from the prefix + ollama_model_name = model_name.replace("ollama:", "") + + # Load Ollama model + print("\nLoading Ollama model...") + print(f"Model: {ollama_model_name}") + print("Note: Make sure Ollama is running on your system.") + + # Initialize Ollama model handler + self.ollama_handler = OllamaModelHandler(ollama_model_name) + + # Create pipeline-like interface + self.pipeline = self.ollama_handler + + else: + # Load HuggingFace token from config + try: + with open('config.yaml', 'r') as f: + config = yaml.safe_load(f) + token = config.get('HUGGING_FACE_HUB_TOKEN') + if not token: + raise ValueError("HUGGING_FACE_HUB_TOKEN not found in config.yaml") + except Exception as e: + raise Exception(f"Failed to load HuggingFace token from config.yaml: {str(e)}") + + # Load model and tokenizer + print("\nLoading model and tokenizer...") + print(f"Model: {model_name}") + if quantization: + print(f"Quantization: {quantization}") + print("Note: Initial loading and inference can take 1-5 minutes depending on your hardware.") + print("Subsequent queries will be faster but may still take 30-60 seconds per response.") + + # Check if CUDA is available and set appropriate dtype + if torch.cuda.is_available(): + print("CUDA is available. Using GPU acceleration.") + dtype = torch.float16 + else: + print("CUDA is not available. Using CPU only (this will be slow).") + dtype = torch.float32 + + # Set up model loading parameters + model_kwargs = { + "torch_dtype": dtype, + "device_map": "auto", + "token": token, + "low_cpu_mem_usage": True, + "offload_folder": "offload" + } + + # Apply quantization if specified + if quantization == '4bit': + try: + from transformers import BitsAndBytesConfig + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" + ) + model_kwargs["quantization_config"] = quantization_config + print("Using 4-bit quantization with bitsandbytes") + except ImportError: + print("Warning: bitsandbytes not installed. Falling back to standard loading.") + print("To use 4-bit quantization, install bitsandbytes: pip install bitsandbytes") + elif quantization == '8bit': + try: + from transformers import BitsAndBytesConfig + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + model_kwargs["quantization_config"] = quantization_config + print("Using 8-bit quantization with bitsandbytes") + except ImportError: + print("Warning: bitsandbytes not installed. Falling back to standard loading.") + print("To use 8-bit quantization, install bitsandbytes: pip install bitsandbytes") + + # Load model with appropriate settings + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + **model_kwargs + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + + # Create text generation pipeline with optimized settings + self.pipeline = pipeline( + "text-generation", + model=self.model, + tokenizer=self.tokenizer, + max_new_tokens=512, + do_sample=True, + temperature=0.1, + top_p=0.95, + device_map="auto" + ) + print("βœ“ Model loaded successfully") # Create LLM wrapper self.llm = LocalLLM(self.pipeline) @@ -101,25 +235,43 @@ def __init__(self, vector_store: VectorStore, model_name: str = "mistralai/Mistr def process_query(self, query: str) -> Dict[str, Any]: """Process a user query using the agentic RAG pipeline""" - # Analyze the query - analysis = self._analyze_query(query) - logger.info(f"Query analysis: {analysis}") + logger.info(f"Processing query with collection: {self.collection}") - if self.use_cot: - return self._process_query_with_cot(query, analysis) + # Process based on collection type and CoT setting + if self.collection == "General Knowledge": + # For General Knowledge, directly use general response + if self.use_cot: + return self._process_query_with_cot(query) + else: + return self._generate_general_response(query) else: - return self._process_query_standard(query, analysis) + # For PDF or Repository collections, use context-based processing + if self.use_cot: + return self._process_query_with_cot(query) + else: + return self._process_query_standard(query) - def _process_query_with_cot(self, query: str, analysis: QueryAnalysis) -> Dict[str, Any]: + def _process_query_with_cot(self, query: str) -> Dict[str, Any]: """Process query using Chain of Thought reasoning with multiple agents""" logger.info("Processing query with Chain of Thought reasoning") - # Get initial context if needed + # Get initial context based on selected collection initial_context = [] - if analysis.requires_context and analysis.query_type != "unsupported": + if self.collection == "PDF Collection": + logger.info(f"Retrieving context from PDF Collection for query: '{query}'") pdf_context = self.vector_store.query_pdf_collection(query) + initial_context.extend(pdf_context) + logger.info(f"Retrieved {len(pdf_context)} chunks from PDF Collection") + # Don't log individual sources to keep console clean + elif self.collection == "Repository Collection": + logger.info(f"Retrieving context from Repository Collection for query: '{query}'") repo_context = self.vector_store.query_repo_collection(query) - initial_context = pdf_context + repo_context + initial_context.extend(repo_context) + logger.info(f"Retrieved {len(repo_context)} chunks from Repository Collection") + # Don't log individual sources to keep console clean + # For General Knowledge, no context is needed + else: + logger.info("Using General Knowledge collection, no context retrieval needed") try: # Step 1: Planning @@ -140,7 +292,8 @@ def _process_query_with_cot(self, query: str, analysis: QueryAnalysis) -> Dict[s continue step_research = self.agents["researcher"].research(query, step) research_results.append({"step": step, "findings": step_research}) - logger.info(f"Research for step: {step}\nFindings: {step_research}") + # Don't log source indices to keep console clean + logger.info(f"Research for step: {step}") else: # If no researcher or no context, use the steps directly research_results = [{"step": step, "findings": []} for step in plan.split("\n") if step.strip()] @@ -160,7 +313,8 @@ def _process_query_with_cot(self, query: str, analysis: QueryAnalysis) -> Dict[s result["findings"] if result["findings"] else [{"content": "Using general knowledge", "metadata": {"source": "General Knowledge"}}] ) reasoning_steps.append(step_reasoning) - logger.info(f"Reasoning for step: {result['step']}\n{step_reasoning}") + # Log just the step, not the full reasoning + logger.info(f"Reasoning for step: {result['step']}") # Step 4: Synthesize final answer logger.info("Step 4: Synthesis") @@ -169,7 +323,7 @@ def _process_query_with_cot(self, query: str, analysis: QueryAnalysis) -> Dict[s return self._generate_general_response(query) final_answer = self.agents["synthesizer"].synthesize(query, reasoning_steps) - logger.info(f"Final synthesized answer:\n{final_answer}") + logger.info("Final answer synthesized successfully") return { "answer": final_answer, @@ -181,68 +335,42 @@ def _process_query_with_cot(self, query: str, analysis: QueryAnalysis) -> Dict[s logger.info("Falling back to general response") return self._generate_general_response(query) - def _process_query_standard(self, query: str, analysis: QueryAnalysis) -> Dict[str, Any]: + def _process_query_standard(self, query: str) -> Dict[str, Any]: """Process query using standard approach without Chain of Thought""" - # If query type is unsupported, use general knowledge - if analysis.query_type == "unsupported": - return self._generate_general_response(query) - - # First try to get context from PDF documents - pdf_context = self.vector_store.query_pdf_collection(query) + # Initialize context variables + pdf_context = [] + repo_context = [] - # Then try repository documents - repo_context = self.vector_store.query_repo_collection(query) + # Get context based on selected collection + if self.collection == "PDF Collection": + logger.info(f"Retrieving context from PDF Collection for query: '{query}'") + pdf_context = self.vector_store.query_pdf_collection(query) + logger.info(f"Retrieved {len(pdf_context)} chunks from PDF Collection") + # Don't log individual sources to keep console clean + elif self.collection == "Repository Collection": + logger.info(f"Retrieving context from Repository Collection for query: '{query}'") + repo_context = self.vector_store.query_repo_collection(query) + logger.info(f"Retrieved {len(repo_context)} chunks from Repository Collection") + # Don't log individual sources to keep console clean # Combine all context all_context = pdf_context + repo_context # Generate response using context if available, otherwise use general knowledge - if all_context and analysis.requires_context: + if all_context: + logger.info(f"Generating response using {len(all_context)} context chunks") response = self._generate_response(query, all_context) else: + logger.info("No context found, using general knowledge") response = self._generate_general_response(query) return response - def _analyze_query(self, query: str) -> QueryAnalysis: - """Analyze the query to determine the best source of information""" - prompt = f"""You are an intelligent agent that analyzes user queries to determine the best source of information. - - Analyze the following query and determine: - 1. Whether it should query the PDF documents collection or general knowledge collection - 2. Your reasoning for this decision - 3. Whether the query requires additional context to provide a good answer - - Query: {query} - - Provide your response in the following JSON format: - {{ - "query_type": "pdf_documents OR general_knowledge OR unsupported", - "reasoning": "your reasoning here", - "requires_context": true OR false - }} - - Response:""" - - try: - response = self._generate_text(prompt) - # Extract JSON from the response using string manipulation - start_idx = response.find("{") - end_idx = response.rfind("}") + 1 - if start_idx != -1 and end_idx != -1: - json_str = response[start_idx:end_idx] - return QueryAnalysis.model_validate_json(json_str) - raise ValueError("Could not parse JSON from response") - except Exception as e: - # Default to PDF documents if parsing fails - return QueryAnalysis( - query_type="pdf_documents", - reasoning="Defaulting to PDF documents due to parsing error", - requires_context=True - ) - def _generate_text(self, prompt: str, max_length: int = 512) -> str: """Generate text using the local model""" + # Log start time for performance monitoring + start_time = time.time() + result = self.pipeline( prompt, max_new_tokens=max_length, @@ -252,6 +380,10 @@ def _generate_text(self, prompt: str, max_length: int = 512) -> str: return_full_text=False )[0]["generated_text"] + # Log completion time + elapsed_time = time.time() - start_time + logger.info(f"Text generation completed in {elapsed_time:.2f} seconds") + return result.strip() def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[str, Any]: @@ -260,8 +392,8 @@ def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[ for i, item in enumerate(context)]) template = """Answer the following query using the provided context. -If the context doesn't contain enough information to answer accurately, -say so explicitly. +Respond as if you are knowledgeable about the topic and incorporate the context naturally. +Do not mention limitations in the context or that you couldn't find specific information. Context: {context} @@ -271,16 +403,39 @@ def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[ Answer:""" prompt = template.format(context=context_str, query=query) - response = self._generate_text(prompt) + response_text = self._generate_text(prompt) + + # Add sources to response if available + sources = {} + if context: + # Group sources by document + for item in context: + source = item['metadata'].get('source', 'Unknown') + if source not in sources: + sources[source] = set() + + # Add page number if available + if 'page' in item['metadata']: + sources[source].add(str(item['metadata']['page'])) + # Add file path if available for code + if 'file_path' in item['metadata']: + sources[source] = item['metadata']['file_path'] + + # Print concise source information + print("\nSources detected:") + # Print a single line for each source without additional details + for source in sources: + print(f"- {source}") return { - "answer": response, - "context": context + "answer": response_text, + "context": context, + "sources": sources } def _generate_general_response(self, query: str) -> Dict[str, Any]: """Generate a response using general knowledge when no context is available""" - template = """You are a helpful AI assistant. While I don't have specific information from my document collection about this query, I'll share what I know about it. + template = """You are a helpful AI assistant. Answer the following query using your general knowledge. Query: {query} @@ -289,10 +444,8 @@ def _generate_general_response(self, query: str) -> Dict[str, Any]: prompt = template.format(query=query) response = self._generate_text(prompt) - prefix = "I didn't find specific information in my documents, but here's what I know about it:\n\n" - return { - "answer": prefix + response, + "answer": response, "context": [] } @@ -303,6 +456,10 @@ def main(): parser.add_argument("--model", default="mistralai/Mistral-7B-Instruct-v0.2", help="Model to use") parser.add_argument("--quiet", action="store_true", help="Disable verbose logging") parser.add_argument("--use-cot", action="store_true", help="Enable Chain of Thought reasoning") + parser.add_argument("--collection", choices=["PDF Collection", "Repository Collection", "General Knowledge"], + help="Specify which collection to query") + parser.add_argument("--skip-analysis", action="store_true", help="Skip query analysis step") + parser.add_argument("--verbose", action="store_true", help="Show full content of sources") args = parser.parse_args() @@ -319,7 +476,13 @@ def main(): logger.info(f"Initializing vector store from: {args.store_path}") store = VectorStore(persist_directory=args.store_path) logger.info("Initializing local RAG agent...") - agent = LocalRAGAgent(store, model_name=args.model, use_cot=args.use_cot) + agent = LocalRAGAgent( + store, + model_name=args.model, + use_cot=args.use_cot, + collection=args.collection, + skip_analysis=args.skip_analysis + ) print(f"\nProcessing query: {args.query}") print("=" * 50) @@ -339,10 +502,22 @@ def main(): if response.get("context"): print("\nSources used:") - for ctx in response["context"]: + print("-" * 50) + + # Print concise list of sources + for i, ctx in enumerate(response["context"]): source = ctx["metadata"].get("source", "Unknown") - pages = ctx["metadata"].get("page_numbers", []) - print(f"- {source} (pages: {pages})") + if "page_numbers" in ctx["metadata"]: + pages = ctx["metadata"].get("page_numbers", []) + print(f"[{i+1}] {source} (pages: {pages})") + else: + file_path = ctx["metadata"].get("file_path", "Unknown") + print(f"[{i+1}] {source} (file: {file_path})") + + # Only print content if verbose flag is set + if args.verbose: + content_preview = ctx["content"][:300] + "..." if len(ctx["content"]) > 300 else ctx["content"] + print(f" Content: {content_preview}\n") except Exception as e: logger.error(f"Error during execution: {str(e)}", exc_info=True) diff --git a/agentic_rag/main.py b/agentic_rag/main.py index 6a37fc2..5839e4e 100644 --- a/agentic_rag/main.py +++ b/agentic_rag/main.py @@ -34,18 +34,47 @@ pdf_processor = PDFProcessor() vector_store = VectorStore() -# Initialize RAG agent - use OpenAI if API key is available, otherwise use local model. by default = local model +# Check for Ollama availability +try: + import ollama + ollama_available = True + print("\nOllama is available. You can use Ollama models for RAG.") +except ImportError: + ollama_available = False + print("\nOllama not installed. You can install it with: pip install ollama") + +# Initialize RAG agent - use OpenAI if API key is available, otherwise use local model or Ollama openai_api_key = os.getenv("OPENAI_API_KEY") if openai_api_key: print("\nUsing OpenAI GPT-4 for RAG...") rag_agent = RAGAgent(vector_store=vector_store, openai_api_key=openai_api_key) else: - print("\nOpenAI API key not found. Using local Mistral model for RAG...") - rag_agent = LocalRAGAgent(vector_store=vector_store) + # Try to use local Mistral model first + try: + print("\nTrying to use local Mistral model...") + rag_agent = LocalRAGAgent(vector_store=vector_store) + print("Successfully initialized local Mistral model.") + except Exception as e: + print(f"\nFailed to initialize local Mistral model: {str(e)}") + + # Fall back to Ollama if Mistral fails and Ollama is available + if ollama_available: + try: + print("\nFalling back to Ollama with llama3 model...") + rag_agent = LocalRAGAgent(vector_store=vector_store, model_name="ollama:llama3") + print("Successfully initialized Ollama with llama3 model.") + except Exception as e: + print(f"\nFailed to initialize Ollama: {str(e)}") + print("No available models. Please check your configuration.") + raise e + else: + print("\nNo available models. Please check your configuration.") + raise e class QueryRequest(BaseModel): query: str use_cot: bool = False + model: Optional[str] = None # Allow specifying model in the request class QueryResponse(BaseModel): answer: str @@ -89,11 +118,35 @@ async def upload_pdf(file: UploadFile = File(...)): async def query(request: QueryRequest): """Process a query using the RAG agent""" try: - # Reinitialize agent with CoT setting - if openai_api_key: - rag_agent = RAGAgent(vector_store=vector_store, openai_api_key=openai_api_key, use_cot=request.use_cot) + # Determine which model to use + if request.model: + if request.model.startswith("ollama:") and ollama_available: + # Use specified Ollama model + rag_agent = LocalRAGAgent(vector_store=vector_store, model_name=request.model, use_cot=request.use_cot) + elif request.model == "openai" and openai_api_key: + # Use OpenAI + rag_agent = RAGAgent(vector_store=vector_store, openai_api_key=openai_api_key, use_cot=request.use_cot) + else: + # Use default local model + rag_agent = LocalRAGAgent(vector_store=vector_store, use_cot=request.use_cot) else: - rag_agent = LocalRAGAgent(vector_store=vector_store, use_cot=request.use_cot) + # Reinitialize agent with CoT setting using default model + if openai_api_key: + rag_agent = RAGAgent(vector_store=vector_store, openai_api_key=openai_api_key, use_cot=request.use_cot) + else: + # Try local Mistral first + try: + rag_agent = LocalRAGAgent(vector_store=vector_store, use_cot=request.use_cot) + except Exception as e: + print(f"Failed to initialize local Mistral model: {str(e)}") + # Fall back to Ollama if available + if ollama_available: + try: + rag_agent = LocalRAGAgent(vector_store=vector_store, model_name="ollama:llama3", use_cot=request.use_cot) + except Exception as e2: + raise Exception(f"Failed to initialize any model: {str(e2)}") + else: + raise e response = rag_agent.process_query(request.query) return response diff --git a/agentic_rag/rag_agent.py b/agentic_rag/rag_agent.py index 9e77f54..8a47fba 100644 --- a/agentic_rag/rag_agent.py +++ b/agentic_rag/rag_agent.py @@ -1,8 +1,6 @@ from typing import List, Dict, Any, Optional from langchain_openai import ChatOpenAI from langchain.prompts import ChatPromptTemplate -from langchain.output_parsers import PydanticOutputParser -from pydantic import BaseModel, Field from store import VectorStore from agents.agent_factory import create_agents import os @@ -14,76 +12,75 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) -class QueryAnalysis(BaseModel): - """Pydantic model for query analysis output""" - query_type: str = Field( - description="Type of query: 'pdf_documents', 'general_knowledge', or 'unsupported'" - ) - reasoning: str = Field( - description="Reasoning behind the query type selection" - ) - requires_context: bool = Field( - description="Whether the query requires additional context to answer" - ) - class RAGAgent: - def __init__(self, vector_store: VectorStore, openai_api_key: str, use_cot: bool = False, language: str = "en"): + def __init__(self, vector_store: VectorStore, openai_api_key: str, use_cot: bool = False, collection: str = None, skip_analysis: bool = False): """Initialize RAG agent with vector store and LLM""" self.vector_store = vector_store self.use_cot = use_cot - self.language = language + self.collection = collection + # skip_analysis parameter kept for backward compatibility but no longer used self.llm = ChatOpenAI( model="gpt-4-turbo-preview", temperature=0, api_key=openai_api_key ) - self.query_analyzer = self._create_query_analyzer() # Initialize specialized agents self.agents = create_agents(self.llm, vector_store) if use_cot else None - def _create_query_analyzer(self): - """Create a chain for analyzing queries""" - template = """You are an intelligent agent that analyzes user queries to determine the best source of information. - - Analyze the following query and determine: - 1. Whether it should query the PDF documents collection or general knowledge collection - 2. Your reasoning for this decision - 3. Whether the query requires additional context to provide a good answer - - Query: {query} - - {format_instructions} - """ - - prompt = ChatPromptTemplate.from_template(template) - output_parser = PydanticOutputParser(pydantic_object=QueryAnalysis) - - prompt = prompt.partial(format_instructions=output_parser.get_format_instructions()) - - return {"prompt": prompt, "parser": output_parser} - def process_query(self, query: str) -> Dict[str, Any]: """Process a user query using the agentic RAG pipeline""" - # Analyze the query - analysis = self._analyze_query(query) - logger.info(f"Query analysis: {analysis}") + logger.info(f"Processing query with collection: {self.collection}") - if self.use_cot: - return self._process_query_with_cot(query, analysis) + # Process based on collection type and CoT setting + if self.collection == "General Knowledge": + # For General Knowledge, directly use general response + if self.use_cot: + return self._process_query_with_cot(query) + else: + return self._generate_general_response(query) else: - return self._process_query_standard(query, analysis) + # For PDF or Repository collections, use context-based processing + if self.use_cot: + return self._process_query_with_cot(query) + else: + return self._process_query_standard(query) - def _process_query_with_cot(self, query: str, analysis: QueryAnalysis) -> Dict[str, Any]: + def _process_query_with_cot(self, query: str) -> Dict[str, Any]: """Process query using Chain of Thought reasoning with multiple agents""" logger.info("Processing query with Chain of Thought reasoning") - # Get initial context if needed + # Get initial context based on selected collection initial_context = [] - if analysis.requires_context and analysis.query_type != "unsupported": + if self.collection == "PDF Collection": + logger.info(f"Retrieving context from PDF Collection for query: '{query}'") pdf_context = self.vector_store.query_pdf_collection(query) + initial_context.extend(pdf_context) + logger.info(f"Retrieved {len(pdf_context)} chunks from PDF Collection") + # Log each chunk with citation number but not full content + for i, chunk in enumerate(pdf_context): + source = chunk["metadata"].get("source", "Unknown") + pages = chunk["metadata"].get("page_numbers", []) + logger.info(f"Source [{i+1}]: {source} (pages: {pages})") + # Only log content preview at debug level + content_preview = chunk["content"][:150] + "..." if len(chunk["content"]) > 150 else chunk["content"] + logger.debug(f"Content preview for source [{i+1}]: {content_preview}") + elif self.collection == "Repository Collection": + logger.info(f"Retrieving context from Repository Collection for query: '{query}'") repo_context = self.vector_store.query_repo_collection(query) - initial_context = pdf_context + repo_context + initial_context.extend(repo_context) + logger.info(f"Retrieved {len(repo_context)} chunks from Repository Collection") + # Log each chunk with citation number but not full content + for i, chunk in enumerate(repo_context): + source = chunk["metadata"].get("source", "Unknown") + file_path = chunk["metadata"].get("file_path", "Unknown") + logger.info(f"Source [{i+1}]: {source} (file: {file_path})") + # Only log content preview at debug level + content_preview = chunk["content"][:150] + "..." if len(chunk["content"]) > 150 else chunk["content"] + logger.debug(f"Content preview for source [{i+1}]: {content_preview}") + # For General Knowledge, no context is needed + else: + logger.info("Using General Knowledge collection, no context retrieval needed") try: # Step 1: Planning @@ -104,7 +101,9 @@ def _process_query_with_cot(self, query: str, analysis: QueryAnalysis) -> Dict[s continue step_research = self.agents["researcher"].research(query, step) research_results.append({"step": step, "findings": step_research}) - logger.info(f"Research for step: {step}\nFindings: {step_research}") + # Log which sources were used for this step + source_indices = [initial_context.index(finding) + 1 for finding in step_research if finding in initial_context] + logger.info(f"Research for step: {step}\nUsing sources: {source_indices}") else: # If no researcher or no context, use the steps directly research_results = [{"step": step, "findings": []} for step in plan.split("\n") if step.strip()] @@ -145,55 +144,97 @@ def _process_query_with_cot(self, query: str, analysis: QueryAnalysis) -> Dict[s logger.info("Falling back to general response") return self._generate_general_response(query) - def _process_query_standard(self, query: str, analysis: QueryAnalysis) -> Dict[str, Any]: + def _process_query_standard(self, query: str) -> Dict[str, Any]: """Process query using standard approach without Chain of Thought""" - # If query type is unsupported, use general knowledge - if analysis.query_type == "unsupported": - return self._generate_general_response(query) - - # First try to get context from PDF documents - pdf_context = self.vector_store.query_pdf_collection(query) + # Initialize context variables + pdf_context = [] + repo_context = [] - # Then try repository documents - repo_context = self.vector_store.query_repo_collection(query) + # Get context based on selected collection + if self.collection == "PDF Collection": + logger.info(f"Retrieving context from PDF Collection for query: '{query}'") + pdf_context = self.vector_store.query_pdf_collection(query) + logger.info(f"Retrieved {len(pdf_context)} chunks from PDF Collection") + # Log each chunk with citation number but not full content + for i, chunk in enumerate(pdf_context): + source = chunk["metadata"].get("source", "Unknown") + pages = chunk["metadata"].get("page_numbers", []) + logger.info(f"Source [{i+1}]: {source} (pages: {pages})") + # Only log content preview at debug level + content_preview = chunk["content"][:150] + "..." if len(chunk["content"]) > 150 else chunk["content"] + logger.debug(f"Content preview for source [{i+1}]: {content_preview}") + elif self.collection == "Repository Collection": + logger.info(f"Retrieving context from Repository Collection for query: '{query}'") + repo_context = self.vector_store.query_repo_collection(query) + logger.info(f"Retrieved {len(repo_context)} chunks from Repository Collection") + # Log each chunk with citation number but not full content + for i, chunk in enumerate(repo_context): + source = chunk["metadata"].get("source", "Unknown") + file_path = chunk["metadata"].get("file_path", "Unknown") + logger.info(f"Source [{i+1}]: {source} (file: {file_path})") + # Only log content preview at debug level + content_preview = chunk["content"][:150] + "..." if len(chunk["content"]) > 150 else chunk["content"] + logger.debug(f"Content preview for source [{i+1}]: {content_preview}") # Combine all context all_context = pdf_context + repo_context # Generate response using context if available, otherwise use general knowledge - if all_context and analysis.requires_context: + if all_context: + logger.info(f"Generating response using {len(all_context)} context chunks") response = self._generate_response(query, all_context) else: + logger.info("No context found, using general knowledge") response = self._generate_general_response(query) return response - def _analyze_query(self, query: str) -> QueryAnalysis: - """Analyze the query to determine the best source of information""" - chain_input = {"query": query} - result = self.llm.invoke(self.query_analyzer["prompt"].format_messages(**chain_input)) - return self.query_analyzer["parser"].parse(result.content) - def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[str, Any]: - """Generate a response using the retrieved context""" - context_str = "\n\n".join([f"Context {i+1}:\n{item['content']}" - for i, item in enumerate(context)]) + """Generate a response based on the query and context""" + # Format context for the prompt + formatted_context = "\n\n".join([f"Context {i+1}:\n{item['content']}" + for i, item in enumerate(context)]) - template = """Answer the following query using the provided context. -If the context doesn't contain enough information to answer accurately, -say so explicitly. - -Context: -{context} - -Query: {query} - -Answer:""" + # Create the prompt + system_prompt = """You are an AI assistant answering questions based on the provided context. +Answer the question based on the context provided. If the answer is not in the context, say "I don't have enough information to answer this question." Be concise and accurate.""" - prompt = ChatPromptTemplate.from_template(template) - messages = prompt.format_messages(context=context_str, query=query) + # Create messages for the chat model + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"Context:\n{formatted_context}\n\nQuestion: {query}"} + ] + + # Generate response response = self.llm.invoke(messages) + # Add sources to response if available + if context: + # Group sources by document + sources = {} + for item in context: + source = item['metadata'].get('source', 'Unknown') + if source not in sources: + sources[source] = set() + + # Add page number if available + if 'page' in item['metadata']: + sources[source].add(str(item['metadata']['page'])) + # Add file path if available for code + if 'file_path' in item['metadata']: + sources[source] = item['metadata']['file_path'] + + # Print concise source information + print("\nSources detected:") + for source, details in sources.items(): + if isinstance(details, set): # PDF with pages + pages = ", ".join(sorted(details)) + print(f"Document: {source} (pages: {pages})") + else: # Code with file path + print(f"Code file: {source}") + + response['sources'] = sources + return { "answer": response.content, "context": context @@ -201,7 +242,7 @@ def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[ def _generate_general_response(self, query: str) -> Dict[str, Any]: """Generate a response using general knowledge when no context is available""" - template = """You are a helpful AI assistant. While I don't have specific information from my document collection about this query, I'll share what I know about it. + template = """You are a helpful AI assistant. Answer the following query using your general knowledge. Query: {query} @@ -211,10 +252,8 @@ def _generate_general_response(self, query: str) -> Dict[str, Any]: messages = prompt.format_messages(query=query) response = self.llm.invoke(messages) - prefix = "I didn't find specific information in my documents, but here's what I know about it:\n\n" - return { - "answer": prefix + response.content, + "answer": response.content, "context": [] } @@ -223,6 +262,10 @@ def main(): parser.add_argument("--query", required=True, help="Query to process") parser.add_argument("--store-path", default="chroma_db", help="Path to the vector store") parser.add_argument("--use-cot", action="store_true", help="Enable Chain of Thought reasoning") + parser.add_argument("--collection", choices=["PDF Collection", "Repository Collection", "General Knowledge"], + help="Specify which collection to query") + parser.add_argument("--skip-analysis", action="store_true", help="Skip query analysis step") + parser.add_argument("--verbose", action="store_true", help="Show full content of sources") args = parser.parse_args() @@ -239,7 +282,13 @@ def main(): try: store = VectorStore(persist_directory=args.store_path) - agent = RAGAgent(store, openai_api_key=os.getenv("OPENAI_API_KEY"), use_cot=args.use_cot) + agent = RAGAgent( + store, + openai_api_key=os.getenv("OPENAI_API_KEY"), + use_cot=args.use_cot, + collection=args.collection, + skip_analysis=args.skip_analysis + ) print(f"\nProcessing query: {args.query}") print("=" * 50) @@ -259,10 +308,22 @@ def main(): if response.get("context"): print("\nSources used:") - for ctx in response["context"]: + print("-" * 50) + + # Print concise list of sources + for i, ctx in enumerate(response["context"]): source = ctx["metadata"].get("source", "Unknown") - pages = ctx["metadata"].get("page_numbers", []) - print(f"- {source} (pages: {pages})") + if "page_numbers" in ctx["metadata"]: + pages = ctx["metadata"].get("page_numbers", []) + print(f"[{i+1}] {source} (pages: {pages})") + else: + file_path = ctx["metadata"].get("file_path", "Unknown") + print(f"[{i+1}] {source} (file: {file_path})") + + # Only print content if verbose flag is set + if args.verbose: + content_preview = ctx["content"][:300] + "..." if len(ctx["content"]) > 300 else ctx["content"] + print(f" Content: {content_preview}\n") except Exception as e: print(f"\nβœ— Error: {str(e)}") diff --git a/agentic_rag/requirements.txt b/agentic_rag/requirements.txt index 5cb347c..cdb3273 100644 --- a/agentic_rag/requirements.txt +++ b/agentic_rag/requirements.txt @@ -15,4 +15,6 @@ trafilatura gradio lxml_html_clean langchain -gitingest \ No newline at end of file +gitingest +bitsandbytes +ollama \ No newline at end of file diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000..0ac3b7c --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python +""" +Launcher script for the planeLLM Gradio interface. + +This script provides a simple way to launch the Gradio interface +without having to import the module directly. + +Usage: + python gradio_app.py +""" + +# Import directly from the modules in the root directory +import os +import gradio as gr +import time +from typing import Dict, List, Tuple, Any, Optional +import json + +# Import planeLLM components +from topic_explorer import TopicExplorer +from lesson_writer import PodcastWriter +from tts_generator import TTSGenerator + +# Create resources directory if it doesn't exist +os.makedirs('./resources', exist_ok=True) + +class PlaneLLMInterface: + """Main class for the Gradio interface of planeLLM.""" + + def __init__(self): + """Initialize the interface components.""" + # Initialize components + self.topic_explorer = TopicExplorer() + self.podcast_writer = PodcastWriter() + + # We'll initialize the TTS generator only when needed to save memory + self.tts_generator = None + + # Track available files + self.update_available_files() + + def update_available_files(self) -> Dict[str, List[str]]: + """Update and return lists of available files by type.""" + resources_dir = './resources' + + # Ensure directory exists + os.makedirs(resources_dir, exist_ok=True) + + # Get all files in resources directory + all_files = os.listdir(resources_dir) + + # Filter by type + self.available_files = { + 'content': [f for f in all_files if f.endswith('.txt') and ('content' in f or 'raw_lesson' in f)], + 'questions': [f for f in all_files if f.endswith('.txt') and 'questions' in f], + 'transcripts': [f for f in all_files if f.endswith('.txt') and 'podcast' in f], + 'audio': [f for f in all_files if f.endswith('.mp3')] + } + + return self.available_files + + def generate_topic_content(self, topic: str, progress=gr.Progress()) -> Tuple[str, str, str]: + """Generate educational content about a topic. + + Args: + topic: The topic to explore + progress: Gradio progress indicator + + Returns: + Tuple of (questions, content, status message) + """ + if not topic: + return "", "", "Error: Please enter a topic" + + try: + progress(0, desc="Initializing...") + + # Generate timestamp for file naming + timestamp = time.strftime("%Y%m%d_%H%M%S") + questions_file = f"questions_{timestamp}.txt" + content_file = f"content_{timestamp}.txt" + + progress(0.1, desc="Generating questions...") + questions = self.topic_explorer.generate_questions(topic) + + # Save questions to file + with open(f"./resources/{questions_file}", 'w', encoding='utf-8') as f: + questions_text = f"# Questions for {topic}\n\n" + for i, q in enumerate(questions, 1): + questions_text += f"{i}. {q}\n" + f.write(questions_text) + + progress(0.3, desc="Exploring questions...") + # Generate content for each question + results = {} + for i, question in enumerate(questions): + progress(0.3 + (0.6 * (i / len(questions))), + desc=f"Exploring question {i+1}/{len(questions)}") + response = self.topic_explorer.explore_question(question) + results[question] = response + + # Combine content + full_content = f"# {topic}\n\n" + for question, response in results.items(): + full_content += f"# {question}\n\n{response}\n\n" + + # Save content to file + with open(f"./resources/{content_file}", 'w', encoding='utf-8') as f: + f.write(full_content) + + progress(1.0, desc="Done!") + self.update_available_files() + + return questions_text, full_content, f"Content generated successfully and saved to {content_file}" + + except Exception as e: + return "", "", f"Error: {str(e)}" + + def create_podcast_transcript(self, content_file: str, detailed_transcript: bool, progress=gr.Progress()) -> Tuple[str, str]: + """Create podcast transcript from content file. + + Args: + content_file: Name of content file to use + detailed_transcript: Whether to use detailed question-by-question processing + progress: Gradio progress indicator + + Returns: + Tuple of (transcript, status message) + """ + if not content_file: + return "", "Error: Please select a content file" + + try: + progress(0, desc="Reading content file...") + + # Generate timestamp for file naming + timestamp = time.strftime("%Y%m%d_%H%M%S") + + # Read content from file + with open(f"./resources/{content_file}", 'r', encoding='utf-8') as f: + content = f.read() + + # Initialize podcast writer + self.podcast_writer = PodcastWriter() + + if detailed_transcript: + progress(0.2, desc="Generating detailed podcast transcript (processing each question individually)...") + transcript = self.podcast_writer.create_detailed_podcast_transcript(content) + transcript_type = "detailed" + else: + progress(0.2, desc="Generating standard podcast transcript...") + transcript = self.podcast_writer.create_podcast_transcript(content) + transcript_type = "standard" + + # Transcript is saved by the PodcastWriter class + # Find the most recently created transcript file + transcript_files = [f for f in os.listdir('./resources') + if f.startswith('podcast_transcript_') and f.endswith(f'{timestamp}.txt')] + + if transcript_files: + transcript_file = transcript_files[0] + else: + # Fallback - save transcript to file + transcript_file = f"podcast_transcript_{transcript_type}_{timestamp}.txt" + with open(f"./resources/{transcript_file}", 'w', encoding='utf-8') as f: + f.write(transcript) + + progress(1.0, desc="Done!") + self.update_available_files() + + return transcript, f"Transcript generated successfully and saved to {transcript_file}" + + except Exception as e: + return "", f"Error: {str(e)}" + + def generate_podcast_audio(self, transcript_file: str, model_type: str, progress=gr.Progress()) -> Tuple[str, str]: + """Generate podcast audio from transcript. + + Args: + transcript_file: Name of transcript file to use + model_type: TTS model to use ('bark' or 'parler') + progress: Gradio progress indicator + + Returns: + Tuple of (audio path, status message) + """ + if not transcript_file: + return "", "Error: Please select a transcript file" + + try: + progress(0, desc=f"Initializing {model_type} model...") + + # Initialize TTS generator if needed + if self.tts_generator is None or self.tts_generator.model_type != model_type: + self.tts_generator = TTSGenerator(model_type=model_type) + + # Generate timestamp for file naming + timestamp = time.strftime("%Y%m%d_%H%M%S") + audio_file = f"podcast_{timestamp}.mp3" + audio_path = f"./resources/{audio_file}" + + progress(0.1, desc="Generating podcast audio...") + + # Read transcript from file + with open(f"./resources/{transcript_file}", 'r', encoding='utf-8') as f: + transcript = f.read() + + # Generate podcast audio + self.tts_generator.generate_podcast(transcript, output_path=audio_path) + + progress(1.0, desc="Done!") + self.update_available_files() + + return audio_path, f"Podcast audio generated successfully and saved to {audio_file}" + + except Exception as e: + return "", f"Error: {str(e)}" + +def create_interface(): + """Create and launch the Gradio interface.""" + # Initialize the interface + interface = PlaneLLMInterface() + + # Define the interface + with gr.Blocks(title="planeLLM Interface") as app: + gr.Markdown("# planeLLM: Educational Content Generation System") + + # Create tabs for different components + with gr.Tabs(): + # Topic Explorer Tab + with gr.Tab("Topic Explorer"): + gr.Markdown("## Generate Educational Content") + + with gr.Row(): + topic_input = gr.Textbox(label="Topic", placeholder="Enter a topic (e.g., Ancient Rome, Quantum Physics)") + generate_button = gr.Button("Generate Content") + + with gr.Row(): + with gr.Column(): + questions_output = gr.Textbox(label="Generated Questions", lines=10, interactive=False) + with gr.Column(): + content_output = gr.Textbox(label="Generated Content", lines=20, interactive=False) + + status_output = gr.Textbox(label="Status", interactive=False) + + # Connect the button to the function + generate_button.click( + fn=interface.generate_topic_content, + inputs=[topic_input], + outputs=[questions_output, content_output, status_output] + ) + + # Lesson Writer Tab + with gr.Tab("Lesson Writer"): + gr.Markdown("## Create Podcast Transcript") + + with gr.Row(): + # Dropdown for selecting content file + content_file_dropdown = gr.Dropdown( + label="Select Content File", + choices=interface.available_files['content'], + interactive=True + ) + refresh_content_button = gr.Button("Refresh Files") + + with gr.Row(): + detailed_transcript = gr.Checkbox( + label="Detailed Processing", + value=True, + info="Process each question individually for more detailed content (recommended)" + ) + + create_transcript_button = gr.Button("Create Transcript") + + transcript_output = gr.Textbox(label="Generated Transcript", lines=20, interactive=False) + transcript_status = gr.Textbox(label="Status", interactive=False) + + # Connect buttons to functions + refresh_content_button.click( + fn=lambda: gr.Dropdown(choices=interface.update_available_files()['content']), + inputs=[], + outputs=[content_file_dropdown] + ) + + create_transcript_button.click( + fn=interface.create_podcast_transcript, + inputs=[content_file_dropdown, detailed_transcript], + outputs=[transcript_output, transcript_status] + ) + + # TTS Generator Tab + with gr.Tab("TTS Generator"): + gr.Markdown("## Generate Podcast Audio") + + with gr.Row(): + # Dropdown for selecting transcript file + transcript_file_dropdown = gr.Dropdown( + label="Select Transcript File", + choices=interface.available_files['transcripts'], + interactive=True + ) + refresh_transcript_button = gr.Button("Refresh Files") + + with gr.Row(): + model_type = gr.Radio( + label="TTS Model", + choices=["bark", "parler"], + value="bark", + info="Bark: Higher quality but slower, Parler: Faster but lower quality" + ) + + generate_audio_button = gr.Button("Generate Audio") + + with gr.Row(): + audio_output = gr.Audio(label="Generated Audio", interactive=False) + + audio_status = gr.Textbox(label="Status", interactive=False) + + # Connect buttons to functions + refresh_transcript_button.click( + fn=lambda: gr.Dropdown(choices=interface.update_available_files()['transcripts']), + inputs=[], + outputs=[transcript_file_dropdown] + ) + + generate_audio_button.click( + fn=interface.generate_podcast_audio, + inputs=[transcript_file_dropdown, model_type], + outputs=[audio_output, audio_status] + ) + + # Add a footer + gr.Markdown("---\n*planeLLM: Bite-sized podcasts to learn about anything powered by the OCI GenAI Service*") + + # Launch the interface + return app + +if __name__ == "__main__": + app = create_interface() + app.launch(share=True) \ No newline at end of file diff --git a/oci-csv-json-translation/data/translated_dog_healthcare.json b/oci-csv-json-translation/data/translated_dog_healthcare.json new file mode 100644 index 0000000..59fc980 --- /dev/null +++ b/oci-csv-json-translation/data/translated_dog_healthcare.json @@ -0,0 +1,184 @@ +{ + "pets" : [ + { + "id" : 1, + "name" : "Pochi", + "species" : "dog", + "age" : 5, + "healthStatus" : "Healthy", + "nextVisit" : "2024-02-15", + "comments" : "A blood test will be conducted on the next visit." + }, + { + "id" : 2, + "name" : "Taro", + "species" : "dog", + "age" : 8, + "healthStatus" : "Needs Monitoring", + "nextVisit" : "2024-01-20", + "comments" : "Because the stomach is common sense, additional tests are required." + }, + { + "id" : 3, + "name" : "Momo", + "species" : "dog", + "age" : 3, + "healthStatus" : "Healthy", + "nextVisit" : "2024-03-01", + "comments" : "Rabies vaccination is complete.Don't forget to update next year." + }, + { + "id" : 4, + "name" : "Hachi", + "species" : "dog", + "age" : 12, + "healthStatus" : "Needs Monitoring", + "nextVisit" : "2024-01-25", + "comments" : "Symptoms of arthritis.Continue to receive pain relief." + }, + { + "id" : 5, + "name" : "Koro", + "species" : "dog", + "age" : 2, + "healthStatus" : "Healthy", + "nextVisit" : "2024-04-10", + "comments" : "All vaccinations are complete.He is in a very healthy state." + }, + { + "id" : 6, + "name" : "Shiro", + "species" : "dog", + "age" : 6, + "healthStatus" : "Treatment Required", + "nextVisit" : "2024-01-18", + "comments" : "Removal of teeth is necessary.Early symptoms of gingivitis." + }, + { + "id" : 7, + "name" : "Luna", + "species" : "dog", + "age" : 4, + "healthStatus" : "Healthy", + "nextVisit" : "2024-03-20", + "comments" : "Prescribed phylaria-preventive medications.Continued periodic administration." + }, + { + "id" : 8, + "name" : "Buddy", + "species" : "dog", + "age" : 9, + "healthStatus" : "Needs Monitoring", + "nextVisit" : "2024-02-01", + "comments" : "Beware of heart test results.Consider starting medication." + }, + { + "id" : 9, + "name" : "Riki", + "species" : "dog", + "age" : 1, + "healthStatus" : "Healthy", + "nextVisit" : "2024-02-28", + "comments" : "The first vaccination was completed.The development is good." + }, + { + "id" : 10, + "name" : "Mei", + "species" : "dog", + "age" : 7, + "healthStatus" : "Treatment Required", + "nextVisit" : "2024-01-22", + "comments" : "Treatment of skin allergies.Meal restrictions continued." + }, + { + "id" : 11, + "name" : "Kenta", + "species" : "dog", + "age" : 5, + "healthStatus" : "Healthy", + "nextVisit" : "2024-04-05", + "comments" : "Annual medical examination.Especially no problem." + }, + { + "id" : 12, + "name" : "Sora", + "species" : "dog", + "age" : 10, + "healthStatus" : "Needs Monitoring", + "nextVisit" : "2024-02-10", + "comments" : "Test values for thyroid function are boundaries.Follow-up is required." + }, + { + "id" : 13, + "name" : "Hana", + "species" : "dog", + "age" : 3, + "healthStatus" : "Healthy", + "nextVisit" : "2024-03-15", + "comments" : "Mixed vaccination is complete.Weight management is fine." + }, + { + "id" : 14, + "name" : "Rocky", + "species" : "dog", + "age" : 8, + "healthStatus" : "Treatment Required", + "nextVisit" : "2024-01-30", + "comments" : "Treatment of ear infections.Provide guidance to maintain cleanliness." + }, + { + "id" : 15, + "name" : "Goro", + "species" : "dog", + "age" : 4, + "healthStatus" : "Healthy", + "nextVisit" : "2024-04-20", + "comments" : "Vaccination and phylaria prevention are carried out regularly." + }, + { + "id" : 16, + "name" : "Sakura", + "species" : "dog", + "age" : 6, + "healthStatus" : "Needs Monitoring", + "nextVisit" : "2024-02-05", + "comments" : "There is mild anemia.Beginning of diet." + }, + { + "id" : 17, + "name" : "Chibi", + "species" : "dog", + "age" : 2, + "healthStatus" : "Healthy", + "nextVisit" : "2024-03-10", + "comments" : "All vaccinations are complete.Growth is going well." + }, + { + "id" : 18, + "name" : "Max", + "species" : "dog", + "age" : 11, + "healthStatus" : "Treatment Required", + "nextVisit" : "2024-01-28", + "comments" : "Early symptoms of cataracts.Prescription of eye drops." + }, + { + "id" : 19, + "name" : "Yuki", + "species" : "dog", + "age" : 5, + "healthStatus" : "Healthy", + "nextVisit" : "2024-04-15", + "comments" : "Regularly screened.The vaccination is complete." + }, + { + "id" : 20, + "name" : "Kiki", + "species" : "dog", + "age" : 7, + "healthStatus" : "Needs Monitoring", + "nextVisit" : "2024-02-20", + "comments" : "Too much weight.Start diet restrictions and exercise therapy." + } + ] +} \ No newline at end of file diff --git a/oci-language-translation/batch_text_translation.py b/oci-language-translation/batch_text_translation.py new file mode 100644 index 0000000..2ae3a09 --- /dev/null +++ b/oci-language-translation/batch_text_translation.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 + +""" +Batch Translation Limitations: +- Maximum 100 records/documents per batch +- Each document must be less than 5000 characters +- Total character limit across all documents: 20,000 characters +- Supported file formats: plain text +""" + +import oci +import yaml +import sys +import time +import datetime +from pathlib import Path + +def load_config(): + """Load configuration from config.yaml file""" + try: + print("Loading configuration from config.yaml...") + with open("config.yaml", "r") as file: + config = yaml.safe_load(file) + print("βœ“ Configuration loaded successfully") + return config + except Exception as e: + print(f"βœ— Error loading config.yaml: {str(e)}") + sys.exit(1) + +def load_sample_texts(filename="sample_texts.txt"): + """Load text documents from file, one per line""" + try: + print(f"\nLoading texts from {filename}...") + with open(filename, 'r', encoding='utf-8') as file: + # Filter out empty lines and strip whitespace + texts = [line.strip() for line in file if line.strip()] + print(f"βœ“ Successfully loaded {len(texts)} texts") + + # Print character count statistics + total_chars = sum(len(text) for text in texts) + avg_chars = total_chars / len(texts) if texts else 0 + print(f" β€’ Total characters: {total_chars:,}") + print(f" β€’ Average characters per text: {avg_chars:.1f}") + + # Check against limitations + if len(texts) > 100: + print("⚠ Warning: Number of texts exceeds 100 limit") + if total_chars > 20000: + print("⚠ Warning: Total characters exceed 20,000 limit") + for i, text in enumerate(texts, 1): + if len(text) > 5000: + print(f"⚠ Warning: Text {i} exceeds 5,000 character limit") + + return texts + except Exception as e: + print(f"βœ— Error loading sample texts: {str(e)}") + print("Using default sample texts instead...") + return [ + "This is the first document to translate.", + "Here is another document that needs translation.", + "And a third document with some more text." + ] + +def init_client(): + """Initialize OCI AI Language client""" + try: + print("\nInitializing OCI AI Language client...") + config = oci.config.from_file(profile_name="comm") + client = oci.ai_language.AIServiceLanguageClient(config=config) + print("βœ“ Client initialized successfully") + return client + except Exception as e: + print(f"βœ— Error initializing OCI client: {str(e)}") + sys.exit(1) + +def translate_batch_documents(ai_client, documents, source_language, target_language, compartment_id): + """Translate a batch of documents using OCI Language service""" + try: + start_time = time.time() + print(f"\nPreparing {len(documents)} documents for translation...") + print(f" β€’ Source language: {source_language}") + print(f" β€’ Target language: {target_language}") + + # Prepare the documents for translation + text_documents = [ + oci.ai_language.models.TextDocument( + key=f"doc_{i}", + text=doc, + language_code=source_language + ) for i, doc in enumerate(documents) + ] + print("βœ“ Documents prepared successfully") + + # Create batch translation request + batch_translation_details = oci.ai_language.models.BatchLanguageTranslationDetails( + documents=text_documents, + compartment_id=compartment_id, + target_language_code=target_language + ) + + # Send translation request + print("\nSending batch translation request...") + response = ai_client.batch_language_translation( + batch_language_translation_details=batch_translation_details + ) + print("βœ“ Translation request sent successfully") + + # Process results + results = [] + if response and response.data and response.data.documents: + success_count = 0 + for doc in response.data.documents: + if doc.translated_text: + success_count += 1 + results.append(doc.translated_text if doc.translated_text else None) + + end_time = time.time() + duration = end_time - start_time + + print(f"\nβœ“ Translation completed in {duration:.1f} seconds") + print(f" β€’ Successfully translated: {success_count}/{len(documents)} documents") + print(f" β€’ Success rate: {(success_count/len(documents))*100:.1f}%") + + return results + else: + print("\nβœ— No translation results received") + return None + + except Exception as e: + print(f"\nβœ— Error during batch translation: {str(e)}") + return None + +def main(): + try: + start_time = time.time() + print("=" * 60) + print("OCI Language Batch Text Translation".center(60)) + print("=" * 60) + + # Load configuration + config = load_config() + + # Get configuration values + compartment_id = config["language_translation"]["compartment_id"] + source_language = config["language_translation"]["source_language"] + target_language = config["language_translation"]["target_language"] + + # Initialize client + ai_client = init_client() + + # Load documents from file + documents = load_sample_texts() + + # Translate documents + translated_texts = translate_batch_documents( + ai_client, + documents, + source_language, + target_language, + compartment_id + ) + + # Print results + if translated_texts: + print("\nDetailed Translation Results:") + print("-" * 60) + for i, (original, translated) in enumerate(zip(documents, translated_texts), 1): + print(f"\nDocument {i}:") + print(f"Original ({source_language}): {original}") + if translated: + print(f"Translated ({target_language}): {translated}") + else: + print("βœ— Translation failed") + print("-" * 60) + else: + print("\nβœ— Translation process failed") + + end_time = time.time() + total_duration = end_time - start_time + print(f"\nTotal execution time: {total_duration:.1f} seconds") + print("=" * 60) + + except Exception as e: + print(f"\nβœ— Error: {str(e)}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/oci-subtitle-translation/translate_srt.py b/oci-subtitle-translation/translate_srt.py new file mode 100644 index 0000000..53f823a --- /dev/null +++ b/oci-subtitle-translation/translate_srt.py @@ -0,0 +1,157 @@ +import oci +import yaml +import argparse +import os +import time +from pathlib import Path + +def load_config(): + """Load configuration from config.yaml""" + with open('config.yaml', 'r') as f: + return yaml.safe_load(f) + +def get_language_client(): + """Initialize and return the OCI Language client""" + config = oci.config.from_file() + return oci.ai_language.AIServiceLanguageClient(config) + +def upload_to_object_storage(object_storage_client, namespace, bucket_name, file_path): + """Upload file to OCI Object Storage""" + file_name = os.path.basename(file_path) + + with open(file_path, 'rb') as f: + object_storage_client.put_object( + namespace, + bucket_name, + file_name, + f + ) + return file_name + +def wait_for_job_completion(client, job_id, compartment_id, max_wait_seconds=1800, wait_interval_seconds=30): + """Wait for the translation job to complete""" + for _ in range(0, max_wait_seconds, wait_interval_seconds): + get_job_response = client.get_job( + job_id=job_id, + compartment_id=compartment_id + ) + + status = get_job_response.data.lifecycle_state + if status == "SUCCEEDED": + return True + elif status in ["FAILED", "CANCELED"]: + print(f"Job failed with status: {status}") + return False + + time.sleep(wait_interval_seconds) + + return False + +def translate_srt(client, object_storage_client, config, input_file, source_lang='en', target_lang='es'): + """Translate an SRT file using OCI Language Async Document Translation""" + try: + # Validate file size (20MB limit) + file_size = os.path.getsize(input_file) + if file_size > 20 * 1024 * 1024: # 20MB in bytes + raise ValueError("Input file exceeds 20MB limit") + + # Upload file to Object Storage + input_object_name = upload_to_object_storage( + object_storage_client, + config['speech']['namespace'], + config['speech']['bucket_name'], + input_file + ) + + # Create document details + document_details = oci.ai_language.models.ObjectLocation( + namespace_name=config['speech']['namespace'], + bucket_name=config['speech']['bucket_name'], + object_names=[input_object_name] + ) + + # Create job details + create_job_details = oci.ai_language.models.CreateBatchLanguageTranslationJobDetails( + compartment_id=config['language']['compartment_id'], + display_name=f"Translate_{os.path.basename(input_file)}_{target_lang}", + source_language_code=source_lang, + target_language_code=target_lang, + input_location=document_details, + output_location=document_details, + model_id="PRETRAINED_LANGUAGE_TRANSLATION" + ) + + # Create translation job + response = client.create_job( + create_job_details=create_job_details + ) + + job_id = response.data.id + print(f"Translation job created with ID: {job_id}") + + # Wait for job completion + if wait_for_job_completion(client, job_id, config['language']['compartment_id']): + print(f"Successfully translated to {target_lang}") + return True + else: + print("Translation job failed or timed out") + return False + + except Exception as e: + print(f"Error translating to {target_lang}: {str(e)}") + return False + +def main(): + # Define supported languages + SUPPORTED_LANGUAGES = { + 'ar': 'Arabic', 'hr': 'Croatian', 'cs': 'Czech', 'da': 'Danish', + 'nl': 'Dutch', 'en': 'English', 'fi': 'Finnish', 'fr': 'French', + 'fr-CA': 'French Canadian', 'de': 'German', 'el': 'Greek', + 'he': 'Hebrew', 'hu': 'Hungarian', 'it': 'Italian', 'ja': 'Japanese', + 'ko': 'Korean', 'no': 'Norwegian', 'pl': 'Polish', 'pt': 'Portuguese', + 'pt-BR': 'Portuguese Brazilian', 'ro': 'Romanian', 'ru': 'Russian', + 'zh-CN': 'Simplified Chinese', 'sk': 'Slovak', 'sl': 'Slovenian', + 'es': 'Spanish', 'sv': 'Swedish', 'th': 'Thai', 'zh-TW': 'Traditional Chinese', + 'tr': 'Turkish', 'vi': 'Vietnamese' + } + + parser = argparse.ArgumentParser(description='Translate SRT files using OCI Language') + parser.add_argument('--input-file', required=True, help='Input SRT file path') + parser.add_argument('--source-lang', default='en', help='Source language code') + parser.add_argument('--target-langs', nargs='+', help='Target language codes (space-separated)') + args = parser.parse_args() + + # Validate input file + if not os.path.exists(args.input_file): + print(f"Error: Input file {args.input_file} not found") + return + + # Load configuration + config = load_config() + + # Initialize clients + language_client = get_language_client() + object_storage_client = oci.object_storage.ObjectStorageClient(oci.config.from_file()) + + # If no target languages specified, translate to all supported languages + target_langs = args.target_langs if args.target_langs else SUPPORTED_LANGUAGES.keys() + + # Translate to each target language + for lang in target_langs: + if lang not in SUPPORTED_LANGUAGES: + print(f"Warning: Unsupported language code '{lang}', skipping...") + continue + + if lang != args.source_lang: + print(f"Translating to {SUPPORTED_LANGUAGES[lang]} ({lang})...") + translate_srt( + language_client, + object_storage_client, + config, + args.input_file, + args.source_lang, + lang + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/planeLLM b/planeLLM new file mode 160000 index 0000000..e97203f --- /dev/null +++ b/planeLLM @@ -0,0 +1 @@ +Subproject commit e97203fcba812bc4d7345cbf75085061d635e539 diff --git a/podcast_controller.py b/podcast_controller.py new file mode 100644 index 0000000..03c2217 --- /dev/null +++ b/podcast_controller.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +""" +Podcast Controller for planeLLM. + +This script orchestrates the entire podcast generation pipeline, from topic exploration +to audio generation. It provides a simple command-line interface to generate educational +podcasts on any topic. + +Examples: + python podcast_controller.py --topic "Ancient Rome" + python podcast_controller.py --topic "Quantum Physics" --tts-model parler + python podcast_controller.py --topic "Machine Learning" --config my_config.yaml + python podcast_controller.py --topic "Artificial Intelligence" --detailed-transcript + +""" + +import os +import time +import argparse +from topic_explorer import TopicExplorer +from lesson_writer import PodcastWriter +from tts_generator import TTSGenerator + +def main(): + """Run the podcast generation pipeline.""" + # Parse command line arguments + parser = argparse.ArgumentParser(description='Generate an educational podcast on any topic') + parser.add_argument('--topic', required=True, help='Topic to generate a podcast about') + parser.add_argument('--tts-model', default='bark', choices=['bark', 'parler'], + help='TTS model to use (default: bark)') + parser.add_argument('--config', default='config.yaml', + help='Path to configuration file (default: config.yaml)') + parser.add_argument('--output', help='Output path for the audio file') + parser.add_argument('--detailed-transcript', action='store_true', + help='Process each question individually for more detailed content') + + args = parser.parse_args() + + # Create resources directory if it doesn't exist + os.makedirs('./resources', exist_ok=True) + + # Generate timestamp for file naming + timestamp = time.strftime("%Y%m%d_%H%M%S") + + # Step 1: Generate educational content + print(f"\n=== Step 1: Exploring topic '{args.topic}' ===") + explorer = TopicExplorer(config_file=args.config) + questions = explorer.generate_questions(args.topic) + + # Save questions to file + questions_file = f"questions_{timestamp}.txt" + with open(f'./resources/{questions_file}', 'w', encoding='utf-8') as file: + file.write("\n".join(questions)) + print(f"Questions saved to ./resources/{questions_file}") + + # Generate content for each question + print("\nGenerating educational content...") + content = "" + for i, question in enumerate(questions[:2]): # Limit to first 2 questions for brevity + print(f"Exploring question {i+1}/{len(questions[:2])}: {question}") + question_content = explorer.explore_question(question) + content += f"# {question}\n\n{question_content}\n\n" + + # Save raw content to file + content_file = f"content_{timestamp}.txt" + with open(f'./resources/{content_file}', 'w', encoding='utf-8') as file: + file.write(content) + print(f"Raw content saved to ./resources/{content_file}") + + # Step 2: Create podcast transcript + print(f"\n=== Step 2: Creating podcast transcript ===") + writer = PodcastWriter(config_file=args.config) + + if args.detailed_transcript: + print("Using detailed transcript generation (processing each question individually)") + transcript = writer.create_detailed_podcast_transcript(content) + else: + transcript = writer.create_podcast_transcript(content) + + # Transcript is saved by the PodcastWriter class + transcript_file = [f for f in os.listdir('./resources') + if f.startswith('podcast_transcript_') and f.endswith(f'{timestamp}.txt')] + if transcript_file: + transcript_path = f"./resources/{transcript_file[0]}" + else: + transcript_path = f"./resources/podcast_transcript_{timestamp}.txt" + + # Step 3: Generate audio + print(f"\n=== Step 3: Generating podcast audio ===") + tts = TTSGenerator(model_type=args.tts_model, config_file=args.config) + + # Determine output path + if args.output: + output_path = args.output + else: + output_path = f"./resources/podcast_{timestamp}.mp3" + + # Generate audio + audio_path = tts.generate_podcast(transcript, output_path=output_path) + + # Print summary + print("\n=== Podcast Generation Complete ===") + print(f"Questions: ./resources/{questions_file}") + print(f"Content: ./resources/{content_file}") + print(f"Transcript: {transcript_path}") + print(f"Audio: {audio_path}") + print("\nThank you for using planeLLM!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/rag_in_a_box/README.md b/rag_in_a_box/README.md index a5c277b..12a3e7b 100644 --- a/rag_in_a_box/README.md +++ b/rag_in_a_box/README.md @@ -2,6 +2,8 @@ ## Introduction +Podman commands + An innovative RAG (Retrieval Augmented Generation) system designed to leverage an LLM agent for effective information retrieval and response generation. This system efficiently processes various document formats and intelligently selects the appropriate knowledge base based on user queries. @@ -14,6 +16,8 @@ This solution includes 2 containers: - Oracle Database 23ai - AIRAG Container +Podman commands + ## 1. Install the containers #### Windows Deployment @@ -48,17 +52,20 @@ podman images ``` 3. Check running containers: + ```bash podman ps ``` 4. Check container logs: + ```bash podman logs -f 23aidb podman logs -f airagdb23aiinbox ``` 5. Connect to the database: + ```bash podman exec -it 23aidb sqlplus VECDEMO/@FREEPDB1 ``` diff --git a/rag_in_a_box/images/airag_in_a_box.png b/rag_in_a_box/images/airag_in_a_box.png new file mode 100644 index 0000000..1253e90 Binary files /dev/null and b/rag_in_a_box/images/airag_in_a_box.png differ diff --git a/rag_in_a_box/images/streamlit_interface.png b/rag_in_a_box/images/streamlit_interface.png new file mode 100644 index 0000000..7c23493 Binary files /dev/null and b/rag_in_a_box/images/streamlit_interface.png differ diff --git a/rag_in_a_box/install_colima_docker_macosx copy.md b/rag_in_a_box/install_colima_docker_macosx copy.md deleted file mode 100644 index dffac90..0000000 --- a/rag_in_a_box/install_colima_docker_macosx copy.md +++ /dev/null @@ -1,171 +0,0 @@ -# Run Oracle Database 23ai Free on Mac computers with Apple silicon - - -But what about Oracle 23ai and the newer M1/M2/M3 ARM based Apple silicon Macs I hear you ask, so here you go. - - -See the next blog to understand the complexity [Here](https://ronekins.com/2024/07/02/run-oracle-database-23ai-free-on-mac-computers-with-apple-silicon/) - - -## Preinstallation - -You need to install Rosetta2 if it is not already installed (Open Terminal in your MAC OSX Mx): - -```Code - -softwareupdate --install-rosetta - - - -I have read and agree to the terms of the software license agreement. A list of Apple SLAs may be found here: https://www.apple.com/legal/sla/ -Type A and press return to agree: A -2024-08-14 17:58:43.771 softwareupdate[67564:9084300] Package Authoring Error: 062-01890: Package reference com.apple.pkg.RosettaUpdateAuto is missing installKBytes attribute -Install of Rosetta 2 finished successfully - - -``` - -Start by installing the Homebrew package manager, if not already installed. - -Now, install Colima container runtime and Docker if not already installed on your Mac. - -```Code - -brew --version - -brew install colima docker -brew install docker-compose -brew reinstall qemu -``` - -You could see next results: - -```Code - -zsh completions have been installed to: - /opt/homebrew/share/zsh/site-functions - -To start colima now and restart at login: - brew services start colima -Or, if you don't want/need a background service you can just run: - /opt/homebrew/opt/colima/bin/colima start -f - -``` - -Restart colima - -```Code - -brew services restart colima - -``` - - -Check the installation: - -```Code - -colima --version -colima version 0.7.3 - -docker --version -Docker version 27.1.2, build d01f264bcc -``` - - -If we now try docker ps to check for processes, we see a docker daemon error message: - -```Code -docker ps -Cannot connect to the Docker daemon at unix:///var/run/docker.sock. Is the docker daemon running? -``` - - -## Start Colima Container Runtime - -If you want to deploy containers created in x86_64 architecture, we must start Colima Container Runtime like this: - -```Code -# colima start --arch x86_64 --vm-type=vz --vz-rosetta --mount-type=virtiofs --memory 8 --cpu 4 - -colima start --arch x86_64 --memory 8 --cpu 4 - -WARN[0000] 'architecture' cannot be updated after initial setup, discarded -WARN[0000] 'virtual machine type' cannot be updated after initial setup, discarded -WARN[0000] 'volume mount type' cannot be updated after initial setup, discarded -INFO[0000] starting colima -INFO[0000] runtime: docker -INFO[0002] starting ... context=vm -INFO[0013] provisioning ... context=docker -INFO[0014] starting ... context=docker -INFO[0016] done - -``` - -We can confirm the Colima Container Runtime configuration with colima status and colima list, for example. - -```Code - -colima status - -INFO[0001] colima is running using macOS Virtualization.Framework -INFO[0001] arch: x86_64 -INFO[0001] runtime: docker -INFO[0001] mountType: virtiofs -INFO[0001] socket: unix:///Users/operard/.colima/default/docker.sock - - - -colima list - -PROFILE STATUS ARCH CPUS MEMORY DISK RUNTIME ADDRESS -default Running aarch64 2 8GiB 60GiB docker - -``` - -Before we try and start any container, let’s check for Docker processes using docker ps. - -We now no longer see the Docker daemon error message. - -```Code - -docker - -CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES -We can check for existing docker images with docker images, for example. - -docker images - -REPOSITORY TAG IMAGE ID CREATED SIZE -The docker context show command should return colima, which means Docker runs under Colima and you can therefore use docker commands as usual. - -docker context show - -colima -``` - - -## Clean the colima and docker deployment in MAC OSX - -If you must clean your environment in order to reset it, use the next commands: - -```Code - -colima stop - - -INFO[0000] stopping colima -INFO[0000] stopping ... context=docker -INFO[0004] stopping ... context=vm -INFO[0007] done - -colima delete - - -are you sure you want to delete colima and all settings? [y/N] y -INFO[0001] deleting colima -INFO[0002] done - -``` - - diff --git a/tts_generator.py b/tts_generator.py new file mode 100644 index 0000000..6fc6b86 --- /dev/null +++ b/tts_generator.py @@ -0,0 +1,305 @@ +import warnings +# Suppress all warnings +warnings.filterwarnings('ignore') + +# Specifically suppress the attention mask warnings +warnings.filterwarnings('ignore', message='.*The attention mask.*') +warnings.filterwarnings('ignore', message='.*The pad token id is not set.*') +warnings.filterwarnings('ignore', message='.*You have modified the pretrained model configuration.*') +warnings.filterwarnings('ignore', category=UserWarning) +warnings.filterwarnings('ignore', category=FutureWarning) + +import os +import torch +# Suppress Flash Attention 2 warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" +# Suppress HF text generation warnings +os.environ["HF_SUPPRESS_GENERATION_WARNINGS"] = "true" +# Additional environment variables to suppress warnings +os.environ["TRANSFORMERS_VERBOSITY"] = "error" +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import time +import yaml +import re +import shutil +from typing import Dict, List, Optional, Union, Tuple +from pydub import AudioSegment +import tempfile +import tqdm + +# Disable logging from transformers +import logging +logging.getLogger("transformers").setLevel(logging.ERROR) +logging.getLogger("transformers.generation_utils").setLevel(logging.ERROR) + +class TTSGenerator: + """Class for generating podcast audio from transcripts.""" + + def __init__(self, model_type: str = "bark", config_file: str = 'config.yaml') -> None: + """Initialize the TTS generator. + + Args: + model_type: Type of TTS model to use ('bark', 'parler', or 'coqui') + config_file: Path to configuration file + + Raises: + ValueError: If model_type is not supported + """ + self.model_type = model_type.lower() + + if self.model_type not in ["bark", "parler", "coqui"]: + raise ValueError("Unsupported TTS model type. Choose 'bark', 'parler', or 'coqui'") + + # Check for FFmpeg dependencies + self.ffmpeg_available = self._check_ffmpeg() + if not self.ffmpeg_available: + print("WARNING: FFmpeg/ffprobe not found. Audio export may fail.") + print("Please install FFmpeg: https://ffmpeg.org/download.html") + + # Load configuration + with open(config_file, 'r', encoding='utf-8') as file: + self.config = yaml.safe_load(file) + + # Initialize model-specific components + if self.model_type == "bark": + self._init_bark() + elif self.model_type == "parler": + self._init_parler() + else: # coqui + self._init_coqui() + + # Initialize execution time tracking + self.execution_times = { + 'start_time': 0, + 'total_time': 0, + 'segments': [] + } + + def _check_ffmpeg(self) -> bool: + """Check if FFmpeg and ffprobe are available.""" + ffmpeg = shutil.which("ffmpeg") + ffprobe = shutil.which("ffprobe") + return ffmpeg is not None and ffprobe is not None + + def _init_bark(self) -> None: + """Initialize the Bark TTS model.""" + print("Initializing Bark TTS model...") + from transformers import AutoProcessor, BarkModel + + # Load model and processor + self.processor = AutoProcessor.from_pretrained("suno/bark") + self.model = BarkModel.from_pretrained("suno/bark") + + # Set pad token ID to avoid warnings + self.model.config.pad_token_id = self.model.config.eos_token_id + + # Move model to GPU if available + if torch.cuda.is_available(): + self.model = self.model.to("cuda") + print("Bark model loaded on GPU") + else: + print("Bark model loaded on CPU") + + # Define speaker presets + self.speakers = { + "Speaker 1": "v2/en_speaker_6", # Male expert + "Speaker 2": "v2/en_speaker_9", # Female student + "Speaker 3": "v2/en_speaker_3" # Second expert + } + + def _init_parler(self) -> None: + """Initialize the Parler TTS model.""" + print("Initializing Parler TTS model...") + try: + # Try both import paths for compatibility + try: + from parler_tts import ParlerTTS + except ImportError: + from parler.tts import ParlerTTS + + # Initialize Parler TTS + self.model = ParlerTTS() + + # Define speaker presets (speaker IDs for Parler) + self.speakers = { + "Speaker 1": 0, # Male expert + "Speaker 2": 1, # Female student + "Speaker 3": 2 # Second expert + } + self.parler_available = True + except ImportError: + print("WARNING: Parler TTS module not found. Using fallback TTS instead.") + print("To install Parler TTS, run: pip install git+https://github.com/huggingface/parler-tts.git") + # Fall back to Bark if Parler is not available + self.model_type = "bark" + self._init_bark() + self.parler_available = False + + def _init_coqui(self) -> None: + """Initialize the Coqui TTS model.""" + print("Initializing Coqui TTS model...") + try: + from TTS.api import TTS + + # Initialize Coqui TTS with a multi-speaker model + # Using VITS model which supports multi-speaker synthesis + self.model = TTS("tts_models/multilingual/multi-dataset/xtts_v2") + + # Define speaker presets (speaker names for Coqui XTTS) + self.speakers = { + "Speaker 1": "p326", # Male expert + "Speaker 2": "p225", # Female student + "Speaker 3": "p330" # Second expert + } + self.coqui_available = True + + # Store sample rate for later use + self.sample_rate = 24000 # Default for XTTS + + except ImportError: + print("WARNING: Coqui TTS module not found. Using fallback TTS instead.") + print("To install Coqui TTS, run: pip install TTS") + # Fall back to Bark if Coqui is not available + self.model_type = "bark" + self._init_bark() + self.coqui_available = False + except Exception as e: + print(f"WARNING: Error initializing Coqui TTS: {str(e)}. Using fallback TTS instead.") + # Fall back to Bark if there's an error with Coqui + self.model_type = "bark" + self._init_bark() + self.coqui_available = False + + def _generate_audio_bark(self, text: str, speaker: str) -> AudioSegment: + """Generate audio using Bark TTS. + + Args: + text: Text to convert to speech + speaker: Speaker identifier + + Returns: + AudioSegment containing the generated speech + """ + try: + # Prepare inputs + inputs = self.processor( + text=text, + voice_preset=self.speakers[speaker], + return_tensors="pt" + ) + + # Create attention mask if not present + if "attention_mask" not in inputs: + # Create attention mask (all 1s, same shape as input_ids) + inputs["attention_mask"] = torch.ones_like(inputs["input_ids"]) + + # Move inputs to GPU if available + if torch.cuda.is_available(): + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + # Generate audio with specific generation parameters + generation_kwargs = { + "pad_token_id": self.model.config.pad_token_id, + "do_sample": True, + "temperature": 0.7, + "max_new_tokens": 250 + } + + # Make a clean copy of inputs without any generation parameters + # to avoid conflicts with generation_kwargs + model_inputs = {} + for k, v in inputs.items(): + if k not in ["max_new_tokens", "do_sample", "temperature", "pad_token_id"]: + model_inputs[k] = v + + # Generate the audio + speech_output = self.model.generate(**model_inputs, **generation_kwargs) + + # Convert to audio segment + audio_array = speech_output.cpu().numpy().squeeze() + + # Save to temporary file and load as AudioSegment + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: + temp_path = temp_file.name + + # Save as WAV + import scipy.io.wavfile as wavfile + wavfile.write(temp_path, rate=24000, data=audio_array) + + # Load as AudioSegment + if not self.ffmpeg_available: + print("WARNING: FFmpeg not available. Using silent audio as fallback.") + audio_segment = AudioSegment.silent(duration=len(audio_array) * 1000 // 24000) + else: + try: + audio_segment = AudioSegment.from_wav(temp_path) + except Exception as e: + print(f"Error loading audio segment: {str(e)}") + # Fallback to silent audio + audio_segment = AudioSegment.silent(duration=len(audio_array) * 1000 // 24000) + + # Clean up temporary file + try: + os.unlink(temp_path) + except Exception as e: + print(f"Warning: Could not delete temporary file {temp_path}: {str(e)}") + + return audio_segment + except Exception as e: + print(f"Error in _generate_audio_bark: {str(e)}") + # Return a silent segment as fallback + return AudioSegment.silent(duration=1000) + + def _generate_audio_coqui(self, text: str, speaker: str) -> AudioSegment: + """Generate audio using Coqui TTS. + + Args: + text: Text to convert to speech + speaker: Speaker identifier + + Returns: + AudioSegment containing the generated speech + """ + try: + # Create a temporary file to save the audio + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: + temp_path = temp_file.name + + # Generate audio with Coqui TTS + # For XTTS, we need to provide a reference audio file for the speaker + # Since we don't have that, we'll use the built-in speaker IDs + self.model.tts_to_file( + text=text, + file_path=temp_path, + speaker=self.speakers[speaker], + language="en" + ) + + # Load as AudioSegment + if not self.ffmpeg_available: + print("WARNING: FFmpeg not available. Using silent audio as fallback.") + # Estimate duration based on text length (rough approximation) + estimated_duration = len(text) * 60 # ~60ms per character + audio_segment = AudioSegment.silent(duration=estimated_duration) + else: + try: + audio_segment = AudioSegment.from_wav(temp_path) + except Exception as e: + print(f"Error loading audio segment: {str(e)}") + # Fallback to silent audio + estimated_duration = len(text) * 60 # ~60ms per character + audio_segment = AudioSegment.silent(duration=estimated_duration) + + # Clean up temporary file + try: + os.unlink(temp_path) + except Exception as e: + print(f"Warning: Could not delete temporary file {temp_path}: {str(e)}") + + return audio_segment + except Exception as e: + print(f"Error generating audio with Coqui: {str(e)}") + # Return a silent segment as fallback + return AudioSegment.silent(duration=1000) \ No newline at end of file