diff --git a/BedrockPromptCachingRoutingDemo/.gitignore b/BedrockPromptCachingRoutingDemo/.gitignore new file mode 100644 index 000000000..99cae2e24 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/.gitignore @@ -0,0 +1,67 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual environments +venv/ +env/ +ENV/ +.venv/ +.env/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# AWS +.aws/ +*.pem + +# Logs +*.log +logs/ + +# Temporary files +*.tmp +*.temp + +# Jupyter Notebook +.ipynb_checkpoints + +# Streamlit +.streamlit/ + +# Benchmark results +benchmark_results_*.csv +benchmark_plot_*.png \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/LICENSE b/BedrockPromptCachingRoutingDemo/LICENSE new file mode 100644 index 000000000..46e552a63 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Amazon Bedrock Workshop + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/README.md b/BedrockPromptCachingRoutingDemo/README.md new file mode 100644 index 000000000..1ea7958c7 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/README.md @@ -0,0 +1,129 @@ +# Amazon Bedrock Prompt Caching and Routing Workshop + +This repository contains a complete implementation of Amazon Bedrock's prompt caching and routing capabilities using the latest Claude 4.5 models. + +## Features + +- **Prompt Caching**: Reduce latency and costs by caching frequently used prompts +- **Prompt Routing**: Intelligently route requests to optimal models +- **Latest Models**: Updated to use Claude Haiku 4.5, Sonnet 4.5, and Opus 4.1 +- **Global Endpoints**: Compatible across all AWS regions +- **Multiple Interfaces**: Both CLI and Streamlit web applications + +## Project Structure + +``` +BedrockPromptDemo/ +├── src/ +│ ├── bedrock_prompt_caching.py # CLI application for prompt caching +│ ├── bedrock_prompt_routing.py # CLI application for prompt routing +│ ├── prompt_caching_app.py # Streamlit UI for prompt caching +│ ├── prompt_router_app.py # Streamlit UI for prompt routing +│ ├── model_manager.py # Model selection and management +│ ├── bedrock_service.py # Bedrock API service wrapper +│ └── file_processor.py # File processing utilities +├── requirements.txt # Python dependencies +└── README.md # This file +``` + +## Latest Models Supported + +- **Claude Haiku 4.5**: `anthropic.claude-haiku-4-5-20251001-v1:0` +- **Claude Sonnet 4.5**: `anthropic.claude-sonnet-4-5-20250929-v1:0` +- **Claude Opus 4.1**: `anthropic.claude-opus-4-1-20250805-v1:0` +- **Amazon Nova Models**: `amazon.nova-micro-v1:0`, `amazon.nova-lite-v1:0`, `amazon.nova-pro-v1:0` + +## Prerequisites + +- AWS CLI configured with appropriate credentials +- Python 3.8+ +- Access to Amazon Bedrock with Claude models enabled + +## Installation + +1. Clone this repository: +```bash +git clone +cd BedrockPromptDemo +``` + +2. Install dependencies: +```bash +pip install -r requirements.txt +``` + +3. Configure AWS credentials: +```bash +aws configure +``` + +## Usage + +### CLI Applications + +**Prompt Caching:** +```bash +cd src +python bedrock_prompt_caching.py +``` + +**Prompt Routing:** +```bash +cd src +python bedrock_prompt_routing.py +``` + +### Web Applications + +**Prompt Caching UI:** +```bash +cd src +streamlit run prompt_caching_app.py +``` + +**Prompt Routing UI:** +```bash +cd src +streamlit run prompt_router_app.py +``` + +## Key Features + +### Prompt Caching +- Automatically caches document context for faster subsequent queries +- Shows cache hit/miss statistics +- Demonstrates cost and latency benefits +- Supports multi-turn conversations + +### Prompt Routing +- Intelligently routes requests to optimal models +- Displays routing decisions and model selection +- Tracks usage statistics across different models +- Supports file uploads (PDF, DOCX, TXT) + +### Model Management +- Dynamic model selection from available Bedrock models +- Inference profile resolution for optimal performance +- Fallback model configuration +- Global endpoint support for multi-region compatibility + +## Configuration + +The applications use global model endpoints by default, making them compatible across all AWS regions. Models are automatically resolved to regional endpoints by Bedrock's routing service. + +## Workshop Learning Objectives + +This code demonstrates: +1. How to implement prompt caching to reduce costs and latency +2. How to use prompt routing for intelligent model selection +3. Best practices for Bedrock API integration +4. Performance monitoring and usage tracking +5. Multi-modal file processing capabilities + +## Contributing + +Feel free to submit issues and enhancement requests! + +## License + +This project is licensed under the MIT License - see the LICENSE file for details. \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/bedrock_prompt_caching_routing.ipynb b/BedrockPromptCachingRoutingDemo/bedrock_prompt_caching_routing.ipynb new file mode 100644 index 000000000..c800c52c2 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/bedrock_prompt_caching_routing.ipynb @@ -0,0 +1,642 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Amazon Bedrock Prompt Caching and Routing Workshop\n", + "\n", + "## Overview\n", + "\n", + "This notebook demonstrates Amazon Bedrock's prompt caching and routing capabilities using the latest Claude models. You'll learn how to reduce latency and costs through intelligent prompt caching and how to route requests to optimal models based on your specific needs.\n", + "\n", + "**Key Learning Outcomes:**\n", + "- Implement prompt caching to reduce costs and latency\n", + "- Use prompt routing for intelligent model selection\n", + "- Understand best practices for Bedrock API integration\n", + "- Monitor performance and usage statistics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Context or Details about feature/use case\n", + "\n", + "### Prompt Caching\n", + "Prompt caching allows you to cache frequently used prompts, reducing both latency and costs for subsequent requests. This is particularly useful for:\n", + "- Document analysis workflows\n", + "- Multi-turn conversations\n", + "- Repetitive query patterns\n", + "\n", + "### Prompt Routing\n", + "Prompt routing intelligently directs requests to the most appropriate model based on:\n", + "- Query complexity\n", + "- Cost optimization requirements\n", + "- Performance needs\n", + "- Model capabilities\n", + "\n", + "### Supported Models\n", + "- **Claude Haiku 4.5**: Fast, cost-effective for simple tasks\n", + "- **Claude Sonnet 4.5**: Balanced performance and cost\n", + "- **Claude Opus 4.1**: Most capable for complex reasoning\n", + "- **Amazon Nova Models**: Latest AWS-native models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "Before running this notebook, ensure you have:\n", + "\n", + "1. **AWS Account** with appropriate permissions\n", + "2. **Amazon Bedrock access** with Claude models enabled\n", + "3. **AWS CLI configured** with credentials\n", + "4. **Python 3.8+** installed\n", + "5. **Required Python packages** (installed in Setup section)\n", + "\n", + "### Required AWS Permissions\n", + "Your AWS credentials need the following permissions:\n", + "- `bedrock:InvokeModel`\n", + "- `bedrock:ListFoundationModels`\n", + "- `bedrock:GetModelInvocationLoggingConfiguration`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Let's install the required dependencies and set up our environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install boto3 streamlit pandas numpy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import required libraries\n", + "import boto3\n", + "import json\n", + "import time\n", + "from datetime import datetime\n", + "import pandas as pd\n", + "from typing import Dict, List, Optional, Tuple\n", + "\n", + "# Initialize Bedrock client\n", + "bedrock_client = boto3.client('bedrock-runtime', region_name='us-east-1')\n", + "\n", + "print(\"✅ Setup complete! Bedrock client initialized.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Your code with comments starts here\n", + "\n", + "### Model Manager Class\n", + "\n", + "First, let's create a model manager to handle different Claude models and their configurations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ModelManager:\n", + " \"\"\"Manages Bedrock model selection and configuration\"\"\"\n", + " \n", + " def __init__(self):\n", + " self.models = {\n", + " 'haiku': 'anthropic.claude-3-haiku-20240307-v1:0',\n", + " 'sonnet': 'anthropic.claude-3-5-sonnet-20241022-v2:0',\n", + " 'opus': 'anthropic.claude-3-opus-20240229-v1:0'\n", + " }\n", + " \n", + " def get_model_id(self, model_name: str) -> str:\n", + " \"\"\"Get the full model ID for a given model name\"\"\"\n", + " return self.models.get(model_name.lower(), self.models['sonnet'])\n", + " \n", + " def list_available_models(self) -> List[str]:\n", + " \"\"\"List all available model names\"\"\"\n", + " return list(self.models.keys())\n", + "\n", + "# Initialize model manager\n", + "model_manager = ModelManager()\n", + "print(f\"Available models: {model_manager.list_available_models()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Bedrock Service Class\n", + "\n", + "Now let's create a service class to handle Bedrock API interactions with caching capabilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class BedrockService:\n", + " \"\"\"Service class for Bedrock API interactions with caching\"\"\"\n", + " \n", + " def __init__(self, client):\n", + " self.client = client\n", + " self.cache = {} # Simple in-memory cache\n", + " self.cache_stats = {'hits': 0, 'misses': 0}\n", + " \n", + " def _generate_cache_key(self, model_id: str, prompt: str) -> str:\n", + " \"\"\"Generate a cache key for the prompt\"\"\"\n", + " return f\"{model_id}:{hash(prompt)}\"\n", + " \n", + " def invoke_model_with_cache(self, model_id: str, prompt: str, use_cache: bool = True) -> Dict:\n", + " \"\"\"Invoke model with optional caching\"\"\"\n", + " cache_key = self._generate_cache_key(model_id, prompt)\n", + " \n", + " # Check cache first\n", + " if use_cache and cache_key in self.cache:\n", + " self.cache_stats['hits'] += 1\n", + " print(\"🎯 Cache HIT - Using cached response\")\n", + " return self.cache[cache_key]\n", + " \n", + " # Cache miss - make API call\n", + " self.cache_stats['misses'] += 1\n", + " print(\"🔄 Cache MISS - Making API call\")\n", + " \n", + " start_time = time.time()\n", + " \n", + " # Prepare request body\n", + " body = {\n", + " \"anthropic_version\": \"bedrock-2023-05-31\",\n", + " \"max_tokens\": 1000,\n", + " \"messages\": [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": prompt\n", + " }\n", + " ]\n", + " }\n", + " \n", + " # Make API call\n", + " response = self.client.invoke_model(\n", + " modelId=model_id,\n", + " body=json.dumps(body)\n", + " )\n", + " \n", + " # Parse response\n", + " response_body = json.loads(response['body'].read())\n", + " \n", + " # Add timing information\n", + " response_body['latency_ms'] = round((time.time() - start_time) * 1000, 2)\n", + " response_body['timestamp'] = datetime.now().isoformat()\n", + " \n", + " # Cache the response\n", + " if use_cache:\n", + " self.cache[cache_key] = response_body\n", + " \n", + " return response_body\n", + " \n", + " def get_cache_stats(self) -> Dict:\n", + " \"\"\"Get cache performance statistics\"\"\"\n", + " total = self.cache_stats['hits'] + self.cache_stats['misses']\n", + " hit_rate = (self.cache_stats['hits'] / total * 100) if total > 0 else 0\n", + " \n", + " return {\n", + " 'cache_hits': self.cache_stats['hits'],\n", + " 'cache_misses': self.cache_stats['misses'],\n", + " 'hit_rate_percent': round(hit_rate, 2),\n", + " 'cached_items': len(self.cache)\n", + " }\n", + "\n", + "# Initialize Bedrock service\n", + "bedrock_service = BedrockService(bedrock_client)\n", + "print(\"✅ Bedrock service initialized with caching capabilities\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prompt Router Class\n", + "\n", + "Let's create a prompt router that intelligently selects the best model based on query characteristics." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class PromptRouter:\n", + " \"\"\"Intelligent prompt routing based on query characteristics\"\"\"\n", + " \n", + " def __init__(self, model_manager: ModelManager):\n", + " self.model_manager = model_manager\n", + " self.routing_stats = {}\n", + " \n", + " def analyze_query_complexity(self, prompt: str) -> str:\n", + " \"\"\"Analyze prompt complexity and return complexity level\"\"\"\n", + " word_count = len(prompt.split())\n", + " \n", + " # Simple heuristics for complexity\n", + " complex_keywords = ['analyze', 'compare', 'evaluate', 'reasoning', 'complex', 'detailed']\n", + " simple_keywords = ['summarize', 'list', 'what is', 'define', 'simple']\n", + " \n", + " has_complex = any(keyword in prompt.lower() for keyword in complex_keywords)\n", + " has_simple = any(keyword in prompt.lower() for keyword in simple_keywords)\n", + " \n", + " if word_count > 100 or has_complex:\n", + " return 'complex'\n", + " elif word_count < 20 or has_simple:\n", + " return 'simple'\n", + " else:\n", + " return 'medium'\n", + " \n", + " def route_prompt(self, prompt: str, priority: str = 'balanced') -> Tuple[str, str]:\n", + " \"\"\"Route prompt to optimal model based on complexity and priority\"\"\"\n", + " complexity = self.analyze_query_complexity(prompt)\n", + " \n", + " # Routing logic\n", + " if priority == 'cost':\n", + " model_name = 'haiku' # Always use cheapest\n", + " elif priority == 'performance':\n", + " model_name = 'opus' # Always use most capable\n", + " else: # balanced\n", + " if complexity == 'simple':\n", + " model_name = 'haiku'\n", + " elif complexity == 'complex':\n", + " model_name = 'opus'\n", + " else:\n", + " model_name = 'sonnet'\n", + " \n", + " # Track routing decisions\n", + " if model_name not in self.routing_stats:\n", + " self.routing_stats[model_name] = 0\n", + " self.routing_stats[model_name] += 1\n", + " \n", + " model_id = self.model_manager.get_model_id(model_name)\n", + " \n", + " print(f\"🎯 Routing Decision: {complexity} complexity → {model_name} model\")\n", + " \n", + " return model_id, model_name\n", + " \n", + " def get_routing_stats(self) -> Dict:\n", + " \"\"\"Get routing statistics\"\"\"\n", + " return self.routing_stats.copy()\n", + "\n", + "# Initialize prompt router\n", + "prompt_router = PromptRouter(model_manager)\n", + "print(\"✅ Prompt router initialized\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Demo 1: Prompt Caching in Action\n", + "\n", + "Let's demonstrate how prompt caching works by making repeated requests." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sample document for caching demo\n", + "sample_document = \"\"\"\n", + "Amazon Web Services (AWS) is a comprehensive cloud computing platform provided by Amazon. \n", + "It offers over 200 fully featured services from data centers globally. AWS serves millions \n", + "of customers including startups, large enterprises, and government agencies. Key services \n", + "include compute power, database storage, content delivery, and machine learning capabilities.\n", + "\"\"\"\n", + "\n", + "# First request - will be cached\n", + "prompt1 = f\"Based on this document: {sample_document}\\n\\nQuestion: What is AWS?\"\n", + "model_id = model_manager.get_model_id('sonnet')\n", + "\n", + "print(\"=== First Request (Cache Miss Expected) ===\")\n", + "response1 = bedrock_service.invoke_model_with_cache(model_id, prompt1, use_cache=True)\n", + "print(f\"Response: {response1['content'][0]['text'][:100]}...\")\n", + "print(f\"Latency: {response1['latency_ms']}ms\")\n", + "print()\n", + "\n", + "# Second identical request - should hit cache\n", + "print(\"=== Second Identical Request (Cache Hit Expected) ===\")\n", + "response2 = bedrock_service.invoke_model_with_cache(model_id, prompt1, use_cache=True)\n", + "print(f\"Response: {response2['content'][0]['text'][:100]}...\")\n", + "print(f\"Latency: {response2['latency_ms']}ms\")\n", + "print()\n", + "\n", + "# Display cache statistics\n", + "cache_stats = bedrock_service.get_cache_stats()\n", + "print(\"=== Cache Performance ===\")\n", + "for key, value in cache_stats.items():\n", + " print(f\"{key}: {value}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Demo 2: Intelligent Prompt Routing\n", + "\n", + "Now let's see how the prompt router selects different models based on query complexity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test different types of queries\n", + "test_queries = [\n", + " \"What is machine learning?\", # Simple\n", + " \"Explain the differences between supervised and unsupervised learning algorithms, including their use cases and performance characteristics.\", # Complex\n", + " \"List the main AWS compute services.\", # Simple\n", + " \"Analyze the trade-offs between microservices and monolithic architectures in cloud-native applications.\" # Complex\n", + "]\n", + "\n", + "print(\"=== Prompt Routing Demonstration ===\")\n", + "results = []\n", + "\n", + "for i, query in enumerate(test_queries, 1):\n", + " print(f\"\\n--- Query {i} ---\")\n", + " print(f\"Query: {query[:60]}...\")\n", + " \n", + " # Route the prompt\n", + " model_id, model_name = prompt_router.route_prompt(query, priority='balanced')\n", + " \n", + " # Make the request (without caching for this demo)\n", + " response = bedrock_service.invoke_model_with_cache(model_id, query, use_cache=False)\n", + " \n", + " results.append({\n", + " 'query': query[:50] + '...',\n", + " 'model': model_name,\n", + " 'latency_ms': response['latency_ms'],\n", + " 'response_length': len(response['content'][0]['text'])\n", + " })\n", + " \n", + " print(f\"Selected Model: {model_name}\")\n", + " print(f\"Latency: {response['latency_ms']}ms\")\n", + " print(f\"Response: {response['content'][0]['text'][:100]}...\")\n", + "\n", + "# Display routing statistics\n", + "print(\"\\n=== Routing Statistics ===\")\n", + "routing_stats = prompt_router.get_routing_stats()\n", + "for model, count in routing_stats.items():\n", + " print(f\"{model}: {count} requests\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Demo 3: Performance Comparison\n", + "\n", + "Let's compare performance with and without caching, and across different routing strategies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Performance comparison\n", + "test_prompt = \"Explain the benefits of cloud computing for small businesses.\"\n", + "model_id = model_manager.get_model_id('sonnet')\n", + "\n", + "# Test without caching\n", + "print(\"=== Performance Comparison ===\")\n", + "print(\"\\n1. Without Caching:\")\n", + "times_no_cache = []\n", + "for i in range(3):\n", + " response = bedrock_service.invoke_model_with_cache(model_id, test_prompt, use_cache=False)\n", + " times_no_cache.append(response['latency_ms'])\n", + " print(f\" Request {i+1}: {response['latency_ms']}ms\")\n", + "\n", + "# Test with caching\n", + "print(\"\\n2. With Caching:\")\n", + "times_with_cache = []\n", + "for i in range(3):\n", + " response = bedrock_service.invoke_model_with_cache(model_id, test_prompt, use_cache=True)\n", + " times_with_cache.append(response['latency_ms'])\n", + " print(f\" Request {i+1}: {response['latency_ms']}ms\")\n", + "\n", + "# Calculate savings\n", + "avg_no_cache = sum(times_no_cache) / len(times_no_cache)\n", + "avg_with_cache = sum(times_with_cache) / len(times_with_cache)\n", + "savings_percent = ((avg_no_cache - avg_with_cache) / avg_no_cache) * 100\n", + "\n", + "print(f\"\\n=== Performance Summary ===\")\n", + "print(f\"Average latency without cache: {avg_no_cache:.2f}ms\")\n", + "print(f\"Average latency with cache: {avg_with_cache:.2f}ms\")\n", + "print(f\"Performance improvement: {savings_percent:.1f}%\")\n", + "\n", + "# Final cache statistics\n", + "final_stats = bedrock_service.get_cache_stats()\n", + "print(f\"\\n=== Final Cache Statistics ===\")\n", + "for key, value in final_stats.items():\n", + " print(f\"{key}: {value}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Other Considerations or Advanced section or Best Practices\n", + "\n", + "### Best Practices for Prompt Caching\n", + "\n", + "1. **Cache Key Design**: Use meaningful cache keys that include model version and relevant parameters\n", + "2. **TTL Management**: Implement time-to-live (TTL) for cache entries to ensure freshness\n", + "3. **Memory Management**: Monitor cache size and implement eviction policies for production use\n", + "4. **Cache Warming**: Pre-populate cache with frequently used prompts\n", + "\n", + "### Best Practices for Prompt Routing\n", + "\n", + "1. **Complexity Analysis**: Develop sophisticated heuristics for query complexity\n", + "2. **Cost Monitoring**: Track costs across different models to optimize routing decisions\n", + "3. **Performance Metrics**: Monitor latency and quality metrics for each model\n", + "4. **Fallback Strategies**: Implement fallback models for high availability\n", + "\n", + "### Production Considerations\n", + "\n", + "- **Persistent Caching**: Use Redis or DynamoDB for distributed caching\n", + "- **Monitoring**: Implement comprehensive logging and metrics\n", + "- **Security**: Ensure sensitive data is not cached inappropriately\n", + "- **Rate Limiting**: Implement proper rate limiting for API calls\n", + "- **Error Handling**: Add robust error handling and retry logic" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example of production-ready cache with TTL\n", + "import time\n", + "from typing import Optional\n", + "\n", + "class ProductionCache:\n", + " \"\"\"Production-ready cache with TTL and size limits\"\"\"\n", + " \n", + " def __init__(self, max_size: int = 1000, default_ttl: int = 3600):\n", + " self.cache = {}\n", + " self.timestamps = {}\n", + " self.max_size = max_size\n", + " self.default_ttl = default_ttl\n", + " \n", + " def get(self, key: str) -> Optional[dict]:\n", + " \"\"\"Get item from cache if not expired\"\"\"\n", + " if key not in self.cache:\n", + " return None\n", + " \n", + " # Check if expired\n", + " if time.time() - self.timestamps[key] > self.default_ttl:\n", + " del self.cache[key]\n", + " del self.timestamps[key]\n", + " return None\n", + " \n", + " return self.cache[key]\n", + " \n", + " def set(self, key: str, value: dict) -> None:\n", + " \"\"\"Set item in cache with size management\"\"\"\n", + " # Evict oldest items if at capacity\n", + " if len(self.cache) >= self.max_size:\n", + " oldest_key = min(self.timestamps.keys(), key=lambda k: self.timestamps[k])\n", + " del self.cache[oldest_key]\n", + " del self.timestamps[oldest_key]\n", + " \n", + " self.cache[key] = value\n", + " self.timestamps[key] = time.time()\n", + "\n", + "# Example usage\n", + "prod_cache = ProductionCache(max_size=100, default_ttl=1800) # 30 minutes TTL\n", + "print(\"✅ Production cache example created\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next Steps\n", + "\n", + "Now that you've learned the basics of prompt caching and routing, here are some next steps to explore:\n", + "\n", + "### 1. Advanced Routing Strategies\n", + "- Implement machine learning-based routing decisions\n", + "- Add user preference learning\n", + "- Develop cost-aware routing algorithms\n", + "\n", + "### 2. Integration Patterns\n", + "- Build a REST API wrapper around these capabilities\n", + "- Create a Streamlit web application for interactive use\n", + "- Integrate with existing applications and workflows\n", + "\n", + "### 3. Monitoring and Analytics\n", + "- Set up CloudWatch metrics for cache performance\n", + "- Implement cost tracking across different models\n", + "- Create dashboards for routing decision analysis\n", + "\n", + "### 4. Scale and Production\n", + "- Deploy using AWS Lambda for serverless scaling\n", + "- Implement distributed caching with ElastiCache\n", + "- Add comprehensive error handling and logging\n", + "\n", + "### 5. Explore Additional Features\n", + "- Multi-modal prompt routing (text, images, documents)\n", + "- Streaming responses with caching\n", + "- Custom model fine-tuning integration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleanup\n", + "\n", + "Let's clean up any resources and display final statistics." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Display final statistics\n", + "print(\"=== Workshop Summary ===\")\n", + "print(f\"Total API calls made: {bedrock_service.cache_stats['hits'] + bedrock_service.cache_stats['misses']}\")\n", + "print(f\"Cache hits: {bedrock_service.cache_stats['hits']}\")\n", + "print(f\"Cache misses: {bedrock_service.cache_stats['misses']}\")\n", + "print(f\"Cache hit rate: {bedrock_service.get_cache_stats()['hit_rate_percent']}%\")\n", + "print(f\"Items in cache: {len(bedrock_service.cache)}\")\n", + "\n", + "print(\"\\n=== Model Usage ===\")\n", + "routing_stats = prompt_router.get_routing_stats()\n", + "for model, count in routing_stats.items():\n", + " print(f\"{model}: {count} requests\")\n", + "\n", + "# Clear cache to free memory\n", + "bedrock_service.cache.clear()\n", + "bedrock_service.cache_stats = {'hits': 0, 'misses': 0}\n", + "\n", + "print(\"\\n✅ Cleanup complete! Cache cleared and statistics reset.\")\n", + "print(\"\\n🎉 Workshop completed successfully!\")\n", + "print(\"\\nYou've learned how to:\")\n", + "print(\"- Implement prompt caching for cost and latency optimization\")\n", + "print(\"- Use intelligent prompt routing for model selection\")\n", + "print(\"- Monitor performance and usage statistics\")\n", + "print(\"- Apply best practices for production deployments\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/cloudformation-template.yaml b/BedrockPromptCachingRoutingDemo/cloudformation-template.yaml new file mode 100644 index 000000000..a7b4636ba --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/cloudformation-template.yaml @@ -0,0 +1,143 @@ +AWSTemplateFormatVersion: '2010-09-09' +Description: 'CloudFormation template for Bedrock Prompt Caching Demo' + +Parameters: + InstanceType: + Type: String + Default: t3.medium + Description: EC2 instance type for running the demo + + KeyPairName: + Type: AWS::EC2::KeyPair::KeyName + Description: EC2 Key Pair for SSH access + +Resources: + # IAM Role for EC2 instance + BedrockDemoRole: + Type: AWS::IAM::Role + Properties: + AssumeRolePolicyDocument: + Version: '2012-10-17' + Statement: + - Effect: Allow + Principal: + Service: ec2.amazonaws.com + Action: sts:AssumeRole + ManagedPolicyArns: + - arn:aws:iam::aws:policy/AmazonBedrockFullAccess + Policies: + - PolicyName: BedrockPromptCachingPolicy + PolicyDocument: + Version: '2012-10-17' + Statement: + - Effect: Allow + Action: + - bedrock:InvokeModel + - bedrock:InvokeModelWithResponseStream + - bedrock:ListFoundationModels + - bedrock:GetFoundationModel + - bedrock:ListPromptRouters + - bedrock:GetPromptRouter + Resource: '*' + + # Instance Profile + BedrockDemoInstanceProfile: + Type: AWS::IAM::InstanceProfile + Properties: + Roles: + - !Ref BedrockDemoRole + + # Security Group + BedrockDemoSecurityGroup: + Type: AWS::EC2::SecurityGroup + Properties: + GroupDescription: Security group for Bedrock demo instance + SecurityGroupIngress: + - IpProtocol: tcp + FromPort: 22 + ToPort: 22 + CidrIp: 0.0.0.0/0 + - IpProtocol: tcp + FromPort: 8501 + ToPort: 8501 + CidrIp: 0.0.0.0/0 + + # EC2 Instance + BedrockDemoInstance: + Type: AWS::EC2::Instance + Properties: + ImageId: ami-0c02fb55956c7d316 # Amazon Linux 2023 + InstanceType: !Ref InstanceType + KeyName: !Ref KeyPairName + IamInstanceProfile: !Ref BedrockDemoInstanceProfile + SecurityGroupIds: + - !Ref BedrockDemoSecurityGroup + UserData: + Fn::Base64: !Sub | + #!/bin/bash + yum update -y + yum install -y python3 python3-pip git + + # Install dependencies + pip3 install boto3 requests pandas matplotlib seaborn numpy PyPDF2 python-docx streamlit + + # Copy code files directly + cd /home/ec2-user + mkdir -p bedrock-prompt-demo/src + + # Download files from your local system or S3 + # For now, create placeholder files + cat > bedrock-prompt-demo/src/requirements.txt << 'EOF' +boto3 +requests +pandas +matplotlib +seaborn +numpy +PyPDF2 +python-docx +streamlit +EOF + + chown -R ec2-user:ec2-user bedrock-prompt-demo + + # Set environment variables + echo 'export AWS_REGION=${AWS::Region}' >> /home/ec2-user/.bashrc + + # Create systemd service for Streamlit + cat > /etc/systemd/system/bedrock-demo.service << EOF + [Unit] + Description=Bedrock Prompt Demo + After=network.target + + [Service] + Type=simple + User=ec2-user + WorkingDirectory=/home/ec2-user/bedrock-prompt-demo/src + Environment=AWS_REGION=${AWS::Region} + ExecStart=/usr/bin/python3 -m streamlit run prompt_router_app.py --server.port=8501 --server.address=0.0.0.0 + Restart=always + + [Install] + WantedBy=multi-user.target + EOF + + systemctl enable bedrock-demo + systemctl start bedrock-demo + +Outputs: + InstanceId: + Description: EC2 Instance ID + Value: !Ref BedrockDemoInstance + + PublicIP: + Description: Public IP address + Value: !GetAtt BedrockDemoInstance.PublicIp + + StreamlitURL: + Description: Streamlit application URL + Value: !Sub 'http://${BedrockDemoInstance.PublicIp}:8501' + + SSHCommand: + Description: SSH command to connect + Value: !Sub 'ssh -i ${KeyPairName}.pem ec2-user@${BedrockDemoInstance.PublicIp}' \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/requirements.txt b/BedrockPromptCachingRoutingDemo/requirements.txt new file mode 100644 index 000000000..34bc59513 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/requirements.txt @@ -0,0 +1,10 @@ +boto3>=1.34.0 +streamlit>=1.28.0 +pandas>=1.5.0 +numpy>=1.24.0 +matplotlib>=3.6.0 +seaborn>=0.12.0 +requests>=2.28.0 +PyPDF2>=3.0.0 +python-docx>=0.8.11 +openpyxl>=3.1.0 \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/RomeoAndJuliet.txt b/BedrockPromptCachingRoutingDemo/src/RomeoAndJuliet.txt new file mode 100644 index 000000000..341745335 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/RomeoAndJuliet.txt @@ -0,0 +1,272 @@ +The Project Gutenberg eBook of Romeo and Juliet + +This ebook is for the use of anyone anywhere in the United States and +most other parts of the world at no cost and with almost no restrictions +whatsoever. You may copy it, give it away or re-use it under the terms +of the Project Gutenberg License included with this ebook or online +at www.gutenberg.org. If you are not located in the United States, +you will have to check the laws of the country where you are located +before using this eBook. + +Title: Romeo and Juliet + +Author: William Shakespeare + +Release date: November 1, 1998 [eBook #1513] + Most recently updated: June 19, 2024 + +Language: English + +Credits: the PG Shakespeare Team, a team of about twenty Project Gutenberg volunteers + + +*** START OF THE PROJECT GUTENBERG EBOOK ROMEO AND JULIET *** + + + + +THE TRAGEDY OF ROMEO AND JULIET + +by William Shakespeare + + + + +Contents + +THE PROLOGUE. + +ACT I +Scene I. A public place. +Scene II. A Street. +Scene III. Room in Capulet’s House. +Scene IV. A Street. +Scene V. A Hall in Capulet’s House. + +ACT II +CHORUS. +Scene I. An open place adjoining Capulet’s Garden. +Scene II. Capulet’s Garden. +Scene III. Friar Lawrence’s Cell. +Scene IV. A Street. +Scene V. Capulet’s Garden. +Scene VI. Friar Lawrence’s Cell. + +ACT III +Scene I. A public Place. +Scene II. A Room in Capulet’s House. +Scene III. Friar Lawrence’s cell. +Scene IV. A Room in Capulet’s House. +Scene V. An open Gallery to Juliet’s Chamber, overlooking the Garden. + +ACT IV +Scene I. Friar Lawrence’s Cell. +Scene II. Hall in Capulet’s House. +Scene III. Juliet’s Chamber. +Scene IV. Hall in Capulet’s House. +Scene V. Juliet’s Chamber; Juliet on the bed. + +ACT V +Scene I. Mantua. A Street. +Scene II. Friar Lawrence’s Cell. +Scene III. A churchyard; in it a Monument belonging to the Capulets. + + + + + Dramatis Personæ + +ESCALUS, Prince of Verona. +MERCUTIO, kinsman to the Prince, and friend to Romeo. +PARIS, a young Nobleman, kinsman to the Prince. +Page to Paris. + +MONTAGUE, head of a Veronese family at feud with the Capulets. +LADY MONTAGUE, wife to Montague. +ROMEO, son to Montague. +BENVOLIO, nephew to Montague, and friend to Romeo. +ABRAM, servant to Montague. +BALTHASAR, servant to Romeo. + +CAPULET, head of a Veronese family at feud with the Montagues. +LADY CAPULET, wife to Capulet. +JULIET, daughter to Capulet. +TYBALT, nephew to Lady Capulet. +CAPULET’S COUSIN, an old man. +NURSE to Juliet. +PETER, servant to Juliet’s Nurse. +SAMPSON, servant to Capulet. +GREGORY, servant to Capulet. +Servants. + +FRIAR LAWRENCE, a Franciscan. +FRIAR JOHN, of the same Order. +An Apothecary. +CHORUS. +Three Musicians. +An Officer. +Citizens of Verona; several Men and Women, relations to both houses; +Maskers, Guards, Watchmen and Attendants. + +SCENE. During the greater part of the Play in Verona; once, in the +Fifth Act, at Mantua. + + + + +THE PROLOGUE + + + Enter Chorus. + +CHORUS. +Two households, both alike in dignity, +In fair Verona, where we lay our scene, +From ancient grudge break to new mutiny, +Where civil blood makes civil hands unclean. +From forth the fatal loins of these two foes +A pair of star-cross’d lovers take their life; +Whose misadventur’d piteous overthrows +Doth with their death bury their parents’ strife. +The fearful passage of their death-mark’d love, +And the continuance of their parents’ rage, +Which, but their children’s end, nought could remove, +Is now the two hours’ traffic of our stage; +The which, if you with patient ears attend, +What here shall miss, our toil shall strive to mend. + + [_Exit._] + + + + +ACT I + +SCENE I. A public place. + + + Enter Sampson and Gregory armed with swords and bucklers. + +SAMPSON. +Gregory, on my word, we’ll not carry coals. + +GREGORY. +No, for then we should be colliers. + +SAMPSON. +I mean, if we be in choler, we’ll draw. + +GREGORY. +Ay, while you live, draw your neck out o’ the collar. + +SAMPSON. +I strike quickly, being moved. + +GREGORY. +But thou art not quickly moved to strike. + +SAMPSON. +A dog of the house of Montague moves me. + +GREGORY. +To move is to stir; and to be valiant is to stand: therefore, if thou +art moved, thou runn’st away. + +SAMPSON. +A dog of that house shall move me to stand. +I will take the wall of any man or maid of Montague’s. + +GREGORY. +That shows thee a weak slave, for the weakest goes to the wall. + +SAMPSON. +True, and therefore women, being the weaker vessels, are ever thrust to +the wall: therefore I will push Montague’s men from the wall, and +thrust his maids to the wall. + +GREGORY. +The quarrel is between our masters and us their men. + +SAMPSON. +’Tis all one, I will show myself a tyrant: when I have fought with the +men I will be civil with the maids, I will cut off their heads. + +GREGORY. +The heads of the maids? + +SAMPSON. +Ay, the heads of the maids, or their maidenheads; take it in what sense +thou wilt. + +GREGORY. +They must take it in sense that feel it. + +SAMPSON. +Me they shall feel while I am able to stand: and ’tis known I am a +pretty piece of flesh. + +GREGORY. +’Tis well thou art not fish; if thou hadst, thou hadst been poor John. +Draw thy tool; here comes of the house of Montagues. + + Enter Abram and Balthasar. + +SAMPSON. +My naked weapon is out: quarrel, I will back thee. + +GREGORY. +How? Turn thy back and run? + +SAMPSON. +Fear me not. + +GREGORY. +No, marry; I fear thee! + +SAMPSON. +Let us take the law of our sides; let them begin. + +GREGORY. +I will frown as I pass by, and let them take it as they list. + +SAMPSON. +Nay, as they dare. I will bite my thumb at them, which is disgrace to +them if they bear it. + +ABRAM. +Do you bite your thumb at us, sir? + +SAMPSON. +I do bite my thumb, sir. + +ABRAM. +Do you bite your thumb at us, sir? + +SAMPSON. +Is the law of our side if I say ay? + +GREGORY. +No. + +SAMPSON. +No sir, I do not bite my thumb at you, sir; but I bite my thumb, sir. + +GREGORY. +Do you quarrel, sir? + +ABRAM. +Quarrel, sir? No, sir. + +SAMPSON. +But if you do, sir, I am for you. I serve as good a man as you. + +ABRAM. +No better. + +SAMPSON. +Well, sir. + + Enter Benvolio. + +GREGORY. +Say better; here comes one of my master’s kinsmen. \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/__init__.py b/BedrockPromptCachingRoutingDemo/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/BedrockPromptCachingRoutingDemo/src/bedrock_claude_code.py b/BedrockPromptCachingRoutingDemo/src/bedrock_claude_code.py new file mode 100644 index 000000000..ed9813769 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/bedrock_claude_code.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +import os +import subprocess +import sys + +class ClaudeSetup: + """Class to handle Claude Code setup and execution.""" + + def install_claude_code(self): + """Install Claude Code using npm.""" + print("Installing Claude Code...") + result = subprocess.run(["npm", "install", "-g", "@anthropic-ai/claude-code"], + capture_output=True, text=True) + if result.returncode != 0: + print(f"Error installing Claude Code: {result.stderr}") + sys.exit(1) + print("Claude Code installed successfully.") + + def configure_environment(self, model="sonnet"): + """Configure environment variables for Bedrock with Claude models. + + Args: + model: Either "sonnet" for Claude 3.7 Sonnet or "haiku" for Claude 3.5 Haiku + """ + print("Configuring environment variables...") + os.environ["CLAUDE_CODE_USE_BEDROCK"] = "1" + + if model == "haiku": + os.environ["ANTHROPIC_MODEL"] = "us.anthropic.claude-3-5-haiku-20241022-v1:0" + os.environ["ANTHROPIC_SMALL_FAST_MODEL"] = "us.anthropic.claude-3-5-haiku-20241022-v1:0" + print(f"Using Claude 3.5 Haiku model: {os.environ['ANTHROPIC_MODEL']}") + else: + os.environ["ANTHROPIC_MODEL"] = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + print(f"Using Claude 3.7 Sonnet model: {os.environ['ANTHROPIC_MODEL']}") + + os.environ["DISABLE_PROMPT_CACHING"] = "1" + print("To disable caching: export DISABLE_PROMPT_CACHING=true") + print("To enable caching (if you have access): unset DISABLE_PROMPT_CACHING") + + def launch_claude(self): + """Launch Claude Code with the configured environment.""" + print("Launching Claude Code...") + subprocess.run(["claude"], env=os.environ) + + def run_setup(self): + """Run the complete setup process.""" + self.install_claude_code() + self.configure_environment() + self.launch_claude() + +def simulate_user_chat(): + """Simulate a user chat session that executes the setup with a friendly interface.""" + print("\n" + "="*60) + print("🤖 Welcome to Claude Code Setup Assistant 🤖".center(60)) + print("="*60) + + print("\n👋 Hi there! I'll help you set up Claude Code with AWS Bedrock.") + print(" This assistant will guide you through the installation and configuration process.") + + # Ask if user wants to proceed with installation + print("\n📦 First, we need to install Claude Code.") + proceed = input(" Would you like to proceed with installation? (y/n, default: y): ").lower() or "y" + + if proceed != "y": + print("\n❌ Setup cancelled. You can run this script again when you're ready.") + return + + setup = ClaudeSetup() + setup.install_claude_code() + + # Ask user about model selection with more details + print("\n🧠 Model Selection:") + print(" Claude offers different models with varying capabilities:") + print(" 1. Claude 3.7 Sonnet - More powerful, better for complex tasks") + print(" 2. Claude 3.5 Haiku - Faster, great for simpler tasks") + + model_choice = input("\n Which model would you prefer? (1/2, default: 1): ") or "1" + + if model_choice == "2": + model = "haiku" + model_name = "Claude 3.5 Haiku" + else: + model = "sonnet" + model_name = "Claude 3.7 Sonnet" + + print(f"\n✅ Selected {model_name}") + setup.configure_environment(model) + + # Ask user about caching preferences with explanation + print("\n🔄 Prompt Caching:") + print(" Caching can improve performance by storing previous responses.") + print(" Note: This feature requires special access.") + print(" 1. Enable caching (if you have access)") + print(" 2. Disable caching") + + cache_choice = input("\n What's your preference? (1/2, default: 2): ") or "2" + + if cache_choice == "1": + if "DISABLE_PROMPT_CACHING" in os.environ: + del os.environ["DISABLE_PROMPT_CACHING"] + print("\n✅ Prompt caching has been enabled.") + else: + os.environ["DISABLE_PROMPT_CACHING"] = "true" + print("\n✅ Prompt caching has been disabled.") + + # Final confirmation before launch + print("\n🚀 Ready to launch Claude Code!") + launch = input(" Would you like to launch it now? (y/n, default: y): ").lower() or "y" + + if launch == "y": + print("\n🔄 Launching Claude Code...") + setup.launch_claude() + print(f"\n✨ Success! Claude Code is now running with {model_name} on AWS Bedrock.") + else: + print("\n👍 Setup complete! You can launch Claude Code later with the 'claude' command.") + + print("\n💡 Tip: You can reconfigure these settings by running this script again.") + print("\n" + "="*60) + print("Thank you for using Claude Code Setup Assistant!".center(60)) + print("="*60 + "\n") + +if __name__ == "__main__": + simulate_user_chat() \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/bedrock_context_retrieval_app.py b/BedrockPromptCachingRoutingDemo/src/bedrock_context_retrieval_app.py new file mode 100644 index 000000000..ae70cacad --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/bedrock_context_retrieval_app.py @@ -0,0 +1,405 @@ +import gradio as gr +import os +import json +import boto3 +import time +from pathlib import Path +from typing import Dict, Any, List, Tuple + +# Import the necessary classes from the module +from bedrock_ctxt_retrieval import BedrockKnowledgeBaseManager, ResponseFormatter, ChunkingStrategySelector + +# Global variables to track state +kb_manager = None +formatter = None +kb_ids = {} +kb_names = {} + +def initialize_system() -> str: + """Initialize the BedrockKnowledgeBaseManager with default settings.""" + global kb_manager, formatter + try: + kb_manager = BedrockKnowledgeBaseManager() + formatter = ResponseFormatter() + return f"System initialized successfully in region: {kb_manager.region}" + except Exception as e: + return f"Error initializing system: {str(e)}" + +def setup_knowledge_bases(chunking_strategy: str) -> str: + """Set up knowledge bases based on selected chunking strategy.""" + global kb_manager, kb_ids, kb_names + + if not kb_manager: + return "Please initialize the system first." + + try: + # Check if lambda function file exists and create a copy with the expected name + lambda_source_file = "lambda_custom_chunking_function.py" + lambda_target_file = "lambda_function.py" + + # Create a copy of the lambda file with the expected name if CUSTOM or BOTH is selected + if chunking_strategy in ["CUSTOM", "BOTH"]: + if os.path.exists(lambda_source_file): + with open(lambda_source_file, 'r') as source: + content = source.read() + with open(lambda_target_file, 'w') as target: + target.write(content) + else: + # Create a simple lambda function if the source doesn't exist + with open(lambda_target_file, 'w') as f: + f.write(""" +import json + +def lambda_handler(event, context): + # Simple custom chunking function + chunks = [] + text = event.get('text', '') + + # Split by paragraphs and create chunks + paragraphs = text.split('\\n\\n') + for i, para in enumerate(paragraphs): + if para.strip(): + chunks.append({ + 'text': para.strip(), + 'metadata': {'chunk_id': i, 'source': 'custom_chunking'} + }) + + return { + 'statusCode': 200, + 'chunks': chunks + } +""") + + # Base names for resources + kb_base_name = 'kb' + kb_description = "Knowledge Base containing complex PDF." + results = [] + + # Create standard knowledge base if selected + if chunking_strategy in ["FIXED_SIZE", "BOTH"]: + kb_name_standard = f"standard-{kb_base_name}" + kb_id_standard = kb_manager.create_knowledge_base( + kb_name=kb_name_standard, + kb_description=kb_description, + chunking_strategy="FIXED_SIZE" + ) + kb_ids["standard"] = kb_id_standard + kb_names["standard"] = kb_name_standard + results.append(f"Standard KB ID: {kb_id_standard}") + + # Create custom chunking knowledge base if selected + if chunking_strategy in ["CUSTOM", "BOTH"]: + kb_name_custom = f"custom-{kb_base_name}" + intermediate_bucket_name = f"{kb_name_custom}-intermediate-{kb_manager.timestamp_suffix}" + lambda_function_name = f"{kb_name_custom}-lambda-{kb_manager.timestamp_suffix}" + + kb_id_custom = kb_manager.create_knowledge_base( + kb_name=kb_name_custom, + kb_description=kb_description, + chunking_strategy="CUSTOM", + lambda_function_name=lambda_function_name, + intermediate_bucket_name=intermediate_bucket_name + ) + kb_ids["custom"] = kb_id_custom + kb_names["custom"] = kb_name_custom + results.append(f"Custom KB ID: {kb_id_custom}") + + return f"Knowledge bases created successfully.\n" + "\n".join(results) + except Exception as e: + return f"Error creating knowledge bases: {str(e)}" + +def upload_directory(directory_path: str) -> str: + """Upload files from a directory to knowledge base buckets.""" + global kb_manager, kb_names + + if not kb_manager: + return "Please initialize the system first." + + if not kb_names: + return "Please set up knowledge bases first." + + try: + results = [] + for kb_type, kb_name in kb_names.items(): + bucket_name = f'{kb_name}-{kb_manager.timestamp_suffix}' + kb_manager.upload_directory_to_s3(directory_path, bucket_name) + results.append(f"Files uploaded to {kb_type} bucket: {bucket_name}") + + return "\n".join(results) + except Exception as e: + return f"Error uploading files: {str(e)}" + +def start_ingestion_jobs() -> str: + """Start ingestion jobs for all knowledge bases.""" + global kb_manager, kb_names + + if not kb_manager: + return "Please initialize the system first." + + if not kb_names: + return "Please set up knowledge bases first." + + try: + results = [] + for kb_type, kb_name in kb_names.items(): + kb_manager.start_ingestion_job(kb_name) + results.append(f"Ingestion job started for {kb_type} knowledge base: {kb_name}") + + return "\n".join(results) + except Exception as e: + return f"Error starting ingestion jobs: {str(e)}" + +def query_knowledge_base(kb_type: str, query_text: str, num_results: int) -> Tuple[str, str]: + """Query the knowledge base using retrieve and generate.""" + global kb_manager, kb_ids + + if not kb_manager: + return "Please initialize the system first.", "" + + if not kb_ids or kb_type not in kb_ids: + return f"Knowledge base {kb_type} not found.", "" + + try: + # Get the KB ID + kb_id = kb_ids[kb_type] + + # Perform retrieve and generate + response = kb_manager.retrieve_and_generate(kb_id, query_text, num_results) + + # Format the response + answer = response['output']['text'] + + # Format citations + citations = "" + if 'citations' in response and response['citations']: + response_refs = response['citations'][0]['retrievedReferences'] + citations = f"Citations ({len(response_refs)}):\n\n" + for num, chunk in enumerate(response_refs, 1): + citations += f"Chunk {num}: {chunk['content']['text']}\n" + citations += f"Location: {chunk['location']}\n" + if 'metadata' in chunk: + citations += f"Metadata: {chunk['metadata']}\n" + citations += "\n" + + return answer, citations + except Exception as e: + return f"Error querying knowledge base: {str(e)}", "" + +def retrieve_only(kb_type: str, query_text: str, num_results: int) -> str: + """Perform a retrieve operation using the knowledge base.""" + global kb_manager, kb_ids + + if not kb_manager: + return "Please initialize the system first." + + if not kb_ids or kb_type not in kb_ids: + return f"Knowledge base {kb_type} not found." + + try: + # Get the KB ID + kb_id = kb_ids[kb_type] + + # Perform retrieve operation + response = kb_manager.retrieve(kb_id, query_text, num_results) + + # Format the results + results = response.get('retrievalResults', []) + output = f"Retrieved {len(results)} results:\n\n" + + for num, chunk in enumerate(results, 1): + output += f"Chunk {num}: {chunk['content']['text']}\n" + output += f"Location: {chunk['location']}\n" + output += f"Score: {chunk['score']}\n" + if 'metadata' in chunk: + output += f"Metadata: {chunk['metadata']}\n" + output += "\n" + + return output + except Exception as e: + return f"Error retrieving from knowledge base: {str(e)}" + +def run_ragas_evaluation(questions: List[str], ground_truths: List[str]) -> str: + """Run RAGAS evaluation on the knowledge bases.""" + global kb_manager, kb_ids + + if not kb_manager: + return "Please initialize the system first." + + if len(kb_ids) < 2 or "standard" not in kb_ids or "custom" not in kb_ids: + return "Both standard and custom knowledge bases are required for evaluation." + + try: + # Import the RAG evaluator + try: + from rag_evaluator import RAGEvaluator + except ImportError: + return "RAG evaluator module not found. Please ensure it's available in the path." + + # Create a Bedrock runtime client with appropriate configuration + bedrock_runtime_client = boto3.client( + 'bedrock-runtime', + region_name=kb_manager.region, + config=boto3.session.Config( + read_timeout=900, # 15 minutes + connect_timeout=60, + retries={'max_attempts': 3} + ) + ) + + # Initialize the RAG evaluator + evaluator = RAGEvaluator( + bedrock_runtime_client=bedrock_runtime_client, + bedrock_agent_runtime_client=kb_manager.bedrock_agent_runtime_client + ) + + # Compare knowledge base strategies + kb_strategy_map = { + "Default Chunking": kb_ids["standard"], + "Contextual Chunking": kb_ids["custom"] + } + + # Run the evaluation + comparison_df = evaluator.compare_kb_strategies(kb_strategy_map, questions, ground_truths) + + # Format and save the results + styled_df = evaluator.format_comparison(comparison_df) + comparison_df.to_csv("ragas_evaluation_results.csv") + + return f"RAGAS Evaluation completed. Results saved to ragas_evaluation_results.csv\n\n{styled_df.to_string()}" + except Exception as e: + return f"Error running RAGAS evaluation: {str(e)}" + +def delete_all_resources() -> str: + """Delete all knowledge bases and associated resources.""" + global kb_manager, kb_ids, kb_names + + if not kb_manager: + return "Please initialize the system first." + + if not kb_names: + return "No knowledge bases to delete." + + try: + results = [] + for kb_type, kb_name in list(kb_names.items()): + kb_manager.delete_knowledge_base( + kb_name, + delete_s3_bucket=True, + delete_iam_roles_and_policies=True, + delete_lambda_function=(kb_type == "custom") + ) + results.append(f"Knowledge base {kb_name} deleted successfully.") + + # Remove from tracking dictionaries + if kb_name in kb_ids: + del kb_ids[kb_name] + del kb_names[kb_type] + + return "\n".join(results) + except Exception as e: + return f"Error deleting resources: {str(e)}" + +# Create the Gradio interface +with gr.Blocks(title="Bedrock Knowledge Base Manager") as app: + gr.Markdown("# AWS Bedrock Knowledge Base Manager") + + with gr.Tab("Setup"): + with gr.Row(): + init_button = gr.Button("Initialize System") + + with gr.Row(): + chunking_strategy = gr.Radio( + ["FIXED_SIZE", "CUSTOM", "BOTH"], + label="Chunking Strategy", + value="FIXED_SIZE", + info="Select the chunking strategy for your knowledge base(s)" + ) + setup_button = gr.Button("Setup Knowledge Bases") + + with gr.Row(): + dir_path = gr.Textbox(label="Data Directory Path", value="synthetic_dataset") + upload_button = gr.Button("Upload Data") + + setup_output = gr.Textbox(label="Setup Status", lines=5) + + init_button.click(initialize_system, outputs=setup_output) + setup_button.click(setup_knowledge_bases, inputs=[chunking_strategy], outputs=setup_output) + upload_button.click(upload_directory, inputs=[dir_path], outputs=setup_output) + + with gr.Tab("Start Ingestion"): + ingest_button = gr.Button("Start Ingestion Jobs") + ingest_output = gr.Textbox(label="Ingestion Status", lines=3) + ingest_button.click(start_ingestion_jobs, outputs=ingest_output) + + with gr.Tab("Query Knowledge Base"): + available_kb_types = gr.Dropdown( + choices=["standard", "custom"], + label="Knowledge Base Type", + value="standard", + interactive=True + ) + + with gr.Row(): + query_text = gr.Textbox(label="Query") + query_num_results = gr.Slider( + minimum=1, + maximum=10, + value=5, + step=1, + label="Number of Results" + ) + + with gr.Row(): + query_button = gr.Button("Query (Retrieve and Generate)") + retrieve_button = gr.Button("Retrieve Only") + + with gr.Row(): + query_answer = gr.Textbox(label="Answer", lines=10) + query_citations = gr.Textbox(label="Citations", lines=10) + + retrieve_output = gr.Textbox(label="Retrieved Results", lines=15) + + query_button.click( + query_knowledge_base, + inputs=[available_kb_types, query_text, query_num_results], + outputs=[query_answer, query_citations] + ) + + retrieve_button.click( + retrieve_only, + inputs=[available_kb_types, query_text, query_num_results], + outputs=retrieve_output + ) + + with gr.Tab("RAGAS Evaluation"): + gr.Markdown("### Run RAGAS Evaluation on Knowledge Bases") + + with gr.Row(): + eval_questions = gr.Textbox( + label="Evaluation Questions (one per line)", + lines=5, + value="What was the primary reason for the increase in net cash provided by operating activities for Octank Financial in 2021?\nIn which year did Octank Financial have the highest net cash used in investing activities, and what was the primary reason for this?\nWhat was the primary source of cash inflows from financing activities for Octank Financial in 2021?\nBased on the information provided, what can you infer about Octank Financial's overall financial health and growth prospects?" + ) + + eval_ground_truths = gr.Textbox( + label="Ground Truths (one per line)", + lines=5, + value="The increase in net cash provided by operating activities was primarily due to an increase in net income and favorable changes in operating assets and liabilities.\nOctank Financial had the highest net cash used in investing activities in 2021, at $360 million. The primary reason for this was an increase in purchases of property, plant, and equipment and marketable securities\nThe primary source of cash inflows from financing activities for Octank Financial in 2021 was an increase in proceeds from the issuance of common stock and long-term debt.\nBased on the information provided, Octank Financial appears to be in a healthy financial position and has good growth prospects. The company has consistently increased its net cash provided by operating activities, indicating strong profitability and efficient management of working capital. Additionally, Octank Financial has been investing in long-term assets, such as property, plant, and equipment, and marketable securities, which suggests plans for future growth and expansion. The company has also been able to finance its growth through the issuance of common stock and long-term debt, indicating confidence from investors and lenders. Overall, Octank Financial's steady increase in cash and cash equivalents over the past three years provides a strong foundation for future growth and investment opportunities." + ) + + eval_button = gr.Button("Run RAGAS Evaluation") + eval_output = gr.Textbox(label="Evaluation Results", lines=20) + + eval_button.click( + run_ragas_evaluation, + inputs=[eval_questions, eval_ground_truths], + outputs=eval_output + ) + + with gr.Tab("Manage Resources"): + delete_button = gr.Button("Delete All Resources") + delete_output = gr.Textbox(label="Deletion Status", lines=3) + delete_button.click(delete_all_resources, outputs=delete_output) + +if __name__ == "__main__": + app.launch() \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/bedrock_ctxt_retrieval.py b/BedrockPromptCachingRoutingDemo/src/bedrock_ctxt_retrieval.py new file mode 100644 index 000000000..f17d53ba1 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/bedrock_ctxt_retrieval.py @@ -0,0 +1,802 @@ +import os +import sys +import time +import boto3 +import logging +import pprint +import json +import zipfile +import io +from pathlib import Path +from typing import List, Dict, Any, Optional, Union +from botocore.client import Config + +class LambdaManager: + """ + A class to manage AWS Lambda functions for custom chunking strategies. + """ + + def __init__(self, region: Optional[str] = None): + """ + Initialize the LambdaManager with AWS clients. + + Args: + region: AWS region to use. If None, uses the default from session. + """ + self.session = boto3.session.Session() + self.region = region or self.session.region_name + self.lambda_client = boto3.client('lambda', region_name=self.region) + self.iam_client = boto3.client('iam', region_name=self.region) + self.sts_client = boto3.client('sts', region_name=self.region) + + # Configure logging + self.logger = logging.getLogger(__name__) + + def create_or_update_lambda(self, function_name: str, role_arn: str, source_file: str) -> str: + """ + Create or update a Lambda function from a local Python file. + + Args: + function_name: Name of the Lambda function + role_arn: ARN of the IAM role for the Lambda function + source_file: Path to the Lambda function source file + + Returns: + The ARN of the Lambda function + """ + # Read the Lambda function code + try: + with open(source_file, 'rb') as f: + code_content = f.read() + except Exception as e: + self.logger.error(f"Error reading Lambda source file {source_file}: {str(e)}") + raise + + # Create a ZIP file in memory + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: + zip_file.writestr('lambda_function.py', code_content) + + zip_buffer.seek(0) + zip_content = zip_buffer.read() + + # Check if the Lambda function already exists + try: + self.lambda_client.get_function(FunctionName=function_name) + + # Update the existing function + self.logger.info(f"Updating existing Lambda function: {function_name}") + response = self.lambda_client.update_function_code( + FunctionName=function_name, + ZipFile=zip_content + ) + except self.lambda_client.exceptions.ResourceNotFoundException: + # Create a new function + self.logger.info(f"Creating new Lambda function: {function_name}") + response = self.lambda_client.create_function( + FunctionName=function_name, + Runtime='python3.9', + Role=role_arn, + Handler='lambda_function.lambda_handler', + Code={ + 'ZipFile': zip_content + }, + Timeout=900, # 15 minutes + MemorySize=1024 + ) + + # Wait for the function to be active + waiter = self.lambda_client.get_waiter('function_active') + waiter.wait(FunctionName=function_name) + + return response['FunctionArn'] + + def create_lambda_role(self, role_name: str) -> str: + """ + Create an IAM role for the Lambda function with necessary permissions. + + Args: + role_name: Name of the IAM role + + Returns: + The ARN of the created role + """ + # Define the trust policy for Lambda + trust_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "lambda.amazonaws.com"}, + "Action": "sts:AssumeRole" + } + ] + } + + try: + # Create the role + response = self.iam_client.create_role( + RoleName=role_name, + AssumeRolePolicyDocument=json.dumps(trust_policy) + ) + role_arn = response['Role']['Arn'] + + # Attach basic Lambda execution policy + self.iam_client.attach_role_policy( + RoleName=role_name, + PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' + ) + + # Attach S3 access policy + self.iam_client.attach_role_policy( + RoleName=role_name, + PolicyArn='arn:aws:iam::aws:policy/AmazonS3FullAccess' + ) + + # Wait for the role to be available + time.sleep(10) + + return role_arn + except self.iam_client.exceptions.EntityAlreadyExistsException: + # If role already exists, get its ARN + response = self.iam_client.get_role(RoleName=role_name) + return response['Role']['Arn'] + except Exception as e: + self.logger.error(f"Error creating Lambda role {role_name}: {str(e)}") + raise + + def update_lambda_timeout(self, function_name: str, timeout_seconds: int = 900) -> None: + """ + Update the timeout configuration for a Lambda function. + + Args: + function_name: Name of the Lambda function + timeout_seconds: Timeout value in seconds (default: 900 seconds / 15 minutes) + """ + try: + response = self.lambda_client.update_function_configuration( + FunctionName=function_name, + Timeout=timeout_seconds + ) + self.logger.info(f"Successfully updated Lambda timeout to {timeout_seconds} seconds for {function_name}") + except Exception as e: + self.logger.error(f"Error updating Lambda timeout for {function_name}: {str(e)}") + raise + + def get_lambda_role(self, function_name: str) -> str: + """ + Get the IAM role associated with a Lambda function. + + Args: + function_name: Name of the Lambda function + + Returns: + The name of the IAM role + """ + try: + response = self.lambda_client.get_function_configuration( + FunctionName=function_name + ) + role_arn = response['Role'] + role_name = role_arn.split('/')[-1] + + self.logger.info(f"Lambda Role ARN: {role_arn}") + self.logger.info(f"Lambda Role Name: {role_name}") + + return role_name + except Exception as e: + self.logger.error(f"Error getting Lambda role for {function_name}: {str(e)}") + raise + + def create_bedrock_policy(self, role_name: str) -> None: + """ + Create and attach a policy for Bedrock access to a Lambda role. + + Args: + role_name: Name of the IAM role + """ + # Define the policy document for Bedrock access + bedrock_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModelWithResponseStream" + ], + "Resource": [ + "*" + ] + } + ] + } + + policy_name = 'BedrockClaudeAccess' + + try: + # Create the policy + try: + response = self.iam_client.create_policy( + PolicyName=policy_name, + PolicyDocument=json.dumps(bedrock_policy) + ) + policy_arn = response['Policy']['Arn'] + self.logger.info(f"Created policy {policy_name} with ARN: {policy_arn}") + except self.iam_client.exceptions.EntityAlreadyExistsException: + # If policy already exists, get its ARN + account_id = self.sts_client.get_caller_identity()['Account'] + policy_arn = f'arn:aws:iam::{account_id}:policy/{policy_name}' + self.logger.info(f"Policy {policy_name} already exists with ARN: {policy_arn}") + + # Attach the policy to the role + self.iam_client.attach_role_policy( + RoleName=role_name, + PolicyArn=policy_arn + ) + self.logger.info(f"Successfully attached Bedrock policy to role {role_name}") + except Exception as e: + self.logger.error(f"Error creating or attaching Bedrock policy: {str(e)}") + raise + +class BedrockKnowledgeBaseManager: + """ + A class to manage AWS Bedrock Knowledge Base operations including creation, + data ingestion, and querying. + """ + + def __init__(self, region: Optional[str] = None): + """ + Initialize the BedrockKnowledgeBaseManager with AWS clients and configuration. + + Args: + region: AWS region to use. If None, uses the default from session. + """ + # Configure logging + logging.basicConfig( + format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', + level=logging.INFO + ) + self.logger = logging.getLogger(__name__) + + # Set up AWS session and clients + self.session = boto3.session.Session() + self.region = region or self.session.region_name + self.s3_client = boto3.client('s3') + self.sts_client = boto3.client('sts') + self.bedrock_agent_client = boto3.client('bedrock-agent') + self.bedrock_agent_runtime_client = boto3.client('bedrock-agent-runtime') + + # Get account ID + self.account_id = self.sts_client.get_caller_identity()["Account"] + + # Add parent directory to path for imports + self._setup_import_paths() + + # Import the BedrockKnowledgeBase class after path setup + from knowledge_base import BedrockKnowledgeBase + self.BedrockKnowledgeBase = BedrockKnowledgeBase + + # Generate timestamp suffix for resource naming + self.timestamp_suffix = self._generate_timestamp_suffix() + + # Default foundation model + self.foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0" + + # Knowledge base instances + self.kb_instances = {} + + # Lambda manager for custom chunking + self.lambda_manager = LambdaManager(region) + + def _setup_import_paths(self) -> None: + """Set up Python import paths to include parent directory.""" + current_path = Path().resolve().parent + if str(current_path) not in sys.path: + sys.path.append(str(current_path)) + + def _generate_timestamp_suffix(self) -> str: + """Generate a timestamp suffix for unique resource naming.""" + current_time = time.time() + return time.strftime("%Y%m%d%H%M%S", time.localtime(current_time))[-7:] + + def _create_s3_bucket(self, bucket_name: str) -> None: + """ + Create an S3 bucket if it doesn't exist. + + Args: + bucket_name: Name of the bucket to create + """ + try: + self.s3_client.create_bucket(Bucket=bucket_name) + self.logger.info(f"Created bucket: {bucket_name}") + except self.s3_client.exceptions.BucketAlreadyExists: + self.logger.info(f"Bucket already exists: {bucket_name}") + except Exception as e: + self.logger.error(f"Error creating bucket {bucket_name}: {e}") + raise + + def setup_custom_chunking_lambda(self, lambda_name: str) -> str: + """ + Set up the Lambda function for custom chunking. + + Args: + lambda_name: Base name for the Lambda function + + Returns: + The name of the created Lambda function + """ + # Create a unique name for the Lambda function + function_name = f"{lambda_name}-{self.timestamp_suffix}" + + # Create IAM role for the Lambda function + role_name = f"{function_name}-role" + role_arn = self.lambda_manager.create_lambda_role(role_name) + + # Create or update the Lambda function + lambda_source_file = "lambda_custom_chunking_function.py" + self.lambda_manager.create_or_update_lambda(function_name, role_arn, lambda_source_file) + + # Configure the Lambda function + self.lambda_manager.update_lambda_timeout(function_name) + + # Attach Bedrock policy to the role + self.lambda_manager.create_bedrock_policy(role_name) + + return function_name + + def create_knowledge_base(self, + kb_name: str, + kb_description: str, + chunking_strategy: str = "FIXED_SIZE", + suffix_override: Optional[str] = None, + lambda_function_name: Optional[str] = None, + intermediate_bucket_name: Optional[str] = None) -> str: + """ + Create a knowledge base with the specified configuration. + + Args: + kb_name: Base name for the knowledge base + kb_description: Description of the knowledge base + chunking_strategy: Strategy for chunking documents (FIXED_SIZE or CUSTOM) + suffix_override: Optional override for the timestamp suffix + lambda_function_name: Name of the Lambda function for custom chunking + intermediate_bucket_name: Name of the intermediate S3 bucket for custom chunking + + Returns: + The knowledge base ID + """ + suffix = suffix_override or self.timestamp_suffix + full_kb_name = f"{kb_name}-{suffix}" + bucket_name = full_kb_name + + # Create S3 bucket for the knowledge base + self._create_s3_bucket(bucket_name) + + # Define data sources + data_source = [{"type": "S3", "bucket_name": bucket_name}] + + # Create the knowledge base instance based on chunking strategy + if chunking_strategy == "CUSTOM": + # If intermediate bucket name is not provided, create one + if not intermediate_bucket_name: + intermediate_bucket_name = f"{full_kb_name}-intermediate" + + # Create the intermediate bucket + self._create_s3_bucket(intermediate_bucket_name) + + # If lambda function name is not provided, create one + if not lambda_function_name: + lambda_function_name = self.setup_custom_chunking_lambda(f"{kb_name}-lambda") + + kb_instance = self.BedrockKnowledgeBase( + kb_name=full_kb_name, + kb_description=kb_description, + data_sources=data_source, + lambda_function_name=lambda_function_name, + intermediate_bucket_name=intermediate_bucket_name, + chunking_strategy=chunking_strategy, + suffix=f"{suffix}-c" + ) + else: + kb_instance = self.BedrockKnowledgeBase( + kb_name=full_kb_name, + kb_description=kb_description, + data_sources=data_source, + chunking_strategy=chunking_strategy, + suffix=f"{suffix}-f" + ) + + # Store the instance for later use + self.kb_instances[full_kb_name] = kb_instance + + return kb_instance.get_knowledge_base_id() + + def upload_directory_to_s3(self, local_path: str, bucket_name: str, + skip_files: List[str] = ["LICENSE", "NOTICE", "README.md"]) -> None: + """ + Upload all files from a local directory to an S3 bucket. + + Args: + local_path: Path to the local directory + bucket_name: Name of the target S3 bucket + skip_files: List of filenames to skip + """ + for root, _, files in os.walk(local_path): + for file in files: + file_to_upload = os.path.join(root, file) + if file not in skip_files: + self.logger.info(f"Uploading file {file_to_upload} to {bucket_name}") + self.s3_client.upload_file(file_to_upload, bucket_name, file) + else: + self.logger.info(f"Skipping file {file_to_upload}") + + def start_ingestion_job(self, kb_name: str) -> None: + """ + Start the ingestion job for a knowledge base. + + Args: + kb_name: Name of the knowledge base + """ + full_kb_name = f"{kb_name}-{self.timestamp_suffix}" + if full_kb_name in self.kb_instances: + self.kb_instances[full_kb_name].start_ingestion_job() + # Wait for ingestion to complete + self.logger.info("Waiting for ingestion to complete...") + time.sleep(30) + else: + self.logger.error(f"Knowledge base {full_kb_name} not found") + raise ValueError(f"Knowledge base {full_kb_name} not found") + + def retrieve_and_generate(self, kb_id: str, query: str, + num_results: int = 5) -> Dict[str, Any]: + """ + Perform a retrieve and generate operation using the knowledge base. + + Args: + kb_id: Knowledge base ID + query: Query text + num_results: Number of results to retrieve + + Returns: + Response from the retrieve and generate operation + """ + response = self.bedrock_agent_runtime_client.retrieve_and_generate( + input={"text": query}, + retrieveAndGenerateConfiguration={ + "type": "KNOWLEDGE_BASE", + "knowledgeBaseConfiguration": { + 'knowledgeBaseId': kb_id, + "modelArn": f"arn:aws:bedrock:{self.region}::foundation-model/{self.foundation_model}", + "retrievalConfiguration": { + "vectorSearchConfiguration": { + "numberOfResults": num_results + } + } + } + } + ) + return response + + def retrieve(self, kb_id: str, query: str, num_results: int = 5) -> Dict[str, Any]: + """ + Perform a retrieve operation using the knowledge base. + + Args: + kb_id: Knowledge base ID + query: Query text + num_results: Number of results to retrieve + + Returns: + Response from the retrieve operation + """ + response = self.bedrock_agent_runtime_client.retrieve( + knowledgeBaseId=kb_id, + nextToken='string', + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": num_results, + } + }, + retrievalQuery={ + 'text': query + } + ) + return response + + def delete_knowledge_base(self, kb_name: str, + delete_s3_bucket: bool = False, + delete_iam_roles_and_policies: bool = True, + delete_lambda_function: bool = False) -> None: + """ + Delete a knowledge base and optionally its associated resources. + + Args: + kb_name: Name of the knowledge base + delete_s3_bucket: Whether to delete the associated S3 bucket + delete_iam_roles_and_policies: Whether to delete IAM roles and policies + delete_lambda_function: Whether to delete the Lambda function + """ + full_kb_name = f"{kb_name}-{self.timestamp_suffix}" + if full_kb_name in self.kb_instances: + self.logger.info(f"Deleting knowledge base: {full_kb_name}") + self.kb_instances[full_kb_name].delete_kb( + delete_s3_bucket=delete_s3_bucket, + delete_iam_roles_and_policies=delete_iam_roles_and_policies, + delete_lambda_function=delete_lambda_function + ) + # Remove from instances dictionary + del self.kb_instances[full_kb_name] + else: + self.logger.warning(f"Knowledge base {full_kb_name} not found for deletion") + +class ResponseFormatter: + """ + A utility class to format and print responses from Bedrock operations. + """ + + @staticmethod + def print_citations(response_citations: List[Dict[str, Any]]) -> None: + """ + Print citation information from a response. + + Args: + response_citations: List of citation references + """ + print(f"# of citations or chunks used to generate the response: {len(response_citations)}") + for num, chunk in enumerate(response_citations, 1): + print(f'Chunk {num}: {chunk["content"]["text"]}\n') + print(f'Chunk {num} Location: {chunk["location"]}\n') + print(f'Chunk {num} Metadata: {chunk["metadata"]}\n') + + @staticmethod + def print_retrieval_results(response: Dict[str, Any]) -> None: + """ + Print retrieval results from a response. + + Args: + response: Response containing retrieval results + """ + results = response.get('retrievalResults', []) + print(f"# of retrieved results: {len(results)}") + for num, chunk in enumerate(results, 1): + print(f'Chunk {num}: {chunk["content"]["text"]}\n') + print(f'Chunk {num} Location: {chunk["location"]}\n') + print(f'Chunk {num} Score: {chunk["score"]}\n') + print(f'Chunk {num} Metadata: {chunk["metadata"]}\n') + +class ChunkingStrategySelector: + """ + A class to handle user selection of chunking strategies. + """ + + STRATEGIES = { + "1": {"name": "FIXED_SIZE", "description": "Standard fixed-size chunking"}, + "2": {"name": "CUSTOM", "description": "Custom chunking using Lambda function"}, + "3": {"name": "BOTH", "description": "Run both fixed-size and custom chunking for comparison"} + } + + @classmethod + def get_user_selection(cls) -> Dict[str, Any]: + """ + Get the user's selection of chunking strategy. + + Returns: + Dictionary with strategy information + """ + print("\n=== Chunking Strategy Selection ===") + print("Select a chunking strategy for your knowledge base:") + + for key, strategy in cls.STRATEGIES.items(): + print(f"{key}. {strategy['name']}: {strategy['description']}") + + while True: + selection = input("\nEnter your selection (1-3): ") + if selection in cls.STRATEGIES: + strategy_info = cls.STRATEGIES[selection] + print(f"Selected: {strategy_info['name']} - {strategy_info['description']}") + return {"strategy": strategy_info['name']} + else: + print("Invalid selection. Please try again.") + +def main(): + """Main function to demonstrate the usage of the classes.""" + try: + # Initialize the manager + kb_manager = BedrockKnowledgeBaseManager() + formatter = ResponseFormatter() + + # Check if lambda function file exists and create a copy with the expected name + lambda_source_file = "lambda_custom_chunking_function.py" + lambda_target_file = "lambda_function.py" + + # Create a copy of the lambda file with the expected name + if os.path.exists(lambda_source_file) and not os.path.exists(lambda_target_file): + print(f"Creating a copy of {lambda_source_file} as {lambda_target_file}") + with open(lambda_source_file, 'r') as source: + content = source.read() + with open(lambda_target_file, 'w') as target: + target.write(content) + elif not os.path.exists(lambda_source_file): + print(f"Error: Lambda source file {lambda_source_file} not found.") + print("Please create this file with your custom chunking code.") + return + + # Get user selection for chunking strategy + strategy_selection = ChunkingStrategySelector.get_user_selection() + chunking_strategy = strategy_selection["strategy"] + + # Base names for resources + kb_base_name = 'kb' + kb_description = "Knowledge Base containing complex PDF." + + # Create knowledge bases with different chunking strategies + kb_ids = {} + kb_names = {} + + # Create standard knowledge base if selected + if chunking_strategy in ["FIXED_SIZE", "BOTH"]: + kb_name_standard = f"standard-{kb_base_name}" + print(f"\nCreating knowledge base with FIXED_SIZE chunking strategy...") + kb_id_standard = kb_manager.create_knowledge_base( + kb_name=kb_name_standard, + kb_description=kb_description, + chunking_strategy="FIXED_SIZE" + ) + kb_ids["standard"] = kb_id_standard + kb_names["standard"] = kb_name_standard + + # Upload data to the S3 bucket + bucket_name = f'{kb_name_standard}-{kb_manager.timestamp_suffix}' + print(f"Uploading data to bucket: {bucket_name}") + kb_manager.upload_directory_to_s3("synthetic_dataset", bucket_name) + + # Start ingestion job + print("Starting ingestion job...") + kb_manager.start_ingestion_job(kb_name_standard) + + # Create custom chunking knowledge base if selected + if chunking_strategy in ["CUSTOM", "BOTH"]: + kb_name_custom = f"custom-{kb_base_name}" + print(f"\nCreating knowledge base with CUSTOM chunking strategy...") + + # Create intermediate bucket name + intermediate_bucket_name = f"{kb_name_custom}-intermediate-{kb_manager.timestamp_suffix}" + + # Set up Lambda function for custom chunking + lambda_function_name = f"{kb_name_custom}-lambda-{kb_manager.timestamp_suffix}" + + kb_id_custom = kb_manager.create_knowledge_base( + kb_name=kb_name_custom, + kb_description=kb_description, + chunking_strategy="CUSTOM", + lambda_function_name=lambda_function_name, + intermediate_bucket_name=intermediate_bucket_name + ) + kb_ids["custom"] = kb_id_custom + kb_names["custom"] = kb_name_custom + + # Upload data to the S3 bucket + bucket_name = f'{kb_name_custom}-{kb_manager.timestamp_suffix}' + print(f"Uploading data to bucket: {bucket_name}") + kb_manager.upload_directory_to_s3("synthetic_dataset", bucket_name) + + # Start ingestion job + print("Starting ingestion job...") + kb_manager.start_ingestion_job(kb_name_custom) + + # Define a query + query = "Provide a summary of consolidated statements of cash flows of Octank Financial for the fiscal years ended December 31, 2019." + + # Wait for knowledge base to be ready + print("\nWaiting for knowledge base to be ready...") + time.sleep(20) + + # Process query with each knowledge base + for kb_type, kb_id in kb_ids.items(): + print(f"\n=== Results from {kb_type.upper()} knowledge base ===") + print(f"Knowledge Base ID: {kb_id}") + + # Perform retrieve and generate + print("\nPerforming retrieve and generate operation...") + response = kb_manager.retrieve_and_generate(kb_id, query) + print(response['output']['text'], end='\n'*2) + + # Format and print citations + print("\nCitation information:") + response_refs = response['citations'][0]['retrievedReferences'] + formatter.print_citations(response_refs) + + # Perform retrieve operation + print("\nPerforming retrieve operation...") + response_ret = kb_manager.retrieve(kb_id, query) + formatter.print_retrieval_results(response_ret) + + # Print summary of knowledge base IDs + print("\n=== Knowledge Base Summary ===") + for kb_type, kb_id in kb_ids.items(): + print(f"{kb_type.capitalize()}: {kb_id}") + + # Run RAGAS evaluation if multiple knowledge bases are created + if len(kb_ids) > 1 and "standard" in kb_ids and "custom" in kb_ids: + print("\n=== Running RAGAS Evaluation ===") + input("Press Enter to Run the RAGAS Evaluation...") + + # Import the RAG evaluator + from rag_evaluator import RAGEvaluator + + # Create a Bedrock runtime client with appropriate configuration + bedrock_runtime_client = boto3.client( + 'bedrock-runtime', + region_name=kb_manager.region, + config=Config( + read_timeout=900, # 15 minutes + connect_timeout=60, + retries={'max_attempts': 3} + ) + ) + + # Initialize the RAG evaluator + evaluator = RAGEvaluator( + bedrock_runtime_client=bedrock_runtime_client, + bedrock_agent_runtime_client=kb_manager.bedrock_agent_runtime_client + ) + + # Define evaluation questions and ground truths + questions = [ + "What was the primary reason for the increase in net cash provided by operating activities for Octank Financial in 2021?", + "In which year did Octank Financial have the highest net cash used in investing activities, and what was the primary reason for this?", + "What was the primary source of cash inflows from financing activities for Octank Financial in 2021?", + "Based on the information provided, what can you infer about Octank Financial's overall financial health and growth prospects?" + ] + ground_truths = [ + "The increase in net cash provided by operating activities was primarily due to an increase in net income and favorable changes in operating assets and liabilities.", + "Octank Financial had the highest net cash used in investing activities in 2021, at $360 million. The primary reason for this was an increase in purchases of property, plant, and equipment and marketable securities", + "The primary source of cash inflows from financing activities for Octank Financial in 2021 was an increase in proceeds from the issuance of common stock and long-term debt.", + "Based on the information provided, Octank Financial appears to be in a healthy financial position and has good growth prospects. The company has consistently increased its net cash provided by operating activities, indicating strong profitability and efficient management of working capital. Additionally, Octank Financial has been investing in long-term assets, such as property, plant, and equipment, and marketable securities, which suggests plans for future growth and expansion. The company has also been able to finance its growth through the issuance of common stock and long-term debt, indicating confidence from investors and lenders. Overall, Octank Financial's steady increase in cash and cash equivalents over the past three years provides a strong foundation for future growth and investment opportunities." + ] + + # Compare knowledge base strategies + kb_strategy_map = { + "Default Chunking": kb_ids["standard"], + "Contextual Chunking": kb_ids["custom"] + } + + comparison_df = evaluator.compare_kb_strategies(kb_strategy_map, questions, ground_truths) + + # Format and display the comparison + styled_df = evaluator.format_comparison(comparison_df) + print("\n=== RAGAS Evaluation Results ===") + print(styled_df.to_string()) + + # Save the results to a CSV file + comparison_df.to_csv("ragas_evaluation_results.csv") + print("\nEvaluation results saved to ragas_evaluation_results.csv") + + finally: + # Clean up resources before exiting + print("\nCleaning up resources...") + input("Press Enter to Delete the Resources...") + try: + # Delete the knowledge bases and associated resources + for kb_type, kb_name in kb_names.items(): + print(f"Cleaning up {kb_type} knowledge base: {kb_name}") + kb_manager.delete_knowledge_base( + kb_name, + delete_s3_bucket=True, + delete_iam_roles_and_policies=True, + delete_lambda_function=False if kb_type == "standard" else True + ) + print(f"Knowledge base {kb_name} cleanup completed successfully.") + + # Clean up the temporary lambda function file + if os.path.exists(lambda_target_file): + print(f"Removing temporary file: {lambda_target_file}") + os.remove(lambda_target_file) + + except Exception as e: + print(f"Error during cleanup: {e}") + +if __name__ == "__main__": + main() + + diff --git a/BedrockPromptCachingRoutingDemo/src/bedrock_prompt_caching.py b/BedrockPromptCachingRoutingDemo/src/bedrock_prompt_caching.py new file mode 100644 index 000000000..230c8c0b8 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/bedrock_prompt_caching.py @@ -0,0 +1,946 @@ +""" +Bedrock Prompt Caching CLI Application + +This module provides a class-based implementation for interacting with Amazon Bedrock +with prompt caching capabilities and a CLI interface for user interaction. +""" + +# Standard libraries +import json +import time +import os +from enum import Enum +from datetime import datetime +from typing import Dict, List, Optional, Tuple, Any + +# Data processing and visualization +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import seaborn as sns + +# AWS and external services +import requests + +# Local imports +from file_processor import FileProcessor # For processing different file types +from bedrock_service import BedrockService # For interacting with Bedrock services +from model_manager import ModelManager # For managing Bedrock models + +# Cache mode constants for controlling prompt caching behavior +class CACHE(str, Enum): + """Enumeration of cache modes for Bedrock prompt caching""" + ON = "ON" # Enable caching with checkpoint + OFF = "OFF" # Disable caching completely + READ = "READ" # Cache hit - reading from cache + WRITE = "WRITE" # Cache miss - writing to cache + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented + +class CacheManager: + """Manages the caching of responses + + This class provides in-memory caching for Bedrock responses. + """ + + def __init__(self): + """Initialize cache manager""" + self.cache_store = {} # Store cache information by cache key + + def store_cache_info(self, cache_key: str, is_cache_hit: bool, document: str, query: str, metrics: Dict, turn: int = 0) -> None: + """Store detailed cache information for analysis + + Args: + cache_key: The unique cache key + is_cache_hit: Whether this was a cache hit + document: The document content + query: The user's question + metrics: Usage metrics from the API response + turn: The conversation turn number (default: 0) + """ + # Handle different key formats in the metrics dictionary + cache_read_tokens = ( + metrics.get("cache_read_input_tokens", 0) or + metrics.get("cacheReadInputTokens", 0) + ) + + cache_creation_tokens = ( + metrics.get("cache_creation_input_tokens", 0) or + metrics.get("cacheCreationInputTokens", 0) + ) + + input_tokens = metrics.get("inputTokens", 0) + output_tokens = metrics.get("outputTokens", 0) + + self.cache_store[cache_key] = { + "cache_key": cache_key, + "is_cache_hit": is_cache_hit, + "cached_content": document if turn == 0 else "", + "question": query, + "cache_creation_tokens": cache_creation_tokens, + "cache_read_tokens": cache_read_tokens, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "invocation_latency": metrics.get("response_time_seconds", 0), + "turn": turn + 1 + } + + def get_cache_summary(self, cache_key: str) -> str: + """Get a formatted summary of cache information for a specific key + + Args: + cache_key: The cache key to get information for + + Returns: + Formatted string with cache summary + """ + if cache_key not in self.cache_store: + return "No cache information available for this query." + + cache_info = self.cache_store[cache_key] + is_cache_hit = cache_info["is_cache_hit"] + + # Calculate performance metrics + input_tokens = cache_info.get("input_tokens", 0) + input_tokens_cache_read = cache_info.get("cache_read_tokens", 0) + input_tokens_cache_create = cache_info.get("cache_creation_tokens", 0) + invocation_latency = cache_info.get("invocation_latency", 0) + + # For cache hits, if cache_read_tokens is still 0, use input_tokens + if is_cache_hit and input_tokens_cache_read == 0 and input_tokens > 0: + input_tokens_cache_read = input_tokens + + total_input_tokens = input_tokens + input_tokens_cache_read + percentage_cached = (input_tokens_cache_read / total_input_tokens * 100) if total_input_tokens > 0 else 0 + + # Calculate estimated cost savings (assuming $0.01 per 1K tokens) + token_cost_per_k = 0.01 + estimated_savings = (input_tokens_cache_read / 1000) * token_cost_per_k + + # Calculate latency benefit (assuming average response time of 2 seconds without cache) + avg_response_time = 2.0 + latency_benefit = ((avg_response_time - invocation_latency) / avg_response_time * 100) if invocation_latency > 0 else 0 + + summary = ["\n📊 Cache Summary:"] + summary.append(f" Cache key: {cache_key}") + + if is_cache_hit: + summary.append(f" ✅ CACHE HIT") + summary.append(f" Cache read tokens: {input_tokens_cache_read}") + summary.append(f" Input tokens saved: {input_tokens_cache_read}") + summary.append(f" {percentage_cached:.1f}% of input prompt cached ({total_input_tokens} tokens)") + + # Add cost and latency benefits + summary.append(f" Estimated cost savings: ${estimated_savings:.4f}") + summary.append(f" Latency improvement: {latency_benefit:.1f}%") + + # Show what was retrieved from cache + if cache_info["turn"] == 1: + summary.append(" Content retrieved from cache: Document context") + cached_content = cache_info["cached_content"] + if cached_content: + preview = cached_content[:100] + "..." if len(cached_content) > 100 else cached_content + summary.append(f" Cached content preview: \"{preview}\"") + else: + summary.append(" Content retrieved from cache: Previous question context") + + summary.append(" This means the model didn't need to process this content again,") + summary.append(" resulting in faster response time and lower token usage.") + else: + summary.append(f" ❌ CACHE MISS") + summary.append(f" Cache creation tokens: {input_tokens_cache_create}") + + # Show what was written to cache + if cache_info["turn"] == 1: + summary.append(" Content written to cache: Document context") + cached_content = cache_info["cached_content"] + if cached_content: + preview = cached_content[:100] + "..." if len(cached_content) > 100 else cached_content + summary.append(f" Cached content preview: \"{preview}\"") + else: + summary.append(" Content written to cache: Current question context") + + summary.append(" This content will be cached for future similar queries.") + summary.append(" Future queries will benefit from faster response times and lower costs.") + + return "\n".join(summary) + +class BedrockChat: + """Main class for interacting with Bedrock for document Q&A + + This class orchestrates the document processing, model selection, + and Bedrock API interactions with prompt caching capabilities. + """ + + def __init__(self): + """Initialize the chat components and service dependencies""" + # Initialize service components + self.bedrock_service = BedrockService() # Manages Bedrock API clients + self.runtime_client = self.bedrock_service.get_runtime_client() # For inference calls + self.model_manager = ModelManager() # For model selection + self.cache_manager = CacheManager() # For response caching + + # State variables + self.current_document = "" # Currently loaded document + self.current_model_id = "" # Selected model ID + self.temperature = 0.0 # Temperature setting for generation + self.blog = "" # For storing blog content for benchmarking + + # Get available models + try: + self.available_models = self.model_manager.get_available_models() + except Exception as e: + print(f"Warning: Could not fetch available models: {e}") + self.available_models = [ + "anthropic.claude-haiku-4-5-20251001-v1:0", + "anthropic.claude-sonnet-4-5-20250929-v1:0", + "anthropic.claude-opus-4-1-20250805-v1:0" + ] + + def set_document(self, document: str) -> None: + """Set the current document for chat""" + self.current_document = document + + def set_model(self, model_id: str) -> None: + """Set the current model ID""" + self.current_model_id = model_id + + def set_temperature(self, temperature: float) -> None: + """Set the temperature for generation""" + self.temperature = temperature + + def select_model(self): + """Display available models and let user select one""" + return self.model_manager.select_model() + + def chat_with_document(self, query: str, use_cache: bool = True, checkpoint: bool = False) -> Tuple[str, Dict, bool, str]: + """ + Process a query against the current document + + Args: + query: The user's question + use_cache: Whether to check cache before calling Bedrock + checkpoint: Whether to use a checkpoint for caching + + Returns: + Tuple of (response_text, usage_info, from_cache, cache_key) + """ + if not self.current_document: + return "No document loaded. Please load a document first.", {}, False, "" + + if not self.current_model_id: + return "No model selected. Please select a model first.", {}, False, "" + + # Check for empty query + if not query.strip(): + return "Please enter a question.", {}, False, "" + + # Generate a simple cache key + cache_key = f"{hash(self.current_document)}-{hash(query)}-{self.current_model_id}-{self.temperature}" + + # Prepare the prompt + instructions = self._get_instructions() + document_content = f"Here is the document: {self.current_document} " + + # Create message body based on cache settings + if use_cache: + # Include cache point for caching + messages_body = [ + { + 'role': 'user', + 'content': [ + {'text': instructions}, + {'text': document_content}, + { + "cachePoint": { + "type": "default" + } + }, + {'text': query} + ] + } + ] + else: + # No cache point when caching is disabled + messages_body = [ + { + 'role': 'user', + 'content': [ + {'text': instructions}, + {'text': document_content}, + {'text': query} + ] + } + ] + + inference_config = { + 'maxTokens': 500, + 'temperature': self.temperature, + 'topP': 1 + } + + # Get updated model ID for specific Claude models + model_id = self.model_manager.get_model_arn_from_inference_profiles(self.current_model_id) + if model_id != self.current_model_id: + print(f"\nUsing updated model ID: {model_id}") + + # Call Bedrock + start_time = time.time() + response = self.runtime_client.converse( + messages=messages_body, + modelId=model_id, + inferenceConfig=inference_config + ) + end_time = time.time() + + # Process response + output_message = response["output"]["message"] + response_text = output_message["content"][0]["text"] + usage_info = response["usage"] + + # Add response time to usage info + usage_info["response_time_seconds"] = end_time - start_time + + # Determine if this was a cache hit or miss based on metrics + is_cache_hit = usage_info.get("cache_read_input_tokens", 0) > 0 or usage_info.get("cacheReadInputTokens", 0) > 0 + + # Store cache information + self.cache_manager.store_cache_info( + cache_key=cache_key, + is_cache_hit=is_cache_hit, + document=self.current_document, + query=query, + metrics=usage_info + ) + + return response_text, usage_info, is_cache_hit, cache_key + + def _get_instructions(self) -> str: + """Return the instructions for the LLM""" + return ( + "I will provide you with a document, followed by a question about its content. " + "Your task is to analyze the document, extract relevant information, and provide " + "a comprehensive answer to the question. Please follow these detailed instructions:" + + "\n\n1. Identifying Relevant Quotes:" + "\n - Carefully read through the entire document." + "\n - Identify sections of the text that are directly relevant to answering the question." + "\n - Select quotes that provide key information, context, or support for the answer." + "\n - Quotes should be concise and to the point, typically no more than 2-3 sentences each." + "\n - Choose a diverse range of quotes if multiple aspects of the question need to be addressed." + "\n - Aim to select between 2 to 5 quotes, depending on the complexity of the question." + + "\n\n2. Presenting the Quotes:" + "\n - List the selected quotes under the heading 'Relevant quotes:'" + "\n - Number each quote sequentially, starting from [1]." + "\n - Present each quote exactly as it appears in the original text, enclosed in quotation marks." + "\n - If no relevant quotes can be found, write 'No relevant quotes' instead." + + "\n\n3. Formulating the Answer:" + "\n - Begin your answer with the heading 'Answer:' on a new line after the quotes." + "\n - Provide a clear, concise, and accurate answer to the question based on the information in the document." + "\n - Ensure your answer is comprehensive and addresses all aspects of the question." + "\n - Use information from the quotes to support your answer." + "\n - Add the bracketed number of the relevant quote at the end of each sentence or point that uses information from that quote." + + "\n\n4. Handling Uncertainty:" + "\n - If the document does not contain enough information to fully answer the question, clearly state this in your answer." + "\n - Provide any partial information that is available." + + "\n\n5. Formatting and Style:" + "\n - Use clear paragraph breaks to separate different points or aspects of your answer." + "\n - Ensure proper grammar, punctuation, and spelling throughout your response." + "\n - Maintain a professional and neutral tone throughout your answer." + ) + + def run_response_latency_benchmark(self, test_configs, epochs=3): + """ + Benchmark response latency metrics for different models and cache modes + + Args: + test_configs: List of test configuration dictionaries with model_id, model_name, and cache_mode + epochs: Number of test iterations to run for each configuration + + Returns: + List of datapoints with benchmark results + """ + datapoints = [] + + for test_config in test_configs: + print(f"[{test_config['model_name']}]") + + # Get updated model ID for specific Claude models + model_id = self.model_manager.get_model_arn_from_inference_profiles(test_config['model_id']) + if model_id != test_config['model_id']: + print(f"Using updated model ID: {model_id}") + + # Prepare the converse command + converse_cmd = { + "modelId": model_id, + "messages": [ + { + "role": "user", + "content": [] + } + ], + "inferenceConfig": { + "maxTokens": 500, + "temperature": self.temperature, + "topP": 1 + } + } + + for cache_mode in test_config['cache_mode']: + # Create a copy for modification + cmd = converse_cmd.copy() + + # Set content with blog text + cmd["messages"][0]["content"] = [ + { + "text": self.blog + }, + { + "text": "what is it about in 20 words." + } + ] + + # Add cache point if needed + if cache_mode == CACHE.ON: + cmd["messages"][0]["content"].insert(1, { + "cachePoint": { + "type": "default" + } + }) + + for epoch in range(epochs): + start_time = time.time() + + # Call the API with streaming + response = self.runtime_client.converse_stream(**cmd) + + ttft = None + + for i, chunk in enumerate(response['stream']): + if "messageStart" in chunk: + pass + elif "contentBlockStop" in chunk: + pass + elif "messageStop" in chunk: + pass + elif "contentBlockDelta" in chunk: + text = chunk["contentBlockDelta"].get('delta',{}).get("text",None) + if text is not None and not text: + print('', end='') + if text is not None: + if not ttft: + ttft = time.time() - start_time + + elif "metadata" in chunk: + if 'cacheReadInputTokens' in chunk["metadata"]['usage']: + if chunk["metadata"]['usage']['cacheWriteInputTokens'] > 1: + cache_result = CACHE.WRITE + elif chunk["metadata"]['usage']['cacheReadInputTokens'] > 1: + cache_result = CACHE.READ + else: + print(json.dumps(chunk, sort_keys=False, indent=4)) + assert False, 'Unclear' + else: + cache_result = CACHE.OFF + + latencyMs = chunk["metadata"]["metrics"]["latencyMs"] / 1000 + requestId = response['ResponseMetadata']['RequestId'] + + datapoints.append({ + 'model': test_config['model_name'], + 'cache': cache_result, + 'measure': 'first_token', + 'time': ttft, + 'requestId': requestId, + }) + + datapoints.append({ + 'model': test_config['model_name'], + 'cache': cache_result, + 'measure': 'last_token', + 'time': latencyMs, + 'requestId': requestId, + }) + + print(f"{epoch:2} {cache_mode},{cache_result} | ttft={ttft:.1f}s | last={latencyMs:.1f}s | {requestId}") + + else: + end_time = time.time() + print('\n\nchunk +{:.3f}s \n{}'.format( + time.time()-end_time, + json.dumps(chunk, sort_keys=False, indent=4) + )) + + time.sleep(30) + + return datapoints + + def add_median_labels(self, ax, fmt=".1f"): + """ + Add text labels to the median lines of a seaborn boxplot. + + Args: + ax: plt.Axes, e.g. the return value of sns.boxplot() + fmt: format string for the median value + """ + lines = ax.get_lines() + boxes = [c for c in ax.get_children() if "Patch" in str(c)] + start = 4 + if not boxes: # seaborn v0.13 => fill=False => no patches => +1 line + boxes = [c for c in ax.get_lines() if len(c.get_xdata()) == 5] + start += 1 + lines_per_box = len(lines) // len(boxes) + for median in lines[start::lines_per_box]: + x, y = (data.mean() for data in median.get_data()) + # choose value depending on horizontal or vertical plot orientation + value = x if len(set(median.get_xdata())) == 1 else y + text = ax.text(x, y, f'{value:{fmt}}', ha='center', va='center', color='white') + # create median-colored border around white text for contrast + text.set_path_effects([ + path_effects.Stroke(linewidth=3, foreground=median.get_color()), + path_effects.Normal(), + ]) + + def visualize_benchmark(self, datapoints): + """ + Visualize benchmark results using seaborn boxplots + + Args: + datapoints: List of benchmark datapoints + """ + if not datapoints: + print("No benchmark data to visualize.") + return + + df = pd.DataFrame(datapoints) + + # Save results to CSV for later analysis + try: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + csv_filename = f"benchmark_results_{timestamp}.csv" + df.to_csv(csv_filename) + print(f"Benchmark results saved to {csv_filename}") + except Exception as e: + print(f"Could not save results to CSV: {str(e)}") + + try: + sns.set_style("whitegrid") + n_models = df['model'].nunique() + + f, axes = plt.subplots(n_models, 1, figsize=(6, n_models * 6.4)) + + # Convert axes to array if there's only one model + axes = np.array([axes]) if n_models == 1 else axes + + for i, model in enumerate(df['model'].unique()): + cond = df['model'] == model + df_i = df.loc[cond] + + ax = sns.boxplot(df_i, + ax=axes[i], + x='measure', + y='time', + hue=df_i[['cache']].apply(tuple, axis=1)) + + ax.tick_params(axis='x', rotation=45) + ax.set_xlabel(None) + self.add_median_labels(ax) + ax.legend(loc='upper left') + ax.set_title(f'Time to First Token (TTFT) - {model}', fontsize=14) + + plt.tight_layout() + + # Save plot to file + try: + plot_filename = f"benchmark_plot_{timestamp}.png" + plt.savefig(plot_filename) + print(f"Plot saved to {plot_filename}") + except Exception as e: + print(f"Could not save plot: {str(e)}") + + plt.show(block=False) # Non-blocking display + plt.pause(0.1) # Small pause to render the plot + + input("\nPress Enter to continue...") + plt.close() + except Exception as e: + print(f"Error during visualization: {str(e)}") + +class ChatCLI: + """Command-line interface for the Bedrock Chat application + + This class provides a user-friendly CLI for interacting with the BedrockChat + functionality, including document loading, model selection, and chat sessions. + """ + + def __init__(self): + """Initialize the CLI interface with sample content""" + # Core chat functionality + self.chat = BedrockChat() + + # Sample AWS blog URLs for demonstration + self.sample_topics = [ + 'https://aws.amazon.com/blogs/aws/reduce-costs-and-latency-with-amazon-bedrock-intelligent-prompt-routing-and-prompt-caching-preview/', + 'https://aws.amazon.com/blogs/machine-learning/enhance-conversational-ai-with-advanced-routing-techniques-with-amazon-bedrock/', + 'https://aws.amazon.com/blogs/security/cost-considerations-and-common-options-for-aws-network-firewall-log-management/' + ] + + # Sample questions for user convenience + self.sample_questions = [ + 'what is it about?', + 'what are the use cases?', + 'Translate "Hello" to French (temperature=0.3)', + 'Translate "Hello" to French (temperature=0.4)' + ] + + def _run_benchmark(self): + """Run TTFT benchmark tests""" + if not self.chat.current_document: + print("\nNo document loaded. Please load a document first.") + return + + if not self.chat.current_model_id: + print("\nNo model selected. Please select a model first.") + return + + # Store the current document as blog for benchmarking + self.chat.blog = self.chat.current_document + + # Define test configurations + tests = [ + { + 'model_id': self.chat.current_model_id, + 'model_name': self.chat.current_model_id.split(':')[0], + 'cache_mode': [CACHE.OFF, CACHE.ON] + } + ] + + # Ask for number of epochs + try: + epochs = int(input("\nEnter number of test iterations (default: 3): ") or "3") + except ValueError: + epochs = 3 + + print(f"\nRunning benchmark with {epochs} iterations...") + print("This may take several minutes. Please wait...") + + try: + # Run the benchmark + datapoints = self.chat.run_response_latency_benchmark(tests, epochs) + + # Visualize results + print("\nGenerating visualization...") + self.chat.visualize_benchmark(datapoints) + + # Show summary statistics + df = pd.DataFrame(datapoints) + print("\nBenchmark Results Summary:") + print(df.groupby(['model', 'cache', 'measure'])['time'].agg(['mean', 'median', 'min', 'max'])) + except Exception as e: + print(f"\nError during benchmark: {str(e)}") + print("Returning to chat menu...") + + def display_welcome(self): + """Display welcome message and system info""" + print("\n" + "="*60) + print("BEDROCK PROMPT CACHING CLI".center(60)) + print("="*60) + print("\nSystem Information:") + print(f"Bedrock Runtime Client initialized") + print("\nThis application demonstrates Amazon Bedrock's prompt caching capabilities.") + print("You can chat with documents and see if responses come from cache or LLM.") + print("You can also use the multi-turn chat feature to visualize cache hits and misses.") + print("="*60) + + def display_system_diagram(self): + """Display a simple ASCII diagram of the system flow""" + diagram = """ + System Flow Diagram: + ┌─────────────┐ ┌───────────────┐ ┌─────────────────┐ + │ User Query │────▶│ Cache Manager │────▶│ Cache Hit? │ + └─────────────┘ └───────────────┘ └────────┬────────┘ + │ │ + │ │ + │ ┌────▼─────┐ + ┌──────▼──────┐ ┌───────────────┐ │ │ + │ Document │ │ Bedrock │ No │ Yes │ + │ Processor │────▶│ Service │◀───────┘ │ + └─────────────┘ └───────┬───────┘ │ + │ │ + │ │ + ┌─────────────┐ ┌──────▼──────┐ ┌─────▼─────┐ + │ User │◀────│ Response │◀─────────────│ Retrieve │ + │ Interface │ │ Processing │ │ from Cache│ + └─────────────┘ └────────────┘ └───────────┘ + """ + print(diagram) + + def load_document_menu(self): + """Menu for loading a document""" + print("\n--- DOCUMENT LOADING OPTIONS ---") + print("1. Load from sample URLs") + print("2. Enter custom URL") + print("3. Enter file path") + print("0. Return to main menu") + + choice = input("\nEnter your choice: ") + + if choice == "1": + print("\nSample URLs:") + for i, url in enumerate(self.sample_topics, 1): + print(f"{i}. {url}") + + url_choice = input("\nSelect URL number: ") + try: + url_index = int(url_choice) - 1 + if 0 <= url_index < len(self.sample_topics): + url = self.sample_topics[url_index] + print(f"\nFetching document from: {url}") + try: + response = requests.get(url) + response.raise_for_status() + document = response.text + if document: + self.chat.set_document(document) + print(f"Document loaded successfully ({len(document)} characters)") + except Exception as e: + print(f"Error fetching document: {e}") + else: + print("Invalid selection") + except ValueError: + print("Please enter a valid number") + + elif choice == "2": + url = input("\nEnter URL: ") + print(f"Fetching document from: {url}") + try: + response = requests.get(url) + response.raise_for_status() + document = response.text + if document: + self.chat.set_document(document) + print(f"Document loaded successfully ({len(document)} characters)") + except Exception as e: + print(f"Error fetching document: {e}") + + elif choice == "3": + file_path = input("\nEnter file path: ") + # Check file extension + _, ext = os.path.splitext(file_path) + if ext.lower() in FileProcessor.SUPPORTED_EXTENSIONS: + try: + # Create a file-like object with name attribute for FileProcessor + import io + with open(file_path, 'rb') as f: + file_content = f.read() + + class FileObj: + def __init__(self, content, name): + self.name = name + self._content = content + self._io = io.BytesIO(content) + + def getvalue(self): + return self._content + + def seek(self, pos, whence=0): + return self._io.seek(pos, whence) + + def tell(self): + return self._io.tell() + + def read(self, size=-1): + return self._io.read(size) + + def close(self): + pass + + # Process the file using FileProcessor + file_obj = FileObj(file_content, os.path.basename(file_path)) + document = FileProcessor.process_uploaded_file(file_obj) + file_obj.close() + + if document: + self.chat.set_document(document) + print(f"Document loaded successfully ({len(document)} characters)") + print("\nProceeding to model selection...") + self.model_selection_menu() + except Exception as e: + print(f"Error processing file: {e}") + else: + print(f"Unsupported file type. Supported types: {', '.join(FileProcessor.SUPPORTED_EXTENSIONS)}") + + def model_selection_menu(self): + """Menu for selecting a model""" + model_id = self.chat.select_model() + self.chat.set_model(model_id) + print(f"\nSelected model: {model_id}") + + # Set temperature + while True: + try: + temp = float(input("\nEnter temperature (0.0-1.0): ")) + if 0 <= temp <= 1: + self.chat.set_temperature(temp) + print(f"Temperature set to: {temp}") + break + else: + print("Temperature must be between 0.0 and 1.0") + except ValueError: + print("Please enter a valid number") + + def chat_menu(self): + """Interactive chat session with the document""" + if not self.chat.current_document: + print("\nNo document loaded. Please load a document first.") + return + + if not self.chat.current_model_id: + print("\nNo model selected. Please select a model first.") + return + + print("\n--- CHAT SESSION ---") + print("Type 'exit' to return to main menu") + print("Type 'sample' to see sample questions") + print("Type 'cache on/off' to toggle cache usage") + print("Type 'benchmark' to run TTFT benchmarks") + print("Type 'stats' to show cache statistics") + + use_cache = True + last_cache_key = "" + + while True: + print("\nSettings:") + print(f"- Cache: {'ON' if use_cache else 'OFF'}") + + query = input("\nYour question: ") + + if query.lower() == 'exit': + break + + elif query.lower() == 'sample': + print("\nSample questions:") + for i, q in enumerate(self.sample_questions, 1): + print(f"{i}. {q}") + continue + + elif query.lower() == 'cache on': + use_cache = True + print("Cache enabled") + continue + + elif query.lower() == 'cache off': + use_cache = False + print("Cache disabled") + continue + + elif query.lower() == 'benchmark': + self._run_benchmark() + continue + + elif query.lower() == 'stats': + if last_cache_key: + cache_summary = self.chat.cache_manager.get_cache_summary(last_cache_key) + print(cache_summary) + else: + print("\nNo cache information available yet. Ask a question first.") + continue + + # Process the query + try: + print("\nProcessing your question...") + response_text, usage, from_cache, cache_key = self.chat.chat_with_document( + query, use_cache=use_cache + ) + last_cache_key = cache_key + except Exception as e: + print(f"\nError: {str(e)}") + continue + + # Display source information + if from_cache: + print("\n[RESPONSE FROM CACHE]") + else: + print("\n[RESPONSE FROM LLM]") + + # Display the response + print("\n" + "="*60) + print(response_text) + print("="*60) + + # Display usage information + print("\nUsage Information:") + print(f"Input tokens: {usage.get('inputTokens', 'N/A')}") + print(f"Output tokens: {usage.get('outputTokens', 'N/A')}") + print(f"Response time: {usage.get('response_time_seconds', 'N/A'):.2f} seconds") + + # Display cache information + cache_read = usage.get("cache_read_input_tokens", 0) or usage.get("cacheReadInputTokens", 0) + cache_write = usage.get("cache_creation_input_tokens", 0) or usage.get("cacheCreationInputTokens", 0) + + # For cache hits from file cache, simulate cache metrics + if from_cache and cache_read == 0: + cache_read = usage.get("inputTokens", 0) + + if cache_read > 0: + print(f"Cache read tokens: {cache_read}") + print("✅ CACHE HIT: Content was retrieved from cache") + elif cache_write > 0: + print(f"Cache write tokens: {cache_write}") + print("📝 CACHE WRITE: Content was written to cache") + + print("\nType 'stats' to see detailed cache information") + + def main_menu(self): + """Main menu for the application""" + self.display_welcome() + self.display_system_diagram() + + while True: + print("\n" + "="*60) + print("MAIN MENU".center(60)) + print("="*60) + print("1. Load Document") + print("2. Select Model") + print("3. Start Chat Session") + print("0. Exit") + + choice = input("\nEnter your choice: ") + + if choice == "1": + self.load_document_menu() + elif choice == "2": + self.model_selection_menu() + elif choice == "3": + self.chat_menu() + elif choice == "0": + print("\nExiting application. Goodbye!") + break + else: + print("Invalid choice. Please try again.") + +def main(): + """Main entry point for the application + + Initializes the CLI interface and starts the main menu loop. + """ + try: + cli = ChatCLI() # Create CLI instance + cli.main_menu() # Start the main menu + except KeyboardInterrupt: + print("\n\nApplication terminated by user. Goodbye!") + except Exception as e: + print(f"\nAn unexpected error occurred: {str(e)}") + print("The application will now exit.") + +if __name__ == "__main__": + main() + diff --git a/BedrockPromptCachingRoutingDemo/src/bedrock_prompt_routing.py b/BedrockPromptCachingRoutingDemo/src/bedrock_prompt_routing.py new file mode 100644 index 000000000..aee01d855 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/bedrock_prompt_routing.py @@ -0,0 +1,581 @@ +""" +Amazon Bedrock Prompt Router Chat Application + +This module provides a command-line interface for interacting with Amazon Bedrock +using prompt routers. It tracks usage statistics and allows switching between +different prompt routers. +""" +import json +import boto3 +import os +import time +from file_processor import FileProcessor + + +class UsageStats: + """ + Tracks and calculates token usage statistics for chat interactions. + + This class maintains counters for input/output tokens, calculates rates, + and provides reporting functionality. + """ + def __init__(self): + # Initialize counters and tracking variables + self.total_input_tokens = 0 + self.total_output_tokens = 0 + self.total_chats = 0 + self.start_time = time.time() + self.model_invocations = {} + + def calculate_tokens(self, text): + """ + Calculate approximate number of tokens in text. + + Args: + text (str): The text to calculate tokens for + + Returns: + int: Estimated token count (based on 4 chars per token) + """ + # Simple token estimation (average 4 chars per token) + return max(1, len(text.strip()) // 4) + + def track_usage(self, input_text, output_text, model_used=None): + """ + Track token usage for a single chat interaction. + + Args: + input_text (str): User input text + output_text (str): Model response text + model_used (str, optional): Name of the model used + + Returns: + dict: Usage statistics for the current interaction + """ + # Calculate token counts + input_tokens = self.calculate_tokens(input_text) + output_tokens = self.calculate_tokens(output_text) + + # Update totals + self.total_input_tokens += input_tokens + self.total_output_tokens += output_tokens + self.total_chats += 1 + + # Track model invocations + if model_used: + self.model_invocations[model_used] = self.model_invocations.get(model_used, 0) + 1 + + # Calculate rates + elapsed_minutes = max(0.1, (time.time() - self.start_time) / 60) + tpm = (input_tokens + output_tokens) / elapsed_minutes + rpm = self.total_chats / elapsed_minutes + + # Print current interaction stats + print("\nUsage Statistics (Current Interaction):") + print("-" * 50) + print(f"{'Metric':<20} {'Count':<15} {'Details'}") + print("-" * 50) + print(f"{'Input Tokens':<20} {input_tokens:<15} (Approximate)") + print(f"{'Output Tokens':<20} {output_tokens:<15} (Approximate)") + print(f"{'Total Tokens':<20} {input_tokens + output_tokens:<15}") + print(f"{'TPM':<20} {tpm:>15.2f} (Tokens/minute)") + print(f"{'RPM':<20} {rpm:>15.2f} (Requests/minute)") + if model_used: + print(f"{'Model Used':<20} {model_used:<15} ({self.model_invocations[model_used]} invocations)") + + return { + 'input_tokens': input_tokens, + 'output_tokens': output_tokens, + 'total_tokens': input_tokens + output_tokens, + 'tpm': tpm, + 'rpm': rpm + } + + def print_total_stats(self): + """ + Print comprehensive usage statistics for the entire session. + """ + elapsed_minutes = max(0.1, (time.time() - self.start_time) / 60) + total_tokens = self.total_input_tokens + self.total_output_tokens + avg_tokens_per_chat = total_tokens / max(1, self.total_chats) + tpm = total_tokens / elapsed_minutes + rpm = self.total_chats / elapsed_minutes + + print("\nTotal Usage Statistics:") + print("=" * 60) + print(f"{'Metric':<25} {'Count':<15} {'Details'}") + print("=" * 60) + print(f"{'Total Chats':<25} {self.total_chats:<15}") + print(f"{'Total Input Tokens':<25} {self.total_input_tokens:<15} (Approximate)") + print(f"{'Total Output Tokens':<25} {self.total_output_tokens:<15} (Approximate)") + print(f"{'Total Tokens':<25} {total_tokens:<15}") + print(f"{'Avg Tokens per Chat':<25} {avg_tokens_per_chat:>15.2f}") + print(f"{'Overall TPM':<25} {tpm:>15.2f} (Tokens/minute)") + print(f"{'Overall RPM':<25} {rpm:>15.2f} (Requests/minute)") + print(f"{'Session Duration':<25} {elapsed_minutes:>15.2f} (Minutes)") + if self.model_invocations: + print("\nModel Invocations:") + for model, count in self.model_invocations.items(): + print(f"{'- ' + model:<25} {count:<15} invocations") + print("=" * 60) + + +class ChatSession: + """ + Manages a chat session with Amazon Bedrock. + + Handles message history, sending messages to Bedrock, and processing responses. + """ + def __init__(self, model_id=None, region="us-east-1"): + # Initialize Bedrock client + self.bedrock_runtime = boto3.client( + "bedrock-runtime", + region_name=region + ) + # Set default model ID or use provided one + self.model_id = model_id or "anthropic.claude-sonnet-4-5-20250929-v1:0" + # Initialize conversation history + self.messages = [] + self.usage_stats = UsageStats() + + def add_message(self, content, role="user"): + """ + Add a message to the conversation history. + + Args: + content (str): Message content + role (str): Message role (user or assistant) + """ + self.messages.append({ + "role": role, + "content": [{"text": content}] + }) + + def send_message(self, message=None): + """ + Send a message to Bedrock and process the streaming response. + + Args: + message (str, optional): User message to send + + Returns: + tuple: (trace_data, model_used) - Routing trace data and model ID + """ + if message: + self.add_message(message) + + try: + # Get streaming response from Bedrock + response = self.bedrock_runtime.converse_stream( + modelId=self.model_id, + messages=self.messages + ) + + # Process the streaming response + assistant_response = "" + trace_data = None + model_used = None + + print("\nAssistant: ", end="") + for chunk in response["stream"]: + if "contentBlockDelta" in chunk: + text = chunk["contentBlockDelta"]["delta"].get("text", "") + print(text, end="", flush=True) + assistant_response += text + + if "metadata" in chunk: + if "trace" in chunk["metadata"]: + trace_data = chunk["metadata"]["trace"] + # Extract the model used from trace data + if "promptRouter" in trace_data and "invokedModelId" in trace_data["promptRouter"]: + full_model_id = trace_data["promptRouter"]["invokedModelId"] + # Extract just the model name after the last '/' + model_used = full_model_id.split('/')[-1] if '/' in full_model_id else full_model_id + elif "selectedRoute" in trace_data: + full_model_id = trace_data["selectedRoute"].get("modelId", "Unknown model") + model_used = full_model_id.split('/')[-1] if '/' in full_model_id else full_model_id + + print("\n") + + # Track usage statistics + self.usage_stats.track_usage(message or "", assistant_response, model_used) + + if assistant_response: + self.add_message(assistant_response, role="assistant") + + if not model_used: + model_used = self.model_id + + return trace_data, model_used + + except Exception as e: + print(f"\nError: {str(e)}") + return None, None + + +class PromptRouterManager: + """ + Manages Amazon Bedrock prompt routers. + + Provides functionality to list, select, and get details about prompt routers. + """ + def __init__(self, region="us-east-1"): + self.bedrock = boto3.client('bedrock', region_name=region) + self.region = region + self.account_id = os.getenv('AWS_ACCOUNT_ID') + + # Try to get account ID if not provided in environment + if not self.account_id: + try: + sts = boto3.client('sts') + self.account_id = sts.get_caller_identity()['Account'] + except Exception as e: + print(f"Warning: Could not determine AWS account ID: {e}") + self.account_id = None + + # Initialize fallback routers for when API calls fail + self.fallback_routers = self._get_fallback_routers() + + def _get_fallback_routers(self): + """ + Get fallback routers configuration when API calls fail. + + Returns: + list: List of default router configurations + """ + if not self.account_id: + return [] + + return [ + { + 'name': 'anthropic.claude', + 'arn': f'arn:aws:bedrock:{self.region}:{self.account_id}:default-prompt-router/anthropic.claude:1', + 'provider': 'Anthropic', + 'type': 'default' + }, + { + 'name': 'meta.llama', + 'arn': f'arn:aws:bedrock:{self.region}:{self.account_id}:default-prompt-router/meta.llama:1', + 'provider': 'Meta', + 'type': 'default' + } + ] + + def extract_provider_and_name(self, router_arn): + """ + Extract provider and name from router ARN. + + Args: + router_arn (str): The ARN of the prompt router + + Returns: + tuple: (provider, router_name) - Provider and name extracted from ARN + """ + provider = 'Unknown' + router_name = 'Default Router' + + # Extract provider from ARN + if 'anthropic' in router_arn.lower(): + provider = 'Anthropic' + elif 'meta' in router_arn.lower() or 'llama' in router_arn.lower(): + provider = 'Meta' + elif 'cohere' in router_arn.lower(): + provider = 'Cohere' + elif 'ai21' in router_arn.lower(): + provider = 'AI21' + elif 'mistral' in router_arn.lower(): + provider = 'Mistral' + elif 'amazon' in router_arn.lower(): + provider = 'Amazon' + else: + # For unknown providers, extract from ARN + parts = router_arn.split('/') + if len(parts) > 1: + model_part = parts[-1].split(':')[0] + if '.' in model_part: + provider = model_part.split('.')[0].capitalize() + + # For unknown router names, use the model part from ARN + parts = router_arn.split('/') + if len(parts) > 1: + router_name = parts[-1].split(':')[0] + + return provider, router_name + + def get_prompt_routers(self): + """ + Get all available prompt routers in Bedrock. + + Returns: + list: List of prompt router configurations + """ + prompt_routers = [] + + try: + # Get custom prompt routers + try: + custom_routers = self.bedrock.list_prompt_routers(type='custom', maxResults=100) + for router in custom_routers.get('promptRouterSummaries', []): + router_arn = router.get('promptRouterArn', '') + router_name = router.get('promptRouterName', 'Custom Router') + + provider, extracted_name = self.extract_provider_and_name(router_arn) + + if router_name == 'Custom Router': + router_name = extracted_name + + prompt_routers.append({ + 'name': router_name, + 'arn': router_arn, + 'provider': provider, + 'type': 'custom' + }) + except Exception as e: + print(f"Could not fetch custom prompt routers: {str(e)}") + + # Get default prompt routers + try: + default_routers = self.bedrock.list_prompt_routers(type='default', maxResults=100) + for router in default_routers.get('promptRouterSummaries', []): + router_arn = router.get('promptRouterArn', '') + router_name = router.get('promptRouterName', 'Default Router') + + provider, extracted_name = self.extract_provider_and_name(router_arn) + + if router_name == 'Default Router': + router_name = extracted_name + + prompt_routers.append({ + 'name': router_name, + 'arn': router_arn, + 'provider': provider, + 'type': 'default' + }) + except Exception as e: + print(f"Could not fetch default prompt routers: {str(e)}") + if self.fallback_routers: + prompt_routers.extend(self.fallback_routers) + + return prompt_routers + + except Exception as e: + print(f"Error fetching prompt routers: {str(e)}") + return self.fallback_routers if self.fallback_routers else [] + + def get_router_details(self, router_arn): + """ + Get details of a prompt router including supported models. + + Args: + router_arn (str): The ARN of the prompt router + + Returns: + dict: Router details including supported models + """ + try: + response = self.bedrock.get_prompt_router(promptRouterArn=router_arn) + + # Extract supported models + supported_models = [] + if 'models' in response: + for model in response['models']: + if 'modelArn' in model: + model_arn = model['modelArn'] + if '/' in model_arn: + model_id = model_arn.split('/')[-1] + supported_models.append(model_id) + + # Add fallback model if present + if 'fallbackModel' in response and 'modelArn' in response['fallbackModel']: + fallback_arn = response['fallbackModel']['modelArn'] + if '/' in fallback_arn: + fallback_id = fallback_arn.split('/')[-1] + if fallback_id not in supported_models: + supported_models.append(fallback_id) + + return { + 'name': response.get('promptRouterName', ''), + 'description': response.get('description', ''), + 'supported_models': supported_models, + 'type': response.get('type', '') + } + except Exception as e: + print(f"Could not fetch router details for {router_arn}: {str(e)}") + return { + 'name': '', + 'description': '', + 'supported_models': [], + 'type': '' + } + + def select_prompt_router(self): + """ + Display available prompt routers and let user select one. + + Returns: + str: ARN of the selected prompt router or None if no selection + """ + routers = self.get_prompt_routers() + + if not routers: + print("No prompt routers available.") + return None + + print("\nAvailable Prompt Routers:") + print("-" * 60) + print(f"{'#':<3} {'Name':<25} {'Provider':<10} {'Type':<10}") + print("-" * 60) + + for i, router in enumerate(routers): + print(f"{i+1:<3} {router['name']:<25} {router['provider']:<10} {router['type']:<10}") + + while True: + try: + choice = input("\nSelect a prompt router (number) or press Enter for default: ") + if not choice: + return routers[0]['arn'] + + choice = int(choice) + if 1 <= choice <= len(routers): + return routers[choice-1]['arn'] + else: + print(f"Please enter a number between 1 and {len(routers)}") + except ValueError: + print("Please enter a valid number") + + +def main(): + """ + Main function to run the Bedrock Prompt Router Chat application. + """ + print("Bedrock Prompt Router Chat") + print("=" * 50) + + # Get region from environment or use default + region = os.getenv('AWS_REGION', 'us-east-1') + router_manager = PromptRouterManager(region=region) + + print("\nFetching available prompt routers...") + router_arn = router_manager.select_prompt_router() + + # Set up model ID based on router selection + if not router_arn: + print("Using default model as no router was selected.") + model_id = os.getenv('BEDROCK_MODEL_ID', "anthropic.claude-sonnet-4-5-20250929-v1:0") + else: + model_id = router_arn + router_details = router_manager.get_router_details(model_id) + print(f"\nUsing prompt router: {model_id}") + if router_details['supported_models']: + print("\nSupported Models:") + for model in router_details['supported_models']: + print(f"- {model}") + + # Initialize chat session + chat = ChatSession(model_id=model_id, region=region) + + # Display welcome message and commands + print("\nWelcome to the Bedrock Chat!") + print("Commands:") + print("- Type 'exit' or 'quit' to end the chat") + print("- Type 'trace' to toggle routing information") + print("- Type 'router' to switch prompt routers") + print("- Type 'models' to see supported models for current router") + print("- Type 'stats' to see total usage statistics") + print("- Type 'upload' to process a file (PDF, DOCX, or TXT)\n") + + # Main chat loop + show_trace = False + while True: + user_input = input("You: ").strip() + + # Handle exit command + if user_input.lower() in ["exit", "quit"]: + chat.usage_stats.print_total_stats() + print("\nGoodbye!") + break + + # Handle trace command + if user_input.lower() == "trace": + show_trace = not show_trace + print(f"Trace display: {'ON' if show_trace else 'OFF'}") + continue + + # Handle router command + if user_input.lower() == "router": + new_router_arn = router_manager.select_prompt_router() + if new_router_arn: + chat.model_id = new_router_arn + print(f"\nSwitched to prompt router: {new_router_arn}") + continue + + # Handle models command + if user_input.lower() == "models": + router_details = router_manager.get_router_details(chat.model_id) + if router_details['supported_models']: + print("\nSupported Models:") + for model in router_details['supported_models']: + print(f"- {model}") + else: + print("No model information available for this router") + continue + + # Handle stats command + if user_input.lower() == "stats": + chat.usage_stats.print_total_stats() + continue + + # Handle file upload command + if user_input.lower() == "upload": + print("\nEnter the path to your file (PDF, DOCX, or TXT):") + file_path = input().strip() + + if not os.path.exists(file_path): + print("File not found. Please check the path and try again.") + continue + + try: + import io + with open(file_path, 'rb') as file: + file_content = file.read() + uploaded_file = io.BytesIO(file_content) + uploaded_file.name = os.path.basename(file_path) + + if not FileProcessor.is_supported_file(uploaded_file.name): + print("Unsupported file type. Please upload a PDF, DOCX, or TXT file.") + continue + + print("Processing file...") + extracted_text = FileProcessor.process_uploaded_file(uploaded_file) + + if extracted_text: + print("\nFile content extracted successfully. Sending to chat...") + trace, model_used = chat.send_message(extracted_text) + else: + print("Could not extract text from the file.") + continue + except Exception as e: + print(f"Error processing file: {str(e)}") + continue + + # Process regular chat message + trace, model_used = chat.send_message(user_input) + + if trace is None: + continue + + print(f"\n[Response generated by: {model_used}]") + + # Display trace information if enabled + if show_trace and trace: + print("\nRouting Trace:") + if "promptRouter" in trace and "invokedModelId" in trace["promptRouter"]: + full_model_id = trace["promptRouter"]["invokedModelId"] + model_name = full_model_id.split('/')[-1] + print(f"Model Used: {model_name}") + else: + print(json.dumps(trace, indent=2)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/bedrock_service.py b/BedrockPromptCachingRoutingDemo/src/bedrock_service.py new file mode 100644 index 000000000..580fd3ba8 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/bedrock_service.py @@ -0,0 +1,126 @@ + +# AWS SDK for Python +import boto3 # For AWS API interactions +from botocore.exceptions import ClientError, NoCredentialsError # For handling API errors + + +class BedrockService: + """Handles connections to Amazon Bedrock services + + This class initializes and provides access to Bedrock clients: + - bedrock-runtime: For model inference and prompt caching operations + - bedrock: For management operations (model listing, etc.) + """ + + def __init__(self, region_name=None): + """Initialize Bedrock clients using boto3 + + Args: + region_name: AWS region name (optional) + + Raises: + NoCredentialsError: If AWS credentials are not found + ClientError: If there's an error creating the clients + """ + try: + # Create clients with explicit parameters to avoid potential issues + kwargs = {'service_name': 'bedrock-runtime'} + if region_name: + kwargs['region_name'] = region_name + + self._bedrock_runtime = boto3.client(**kwargs) + + kwargs['service_name'] = 'bedrock' + self._bedrock = boto3.client(**kwargs) + + except NoCredentialsError as e: + raise NoCredentialsError("AWS credentials not found. Please configure your credentials.") from e + except Exception as e: + raise ClientError( + error_response={"Error": {"Message": f"Failed to initialize Bedrock clients: {str(e)}"}}, + operation_name="__init__" + ) from e + + def get_runtime_client(self): + """Return the Bedrock runtime client for inference operations + + Returns: + boto3 client for bedrock-runtime + + Raises: + RuntimeError: If the client is not initialized + """ + if not hasattr(self, '_bedrock_runtime') or self._bedrock_runtime is None: + raise RuntimeError("Bedrock runtime client is not initialized") + return self._bedrock_runtime + + def get_bedrock_client(self): + """Return the Bedrock management client for admin operations + + Returns: + boto3 client for bedrock + + Raises: + RuntimeError: If the client is not initialized + """ + if not hasattr(self, '_bedrock') or self._bedrock is None: + raise RuntimeError("Bedrock management client is not initialized") + return self._bedrock + + @property + def bedrock_runtime(self): + """Property to access the bedrock runtime client""" + return self.get_runtime_client() + + @property + def bedrock(self): + """Property to access the bedrock management client""" + return self.get_bedrock_client() + + def list_inference_profiles(self, max_results=None, next_token=None, type_equals=None): + """List inference profiles from Bedrock + + Args: + max_results: Maximum number of results to return + next_token: Token for pagination + type_equals: Filter by profile type ('SYSTEM_DEFINED' or 'APPLICATION') + + Returns: + Response from the list_inference_profiles API call + + Raises: + ClientError: If there's an error calling the API + RuntimeError: If the client is not initialized + """ + # Validate parameters + if max_results is not None and not isinstance(max_results, int): + raise ValueError("max_results must be an integer") + + if type_equals is not None and type_equals not in ['SYSTEM_DEFINED', 'APPLICATION']: + raise ValueError("type_equals must be 'SYSTEM_DEFINED' or 'APPLICATION'") + + params = {} + if max_results: + params['maxResults'] = max_results + if next_token: + params['nextToken'] = next_token + if type_equals: + params['typeEquals'] = type_equals + + try: + return self.bedrock.list_inference_profiles(**params) + except ClientError as e: + # Re-raise with more context + raise ClientError( + error_response=e.response, + operation_name="list_inference_profiles" + ) from e + + def __del__(self): + """Clean up resources when the object is garbage collected""" + # Clear references to help with garbage collection + if hasattr(self, '_bedrock_runtime'): + self._bedrock_runtime = None + if hasattr(self, '_bedrock'): + self._bedrock = None + diff --git a/BedrockPromptCachingRoutingDemo/src/file_processor.py b/BedrockPromptCachingRoutingDemo/src/file_processor.py new file mode 100644 index 000000000..70d7fdf3c --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/file_processor.py @@ -0,0 +1,103 @@ +import os +import PyPDF2 +import docx +import tempfile + +class FileProcessor: + """Class to handle file uploads and text extraction""" + + SUPPORTED_EXTENSIONS = {'.pdf', '.docx', '.txt', '.py', '.md', '.json', '.html', '.css', '.js'} + + @staticmethod + def extract_text_from_pdf(file): + """Extract text from a PDF file""" + pdf_text = "" + try: + # File is already a BytesIO object from bedrock_prompt_routing.py + pdf_reader = PyPDF2.PdfReader(file) + for page_num in range(len(pdf_reader.pages)): + page = pdf_reader.pages[page_num] + pdf_text += page.extract_text() + "\n\n" + return pdf_text + except Exception as e: + print(f"Error extracting text from PDF: {str(e)}") + return "" + + @staticmethod + def extract_text_from_docx(file): + """Extract text from a DOCX file""" + try: + # Save the uploaded file to a temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp: + tmp.write(file.getvalue()) + tmp_path = tmp.name + + # Open the temporary file with python-docx + doc = docx.Document(tmp_path) + text = [paragraph.text for paragraph in doc.paragraphs] + + # Clean up the temporary file + os.unlink(tmp_path) + + return "\n\n".join(text) + except Exception as e: + print(f"Error extracting text from DOCX: {str(e)}") + return "" + + @staticmethod + def extract_text_from_txt(file): + """Extract text from a TXT file""" + try: + # Handle different file object types + if hasattr(file, 'getvalue'): + # Standard file object from Gradio + return file.getvalue().decode('utf-8') + elif hasattr(file, 'read'): + # File-like object with read method + return file.read().decode('utf-8') + elif isinstance(file, str): + # Already a string + return file + else: + # Try to convert to string + return str(file) + except Exception as e: + print(f"Error extracting text from TXT: {str(e)}") + return "" + + @classmethod + def process_uploaded_file(cls, uploaded_file): + """Process an uploaded file and extract text""" + if uploaded_file is None: + return "" + + # Get file extension safely + if hasattr(uploaded_file, 'name'): + file_name = uploaded_file.name + else: + # If it's a path string + file_name = str(uploaded_file) + + # Extract extension + file_ext = os.path.splitext(file_name)[1].lower() + + # Process based on extension + if file_ext == '.pdf': + return cls.extract_text_from_pdf(uploaded_file) + elif file_ext == '.docx': + return cls.extract_text_from_docx(uploaded_file) + elif file_ext == '.txt' or file_ext == '.py' or file_ext == '.md' or file_ext == '.json': + return cls.extract_text_from_txt(uploaded_file) + else: + print(f"Unsupported file type: {file_ext}") + # Try to extract as text anyway for common text-based formats + try: + return cls.extract_text_from_txt(uploaded_file) + except: + return "" + + @classmethod + def is_supported_file(cls, filename): + """Check if the file type is supported""" + ext = os.path.splitext(filename)[1].lower() + return ext in cls.SUPPORTED_EXTENSIONS \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/images/prompt-caching.png b/BedrockPromptCachingRoutingDemo/src/images/prompt-caching.png new file mode 100644 index 000000000..8b00c4bec Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/images/prompt-caching.png differ diff --git a/BedrockPromptCachingRoutingDemo/src/knowledge_base.py b/BedrockPromptCachingRoutingDemo/src/knowledge_base.py new file mode 100644 index 000000000..69b7000a8 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/knowledge_base.py @@ -0,0 +1,1325 @@ +import json +import boto3 +import time +from botocore.exceptions import ClientError +from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, RequestError +import pprint +from retrying import retry +import zipfile +from io import BytesIO +import warnings +import random +warnings.filterwarnings('ignore') + +valid_generation_models = ["anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-haiku-20241022-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", + "amazon.nova-micro-v1:0"] + +valid_reranking_models = ["cohere.rerank-v3-5:0", + "amazon.rerank-v1:0"] + +valid_embedding_models = ["cohere.embed-multilingual-v3", + "cohere.embed-english-v3", + "amazon.titan-embed-text-v1", + "amazon.titan-embed-text-v2:0"] + +embedding_context_dimensions = { + "cohere.embed-multilingual-v3": 512, + "cohere.embed-english-v3": 512, + "amazon.titan-embed-text-v1": 1536, + "amazon.titan-embed-text-v2:0": 1024 +} + +pp = pprint.PrettyPrinter(indent=2) + +def interactive_sleep(seconds: int): + dots = '' + for i in range(seconds): + dots += '.' + print(dots, end='\r') + time.sleep(1) + +class BedrockKnowledgeBase: + """ + Support class that allows for: + - creation (or retrieval) of a Knowledge Base for Amazon Bedrock with all its pre-requisites + (including OSS, IAM roles and Permissions and S3 bucket) + - Ingestion of data into the Knowledge Base + - Deletion of all resources created + """ + def __init__( + self, + kb_name=None, + kb_description=None, + data_sources=None, + multi_modal=None, + parser=None, + intermediate_bucket_name=None, + lambda_function_name=None, + embedding_model="amazon.titan-embed-text-v2:0", + generation_model="anthropic.claude-3-sonnet-20240229-v1:0", + reranking_model="cohere.rerank-v3-5:0", + graph_model="anthropic.claude-3-haiku-20240307-v1:0", + chunking_strategy="FIXED_SIZE", + suffix=None, + vector_store="OPENSEARCH_SERVERLESS" # can be OPENSEARCH_SERVERLESS or NEPTUNE_ANALYTICS + ): + """ + Class initializer + Args: + kb_name(str): The name of the Knowledge Base. + kb_description(str): The description of the Knowledge Base. + data_sources(list): The list of data source used for the Knowledge Base. + multi_modal(bool): Whether the Knowledge Base supports multi-modal data. + parser(str): The parser to be used for the Knowledge Base. + intermediate_bucket_name(str): The name of the intermediate S3 bucket to be used for custom chunking strategy. + lambda_function_name(str): The name of the Lambda function to be used for custom chunking strategy. + embedding_model(str): The embedding model to be used for the Knowledge Base. + generation_model(str): The generation model to be used for the Knowledge Base. + reranking_model(str): The reranking model to be used for the Knowledge Base. + chunking_strategy(str): The chunking strategy to be used for the Knowledge Base. + suffix(str): A suffix to be used for naming resources. + """ + + boto3_session = boto3.session.Session() + self.region_name = boto3_session.region_name + self.iam_client = boto3_session.client('iam') + self.lambda_client = boto3.client('lambda') + self.account_number = boto3.client('sts').get_caller_identity().get('Account') + self.suffix = suffix or f'{self.region_name}-{self.account_number}' + self.identity = boto3.client('sts').get_caller_identity()['Arn'] + self.aoss_client = boto3_session.client('opensearchserverless') + self.neptune_client = boto3.client('neptune-graph') + self.s3_client = boto3.client('s3') + self.bedrock_agent_client = boto3.client('bedrock-agent') + credentials = boto3.Session().get_credentials() + self.awsauth = AWSV4SignerAuth(credentials, self.region_name, 'aoss') + + self.kb_name = kb_name or f"default-knowledge-base-{self.suffix}" + self.vector_store = vector_store + self.graph_name = self.kb_name + self.kb_description = kb_description or "Default Knowledge Base" + + self.data_sources = data_sources + self.bucket_names=[d["bucket_name"] for d in self.data_sources if d['type']== 'S3'] + self.secrets_arns = [d["credentialsSecretArn"] for d in self.data_sources if d['type']== 'CONFLUENCE'or d['type']=='SHAREPOINT' or d['type']=='SALESFORCE'] + self.chunking_strategy = chunking_strategy + self.multi_modal = multi_modal + self.parser = parser + + if multi_modal or chunking_strategy == "CUSTOM" : + self.intermediate_bucket_name = intermediate_bucket_name or f"{self.kb_name}-intermediate-{self.suffix}" + self.lambda_function_name = lambda_function_name or f"{self.kb_name}-lambda-{self.suffix}" + else: + self.intermediate_bucket_name = None + self.lambda_function_name = None + + self.embedding_model = embedding_model + self.generation_model = generation_model + self.reranking_model = reranking_model + self.graph_model = graph_model + + self._validate_models() + + self.encryption_policy_name = f"bedrock-sample-rag-sp-{self.suffix}" + self.network_policy_name = f"bedrock-sample-rag-np-{self.suffix}" + self.access_policy_name = f'bedrock-sample-rag-ap-{self.suffix}' + self.kb_execution_role_name = f'AmazonBedrockExecutionRoleForKnowledgeBase_{self.suffix}' + self.fm_policy_name = f'AmazonBedrockFoundationModelPolicyForKnowledgeBase_{self.suffix}' + self.s3_policy_name = f'AmazonBedrockS3PolicyForKnowledgeBase_{self.suffix}' + self.sm_policy_name = f'AmazonBedrockSecretPolicyForKnowledgeBase_{self.suffix}' + self.cw_log_policy_name = f'AmazonBedrockCloudWatchPolicyForKnowledgeBase_{self.suffix}' + self.oss_policy_name = f'AmazonBedrockOSSPolicyForKnowledgeBase_{self.suffix}' + self.lambda_policy_name = f'AmazonBedrockLambdaPolicyForKnowledgeBase_{self.suffix}' + self.bda_policy_name = f'AmazonBedrockBDAPolicyForKnowledgeBase_{self.suffix}' + self.neptune_policy_name = f'AmazonBedrockNeptunePolicyForKnowledgeBase_{self.suffix}' + self.lambda_arn = None + self.roles = [self.kb_execution_role_name] + + self.vector_store_name = f'bedrock-sample-rag-{self.suffix}' + self.index_name = f"bedrock-sample-rag-index-{self.suffix}" + self.graph_id = None + + self._setup_resources() + + def _validate_models(self): + if self.embedding_model not in valid_embedding_models: + raise ValueError(f"Invalid embedding model. Your embedding model should be one of {valid_embedding_models}") + if self.generation_model not in valid_generation_models: + raise ValueError(f"Invalid Generation model. Your generation model should be one of {valid_generation_models}") + if self.reranking_model not in valid_reranking_models: + raise ValueError(f"Invalid Reranking model. Your reranking model should be one of {valid_reranking_models}") + + def _setup_resources(self): + print("========================================================================================") + print(f"Step 1 - Creating or retrieving S3 bucket(s) for Knowledge Base documents") + self.create_s3_bucket() + + print("========================================================================================") + print(f"Step 2 - Creating Knowledge Base Execution Role ({self.kb_execution_role_name}) and Policies") + self.bedrock_kb_execution_role = self.create_bedrock_execution_role_multi_ds(self.bucket_names, self.secrets_arns) + self.bedrock_kb_execution_role_name = self.bedrock_kb_execution_role['Role']['RoleName'] + + if self.vector_store == "OPENSEARCH_SERVERLESS": + print("========================================================================================") + print(f"Step 3a - Creating OSS encryption, network and data access policies") + self.encryption_policy, self.network_policy, self.access_policy = self.create_policies_in_oss() + + print("========================================================================================") + print(f"Step 3b - Creating OSS Collection (this step takes a couple of minutes to complete)") + self.host, self.collection, self.collection_id, self.collection_arn = self.create_oss() + self.oss_client = OpenSearch( + hosts=[{'host': self.host, 'port': 443}], + http_auth=self.awsauth, + use_ssl=True, + verify_certs=True, + connection_class=RequestsHttpConnection, + timeout=300 + ) + + print("========================================================================================") + print(f"Step 3c - Creating OSS Vector Index") + self.create_vector_index() + else: + print("========================================================================================") + print(f"Step 3 - Creating Neptune Analytics Graph Index: might take upto 5-7 minutes") + self.graph_id = self.create_neptune() + + + + print("========================================================================================") + print(f"Step 4 - Will create Lambda Function if chunking strategy selected as CUSTOM") + if self.chunking_strategy == "CUSTOM": + print(f"Creating lambda function... as chunking strategy is {self.chunking_strategy}") + response = self.create_lambda() + self.lambda_arn = response['FunctionArn'] + print(response) + print(f"Lambda function ARN: {self.lambda_arn}") + else: + print(f"Not creating lambda function as chunking strategy is {self.chunking_strategy}") + + print("========================================================================================") + print(f"Step 5 - Creating Knowledge Base") + self.knowledge_base, self.data_source = self.create_knowledge_base(self.data_sources) + print("========================================================================================") + + def create_s3_bucket(self, multi_modal=False): + + buckets_to_check = self.bucket_names.copy() + # if multi_modal: + # buckets_to_check.append(buckets_to_check[0] + '-multi-modal-storage') + + if self.multi_modal or self.chunking_strategy == "CUSTOM": + buckets_to_check.append(self.intermediate_bucket_name) + + print(buckets_to_check) + print('buckets_to_check: ', buckets_to_check) + + existing_buckets = [] + for bucket_name in buckets_to_check: + try: + self.s3_client.head_bucket(Bucket=bucket_name) + existing_buckets.append(bucket_name) + print(f'Bucket {bucket_name} already exists - retrieving it!') + except ClientError: + pass + + buckets_to_create = [b for b in buckets_to_check if b not in existing_buckets] + + for bucket_name in buckets_to_create: + print(f'Creating bucket {bucket_name}') + if self.region_name == "us-east-1": + self.s3_client.create_bucket(Bucket=bucket_name) + else: + self.s3_client.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={'LocationConstraint': self.region_name} + ) + + def create_lambda(self): + # add to function + lambda_iam_role = self.create_lambda_role() + self.lambda_iam_role_name = lambda_iam_role['Role']['RoleName'] + self.roles.append(self.lambda_iam_role_name) + # Package up the lambda function code + s = BytesIO() + z = zipfile.ZipFile(s, 'w') + z.write("lambda_function.py") + z.close() + zip_content = s.getvalue() + + # Create Lambda Function + lambda_function = self.lambda_client.create_function( + FunctionName=self.lambda_function_name, + Runtime='python3.12', + Timeout=60, + Role=lambda_iam_role['Role']['Arn'], + Code={'ZipFile': zip_content}, + Handler='lambda_function.lambda_handler' + ) + return lambda_function + + def create_lambda_role(self): + lambda_function_role = f'{self.kb_name}-lambda-role-{self.suffix}' + s3_access_policy_name = f'{self.kb_name}-s3-policy' + # Create IAM Role for the Lambda function + try: + assume_role_policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": "lambda.amazonaws.com" + }, + "Action": "sts:AssumeRole" + } + ] + } + + assume_role_policy_document_json = json.dumps(assume_role_policy_document) + + lambda_iam_role = self.iam_client.create_role( + RoleName=lambda_function_role, + AssumeRolePolicyDocument=assume_role_policy_document_json + ) + + # Pause to make sure role is created + time.sleep(10) + except self.iam_client.exceptions.EntityAlreadyExistsException: + lambda_iam_role = self.iam_client.get_role(RoleName=lambda_function_role) + + # Attach the AWSLambdaBasicExecutionRole policy + self.iam_client.attach_role_policy( + RoleName=lambda_function_role, + PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' + ) + + # Create a policy to grant access to the intermediate S3 bucket + s3_access_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:ListBucket", + "s3:PutObject" + ], + "Resource": [ + f"arn:aws:s3:::{self.intermediate_bucket_name}", + f"arn:aws:s3:::{self.intermediate_bucket_name}/*" + ], + "Condition": { + "StringEquals": { + "aws:ResourceAccount": f"{self.account_number}" + } + } + } + ] + } + + # Create the policy + s3_access_policy_json = json.dumps(s3_access_policy) + s3_access_policy_response = self.iam_client.create_policy( + PolicyName=s3_access_policy_name, + PolicyDocument= s3_access_policy_json + ) + + # Attach the policy to the Lambda function's role + self.iam_client.attach_role_policy( + RoleName=lambda_function_role, + PolicyArn=s3_access_policy_response['Policy']['Arn'] + ) + return lambda_iam_role + + def create_bedrock_execution_role_multi_ds(self, bucket_names=None, secrets_arns=None): + """ + Create Knowledge Base Execution IAM Role and its required policies. + If role and/or policies already exist, retrieve them + Returns: + IAM role + """ + + bucket_names = self.bucket_names.copy() + if self.intermediate_bucket_name: + bucket_names.append(self.intermediate_bucket_name) + + # Check if the role already exists + try: + existing_role = self.iam_client.get_role(RoleName=self.kb_execution_role_name) + print(f"Using existing IAM role: {self.kb_execution_role_name}") + return existing_role + except self.iam_client.exceptions.NoSuchEntityException: + # Role doesn't exist, continue with creation + print(f"Creating new IAM role: {self.kb_execution_role_name}") + pass + + # 1. Create and attach policy for foundation models + foundation_model_policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModel", + ], + "Resource": [ + f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.embedding_model}", + f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.generation_model}", + f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.reranking_model}", + f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.graph_model}" + ] + } + ] + } + + # 2. Define policy documents for s3 bucket + if bucket_names: + s3_policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:ListBucket", + "s3:PutObject", + "s3:DeleteObject" + ], + "Resource": [item for sublist in [[f'arn:aws:s3:::{bucket}', f'arn:aws:s3:::{bucket}/*'] for bucket in bucket_names] for item in sublist], + "Condition": { + "StringEquals": { + "aws:ResourceAccount": f"{self.account_number}" + } + } + } + ] + } + if self.vector_store == "NEPTUNE_ANALYTICS": + neptune_policy_name = { + "Version": "2012-10-17", + "Statement": [{ + "Sid": "NeptuneAnalyticsAccess", + "Effect": "Allow", + "Action": [ + "*" + ], + "Resource": f"arn:aws:neptune-graph:{self.region_name}:{self.account_number}:graph/*" + } + ] + } + + + # 3. Define policy documents for secrets manager + if secrets_arns: + secrets_manager_policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "secretsmanager:GetSecretValue", + "secretsmanager:PutSecretValue" + ], + "Resource": secrets_arns + } + ] + } + + # 4. Define policy documents for BDA + bda_policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "BDAGetStatement", + "Effect": "Allow", + "Action": [ + "bedrock:GetDataAutomationStatus" + ], + "Resource": f"arn:aws:bedrock:{self.region_name}:{self.account_number}:data-automation-invocation/*" + }, + { + "Sid": "BDAInvokeStatement", + "Effect": "Allow", + "Action": [ + "bedrock:InvokeDataAutomationAsync" + ], + "Resource": f"arn:aws:bedrock:{self.region_name}:aws:data-automation-project/public-rag-default" + } + ] + } + + + # 5. Define policy documents for lambda + if self.chunking_strategy == "CUSTOM": + lambda_policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "LambdaInvokeFunctionStatement", + "Effect": "Allow", + "Action": [ + "lambda:InvokeFunction" + ], + "Resource": [ + f"arn:aws:lambda:{self.region_name}:{self.account_number}:function:{self.lambda_function_name}:*" + ], + "Condition": { + "StringEquals": { + "aws:ResourceAccount": f"{self.account_number}" + } + } + } + ] + } + + cw_log_policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "logs:CreateLogStream", + "logs:PutLogEvents", + "logs:DescribeLogStreams" + ], + "Resource": "arn:aws:logs:*:*:log-group:/aws/bedrock/invokemodel:*" + } + ] + } + + assume_role_policy_document = { + "Version": "2012-10-17", + "Statement": [{ + "Effect": "Allow", + "Principal": { + "Service": "bedrock.amazonaws.com" + }, + "Action": "sts:AssumeRole", + "Condition": { + "StringEquals": { + "aws:SourceAccount": f"{self.account_number}" + }, + "ArnLike": { + "AWS:SourceArn": f"arn:aws:bedrock:{self.region_name}:{self.account_number}:knowledge-base/*" + } + } + }] + } + + # combine all policies into one list from policy documents + policies = [ + (self.fm_policy_name, foundation_model_policy_document, 'Policy for accessing foundation model'), + (self.cw_log_policy_name, cw_log_policy_document, 'Policy for writing logs to CloudWatch Logs'), + ] + if self.bucket_names: + policies.append((self.s3_policy_name, s3_policy_document, 'Policy for reading documents from s3')) + if self.secrets_arns: + policies.append((self.sm_policy_name, secrets_manager_policy_document, 'Policy for accessing secret manager')) + if self.chunking_strategy == 'CUSTOM': + policies.append((self.lambda_policy_name, lambda_policy_document, 'Policy for invoking lambda function')) + if self.multi_modal: + policies.append((self.bda_policy_name, bda_policy_document, 'Policy for accessing BDA')) + if self.vector_store == "NEPTUNE_ANALYTICS": + policies.append((self.neptune_policy_name, neptune_policy_name, 'Policy for Neptune Vector Store')) + + # create bedrock execution role + try: + bedrock_kb_execution_role = self.iam_client.create_role( + RoleName=self.kb_execution_role_name, + AssumeRolePolicyDocument=json.dumps(assume_role_policy_document), + Description='Amazon Bedrock Knowledge Base Execution Role for accessing OSS, secrets manager and S3', + MaxSessionDuration=3600 + ) + except self.iam_client.exceptions.EntityAlreadyExistsException: + # If role already exists, get it + bedrock_kb_execution_role = self.iam_client.get_role(RoleName=self.kb_execution_role_name) + print(f"Using existing IAM role: {self.kb_execution_role_name}") + + # create and attach the policies to the bedrock execution role + for policy_name, policy_document, description in policies: + policy = self.iam_client.create_policy( + PolicyName=policy_name, + PolicyDocument=json.dumps(policy_document), + Description=description, + ) + self.iam_client.attach_role_policy( + RoleName=bedrock_kb_execution_role["Role"]["RoleName"], + PolicyArn=policy["Policy"]["Arn"] + ) + + return bedrock_kb_execution_role + + def create_neptune(self): + response = self.neptune_client.create_graph( + graphName=self.graph_name, + tags={ + 'usecase': 'graphRAG' + }, + publicConnectivity=True, + vectorSearchConfiguration={ + 'dimension': embedding_context_dimensions[self.embedding_model] + }, + replicaCount=1, + deletionProtection=True, + provisionedMemory=16 + ) + graph_id = response["id"] + + self.neptune_client.get_graph(graphIdentifier=graph_id)["status"] + try: + while self.neptune_client.get_graph(graphIdentifier=graph_id)["status"] == "CREATING": + print("Graph is getting creating...") + time.sleep(90) + if response["status"] == "CREATED": + print("Graph created successfully") + except KeyError as e: + print(f"Error: 'status' key not found in response dictionary: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + return graph_id + + def create_policies_in_oss(self): + """ + Create OpenSearch Serverless policy and attach it to the Knowledge Base Execution role. + If policy already exists, attaches it + """ + try: + encryption_policy = self.aoss_client.create_security_policy( + name=self.encryption_policy_name, + policy=json.dumps( + { + 'Rules': [{'Resource': ['collection/' + self.vector_store_name], + 'ResourceType': 'collection'}], + 'AWSOwnedKey': True + }), + type='encryption' + ) + except self.aoss_client.exceptions.ConflictException: + encryption_policy = self.aoss_client.get_security_policy( + name=self.encryption_policy_name, + type='encryption' + ) + + try: + network_policy = self.aoss_client.create_security_policy( + name=self.network_policy_name, + policy=json.dumps( + [ + {'Rules': [{'Resource': ['collection/' + self.vector_store_name], + 'ResourceType': 'collection'}], + 'AllowFromPublic': True} + ]), + type='network' + ) + except self.aoss_client.exceptions.ConflictException: + network_policy = self.aoss_client.get_security_policy( + name=self.network_policy_name, + type='network' + ) + + try: + access_policy = self.aoss_client.create_access_policy( + name=self.access_policy_name, + policy=json.dumps( + [ + { + 'Rules': [ + { + 'Resource': ['collection/' + self.vector_store_name], + 'Permission': [ + 'aoss:CreateCollectionItems', + 'aoss:DeleteCollectionItems', + 'aoss:UpdateCollectionItems', + 'aoss:DescribeCollectionItems'], + 'ResourceType': 'collection' + }, + { + 'Resource': ['index/' + self.vector_store_name + '/*'], + 'Permission': [ + 'aoss:CreateIndex', + 'aoss:DeleteIndex', + 'aoss:UpdateIndex', + 'aoss:DescribeIndex', + 'aoss:ReadDocument', + 'aoss:WriteDocument'], + 'ResourceType': 'index' + }], + 'Principal': [self.identity, self.bedrock_kb_execution_role['Role']['Arn']], + 'Description': 'Easy data policy'} + ]), + type='data' + ) + except self.aoss_client.exceptions.ConflictException: + access_policy = self.aoss_client.get_access_policy( + name=self.access_policy_name, + type='data' + ) + + return encryption_policy, network_policy, access_policy + + def create_oss(self): + """ + Create OpenSearch Serverless Collection. If already existent, retrieve + """ + try: + collection = self.aoss_client.create_collection(name=self.vector_store_name, type='VECTORSEARCH') + collection_id = collection['createCollectionDetail']['id'] + collection_arn = collection['createCollectionDetail']['arn'] + except self.aoss_client.exceptions.ConflictException: + collection = self.aoss_client.batch_get_collection(names=[self.vector_store_name])['collectionDetails'][0] + collection_id = collection['id'] + collection_arn = collection['arn'] + pp.pprint(collection) + + host = collection_id + '.' + self.region_name + '.aoss.amazonaws.com' + print(host) + + response = self.aoss_client.batch_get_collection(names=[self.vector_store_name]) + while (response['collectionDetails'][0]['status']) == 'CREATING': + print('Creating collection...') + interactive_sleep(30) + response = self.aoss_client.batch_get_collection(names=[self.vector_store_name]) + print('\nCollection successfully created:') + pp.pprint(response["collectionDetails"]) + + try: + self.create_oss_policy_attach_bedrock_execution_role(collection_id) + print("Sleeping for a minute to ensure data access rules have been enforced") + interactive_sleep(60) + except Exception as e: + print("Policy already exists") + pp.pprint(e) + + return host, collection, collection_id, collection_arn + + def create_oss_policy_attach_bedrock_execution_role(self, collection_id): + oss_policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "aoss:APIAccessAll" + ], + "Resource": [ + f"arn:aws:aoss:{self.region_name}:{self.account_number}:collection/{collection_id}" + ] + } + ] + } + try: + oss_policy = self.iam_client.create_policy( + PolicyName=self.oss_policy_name, + PolicyDocument=json.dumps(oss_policy_document), + Description='Policy for accessing opensearch serverless', + ) + oss_policy_arn = oss_policy["Policy"]["Arn"] + except self.iam_client.exceptions.EntityAlreadyExistsException: + oss_policy_arn = f"arn:aws:iam::{self.account_number}:policy/{self.oss_policy_name}" + + print("Opensearch serverless arn: ", oss_policy_arn) + + self.iam_client.attach_role_policy( + RoleName=self.bedrock_kb_execution_role["Role"]["RoleName"], + PolicyArn=oss_policy_arn + ) + + def create_vector_index(self): + """ + Create OpenSearch Serverless vector index. If existent, ignore + """ + body_json = { + "settings": { + "index.knn": "true", + "number_of_shards": 1, + "knn.algo_param.ef_search": 512, + "number_of_replicas": 0, + }, + "mappings": { + "properties": { + "vector": { + "type": "knn_vector", + "dimension": embedding_context_dimensions[self.embedding_model], + "method": { + "name": "hnsw", + "engine": "faiss", + "space_type": "l2" + }, + }, + "text": { + "type": "text" + }, + "text-metadata": { + "type": "text"} + } + } + } + + try: + response = self.oss_client.indices.create(index=self.index_name, body=json.dumps(body_json)) + print('\nCreating index:') + pp.pprint(response) + interactive_sleep(60) + except RequestError as e: + print(f'Error while trying to create the index, with error {e.error}') + + def create_chunking_strategy_config(self, strategy): + configs = { + + "GRAPH": { + "contextEnrichmentConfiguration": { + "bedrockFoundationModelConfiguration": { + "enrichmentStrategyConfiguration": { + "method": "CHUNK_ENTITY_EXTRACTION" + }, + "modelArn": f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.graph_model}" + }, + "type": "BEDROCK_FOUNDATION_MODEL" + } + }, + + "NONE": { + "chunkingConfiguration": {"chunkingStrategy": "NONE"} + }, + "FIXED_SIZE": { + "chunkingConfiguration": { + "chunkingStrategy": "FIXED_SIZE", + "fixedSizeChunkingConfiguration": { + "maxTokens": 300, + "overlapPercentage": 20 + } + } + }, + "HIERARCHICAL": { + "chunkingConfiguration": { + "chunkingStrategy": "HIERARCHICAL", + "hierarchicalChunkingConfiguration": { + "levelConfigurations": [{"maxTokens": 1500}, {"maxTokens": 300}], + "overlapTokens": 60 + } + } + }, + "SEMANTIC": { + "chunkingConfiguration": { + "chunkingStrategy": "SEMANTIC", + "semanticChunkingConfiguration": { + "maxTokens": 300, + "bufferSize": 1, + "breakpointPercentileThreshold": 95} + } + }, + "CUSTOM": { + "customTransformationConfiguration": { + "intermediateStorage": { + "s3Location": { + "uri": f"s3://{self.intermediate_bucket_name}/" + } + }, + "transformations": [ + { + "transformationFunction": { + "transformationLambdaConfiguration": { + "lambdaArn": self.lambda_arn + } + }, + "stepToApply": "POST_CHUNKING" + } + ] + }, + "chunkingConfiguration": {"chunkingStrategy": "NONE"} + } + } + return configs.get(strategy, configs["NONE"]) + + @retry(wait_random_min=1000, wait_random_max=2000, stop_max_attempt_number=7) + def create_knowledge_base(self, data_sources): + """ + Create Knowledge Base and its Data Source. If existent, retrieve + """ + if self.graph_id: + storage_configuration = { + "type": "NEPTUNE_ANALYTICS", + "neptuneAnalyticsConfiguration": { + "graphArn": f"arn:aws:neptune-graph:{self.region_name}:{self.account_number}:graph/{self.graph_id}", + "fieldMapping": { + "textField": "text", + "metadataField": "text-metadata" + } + } + } + else: + storage_configuration = { + "type": "OPENSEARCH_SERVERLESS", + "opensearchServerlessConfiguration": { + "collectionArn": self.collection_arn, + "vectorIndexName": self.index_name, + "fieldMapping": { + "vectorField": "vector", + "textField": "text", + "metadataField": "text-metadata" + } + } + } + + # create Knowledge Bases + embedding_model_arn = f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.embedding_model}" + knowledgebase_configuration = { "type": "VECTOR", "vectorKnowledgeBaseConfiguration": { "embeddingModelArn": embedding_model_arn}} + + if self.multi_modal: + supplemental_storageLocation={"storageLocations": [{ "s3Location": { "uri": f"s3://{self.intermediate_bucket_name}"},"type": "S3"}]} + knowledgebase_configuration['vectorKnowledgeBaseConfiguration']['supplementalDataStorageConfiguration'] = supplemental_storageLocation + + try: + create_kb_response = self.bedrock_agent_client.create_knowledge_base( + name=self.kb_name, + description=self.kb_description, + roleArn=self.bedrock_kb_execution_role['Role']['Arn'], + knowledgeBaseConfiguration=knowledgebase_configuration, + storageConfiguration=storage_configuration, + ) + kb = create_kb_response["knowledgeBase"] + pp.pprint(kb) + except self.bedrock_agent_client.exceptions.ConflictException: + kbs = self.bedrock_agent_client.list_knowledge_bases(maxResults=100) + kb_id = next((kb['knowledgeBaseId'] for kb in kbs['knowledgeBaseSummaries'] if kb['name'] == self.kb_name), None) + response = self.bedrock_agent_client.get_knowledge_base(knowledgeBaseId=kb_id) + kb = response['knowledgeBase'] + pp.pprint(kb) + + # create Data Sources + print("Creating Data Sources") + try: + ds_list = self.create_data_sources(kb_id, self.data_sources) + pp.pprint(ds_list) + except self.bedrock_agent_client.exceptions.ConflictException: + ds_id = self.bedrock_agent_client.list_data_sources( + knowledgeBaseId=kb['knowledgeBaseId'], + maxResults=100 + )['dataSourceSummaries'][0]['dataSourceId'] + get_ds_response = self.bedrock_agent_client.get_data_source( + dataSourceId=ds_id, + knowledgeBaseId=kb['knowledgeBaseId'] + ) + ds = get_ds_response["dataSource"] + pp.pprint(ds) + return kb, ds_list + + def create_data_sources(self, kb_id, data_sources): + """ + Create Data Sources for the Knowledge Base. + """ + ds_list=[] + + # create data source for each data source type in list data_sources + for idx, ds in enumerate(data_sources): + + # The data source to ingest documents from, into the OpenSearch serverless knowledge base index + s3_data_source_congiguration = { + "type": "S3", + "s3Configuration":{ + "bucketArn": "", + # "inclusionPrefixes":["*.*"] # you can use this if you want to create a KB using data within s3 prefixes. + } + } + + confluence_data_source_congiguration = { + "confluenceConfiguration": { + "sourceConfiguration": { + "hostUrl": "", + "hostType": "SAAS", + "authType": "", # BASIC | OAUTH2_CLIENT_CREDENTIALS + "credentialsSecretArn": "" + + }, + "crawlerConfiguration": { + "filterConfiguration": { + "type": "PATTERN", + "patternObjectFilter": { + "filters": [ + { + "objectType": "Attachment", + "inclusionFilters": [ + ".*\\.pdf" + ], + "exclusionFilters": [ + ".*private.*\\.pdf" + ] + } + ] + } + } + } + }, + "type": "CONFLUENCE" + } + + sharepoint_data_source_congiguration = { + "sharePointConfiguration": { + "sourceConfiguration": { + "tenantId": "", + "hostType": "ONLINE", + "domain": "domain", + "siteUrls": [], + "authType": "", # BASIC | OAUTH2_CLIENT_CREDENTIALS + "credentialsSecretArn": "" + + }, + "crawlerConfiguration": { + "filterConfiguration": { + "type": "PATTERN", + "patternObjectFilter": { + "filters": [ + { + "objectType": "Attachment", + "inclusionFilters": [ + ".*\\.pdf" + ], + "exclusionFilters": [ + ".*private.*\\.pdf" + ] + } + ] + } + } + } + }, + "type": "SHAREPOINT" + } + + + salesforce_data_source_congiguration = { + "salesforceConfiguration": { + "sourceConfiguration": { + "hostUrl": "", + "authType": "", # BASIC | OAUTH2_CLIENT_CREDENTIALS + "credentialsSecretArn": "" + }, + "crawlerConfiguration": { + "filterConfiguration": { + "type": "PATTERN", + "patternObjectFilter": { + "filters": [ + { + "objectType": "Attachment", + "inclusionFilters": [ + ".*\\.pdf" + ], + "exclusionFilters": [ + ".*private.*\\.pdf" + ] + } + ] + } + } + } + }, + "type": "SALESFORCE" + } + + webcrawler_data_source_congiguration = { + "webConfiguration": { + "sourceConfiguration": { + "urlConfiguration": { + "seedUrls": [] + } + }, + "crawlerConfiguration": { + "crawlerLimits": { + "rateLimit": 50 + }, + "scope": "HOST_ONLY", + "inclusionFilters": [], + "exclusionFilters": [] + } + }, + "type": "WEB" + } + + # Set the data source configuration based on the Data source type + + if ds['type'] == "S3": + print(f'{idx +1 } data source: S3') + ds_name = f'{kb_id}-s3' + s3_data_source_congiguration["s3Configuration"]["bucketArn"] = f'arn:aws:s3:::{ds["bucket_name"]}' + # print(s3_data_source_congiguration) + data_source_configuration = s3_data_source_congiguration + + if ds['type'] == "CONFLUENCE": + print(f'{idx +1 } data source: CONFLUENCE') + ds_name = f'{kb_id}-confluence' + confluence_data_source_congiguration['confluenceConfiguration']['sourceConfiguration']['hostUrl'] = ds['hostUrl'] + confluence_data_source_congiguration['confluenceConfiguration']['sourceConfiguration']['authType'] = ds['authType'] + confluence_data_source_congiguration['confluenceConfiguration']['sourceConfiguration']['credentialsSecretArn'] = ds['credentialsSecretArn'] + # print(confluence_data_source_congiguration) + data_source_configuration = confluence_data_source_congiguration + + if ds['type'] == "SHAREPOINT": + print(f'{idx +1 } data source: SHAREPOINT') + ds_name = f'{kb_id}-sharepoint' + sharepoint_data_source_congiguration['sharePointConfiguration']['sourceConfiguration']['tenantId'] = ds['tenantId'] + sharepoint_data_source_congiguration['sharePointConfiguration']['sourceConfiguration']['domain'] = ds['domain'] + sharepoint_data_source_congiguration['sharePointConfiguration']['sourceConfiguration']['authType'] = ds['authType'] + sharepoint_data_source_congiguration['sharePointConfiguration']['sourceConfiguration']['siteUrls'] = ds["siteUrls"] + sharepoint_data_source_congiguration['sharePointConfiguration']['sourceConfiguration']['credentialsSecretArn'] = ds['credentialsSecretArn'] + # print(sharepoint_data_source_congiguration) + data_source_configuration = sharepoint_data_source_congiguration + + + if ds['type'] == "SALESFORCE": + print(f'{idx +1 } data source: SALESFORCE') + ds_name = f'{kb_id}-salesforce' + salesforce_data_source_congiguration['salesforceConfiguration']['sourceConfiguration']['hostUrl'] = ds['hostUrl'] + salesforce_data_source_congiguration['salesforceConfiguration']['sourceConfiguration']['authType'] = ds['authType'] + salesforce_data_source_congiguration['salesforceConfiguration']['sourceConfiguration']['credentialsSecretArn'] = ds['credentialsSecretArn'] + # print(salesforce_data_source_congiguration) + data_source_configuration = salesforce_data_source_congiguration + + if ds['type'] == "WEB": + print(f'{idx +1 } data source: WEB') + ds_name = f'{kb_id}-web' + webcrawler_data_source_congiguration['webConfiguration']['sourceConfiguration']['urlConfiguration']['seedUrls'] = ds['seedUrls'] + webcrawler_data_source_congiguration['webConfiguration']['crawlerConfiguration']['inclusionFilters'] = ds['inclusionFilters'] + webcrawler_data_source_congiguration['webConfiguration']['crawlerConfiguration']['exclusionFilters'] = ds['exclusionFilters'] + # print(webcrawler_data_source_congiguration) + data_source_configuration = webcrawler_data_source_congiguration + + + # Create a DataSource in KnowledgeBase + chunking_strategy_configuration = self.create_chunking_strategy_config(self.chunking_strategy) + print("============Chunking config========\n", chunking_strategy_configuration) + vector_ingestion_configuration = chunking_strategy_configuration + + if self.multi_modal: + if self.parser == "BEDROCK_FOUNDATION_MODEL": + parsing_configuration = {"bedrockFoundationModelConfiguration": + {"parsingModality": "MULTIMODAL", "modelArn": f"arn:aws:bedrock:{self.region_name}::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0"}, + "parsingStrategy": "BEDROCK_FOUNDATION_MODEL"} + + if self.parser == 'BEDROCK_DATA_AUTOMATION': + parsing_configuration = {"bedrockDataAutomationConfiguration": {"parsingModality": "MULTIMODAL"}, "parsingStrategy": "BEDROCK_DATA_AUTOMATION"} + + vector_ingestion_configuration["parsingConfiguration"] = parsing_configuration + + create_ds_response = self.bedrock_agent_client.create_data_source( + name = ds_name, + description = self.kb_description, + knowledgeBaseId = kb_id, + dataSourceConfiguration = data_source_configuration, + vectorIngestionConfiguration = vector_ingestion_configuration + ) + ds = create_ds_response["dataSource"] + pp.pprint(ds) + # self.data_sources[idx]['dataSourceId'].append(ds['dataSourceId']) + ds_list.append(ds) + return ds_list + + + def start_ingestion_job(self): + """ + Start an ingestion job to synchronize data from an S3 bucket to the Knowledge Base + """ + + for idx, ds in enumerate(self.data_sources): + try: + start_job_response = self.bedrock_agent_client.start_ingestion_job( + knowledgeBaseId=self.knowledge_base['knowledgeBaseId'], + dataSourceId=self.data_source[idx]["dataSourceId"] + ) + job = start_job_response["ingestionJob"] + print(f"job {idx+1} started successfully\n") + pp.pprint(job) + while job['status'] not in ["COMPLETE", "FAILED", "STOPPED"]: + get_job_response = self.bedrock_agent_client.get_ingestion_job( + knowledgeBaseId=self.knowledge_base['knowledgeBaseId'], + dataSourceId=self.data_source[idx]["dataSourceId"], + ingestionJobId=job["ingestionJobId"] + ) + job = get_job_response["ingestionJob"] + pp.pprint(job) + interactive_sleep(40) + + except Exception as e: + print(f"Couldn't start {idx} job.\n") + print(e) + + + def get_knowledge_base_id(self): + """ + Get Knowledge Base Id + """ + pp.pprint(self.knowledge_base["knowledgeBaseId"]) + return self.knowledge_base["knowledgeBaseId"] + + def get_bucket_name(self): + """ + Get the name of the bucket connected with the Knowledge Base Data Source + """ + pp.pprint(f"Bucket connected with KB: {self.bucket_name}") + return self.bucket_name + + def delete_kb(self, delete_s3_bucket=False, delete_iam_roles_and_policies=True, delete_lambda_function=False): + """ + Delete the Knowledge Base resources + Args: + delete_s3_bucket (bool): boolean to indicate if s3 bucket should also be deleted + delete_iam_roles_and_policies (bool): boolean to indicate if IAM roles and Policies should also be deleted + delete_lambda_function (bool): boolean to indicate if Lambda function should also be deleted + """ + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + + # delete knowledge base and data source. + + # Delete knowledge base and data sources + try: + # First delete all data sources + for ds in self.data_source: + try: + self.bedrock_agent_client.delete_data_source( + dataSourceId=ds["dataSourceId"], + knowledgeBaseId=self.knowledge_base['knowledgeBaseId'] + ) + print(f"Deleted data source {ds['dataSourceId']}") + except self.bedrock_agent_client.exceptions.ResourceNotFoundException: + print(f"Data source {ds['dataSourceId']} not found") + except Exception as e: + print(f"Error deleting data source {ds['dataSourceId']}: {str(e)}") + + # Then delete the knowledge base + self.bedrock_agent_client.delete_knowledge_base( + knowledgeBaseId=self.knowledge_base['knowledgeBaseId'] + ) + print("======== Knowledge base and all data sources deleted =========") + + except self.bedrock_agent_client.exceptions.ResourceNotFoundException as e: + print("Knowledge base not found:", e) + except Exception as e: + print(f"Error during knowledge base deletion: {str(e)}") + + # delete s3 bucket + if delete_s3_bucket==True: + self.delete_s3() + + # delete IAM role and policies + if delete_iam_roles_and_policies: + self.delete_iam_roles_and_policies() + + if delete_lambda_function: + try: + self.delete_lambda_function() + print(f"Deleted Lambda function {self.lambda_function_name}") + except self.lambda_client.exceptions.ResourceNotFoundException: + print(f"Lambda function {self.lambda_function_name} not found.") + + # delete vector index and collection from vector store + if self.vector_store=="OPENSEARCH_SERVERLESS": + try: + self.aoss_client.delete_collection(id=self.collection_id) + self.aoss_client.delete_access_policy( + type="data", + name=self.access_policy_name + ) + self.aoss_client.delete_security_policy( + type="network", + name=self.network_policy_name + ) + self.aoss_client.delete_security_policy( + type="encryption", + name=self.encryption_policy_name + ) + print("======== Vector Index, collection and associated policies deleted =========") + except Exception as e: + print(e) + else: + try: + # disable delete protection + response = self.neptune_client.update_graph( + graphIdentifier=self.graph_id, + deletionProtection=False) + print("======= Delete protection disabled before deleting the graph: ", response['deletionProtection']) + + # delete the graph + self.neptune_client.delete_graph( + graphIdentifier=self.graph_id, + skipSnapshot=True) + print("========= Neptune Analytics Graph Deleted =================================") + except Exception as e: + print(e) + + + def delete_iam_roles_and_policies(self): + for role_name in self.roles: + print(f"Found role {role_name}") + try: + self.iam_client.get_role(RoleName=role_name) + except self.iam_client.exceptions.NoSuchEntityException: + print(f"Role {role_name} does not exist") + continue + attached_policies = self.iam_client.list_attached_role_policies(RoleName=role_name)["AttachedPolicies"] + print(f"======Attached policies with role {role_name}========\n", attached_policies) + for attached_policy in attached_policies: + policy_arn = attached_policy["PolicyArn"] + policy_name = attached_policy["PolicyName"] + self.iam_client.detach_role_policy(RoleName=role_name, PolicyArn=policy_arn) + print(f"Detached policy {policy_name} from role {role_name}") + if str(policy_arn.split("/")[1]) == "service-role": + print(f"Skipping deletion of service-linked role policy {policy_name}") + else: + self.iam_client.delete_policy(PolicyArn=policy_arn) + print(f"Deleted policy {policy_name} from role {role_name}") + + self.iam_client.delete_role(RoleName=role_name) + print(f"Deleted role {role_name}") + print("======== All IAM roles and policies deleted =========") + + def bucket_exists(bucket): + s3 = boto3.resource('s3') + return s3.Bucket(bucket) in s3.buckets.all() + + def delete_s3(self): + """ + Delete the objects contained in the Knowledge Base S3 bucket. + Once the bucket is empty, delete the bucket + """ + s3 = boto3.resource('s3') + bucket_names = self.bucket_names.copy() + if self.intermediate_bucket_name: + bucket_names.append(self.intermediate_bucket_name) + + for bucket_name in bucket_names: + try: + bucket = s3.Bucket(bucket_name) + if bucket in s3.buckets.all(): + print(f"Found bucket {bucket_name}") + # Delete all objects including versions (if versioning enabled) + bucket.object_versions.delete() + bucket.objects.all().delete() + print(f"Deleted all objects in bucket {bucket_name}") + + # Delete the bucket + bucket.delete() + print(f"Deleted bucket {bucket_name}") + else: + print(f"Bucket {bucket_name} does not exist, skipping deletion") + except Exception as e: + print(f"Error deleting bucket {bucket_name}: {str(e)}") + + print("======== S3 bucket deletion process completed =========") + + + def delete_lambda_function(self): + """ + Delete the Knowledge Base Lambda function + Delete the IAM role used by the Knowledge Base Lambda function + """ + # delete lambda function + try: + self.lambda_client.delete_function(FunctionName=self.lambda_function_name) + print(f"======== Lambda function {self.lambda_function_name} deleted =========") + except Exception as e: + print(e) \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/lambda_custom_chunking_function.py b/BedrockPromptCachingRoutingDemo/src/lambda_custom_chunking_function.py new file mode 100644 index 000000000..877ab85d7 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/lambda_custom_chunking_function.py @@ -0,0 +1,239 @@ +import json +import boto3 +import os +import logging +import traceback +from botocore.exceptions import ClientError + +# Set up logging +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +# Constants for chunking +MAX_TOKENS = 1000 +OVERLAP_PERCENTAGE = 0.20 + +def estimate_tokens(text): + """ + Rough estimation of tokens (approximation: 1 token ≈ 0.75 words) + """ + words = text.split() + return int(len(words) * 0.75) + +def chunk_text(text, max_tokens=MAX_TOKENS, overlap_percentage=OVERLAP_PERCENTAGE): + """ + Chunk text based on words with specified max tokens and overlap + """ + if not text: + return [] + + # Split text into words + words = text.split() + if not words: + return [] + + # Estimate words per chunk based on token limit + # Assuming average of 0.75 tokens per word + words_per_chunk = int(max_tokens * 1.33) # Convert tokens to approximate words + overlap_words = int(words_per_chunk * overlap_percentage) + + chunks = [] + current_position = 0 + total_words = len(words) + + while current_position < total_words: + # Calculate end position for current chunk + chunk_end = min(current_position + words_per_chunk, total_words) + + # Get current chunk words + chunk_words = words[current_position:chunk_end] + + # If this isn't the last chunk, try to find a good break point + if chunk_end < total_words: + # Look for sentence-ending punctuation in the last few words + for i in range(len(chunk_words) - 1, max(len(chunk_words) - 10, 0), -1): + if chunk_words[i].endswith(('.', '!', '?')): + chunk_end = current_position + i + 1 + chunk_words = chunk_words[:i + 1] + break + + # Join words back into text + chunk_text = ' '.join(chunk_words) + chunks.append(chunk_text.strip()) + + # Move position considering overlap + current_position = chunk_end - overlap_words if chunk_end < total_words else chunk_end + + return chunks + +def write_output_to_s3(s3_client, bucket_name, file_name, json_data): + """ + Write JSON data to S3 bucket + """ + try: + json_string = json.dumps(json_data) + response = s3_client.put_object( + Bucket=bucket_name, + Key=file_name, + Body=json_string, + ContentType='application/json' + ) + + if response['ResponseMetadata']['HTTPStatusCode'] == 200: + print(f"Successfully uploaded {file_name} to {bucket_name}") + return True + else: + print(f"Failed to upload {file_name} to {bucket_name}") + return False + + except ClientError as e: + print(f"Error occurred: {e}") + return False + +def read_from_s3(s3_client, bucket_name, file_name): + """ + Read JSON data from S3 bucket + """ + try: + response = s3_client.get_object(Bucket=bucket_name, Key=file_name) + return json.loads(response['Body'].read().decode('utf-8')) + except ClientError as e: + print(f"Error reading file from S3: {str(e)}") + +def parse_s3_path(s3_path): + """ + Parse S3 path into bucket and key + """ + s3_path = s3_path.replace('s3://', '') + parts = s3_path.split('/', 1) + if len(parts) != 2: + raise ValueError("Invalid S3 path format") + return parts[0], parts[1] + +def invoke_model_with_response_stream(bedrock_runtime, prompt, max_tokens=1000): + """ + Invoke Bedrock model with streaming response + """ + model_id = 'anthropic.claude-3-haiku-20240307-v1:0' + request_body = json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": max_tokens, + "messages": [ + { + "role": "user", + "content": prompt + } + ], + "temperature": 0.0, + }) + + try: + response = bedrock_runtime.invoke_model_with_response_stream( + modelId=model_id, + contentType='application/json', + accept='application/json', + body=request_body + ) + + for event in response.get('body'): + chunk = json.loads(event['chunk']['bytes'].decode()) + if chunk['type'] == 'content_block_delta': + yield chunk['delta']['text'] + elif chunk['type'] == 'message_delta': + if 'stop_reason' in chunk['delta']: + break + + except ClientError as e: + print(f"An error occurred: {e}") + yield None + +# Define the contextual retrieval prompt +contextual_retrieval_prompt = """ + + {doc_content} + + + Here is the chunk we want to situate within the whole document + + {chunk_content} + + + Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. + Answer only with the succinct context and nothing else. + """ + +def lambda_handler(event, context): + """ + Lambda handler function + """ + logger.debug('input={}'.format(json.dumps(event))) + + s3_client = boto3.client('s3') + bedrock_runtime = boto3.client( + service_name='bedrock-runtime', + region_name='us-east-1' + ) + + input_files = event.get('inputFiles') + input_bucket = event.get('bucketName') + + if not all([input_files, input_bucket]): + raise ValueError("Missing required input parameters") + + output_files = [] + for input_file in input_files: + processed_batches = [] + for batch in input_file.get('contentBatches'): + input_key = batch.get('key') + + if not input_key: + raise ValueError("Missing uri in content batch") + + file_content = read_from_s3(s3_client, bucket_name=input_bucket, file_name=input_key) + print(file_content.get('fileContents')) + + original_document_content = ''.join( + content.get('contentBody') + for content in file_content.get('fileContents') + if content + ) + + chunked_content = { + 'fileContents': [] + } + + for content in file_content.get('fileContents'): + content_body = content.get('contentBody', '') + content_type = content.get('contentType', '') + content_metadata = content.get('contentMetadata', {}) + + # Apply chunking strategy + chunks = chunk_text(content_body) + + for chunk in chunks: + prompt = contextual_retrieval_prompt.format( + doc_content=original_document_content, + chunk_content=chunk + ) + response_stream = invoke_model_with_response_stream(bedrock_runtime, prompt) + chunk_context = ''.join(chunk_text for chunk_text in response_stream if chunk_text) + + chunked_content['fileContents'].append({ + "contentBody": chunk_context + "\n\n" + chunk, + "contentType": content_type, + "contentMetadata": content_metadata, + }) + + output_key = f"Output/{input_key}" + write_output_to_s3(s3_client, input_bucket, output_key, chunked_content) + processed_batches.append({"key": output_key}) + + output_files.append({ + "originalFileLocation": input_file.get('originalFileLocation'), + "fileMetadata": {}, + "contentBatches": processed_batches + }) + + return { + "outputFiles": output_files + } \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/model_manager.py b/BedrockPromptCachingRoutingDemo/src/model_manager.py new file mode 100644 index 000000000..374e2bf25 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/model_manager.py @@ -0,0 +1,315 @@ + +import weakref + +from bedrock_service import BedrockService + +class ModelManager: + """Manages model information and selection + + This class maintains a list of Bedrock models that support prompt caching + and provides methods for displaying and selecting models. It can be used + independently or integrated with other components. + """ + + def __init__(self, bedrock_service=None): + """Initialize with models that support prompt caching + + Args: + bedrock_service: BedrockService instance for API calls + If None, a new instance will be created + + Raises: + TypeError: If bedrock_service is not a BedrockService instance + """ + # Validate bedrock_service type if provided + if bedrock_service is not None and not isinstance(bedrock_service, BedrockService): + raise TypeError("bedrock_service must be an instance of BedrockService") + + # Initialize bedrock_service first so it's available for any methods called during initialization + # Use weakref to avoid circular references if this ModelManager is used by the same object + # that owns the bedrock_service + if bedrock_service: + self._bedrock_service_ref = weakref.ref(bedrock_service) + self._owns_bedrock_service = False + else: + self._bedrock_service = BedrockService() + self._bedrock_service_ref = weakref.ref(self._bedrock_service) + self._owns_bedrock_service = True + + # Get models after bedrock_service is initialized + self.models = self._get_prompt_cache_enabled_models() + + @property + def bedrock_service(self): + """Get the bedrock service instance + + Returns: + BedrockService instance + + Raises: + RuntimeError: If the bedrock service has been garbage collected + """ + service = self._bedrock_service_ref() + if service is None: + raise RuntimeError("BedrockService is no longer available") + return service + + def _get_prompt_cache_enabled_models(self): + """Return a dictionary of models that support prompt caching + + This method returns a static list of models known to support prompt caching. + In a production environment, this could be dynamically determined by + querying the Bedrock service for models with specific capabilities. + + Returns: + Dictionary with model categories as keys and lists of model IDs as values + + Raises: + RuntimeError: If there's an error creating the model list + """ + try: + return { + "Anthropic Claude Models": [ + "anthropic.claude-haiku-4-5-20251001-v1:0", + "anthropic.claude-sonnet-4-5-20250929-v1:0", + "anthropic.claude-opus-4-1-20250805-v1:0" + ], + "Amazon Nova Models": [ + "amazon.nova-micro-v1:0", + "amazon.nova-lite-v1:0", + "amazon.nova-pro-v1:0" + ] + } + except Exception as e: + raise RuntimeError(f"Failed to create model list: {str(e)}") + + def display_models(self): + """Display available models with headers + + Raises: + RuntimeError: If models data structure is invalid + """ + if not self.models: + print("No models available to display") + return + + if not isinstance(self.models, dict): + raise RuntimeError("Models data structure is invalid") + + for category, model_list in self.models.items(): + if not isinstance(model_list, list): + print(f"Warning: Skipping invalid category '{category}'") + continue + + print(f"\n{category}:") + for i, model in enumerate(model_list, 1): + if isinstance(model, str): + print(f"{i}. {model}") + else: + print(f"{i}. ") + + print("") # Add a blank line at the end + + def get_model_arn_from_inference_profiles(self, model_id): + """Get model ARN from inference profiles for non-ON_DEMAND models + + This method resolves model IDs to their actual ARNs by checking: + 1. If the model is an ON_DEMAND type (returns as-is) + 2. If the model has a specific mapping to an inference profile + 3. If the model has a default mapping as fallback + + Args: + model_id: The model ID or alias to resolve + + Returns: + Resolved model ID that can be used with Bedrock APIs + + Raises: + ValueError: If model_id is invalid + RuntimeError: If there's an error communicating with Bedrock service + """ + if not model_id or not isinstance(model_id, str): + raise ValueError("Model ID must be a non-empty string") + + # Define specific model mappings to ensure correct matching + model_mappings = { + "anthropic.claude-sonnet-4-5-20250929-v1:0": "claude-sonnet-4-5", + "anthropic.claude-haiku-4-5-20251001-v1:0": "claude-haiku-4-5", + "anthropic.claude-opus-4-1-20250805-v1:0": "claude-opus-4-1" + } + + # Default mappings as fallback + default_mappings = { + "anthropic.claude-sonnet-4-5-20250929-v1:0": "anthropic.claude-sonnet-4-5-20250929-v1:0", + "anthropic.claude-haiku-4-5-20251001-v1:0": "anthropic.claude-haiku-4-5-20251001-v1:0", + "anthropic.claude-opus-4-1-20250805-v1:0": "anthropic.claude-opus-4-1-20250805-v1:0" + } + + # First check if we can use the default mapping directly (faster) + if model_id in default_mappings: + try: + # Try to get the bedrock service, but handle the case where it's no longer available + bedrock_service = self.bedrock_service + except RuntimeError: + # If bedrock service is gone, fall back to default mapping + return default_mappings.get(model_id) + + try: + # Try to get the bedrock service + bedrock_service = self.bedrock_service + + # First check if the model is an ON_DEMAND type + on_demand_models = [] + try: + response = bedrock_service.bedrock.list_foundation_models(byInferenceType="ON_DEMAND") + for model in response.get('modelSummaries', []): + model_id_value = model.get('modelId') + if model_id_value: + on_demand_models.append(model_id_value) + except Exception as e: + raise RuntimeError(f"Failed to list foundation models: {str(e)}") + + # If the model is ON_DEMAND, no need to look up inference profiles + if model_id in on_demand_models: + return model_id + + # If we don't have a specific mapping for this model, return as is + if model_id not in model_mappings: + return model_id + + # Get the specific model identifier to look for + model_identifier = model_mappings.get(model_id) + + # For non-ON_DEMAND models, check inference profiles + try: + response = bedrock_service.list_inference_profiles(type_equals='SYSTEM_DEFINED') + except Exception as e: + raise RuntimeError(f"Failed to list inference profiles: {str(e)}") + + # Search for the model in the profiles using exact model identifier + for profile in response.get('inferenceProfileSummaries', []): + profile_arn = profile.get('inferenceProfileArn', '') + + if profile_arn and model_identifier in profile_arn: + # Extract the model ID from the ARN (last part after the slash) + try: + extracted_model_id = profile_arn.split('/')[-1] + print(f"Found inference profile for {model_id}: {profile_arn}") + return extracted_model_id + except Exception: + # If extraction fails, continue to next profile + continue + + # If no matching profile found, use default mapping + if model_id in default_mappings: + print(f"No matching inference profile found, using default mapping for {model_id}") + return default_mappings.get(model_id) + + # If no default mapping, return original model ID + return model_id + + except RuntimeError as e: + # Re-raise RuntimeError for service communication issues + raise + except Exception as e: + print(f"Error getting model ID from inference profiles: {e}") + # Use default mapping if available + if model_id in default_mappings: + return default_mappings.get(model_id) + return model_id + + def select_model(self): + """Allow user to select a model from the available options + + This method displays all available models with their categories, + marks the default model, and prompts the user to make a selection. + It then resolves the selected model ID using inference profiles. + + Returns: + Resolved model ID ready to use with Bedrock APIs + + Raises: + RuntimeError: If there's an error resolving the model ID or no models are available + ValueError: If user input is invalid after multiple attempts + """ + if not self.models: + raise RuntimeError("No models available for selection") + + print("\nAvailable models:") + + # Get default model ID and name for marking + default_model_name = "Claude Sonnet 4.5" + default_model = "anthropic.claude-sonnet-4-5-20250929-v1:0" + + try: + default_model_resolved = self.get_model_arn_from_inference_profiles(default_model) + except Exception as e: + print(f"Warning: Could not resolve default model: {str(e)}") + default_model_resolved = default_model + + # Display models with default model marked + model_index = 1 + all_models = [] + + # Validate models structure + if not isinstance(self.models, dict): + raise RuntimeError("Models data structure is invalid") + + for category, model_list in self.models.items(): + if not isinstance(model_list, list): + print(f"Warning: Skipping invalid category '{category}'") + continue + + print(f"\n{category}:") + for model in model_list: + if not isinstance(model, str): + continue # Skip invalid models + + if model == default_model or model == default_model_resolved: + print(f"{model_index}. {model} [DEFAULT - {default_model_name}]") + else: + print(f"{model_index}. {model}") + all_models.append(model) + model_index += 1 + + if not all_models: + raise RuntimeError("No valid models found for selection") + + max_attempts = 3 + attempts = 0 + + while attempts < max_attempts: + try: + choice = int(input("\nSelect a model (enter number): ")) + if 1 <= choice <= len(all_models): + selected_model = all_models[choice-1] + try: + resolved_model = self.get_model_arn_from_inference_profiles(selected_model) + if resolved_model != selected_model: + print(f"Selected model {selected_model} resolved to {resolved_model}") + return resolved_model + except Exception as e: + print(f"Error resolving model: {str(e)}") + # Fall back to selected model if resolution fails + print(f"Using unresolved model ID: {selected_model}") + return selected_model + else: + print(f"Please enter a number between 1 and {len(all_models)}") + except ValueError: + print("Please enter a valid number") + + attempts += 1 + + # If we've exhausted attempts, use default model + print(f"Maximum attempts reached. Using default model: {default_model_resolved}") + return default_model_resolved + + def __del__(self): + """Clean up resources when the object is garbage collected""" + # Clear references to help with garbage collection + if hasattr(self, '_owns_bedrock_service') and self._owns_bedrock_service: + if hasattr(self, '_bedrock_service'): + self._bedrock_service = None + + self._bedrock_service_ref = None + self.models = None \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/prompt_caching_app.py b/BedrockPromptCachingRoutingDemo/src/prompt_caching_app.py new file mode 100644 index 000000000..d6651b5c5 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/prompt_caching_app.py @@ -0,0 +1,1495 @@ +""" +Bedrock Prompt Caching Gradio Application + +This module provides a web interface for the Bedrock Prompt Caching CLI application +using Gradio. +""" + +import gradio as gr +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import io +import time +import os +from PIL import Image +from typing import List, Tuple +import requests + +# Import from the CLI module +from bedrock_prompt_caching import ( + BedrockChat, ModelManager, CACHE, BedrockService +) +from prompt_caching_multi_turn import PromptCachingExperiment +from file_processor import FileProcessor +from bedrock_claude_code import ClaudeSetup + + +class GradioBedrockApp: + """Gradio interface for Bedrock Prompt Caching""" + + def __init__(self): + """Initialize the Gradio application with required components""" + # Initialize core components + self.bedrock_service = BedrockService() + self.chat = BedrockChat() + self.model_manager = ModelManager() + self.prompt_caching_experiment = PromptCachingExperiment( + bedrock_service=self.bedrock_service, + model_manager=self.model_manager + ) + self.claude_setup = ClaudeSetup() + + # Sample URLs for demonstration + self.sample_urls = [ + 'https://aws.amazon.com/blogs/aws/reduce-costs-and-latency-with-amazon-bedrock-intelligent-prompt-routing-and-prompt-caching-preview/', + 'https://aws.amazon.com/blogs/machine-learning/enhance-conversational-ai-with-advanced-routing-techniques-with-amazon-bedrock/', + 'https://aws.amazon.com/blogs/security/cost-considerations-and-common-options-for-aws-network-firewall-log-management/' + ] + + # Chat state management + self.history = [] # Will store messages in dict format with 'role' and 'content' keys + self.use_cache = True + self.use_checkpoint = False + + # Multi-turn chat state + self.multi_turn_conversation = [] + self.multi_turn_turn = 0 + self.multi_turn_context = "" + + # Common questions about Bedrock prompt caching for quick access + self.common_questions = [ + "What is Amazon Bedrock prompt caching?", + "How does prompt caching reduce costs?", + "What are the benefits of using checkpoints?", + "Which models support prompt caching?", + "How much latency improvement can I expect?", + "How is prompt caching different from RAG?", + "Can I use prompt caching with streaming responses?", + "How do I implement prompt caching in my application?", + "What are the limitations of prompt caching?", + "How does prompt caching handle similar but not identical prompts?" + ] + + # Claude Code settings + self.claude_working_dir = os.getcwd() + self.claude_model = "sonnet" # Default model + self.claude_caching = True # Default caching enabled + + def get_models(self) -> List[str]: + """Get a flattened list of available models with caching support""" + # Get models from the prompt_caching_experiment which has the latest model list + # with proper caching support information + all_models = [] + + # Get models from model categories + for category, models in self.model_manager.models.items(): + # Add category prefix to each model for better organization in dropdown + for model in models: + # Get the short name for display + model_short = model.split('/')[-1].split(':')[0] + all_models.append(f"{category}: {model_short}") + + return all_models + + def get_model_id_from_display_name(self, display_name: str) -> str: + """Convert display name back to model ID""" + if not display_name or ":" not in display_name: + # Default to Claude 3.7 Sonnet if invalid + return "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + + # Extract category and model short name + category, model_short = display_name.split(":", 1) + model_short = model_short.strip() + + # Find matching model in the category + if category in self.model_manager.models: + for model_id in self.model_manager.models[category]: + if model_short in model_id: + return model_id + + # If not found, use model_manager to resolve + return self.model_manager.get_model_arn_from_inference_profiles(model_short) + + def load_document_from_url(self, url: str) -> str: + """Load document from URL and set it as the current document""" + if not url: + return "Please enter a URL" + + try: + response = requests.get(url) + response.raise_for_status() + document = response.text + if document: + self.chat.set_document(document) + return f"Document loaded successfully ({len(document)} characters)" + else: + return "Empty document received from URL" + except Exception as e: + return f"Error fetching document from URL: {str(e)}" + + def load_document_from_file(self, file) -> str: + """Load document from uploaded file and set it as the current document""" + if file is None: + return "No file uploaded" + + try: + # Check if file is supported + file_name = file.name if hasattr(file, 'name') else str(file) + if not FileProcessor.is_supported_file(file_name): + # Try to add the file extension to supported extensions + file_ext = os.path.splitext(file_name)[1].lower() + if file_ext: + FileProcessor.SUPPORTED_EXTENSIONS.add(file_ext) + + # Handle file path case (when file is a path string) + if isinstance(file, str) or (hasattr(file, 'name') and not hasattr(file, 'getvalue')): + # If it's a file path, read the file directly + try: + with open(file if isinstance(file, str) else file.name, 'r', encoding='utf-8') as f: + text = f.read() + except Exception as e: + return f"Error reading file: {str(e)}" + else: + # Process file using FileProcessor + text = FileProcessor.process_uploaded_file(file) + + if text: + self.chat.set_document(text) + return f"Document loaded successfully ({len(text)} characters)" + else: + return "No text extracted from file" + except Exception as e: + return f"Error loading document: {str(e)}" + + def set_model_and_temp(self, display_name: str, temperature: float) -> str: + """Set the model and temperature for inference""" + # Convert display name to actual model ID + model_id = self.get_model_id_from_display_name(display_name) + + # Resolve the model ID using model_manager if needed + resolved_model_id = self.model_manager.get_model_arn_from_inference_profiles(model_id) + if resolved_model_id != model_id: + model_id = resolved_model_id + + self.chat.set_model(model_id) + self.chat.set_temperature(temperature) + + # Return a more informative message + model_short = model_id.split('/')[-1].split(':')[0] + return f"Model set to {model_short} with temperature {temperature}" + + def toggle_cache(self, use_cache: bool) -> str: + """Toggle cache usage on/off""" + self.use_cache = use_cache + return f"Cache {'enabled' if use_cache else 'disabled'}" + + def toggle_checkpoint(self, use_checkpoint: bool) -> str: + """Toggle checkpoint usage on/off""" + self.use_checkpoint = use_checkpoint + return f"Checkpoint {'enabled' if use_checkpoint else 'disabled'}" + + def chat_with_document(self, query: str) -> Tuple[List, str, str, str]: + """Process a query against the loaded document and update chat history""" + if not self.chat.current_document: + return self.history, "No document loaded. Please load a document first.", "", "" + + if not self.chat.current_model_id: + return self.history, "No model selected. Please select a model first.", "", "" + + if not query.strip(): + return self.history, "Please enter a question.", "", "" + + try: + # Measure response time + start_time = time.time() + response_text, usage, from_cache, cache_key = self.chat.chat_with_document( + query, use_cache=self.use_cache, checkpoint=False + ) + cache_retrieval_time = time.time() - start_time + + # Format usage info based on cache hit or miss + cache_read = usage.get("cache_read_input_tokens", 0) or usage.get("cacheReadInputTokens", 0) + cache_write = usage.get("cache_creation_input_tokens", 0) or usage.get("cacheCreationInputTokens", 0) + + # For cache hits from file cache, simulate cache metrics + if from_cache and cache_read == 0: + cache_read = usage.get("inputTokens", 0) + + if from_cache: + # Calculate latency benefit percentage + standard_response_time = 2.0 # Estimated standard response time without cache + latency_benefit = ((standard_response_time - cache_retrieval_time) / standard_response_time) * 100 + + # Calculate token savings + token_savings_percentage = (cache_read / (cache_read + usage.get("inputTokens", 0))) * 100 if cache_read > 0 else 0 + + usage_info = ( + f"Response retrieved from cache\n" + f"Cache retrieval time: {cache_retrieval_time:.4f} seconds\n" + f"Cache hit: {self.use_cache}\n" + f"Cache read tokens: {cache_read}\n" + f"Latency reduction: {latency_benefit:.1f}%\n" + f"Token savings: {token_savings_percentage:.1f}%" + ) + else: + usage_info = ( + f"Input tokens: {usage.get('inputTokens', 'N/A')}\n" + f"Output tokens: {usage.get('outputTokens', 'N/A')}\n" + f"Response time: {usage.get('response_time_seconds', 'N/A'):.2f} seconds" + ) + + if cache_write > 0: + usage_info += f"\nCache write tokens: {cache_write}" + + # Get detailed cache summary + cache_summary = "" + if cache_key: + cache_summary = self.chat.cache_manager.get_cache_summary(cache_key) + + # Add cache checkpoint information + input_tokens = usage.get('inputTokens', 0) + cache_read_tokens = usage.get('cache_read_input_tokens', 0) or usage.get('cacheReadInputTokens', 0) + total_tokens = input_tokens + cache_read_tokens # Total tokens including those from cache + model_name = self.chat.current_model_id + + # Determine minimum token requirements based on model + min_tokens = 1024 # Default minimum for most models + if "claude-3-7-sonnet" in model_name: + min_tokens = 1024 + elif "claude-3-5" in model_name: + min_tokens = 1024 + elif "nova" in model_name: + min_tokens = 512 + + # Add cache checkpoint information + cache_summary += "\n\n### Cache Checkpoint Information\n" + cache_summary += "Cache checkpoints have minimum token requirements that vary by model:\n" + cache_summary += f"- Current model: {model_name}\n" + cache_summary += f"- Minimum tokens required: {min_tokens}\n" + cache_summary += f"- Your total tokens: {total_tokens} (input: {input_tokens}, cache: {cache_read_tokens})\n" + + # Determine if document meets minimum requirements based on total tokens + if total_tokens >= min_tokens: + cache_summary += f"✅ Your prompt meets the minimum token requirement ({total_tokens} ≥ {min_tokens})\n" + cache_summary += f"- First checkpoint can be defined after {min_tokens} tokens\n" + cache_summary += f"- Second checkpoint can be defined after {min_tokens * 2} tokens\n" + else: + cache_summary += f"❌ Your prompt does not meet the minimum token requirement ({total_tokens} < {min_tokens})\n" + cache_summary += "- Your prefix will not be cached\n" + + cache_summary += "\nCache has a five minute Time To Live (TTL), which resets with each successful cache hit." + cache_summary += "\nIf no cache hits occur within the TTL window, your cache expires." + + # Add business benefits section to cache summary + if from_cache: + cache_summary += "\n\n### Business Benefits of Prompt Caching\n" + cache_summary += "- **Cost Reduction**: Cached tokens pricing is different from LLM token usage costs\n" + cache_summary += "- **Improved Latency**: Faster responses by skipping redundant processing\n" + cache_summary += "- **Consistent Responses**: Same prompts yield identical outputs\n" + cache_summary += "- **Scalability**: Handle more requests with the same resources\n" + + # Update history with the new Q&A pair using messages format + self.history.append({"role": "user", "content": query}) + self.history.append({"role": "assistant", "content": response_text}) + + return self.history, "", usage_info, cache_summary + + except Exception as e: + error_message = f"Error: {str(e)}" + # Special handling for specific models that might cause issues + if "claude-3-5-sonnet" in self.chat.current_model_id or "nova-pro" in self.chat.current_model_id: + error_message = f"This model ({self.chat.current_model_id}) currently has compatibility issues with prompt caching. Please try a different model." + + # Return all four expected values even in case of error + return self.history, error_message, "", "" + + def clear_history(self) -> Tuple[List, str, str, str]: + """Clear the chat history""" + self.history = [] + return [], "Chat history cleared", "No queries yet", "No cache information yet" + + def run_benchmark(self, epochs: int) -> Tuple[str, str]: + """Run TTFT (Time To First Token) benchmark tests""" + if not self.chat.current_document: + return "No document loaded. Please load a document first.", None + + if not self.chat.current_model_id: + return "No model selected. Please select a model first.", None + + print(f"Running benchmark with {epochs} iterations...") + print("This may take several minutes. Please wait...") + + # Store the current document as blog for benchmarking + self.chat.blog = self.chat.current_document + + # Define test configurations + tests = [ + { + 'model_id': self.chat.current_model_id, + 'model_name': self.chat.current_model_id.split(':')[0], + 'cache_mode': [CACHE.OFF, CACHE.ON] + } + ] + + try: + # Run the benchmark + datapoints = self.chat.run_response_latency_benchmark(tests, epochs) + + # Create visualization + if not datapoints: + return "No benchmark data to visualize.", None + + df = pd.DataFrame(datapoints) + + # Save results to CSV for later analysis + timestamp = time.strftime("%Y%m%d_%H%M%S") + csv_filename = f"benchmark_results_{timestamp}.csv" + df.to_csv(csv_filename) + + # Create plot + plt_img = self._create_benchmark_plot(df) + + # Print raw data for debugging + print("Raw benchmark data:") + for dp in datapoints: + print(f"{dp['model']} - {dp['cache']} - {dp['measure']} - {dp['time']:.2f}s") + + # Format summary statistics in a more readable way + summary = df.groupby(['model', 'cache', 'measure'])['time'].agg(['mean', 'median', 'min', 'max']) + + # Create a more readable summary text + summary_text = "## Benchmark Results Summary\n\n" + + # Process each model separately + for model_name in summary.index.get_level_values('model').unique(): + model_short_name = model_name.split('-')[0].split('.')[-1].capitalize() + summary_text += f"### Model: {model_short_name}\n\n" + + # Get cache modes for this model + cache_modes = summary.loc[model_name].index.get_level_values('cache').unique() + + # Store baseline values for comparison + baseline_first_token = None + baseline_last_token = None + + # Get baseline values from CACHE.OFF if available + if "CACHE.OFF" in cache_modes: + try: + baseline_first_token = summary.loc[(model_name, "CACHE.OFF", 'first_token')] + baseline_last_token = summary.loc[(model_name, "CACHE.OFF", 'last_token')] + except: + pass + + for cache_mode in cache_modes: + # Map the cache mode to a user-friendly status + if cache_mode == "CACHE.OFF": + cache_status = "Cache OFF" + elif cache_mode == "CACHE.ON": + cache_status = "Cache ON" + elif cache_mode == "CACHE.READ": + cache_status = "Cache HIT" + elif cache_mode == "CACHE.WRITE": + cache_status = "Cache WRITE" + else: + cache_status = str(cache_mode) + + summary_text += f"#### {cache_status}\n\n" + + # Format as a table with wider columns to accommodate percentages + summary_text += "| Metric | Mean (sec) | Median (sec) | Min (sec) | Max (sec) |\n" + summary_text += "|--------|-------------------|-------------------|----------------|----------------|\n" + + try: + # First token metrics + first_token = summary.loc[(model_name, cache_mode, 'first_token')] + + # Calculate percentage differences for all metrics + mean_diff = median_diff = min_diff = max_diff = "" + + if cache_mode != "CACHE.OFF" and baseline_first_token is not None: + # For mean + if baseline_first_token['mean'] > 0: + mean_pct = ((baseline_first_token['mean'] - first_token['mean']) / baseline_first_token['mean']) * 100 + mean_diff = f" ({mean_pct:.1f}% {'faster' if mean_pct > 0 else 'slower'})" + + # For median + if baseline_first_token['median'] > 0: + median_pct = ((baseline_first_token['median'] - first_token['median']) / baseline_first_token['median']) * 100 + median_diff = f" ({median_pct:.1f}% {'faster' if median_pct > 0 else 'slower'})" + + # For min + if baseline_first_token['min'] > 0: + min_pct = ((baseline_first_token['min'] - first_token['min']) / baseline_first_token['min']) * 100 + min_diff = f" ({min_pct:.1f}% {'faster' if min_pct > 0 else 'slower'})" + + # For max + if baseline_first_token['max'] > 0: + max_pct = ((baseline_first_token['max'] - first_token['max']) / baseline_first_token['max']) * 100 + max_diff = f" ({max_pct:.1f}% {'faster' if max_pct > 0 else 'slower'})" + + summary_text += f"| Time to First Token | {first_token['mean']:.2f}{mean_diff} | {first_token['median']:.2f}{median_diff} | {first_token['min']:.2f}{min_diff} | {first_token['max']:.2f}{max_diff} |\n" + + # Last token metrics + last_token = summary.loc[(model_name, cache_mode, 'last_token')] + + # Calculate percentage differences for all metrics + mean_diff = median_diff = min_diff = max_diff = "" + + if cache_mode != "CACHE.OFF" and baseline_last_token is not None: + # For mean + if baseline_last_token['mean'] > 0: + mean_pct = ((baseline_last_token['mean'] - last_token['mean']) / baseline_last_token['mean']) * 100 + mean_diff = f" ({mean_pct:.1f}% {'faster' if mean_pct > 0 else 'slower'})" + + # For median + if baseline_last_token['median'] > 0: + median_pct = ((baseline_last_token['median'] - last_token['median']) / baseline_last_token['median']) * 100 + median_diff = f" ({median_pct:.1f}% {'faster' if median_pct > 0 else 'slower'})" + + # For min + if baseline_last_token['min'] > 0: + min_pct = ((baseline_last_token['min'] - last_token['min']) / baseline_last_token['min']) * 100 + min_diff = f" ({min_pct:.1f}% {'faster' if min_pct > 0 else 'slower'})" + + # For max + if baseline_last_token['max'] > 0: + max_pct = ((baseline_last_token['max'] - last_token['max']) / baseline_last_token['max']) * 100 + max_diff = f" ({max_pct:.1f}% {'faster' if max_pct > 0 else 'slower'})" + + summary_text += f"| Total Response Time | {last_token['mean']:.2f}{mean_diff} | {last_token['median']:.2f}{median_diff} | {last_token['min']:.2f}{min_diff} | {last_token['max']:.2f}{max_diff} |\n" + except: + summary_text += "| Data not available | - | - | - | - |\n" + + summary_text += "\n" + + # Calculate speedup if we have both cache off and any cache on mode + try: + # Find all first token times for each cache mode + cache_times = {} + for mode in cache_modes: + try: + cache_times[mode] = summary.loc[(model_name, mode, 'first_token')]['mean'] + except: + pass + + # If we have both OFF and any other mode, calculate speedup + if 'CACHE.OFF' in cache_times: + cache_off_time = cache_times['CACHE.OFF'] + + # Find the best cache hit time (READ is preferred) + cache_hit_time = None + hit_mode = None + for mode in ['CACHE.READ', 'CACHE.ON', 'CACHE.WRITE']: + if mode in cache_times: + cache_hit_time = cache_times[mode] + hit_mode = mode + break + + if cache_hit_time is not None and cache_off_time > 0 and hit_mode is not None: + speedup = (cache_off_time - cache_hit_time) / cache_off_time * 100 + + # Calculate token savings if available + token_savings = "N/A" + try: + # Get average token usage for each mode + cache_read_tokens = 0 + + # Try different column names for cache read tokens + for col_name in ['cacheReadInputTokens', 'cache_read_input_tokens']: + if col_name in df.columns: + cache_read_tokens = df[df['cache'] == hit_mode][col_name].mean() + if not pd.isna(cache_read_tokens) and cache_read_tokens > 0: + break + + # Get input tokens + input_tokens = 0 + for col_name in ['inputTokens', 'input_tokens']: + if col_name in df.columns: + input_tokens = df[df['cache'] == 'CACHE.OFF'][col_name].mean() + if not pd.isna(input_tokens) and input_tokens > 0: + break + + if input_tokens > 0 and cache_read_tokens > 0: + token_savings = (cache_read_tokens / input_tokens) * 100 + except Exception as e: + print(f"Error calculating token savings: {e}") + + summary_text += f"**Cache Speedup: {speedup:.1f}%** (comparing CACHE.OFF vs {hit_mode})\n\n" + if token_savings != "N/A": + summary_text += f"**Token Savings: {token_savings:.1f}%** of input tokens retrieved from cache\n\n" + + # Add business impact + summary_text += "### Business Benefits\n\n" + summary_text += "- **Cost Reduction**: Lower token usage means reduced API costs\n" + summary_text += "- **Improved User Experience**: Faster response times lead to better user engagement\n" + summary_text += "- **Higher Throughput**: Process more requests with the same resources\n" + summary_text += "- **Reduced Latency**: Critical for real-time applications\n\n" + except Exception as e: + print(f"Error calculating speedup: {e}") + + summary_text += f"\nResults saved to {csv_filename}" + + return summary_text, plt_img + + except Exception as e: + return f"Error during benchmark: {str(e)}", None + + def _create_benchmark_plot(self, df): + """Create benchmark plot comparing cache performance and return as image""" + import seaborn as sns + import numpy as np + + plt.figure(figsize=(10, 8)) + sns.set_style("whitegrid") + n_models = df['model'].nunique() + + f, axes = plt.subplots(n_models, 1, figsize=(10, n_models * 6)) + + # Convert axes to array if there's only one model + axes = np.array([axes]) if n_models == 1 else axes + + for i, model in enumerate(df['model'].unique()): + cond = df['model'] == model + df_i = df.loc[cond] + + ax = sns.boxplot(df_i, + ax=axes[i], + x='measure', + y='time', + hue=df_i[['cache']].apply(tuple, axis=1)) + + ax.tick_params(axis='x', rotation=45) + ax.set_xlabel(None) + self._add_median_labels(ax) + ax.legend(loc='upper left') + ax.set_title(f'Time to First Token (TTFT) - {model}', fontsize=14) + + plt.tight_layout() + + # Convert plot to image + buf = io.BytesIO() + plt.savefig(buf, format='png') + buf.seek(0) + plt_img = Image.open(buf) + + return plt_img + + def _add_median_labels(self, ax, fmt=".1f"): + """Add text labels to the median lines of a seaborn boxplot""" + lines = ax.get_lines() + boxes = [c for c in ax.get_children() if "Patch" in str(c)] + start = 4 + if not boxes: # seaborn v0.13 => fill=False => no patches => +1 line + boxes = [c for c in ax.get_lines() if len(c.get_xdata()) == 5] + start += 1 + lines_per_box = len(lines) // len(boxes) + for median in lines[start::lines_per_box]: + x, y = (data.mean() for data in median.get_data()) + # choose value depending on horizontal or vertical plot orientation + value = x if len(set(median.get_xdata())) == 1 else y + text = ax.text(x, y, f'{value:{fmt}}', ha='center', va='center', color='white') + # create median-colored border around white text for contrast + text.set_path_effects([ + path_effects.Stroke(linewidth=3, foreground=median.get_color()), + path_effects.Normal(), + ]) + + # Multi-Turn Chat Methods + def load_multi_turn_document_from_url(self, url: str) -> str: + """Load document from URL for multi-turn chat""" + if not url: + return "Please enter a URL" + + try: + result = self.prompt_caching_experiment.load_context_from_url(url) + if result: + self.multi_turn_context = self.prompt_caching_experiment.sample_text + self.multi_turn_conversation = [] + self.multi_turn_turn = 0 + return f"Document loaded successfully ({len(self.multi_turn_context)} characters)" + else: + return "Failed to load document from URL" + except Exception as e: + return f"Error loading document from URL: {str(e)}" + + def load_multi_turn_document_from_file(self, file) -> str: + """Load document from file for multi-turn chat""" + if file is None: + return "No file uploaded" + + try: + # Handle file path case (when file is a path string) + if isinstance(file, str) or (hasattr(file, 'name') and not hasattr(file, 'getvalue')): + # If it's a file path, read the file directly + file_path = file if isinstance(file, str) else file.name + try: + with open(file_path, 'r', encoding='utf-8') as f: + file_content = f.read() + # Set the content directly + self.multi_turn_context = file_content + self.multi_turn_conversation = [] + self.multi_turn_turn = 0 + self.prompt_caching_experiment.set_context_text(file_content) + return f"Document loaded successfully ({len(file_content)} characters)" + except Exception as e: + return f"Error reading file: {str(e)}" + else: + # Get file path from Gradio file object + file_path = file.name if hasattr(file, 'name') else file + + result = self.prompt_caching_experiment.load_context_from_file(file_path) + if result: + self.multi_turn_context = self.prompt_caching_experiment.sample_text + self.multi_turn_conversation = [] + self.multi_turn_turn = 0 + return f"Document loaded successfully ({len(self.multi_turn_context)} characters)" + else: + return "Failed to load document from file" + except Exception as e: + return f"Error loading document: {str(e)}" + + def set_multi_turn_model(self, display_name: str) -> str: + """Set the model for multi-turn chat""" + # Convert display name to actual model ID + model_id = self.get_model_id_from_display_name(display_name) + + # Resolve the model ID using model_manager + resolved_model_id = self.model_manager.get_model_arn_from_inference_profiles(model_id) + if resolved_model_id != model_id: + model_id = resolved_model_id + + self.multi_turn_model_id = model_id + + # Get a shorter display name for the model + model_short = model_id.split('/')[-1].split(':')[0] + return f"Using model: {model_short}" + + def multi_turn_chat(self, query: str, max_tokens: int = 2048, temperature: float = 0.5, + top_p: float = 0.8, top_k: int = 250, stop_sequences: str = "") -> Tuple[List, str, str]: + """Process a query in multi-turn chat mode with model parameters""" + if not self.multi_turn_context: + return [], "No document loaded. Please load a document first.", "" + + if not hasattr(self, 'multi_turn_model_id'): + # Default model + default_model = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + self.multi_turn_model_id = self.model_manager.get_model_arn_from_inference_profiles(default_model) + + if not query.strip(): + return [], "Please enter a question.", "" + + try: + # Set the context text in the experiment + self.prompt_caching_experiment.set_context_text(self.multi_turn_context) + + # Record the start time + start_time = time.time() + + # Convert stop_sequences from string to list if provided + stop_seq_list = None + if stop_sequences: + stop_seq_list = [seq.strip() for seq in stop_sequences.split(',')] + + # Store model parameters in experiment + self.prompt_caching_experiment.model_params = { + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "top_k": 250, # Add top_k for Anthropic models + "stop_sequences": stop_seq_list + } + + # Process the turn with model parameters + turn_data = self.prompt_caching_experiment.process_turn( + self.multi_turn_turn, + self.multi_turn_conversation, + query, + self.multi_turn_model_id + ) + + # Add the turn data to the experiment's all_experiments_data + self.prompt_caching_experiment.all_experiments_data.append(turn_data) + + # Record the end time + end_time = time.time() + elapsed_time = end_time - start_time + + # Increment turn counter + self.multi_turn_turn += 1 + + # Get the response text + if self.multi_turn_conversation and len(self.multi_turn_conversation) >= 2: + response_text = self.multi_turn_conversation[-1]["content"][0]["text"] + else: + response_text = "No response generated." + + # Format chat history for display with messages format + chat_history = [] + for i, msg in enumerate(self.multi_turn_conversation): + if msg["role"] == "user": + # Skip displaying the context text for readability + if i == 0 and len(msg["content"]) > 1: + user_text = msg['content'][-1]['text'] + else: + user_text = ' '.join([c['text'] for c in msg['content'] if 'text' in c]) + chat_history.append({"role": "user", "content": user_text}) + else: + chat_history.append({"role": "assistant", "content": msg['content'][0]['text']}) + + # Get turn metrics using the prompt_caching_experiment's method + turn_metrics = self.prompt_caching_experiment.get_turn_metrics(turn_data) + + # Get cache summary using the prompt_caching_experiment's method + cache_summary = self.prompt_caching_experiment.get_cache_summary(turn_data) + + # Combine metrics and cache summary for display + usage_info = f"{turn_metrics}\n\n{cache_summary}" + + # Add business benefits if this is a cache hit + if turn_data['is_cache_hit']: + # Calculate latency benefit percentage + standard_response_time = 2.0 # Estimated standard response time without cache + latency_benefit = ((standard_response_time - elapsed_time) / standard_response_time) * 100 + + # Calculate token savings + input_tokens = turn_data['input_tokens'] + input_tokens_cache_read = turn_data['cache_read_input_tokens'] + token_savings_percentage = (input_tokens_cache_read / (input_tokens_cache_read + input_tokens) * 100) if input_tokens_cache_read > 0 else 0 + + # Add business benefits section + usage_info += "\n\n### Business Benefits of Prompt Caching\n" + usage_info += f"- **Cost Reduction**: {token_savings_percentage:.1f}% token savings\n" + usage_info += f"- **Improved Latency**: {latency_benefit:.1f}% faster response time\n" + usage_info += "- **Consistent Responses**: Same prompts yield identical outputs\n" + usage_info += "- **Scalability**: Handle more requests with the same resources\n" + + return chat_history, "", usage_info + + except Exception as e: + return [], f"Error: {str(e)}", "" + + def clear_multi_turn_history(self) -> Tuple[List, str]: + """Clear the multi-turn chat history""" + self.multi_turn_conversation = [] + self.multi_turn_turn = 0 + return [], "Chat history cleared" + + def show_experiment_stats(self) -> str: + """Show summary statistics for the multi-turn chat experiment""" + # Access the experiment data directly from the experiment object + all_experiments_data = self.prompt_caching_experiment.all_experiments_data + + if not all_experiments_data: + return "No experiment data available. Please chat with the model first." + + try: + # Use the experiment's built-in method to get a formatted summary + return self.prompt_caching_experiment.get_experiment_summary() + + except Exception as e: + return f"Error generating statistics: {str(e)}" + + # Claude Code methods + + def get_current_working_dir(self) -> str: + """Get the current working directory for Claude Code""" + return self.claude_working_dir + + def change_working_dir(self, new_dir: str) -> str: + """Change the working directory for Claude Code""" + try: + if not os.path.exists(new_dir): + return f"Directory does not exist: {new_dir}" + + if not os.path.isdir(new_dir): + return f"Not a directory: {new_dir}" + + self.claude_working_dir = new_dir + return f"Working directory changed to: {new_dir}" + except Exception as e: + return f"Error changing directory: {str(e)}" + + def install_claude_code(self) -> str: + """Install Claude Code using npm""" + try: + # Show progress message + yield "Installing Claude Code... This may take a moment." + + import subprocess + result = subprocess.run(["npm", "install", "-g", "@anthropic-ai/claude-code"], + capture_output=True, text=True) + if result.returncode != 0: + yield f"Error installing Claude Code: {result.stderr}" + else: + yield "Claude Code installed successfully." + except Exception as e: + yield f"Error: {str(e)}" + + def check_aws_config(self) -> str: + """Check AWS configuration""" + try: + import subprocess + result = subprocess.run(["aws", "sts", "get-caller-identity"], + capture_output=True, text=True) + if result.returncode == 0: + return f"AWS credentials configured correctly:\n{result.stdout}" + else: + return f"AWS credentials not configured correctly:\n{result.stderr}" + except Exception as e: + return f"Error checking AWS configuration: {str(e)}" + + def check_claude_version(self) -> str: + """Check Claude Code version""" + try: + import subprocess + result = subprocess.run(["claude", "--version"], + capture_output=True, text=True) + if result.returncode == 0: + return f"Claude Code version: {result.stdout}" + else: + return f"Error checking Claude Code version: {result.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + def configure_claude_environment(self, model: str, enable_caching: bool) -> str: + """Configure environment variables for Claude Code""" + try: + os.environ["CLAUDE_CODE_USE_BEDROCK"] = "1" + + if model == "haiku": + model_id = "us.anthropic.claude-3-5-haiku-20241022-v1:0" + os.environ["ANTHROPIC_MODEL"] = model_id + os.environ["ANTHROPIC_SMALL_FAST_MODEL"] = model_id + model_name = "Claude 3.5 Haiku" + self.claude_model = "haiku" + else: + model_id = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + os.environ["ANTHROPIC_MODEL"] = model_id + model_name = "Claude 3.7 Sonnet" + self.claude_model = "sonnet" + + if enable_caching: + if "DISABLE_PROMPT_CACHING" in os.environ: + del os.environ["DISABLE_PROMPT_CACHING"] + caching_status = "enabled" + self.claude_caching = True + else: + os.environ["DISABLE_PROMPT_CACHING"] = "true" + caching_status = "disabled" + self.claude_caching = False + + return f"Environment configured with {model_name} ({model_id}). Prompt caching is {caching_status}." + except Exception as e: + return f"Error configuring environment: {str(e)}" + + def generate_environment_script(self, model: str, enable_caching: bool) -> str: + """Generate a script with environment variables for Claude Code""" + try: + if model == "haiku": + model_id = "us.anthropic.claude-3-5-haiku-20241022-v1:0" + model_name = "Claude 3.5 Haiku" + else: + model_id = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + model_name = "Claude 3.7 Sonnet" + + script = f"""#!/bin/bash +# Environment setup for {model_name} + +# Set Bedrock integration +export CLAUDE_CODE_USE_BEDROCK=1 + +# Set model +export ANTHROPIC_MODEL='{model_id}' +""" + + if model == "haiku": + script += f"export ANTHROPIC_SMALL_FAST_MODEL='{model_id}'\n" + + if enable_caching: + script += """ +# Enable prompt caching +# Remove DISABLE_PROMPT_CACHING if it exists +if [ -n "$DISABLE_PROMPT_CACHING" ]; then + unset DISABLE_PROMPT_CACHING +fi +""" + else: + script += """ +# Disable prompt caching +export DISABLE_PROMPT_CACHING=true +""" + + script += """ +# Launch Claude Code +claude +""" + return script + except Exception as e: + return f"# Error generating script: {str(e)}" + + def run_claude_setup_script(self) -> str: + """Run the bedrock_claude_code.py setup script""" + try: + yield "Running Claude Code setup script..." + + # Get the script path + script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "bedrock_claude_code.py") + + # Check if the script exists + if not os.path.exists(script_path): + return f"Error: Setup script not found at {script_path}" + + # Run the script + import subprocess + process = subprocess.Popen( + ["python3", script_path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + # Read output in real-time + while True: + # Check if process has terminated + if process.poll() is not None: + # Read any remaining output + remaining_stdout, remaining_stderr = process.communicate() + if remaining_stdout: + yield f"Script output: {remaining_stdout}" + if remaining_stderr: + yield f"Script errors: {remaining_stderr}" + break + + # Read available output without blocking + stdout_chunk = process.stdout.readline() + if stdout_chunk: + yield f"Script output: {stdout_chunk}" + + # Small sleep to prevent CPU hogging + time.sleep(0.1) + + yield "Claude Code setup script completed. Please check your terminal for the interactive Claude Code session." + except Exception as e: + yield f"Error running Claude Code setup script: {str(e)}" + + def launch_claude_shell(self) -> str: + """Launch Claude Code shell directly""" + try: + yield "Launching Claude Code shell..." + + # Configure environment variables + env_vars = os.environ.copy() + env_vars["CLAUDE_CODE_USE_BEDROCK"] = "1" + + # Use the model from dropdown + if hasattr(self, 'claude_model') and self.claude_model == "haiku": + env_vars["ANTHROPIC_MODEL"] = "us.anthropic.claude-3-5-haiku-20241022-v1:0" + env_vars["ANTHROPIC_SMALL_FAST_MODEL"] = "us.anthropic.claude-3-5-haiku-20241022-v1:0" + else: + env_vars["ANTHROPIC_MODEL"] = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + + # Handle caching setting + if hasattr(self, 'claude_caching') and self.claude_caching: + if "DISABLE_PROMPT_CACHING" in env_vars: + del env_vars["DISABLE_PROMPT_CACHING"] + else: + env_vars["DISABLE_PROMPT_CACHING"] = "true" + + # Launch Claude shell + import subprocess + subprocess.run(["claude"], env=env_vars, cwd=self.claude_working_dir) + + yield "Claude Code shell session ended." + except Exception as e: + yield f"Error launching Claude Code shell: {str(e)}" + + + + +def create_gradio_interface(): + """Create and launch the Gradio interface""" + app = GradioBedrockApp() + + with gr.Blocks(title="Amazon Bedrock Prompt Caching Demo") as interface: + gr.Markdown("# Amazon Bedrock Prompt Caching Demo") + gr.Image("/Users/arunmamb/myTechs/bedrock/bedrock-prompt-caching/src/images/prompt-caching.png", label="Prompt Caching Diagram") + + with gr.Tabs() as tabs: + # RAG Chat Tab + with gr.TabItem("RAG Chat"): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("### 1. Load Document") + + with gr.Tab("From URL"): + url_input = gr.Textbox(label="Enter URL") + url_examples = gr.Examples( + examples=app.sample_urls, + inputs=url_input + ) + load_url_btn = gr.Button("Load from URL") + url_status = gr.Textbox(label="Status", interactive=False) + + with gr.Tab("From File"): + file_input = gr.File(label="Upload Document") + load_file_btn = gr.Button("Load File") + file_status = gr.Textbox(label="Status", interactive=False) + + gr.Markdown("### 2. Select Model") + model_dropdown = gr.Dropdown( + choices=app.get_models(), + label="Model (with caching support)", + value=app.get_models()[0] if app.get_models() else None, + info="Select a model that supports prompt caching" + ) + temperature_slider = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.0, + step=0.1, + label="Temperature" + ) + set_model_btn = gr.Button("Set Model") + model_status = gr.Textbox(label="Status", interactive=False) + + gr.Markdown("### Settings") + cache_checkbox = gr.Checkbox(label="Use Cache", value=True) + gr.Markdown(""" + #### Cache Checkpoint Information + Cache checkpoints have minimum token requirements: + - Claude 3.7 Sonnet: 1,024 tokens minimum + - Claude 3.5 models: 1,024 tokens minimum + - Nova models: 512 tokens minimum + + Cache has a 5-minute TTL that resets with each hit. + + #### Model Format Information + - Anthropic Claude models: Use "cache_control" with "ephemeral" type + - Amazon Nova models: Use "cachePoint" with "default" type + """) + + gr.Markdown("### Benchmark") + epochs_slider = gr.Slider( + minimum=1, + maximum=10, + value=3, + step=1, + label="Test Iterations" + ) + benchmark_btn = gr.Button("Run Benchmark") + + with gr.Column(scale=2): + gr.Markdown("### Chat") + chatbot = gr.Chatbot(height=500, type="messages") + msg = gr.Textbox(label="Your Question") + + with gr.Row(): + submit_btn = gr.Button("Submit") + clear_btn = gr.Button("Clear History") + + error_output = gr.Textbox(label="Error", visible=True, interactive=False) + usage_info = gr.Textbox(label="Usage Information", interactive=False) + cache_stats = gr.Textbox(label="Cache Statistics", interactive=False, lines=10) + + with gr.Accordion("Common Questions", open=False): + common_q_btns = [gr.Button(q) for q in app.common_questions] + + with gr.Accordion("Benchmark Results", open=False): + with gr.Row(): + benchmark_output = gr.Markdown(label="Benchmark Results") + benchmark_plot = gr.Image(label="Benchmark Plot") + + # Multi-Turn Chat Tab + with gr.TabItem("Multi-Turn Chat"): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("### 1. Load Document") + + with gr.Tab("From URL"): + mt_url_input = gr.Textbox(label="Enter URL") + mt_url_examples = gr.Examples( + examples=app.sample_urls, + inputs=mt_url_input + ) + mt_load_url_btn = gr.Button("Load from URL") + mt_url_status = gr.Textbox(label="Status", interactive=False) + + with gr.Tab("From File"): + mt_file_input = gr.File(label="Upload Document") + mt_load_file_btn = gr.Button("Load File") + mt_file_status = gr.Textbox(label="Status", interactive=False) + + gr.Markdown("### 2. Select Model") + mt_model_dropdown = gr.Dropdown( + choices=app.get_models(), + label="Model (with caching support)", + value=app.get_models()[0] if app.get_models() else None, + info="Select a model that supports prompt caching" + ) + + gr.Markdown("### 3. Model Parameters") + with gr.Row(): + mt_max_tokens = gr.Number(value=2048, label="Max Tokens", minimum=1, maximum=4096, step=1) + mt_temperature = gr.Slider(value=0.5, label="Temperature", minimum=0.0, maximum=1.0, step=0.1) + + with gr.Row(): + mt_top_p = gr.Slider(value=0.8, label="Top P", minimum=0.0, maximum=1.0, step=0.1) + mt_top_k = gr.Number(value=250, label="Top K (Anthropic only)", minimum=0, maximum=500, step=1) + + mt_stop_sequences = gr.Textbox(value="", label="Stop Sequences (comma-separated)") + + mt_set_model_btn = gr.Button("Set Model") + mt_model_status = gr.Textbox(label="Status", interactive=False) + + gr.Markdown("### Experiment Results") + mt_show_stats_btn = gr.Button("Show Summary Statistics") + mt_stats_output = gr.Markdown(label="Summary Statistics") + + with gr.Column(scale=2): + gr.Markdown("### Multi-Turn Chat with Prompt Caching") + mt_chatbot = gr.Chatbot(height=500, type="messages") + mt_msg = gr.Textbox(label="Your Question") + + with gr.Row(): + mt_submit_btn = gr.Button("Submit") + mt_clear_btn = gr.Button("Clear History") + + mt_error_output = gr.Textbox(label="Error", visible=True, interactive=False) + mt_usage_info = gr.Markdown(label="Cache & Performance Metrics") + + # Claude Code Tab + with gr.TabItem("Claude Code"): + with gr.Row(): + # Left column for reference information + with gr.Column(scale=1): + gr.Markdown("## Claude Code Reference") + + with gr.Accordion("Command Line Setup", open=True): + gr.Markdown(""" + ### Running Claude Code from the Command Line + + To use Claude Code, open a terminal and run: + + ```bash + # Navigate to the project directory + cd + + # Run the Claude Code setup script + python3 src/bedrock_claude_code.py + ``` + + This will guide you through setup and launch Claude Code. + """) + + # Keep hidden elements for backward compatibility + install_btn = gr.Button("Install Claude Code", visible=False) + install_status = gr.Textbox(label="Installation Status", interactive=False, visible=False) + check_aws_btn = gr.Button("Check AWS Configuration", visible=False) + aws_status = gr.Textbox(label="AWS Status", interactive=False, visible=False) + check_version_btn = gr.Button("Check Claude Version", visible=False) + version_status = gr.Textbox(label="Version Information", interactive=False, visible=False) + cc_model_dropdown = gr.Dropdown( + choices=[ + {"label": "Claude 3.7 Sonnet", "value": "sonnet", "info": "us.anthropic.claude-3-7-sonnet-20250219-v1:0"}, + {"label": "Claude 3.5 Haiku", "value": "haiku", "info": "us.anthropic.claude-3-5-haiku-20241022-v1:0"} + ], + label="Model", + value="sonnet", + info="Select Claude model to use", + visible=False + ) + cc_cache_checkbox = gr.Checkbox(label="Enable Prompt Caching", value=True, visible=False) + configure_btn = gr.Button("Configure Environment", visible=False) + configure_status = gr.Textbox(label="Configuration Status", interactive=False, visible=False) + generate_script_btn = gr.Button("Generate Script", visible=False) + script_output = gr.Code(language="shell", label="Environment Script", lines=10, visible=False) + + # Right column for instructions + with gr.Column(scale=2): + gr.Markdown("# Getting Started with Claude Code: Step-by-Step Guide") + + with gr.Accordion("Getting Started with Claude Code", open=True): + gr.Markdown(""" + ### Quick Start with bedrock_claude_code.py + + The easiest way to get started with Claude Code is to run the provided script: + + ```bash + # Navigate to the project root directory + cd + + # Run the Claude Code setup script + python3 src/bedrock_claude_code.py + ``` + + This script will: + 1. Install Claude Code if needed + 2. Let you select the Claude model (Sonnet or Haiku) + 3. Configure prompt caching + 4. Launch the Claude Code interactive shell + + ### What to Expect + + When you run the script, you'll see: + ``` + === Claude Code Setup Chat === + User: I need to set up Claude Code with Bedrock. + Assistant: I'll help you set up Claude Code with Bedrock. Running the setup now... + Installing Claude Code... + Claude Code installed successfully. + + A: Which Claude model would you like to use? + 1. Claude 3.7 Sonnet (more capable) + 2. Claude 3.5 Haiku (faster) + Enter your choice (1 or 2): + ``` + + After making your selections, Claude Code will launch automatically. + """) + + with gr.Accordion("Learning Claude Code Commands", open=False): + gr.Markdown(""" + ### Initialize a Project + ```bash + # Inside your project directory + claude + > /init + ``` + This scans your project and creates a CLAUDE.md guide + + ### Get Help with Available Commands + ```bash + > /help + ``` + + ### Try Basic Coding Assistance + ```bash + > Create a simple HTML calculator + ``` + Review Claude's suggestions. When prompted to create files, type "yes" to approve. + + ### View Project Files + ```bash + > /ls + ``` + + ### Examine File Contents + ```bash + > /cat calculator.html + ``` + + ### Edit a File + ```bash + > Edit calculator.html to add scientific functions + ``` + + ### Run Commands in the Terminal + ```bash + > /sh ls -la + ``` + """) + + with gr.Accordion("Managing Context and Costs", open=False): + gr.Markdown(""" + ### Check Token Usage + ```bash + > /cost + ``` + + ### Compact the Conversation + ```bash + > /compact + ``` + This preserves important context while reducing token usage + + ### Clear the Conversation + ```bash + > /clear + ``` + Use when starting a completely new task + """) + + with gr.Accordion("Advanced Usage", open=False): + gr.Markdown(""" + ### Work with Multiple Files + ```bash + > Create a complete web app with HTML, CSS, and JavaScript files for a todo list + ``` + Notice how Claude handles multiple file creation and relationships + + ### Debug Code Issues + ```bash + > There's a bug in my calculator.html file where division by zero doesn't show an error. Can you fix it? + ``` + + ### Explain Code Architecture + ```bash + > Explain how the JavaScript functions in calculator.html work together + ``` + + ### Generate Tests + ```bash + > Create test cases for the calculator functions + ``` + + ### Optimize Code + ```bash + > Optimize the calculator.js file for better performance + ``` + """) + + with gr.Accordion("Tips for Effective Usage", open=False): + gr.Markdown(""" + - **Be Specific**: Provide clear, detailed instructions + - **Review Changes**: Always review code before approving file modifications + - **Use /compact Regularly**: Helps manage token usage during long sessions + - **Create Project Structure**: Start with a clear project outline for better results + - **Ask for Explanations**: If you don't understand Claude's suggestions, ask for clarification + - **Monitor Costs**: Use the /cost command periodically to track token usage + + This step-by-step guide will help you learn Claude Code effectively through hands-on practice with the command line interface. Each step builds on the previous one, allowing you to gradually explore more advanced features as you become comfortable with the basics. + """) + + # Hidden element for backward compatibility + current_dir = gr.Textbox(label="Current Directory", value=os.getcwd(), interactive=False, visible=False) + + # Single-Turn Chat Event handlers + load_url_btn.click( + fn=app.load_document_from_url, + inputs=[url_input], + outputs=[url_status] + ) + + load_file_btn.click( + fn=app.load_document_from_file, + inputs=[file_input], + outputs=[file_status] + ) + + set_model_btn.click( + fn=app.set_model_and_temp, + inputs=[model_dropdown, temperature_slider], + outputs=[model_status] + ) + + cache_checkbox.change( + fn=app.toggle_cache, + inputs=[cache_checkbox], + outputs=[] + ) + + + + submit_btn.click( + fn=app.chat_with_document, + inputs=[msg], + outputs=[chatbot, error_output, usage_info, cache_stats], + api_name="chat" + ) + + clear_btn.click( + fn=app.clear_history, + inputs=[], + outputs=[chatbot, error_output, usage_info, cache_stats] + ) + + benchmark_btn.click( + fn=app.run_benchmark, + inputs=[epochs_slider], + outputs=[benchmark_output, benchmark_plot] + ) + + # Connect common question buttons + for btn in common_q_btns: + btn.click( + fn=app.chat_with_document, + inputs=[btn], + outputs=[chatbot, error_output, usage_info, cache_stats] + ) + + # Multi-Turn Chat Event handlers + mt_load_url_btn.click( + fn=app.load_multi_turn_document_from_url, + inputs=[mt_url_input], + outputs=[mt_url_status] + ) + + mt_load_file_btn.click( + fn=app.load_multi_turn_document_from_file, + inputs=[mt_file_input], + outputs=[mt_file_status] + ) + + mt_set_model_btn.click( + fn=app.set_multi_turn_model, + inputs=[mt_model_dropdown], + outputs=[mt_model_status] + ) + + mt_submit_btn.click( + fn=app.multi_turn_chat, + inputs=[mt_msg, mt_max_tokens, mt_temperature, mt_top_p, mt_top_k, mt_stop_sequences], + outputs=[mt_chatbot, mt_error_output, mt_usage_info], + api_name="multi_turn_chat" + ).then( + fn=lambda: "", + outputs=[mt_msg] + ) + + mt_clear_btn.click( + fn=app.clear_multi_turn_history, + inputs=[], + outputs=[mt_chatbot, mt_error_output] + ) + + mt_show_stats_btn.click( + fn=app.show_experiment_stats, + inputs=[], + outputs=[mt_stats_output] + ) + + # Claude Code Event handlers + install_btn.click( + fn=app.install_claude_code, + inputs=[], + outputs=[install_status] + ) + + configure_btn.click( + fn=app.configure_claude_environment, + inputs=[cc_model_dropdown, cc_cache_checkbox], + outputs=[configure_status] + ) + + # Remove directory handlers + + # This handler is already defined above + + # No run script button handler needed + + # Keep handlers for backward compatibility (hidden elements) + check_aws_btn.click( + fn=app.check_aws_config, + inputs=[], + outputs=[aws_status] + ) + + check_version_btn.click( + fn=app.check_claude_version, + inputs=[], + outputs=[version_status] + ) + + generate_script_btn.click( + fn=app.generate_environment_script, + inputs=[cc_model_dropdown, cc_cache_checkbox], + outputs=[script_output] + ) + + return interface + + +if __name__ == "__main__": + interface = create_gradio_interface() + interface.launch(share=False) \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/prompt_caching_multi_turn.py b/BedrockPromptCachingRoutingDemo/src/prompt_caching_multi_turn.py new file mode 100644 index 000000000..99952180e --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/prompt_caching_multi_turn.py @@ -0,0 +1,1170 @@ +import time +import json +import os +import pandas as pd +import hashlib +import requests +from file_processor import FileProcessor +from bedrock_service import BedrockService +from model_manager import ModelManager + + +class PromptCachingExperiment: + """Handles prompt caching experiments with foundation models + + This class provides functionality for running multi-turn conversations + with prompt caching, collecting metrics, and analyzing cache performance. + It can be used standalone or integrated with other applications. + """ + + def __init__(self, bedrock_service=None, model_manager=None): + """Initialize the experiment with Bedrock service and model manager + + Args: + bedrock_service: BedrockService instance for making API calls + If None, a new instance will be created + model_manager: ModelManager instance for model selection and resolution + If None, a new instance will be created + + Raises: + TypeError: If provided services are not of the correct type + """ + # Validate service types if provided + if bedrock_service is not None and not isinstance(bedrock_service, BedrockService): + raise TypeError("bedrock_service must be an instance of BedrockService") + + if model_manager is not None and not isinstance(model_manager, ModelManager): + raise TypeError("model_manager must be an instance of ModelManager") + + # Use provided services or create new ones if not provided + self.bedrock_service = bedrock_service if bedrock_service else BedrockService() + self.model_manager = model_manager if model_manager else ModelManager(self.bedrock_service) + + self.all_experiments_data = [] # Stores metrics for all conversation turns + self.cache_store = {} # Stores cache information by cache key + self.sample_text = "" # Context text, set by set_context_text + + # Default model parameters + self.model_params = { + "max_tokens": 2048, + "temperature": 0.5, + "top_p": 0.8, + "stop_sequences": None + } + + # Default questions for each turn in automated experiments + self.default_questions = [ + "Please summarize the story.", + "What is the subject of the story?", + "Where did Romeo and Juliet first meet?", + "What is the name of the woman Romeo loved before?", + "How does Mercutio die?", + "What method did Juliet use to fake her death?", + ] + + def set_context_text(self, text): + """Set the context text for the experiment + + Args: + text: The text to use as context + """ + self.sample_text = text + print(f"Context text set ({len(text)} characters)") + + def load_context_from_file(self, file_path): + """Load context text from a file using FileProcessor + + Args: + file_path: Path to the file to load + + Returns: + True if successful, False otherwise + """ + try: + # Check if file extension is supported + _, ext = os.path.splitext(file_path) + if ext.lower() not in FileProcessor.SUPPORTED_EXTENSIONS: + print(f"Unsupported file type. Supported types: {', '.join(FileProcessor.SUPPORTED_EXTENSIONS)}") + return False + + # Create a file-like object with name attribute for FileProcessor + class FileObj: + def __init__(self, path): + self.name = os.path.basename(path) + self._file = open(path, 'rb') + + def getvalue(self): + self._file.seek(0) + return self._file.read() + + def close(self): + self._file.close() + + # Process the file using FileProcessor + file_obj = FileObj(file_path) + self.sample_text = FileProcessor.process_uploaded_file(file_obj) + file_obj.close() + + if not self.sample_text: + print("No text extracted from file.") + return False + + print(f"Context loaded from file: {file_path} ({len(self.sample_text)} characters)") + return True + except Exception as e: + print(f"Error loading context from file: {e}") + return False + + def load_context_from_url(self, url): + """Load context text from a URL + + Args: + url: URL to fetch the context from + + Returns: + True if successful, False otherwise + """ + try: + response = requests.get(url) + response.raise_for_status() + self.sample_text = response.text + + if not self.sample_text: + print("Empty document received from URL.") + return False + + print(f"Context loaded from URL: {url} ({len(self.sample_text)} characters)") + return True + except Exception as e: + print(f"Error loading context from URL: {e}") + return False + + def run_experiments(self, n_experiments=1, n_turns=6, model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0"): + """Run multiple experiments with multi-turn conversations + + Args: + n_experiments: Number of experiment iterations to run + n_turns: Number of conversation turns in each experiment + model_id: The Bedrock model ID to use (default: Claude 3.7 Sonnet) + + Returns: + List of all turn data dictionaries from the experiments + + Raises: + ValueError: If parameters are invalid + """ + if not self.sample_text: + print("No context text set. Please set context text before running experiments.") + return [] + + if n_experiments < 1: + raise ValueError("Number of experiments must be at least 1") + + if n_turns < 1: + raise ValueError("Number of turns must be at least 1") + + if not model_id or not isinstance(model_id, str): + raise ValueError("Model ID must be a non-empty string") + + print(f"Running experiments with model: {model_id}") + print("Enabling cache testing mode - will repeat the same question twice to test caching") + print("\nCache Information:") + print("- First turn always includes context text which will be cached") + + all_experiment_data = [] + + for exp_num in range(n_experiments): + print(f"Running experiment {exp_num+1}/{n_experiments}") + experiment_data = [] + conversation = [] + + # Simulate n_turns + for turn in range(n_turns): + # Get the current question + question = self.default_questions[min(turn, len(self.default_questions)-1)] + + # For even turns after turn 0, repeat the previous question to test caching + if turn > 0 and turn % 2 == 0: + question = self.default_questions[min(turn-1, len(self.default_questions)-1)] + print(f" Turn {turn+1}/{n_turns}: {question} (REPEATED to test cache)") + else: + print(f" Turn {turn+1}/{n_turns}: {question}") + + turn_data = self.process_turn(turn, conversation, question, model_id) + experiment_data.append(turn_data) + time.sleep(30) # Wait between requests + + all_experiment_data.extend(experiment_data) + self.all_experiments_data.extend(experiment_data) + + # Save results + self.save_results() + + return all_experiment_data + + def interactive_chat(self, model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0"): + """Run an interactive chat session where user can type questions + + Args: + model_id: The Bedrock model ID to use (default: Claude 3.7 Sonnet) + + Returns: + List of turn data dictionaries containing metrics for all turns + + Raises: + ValueError: If model_id is invalid + """ + if not self.sample_text: + print("No context text set. Please set context text before starting chat.") + return [] + + if not model_id or not isinstance(model_id, str): + raise ValueError("Model ID must be a non-empty string") + + print(f"Starting interactive chat session with context ({len(self.sample_text)} characters)") + print(f"Using model: {model_id}") + print("Type 'exit' to end the conversation") + print("Type 'model:' to change the model") + print("Type 'select' to select a model from the list") + print("Type 'stats' to show experiment statistics") + print("\nCache Information:") + print("- First turn always includes context text which will be cached") + + conversation = [] + turn_data = [] + turn = 0 + + # First turn always includes the sample text + while True: + if turn == 0: + print("\nFirst message will include the context text") + + # Get user question + user_question = input("\nEnter your question: ") + if user_question.lower() == 'exit': + break + + # Check if user wants to change the model + if user_question.lower().startswith('model:'): + new_model_id = user_question[6:].strip() + model_id = self.model_manager.get_model_arn_from_inference_profiles(new_model_id) + print(f"Model changed to: {model_id}") + continue + + # Check if user wants to select a model from the list + if user_question.lower() == 'select': + model_id = self.model_manager.select_model() + print(f"Model selected: {model_id}") + continue + + # Check if user wants to see experiment statistics + if user_question.lower() == 'stats': + if turn_data: + self.display_metrics() + else: + print("No experiment data available yet.") + continue + + # Process the turn + data = self.process_turn(turn, conversation, user_question, model_id) + turn_data.append(data) + + # Print the response + print("\nAssistant:", conversation[-1]["content"][0]["text"]) + + # Print metrics for this turn + print("\n" + self.get_turn_metrics(data)) + + # Print cache information + self.print_cache_info(data) + + turn += 1 + + # Save results if any turns were processed + if turn_data: + self.all_experiments_data.extend(turn_data) + self.save_results() + self.display_metrics() + + def process_turn(self, turn, conversation, question, model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0"): + """Process a single conversation turn with prompt caching + + This method handles the core functionality of the experiment: + - Constructing the message with appropriate cache controls + - Invoking the model with the provided conversation history + - Collecting and returning detailed metrics about the interaction + - Updating the conversation history with the new turn + + Args: + turn: The turn number (0-based) in the conversation + conversation: The conversation history list that will be modified in-place + question: The user question to process in this turn + model_id: The Bedrock model ID to use (default: Claude 3.7 Sonnet) + + Returns: + Dictionary containing detailed metrics for this turn including: + - turn: Turn number (1-based) + - question: The question that was asked + - input_tokens: Number of input tokens processed + - output_tokens: Number of output tokens generated + - cache_creation_input_tokens: Number of tokens written to cache + - cache_read_input_tokens: Number of tokens read from cache + - invocation_latency: Total time taken for the request + - cache_key: Unique key for this content in the cache + - is_cache_hit: Boolean indicating if cache was hit + + Raises: + ValueError: If parameters are invalid or sample_text is not set + RuntimeError: If model invocation fails + """ + if not isinstance(conversation, list): + raise ValueError("Conversation must be a list") + + if not question or not isinstance(question, str): + raise ValueError("Question must be a non-empty string") + + if turn == 0 and not self.sample_text: + raise ValueError("Context text not set. Call set_context_text before processing the first turn.") + + if not isinstance(turn, int) or turn < 0: + raise ValueError("Turn must be a non-negative integer") + + # Get the resolved model ID from the model manager + model_id = self.model_manager.get_model_arn_from_inference_profiles(model_id) + + # Generate a cache key for this content + cache_key = self.generate_cache_key(self.sample_text if turn == 0 else "", question) + + # Determine if this is a Claude model or Nova model + is_claude_model = "anthropic" in model_id.lower() or "claude" in model_id.lower() # Used throughout the method + + # Construct message content for this turn based on model type + content = [] + if turn == 0: + if is_claude_model: + content.append({"type": "text", "text": self.sample_text}) + else: + # For Nova models, no "type" field + content.append({"text": self.sample_text}) + + # Add the current question - format depends on model type + if is_claude_model: + content.append({ + "type": "text", + "text": question + " " + }) + else: + content.append({ + "text": question + " " + }) + + # Construct full messages list with history + current message + messages = conversation.copy() + messages.append({"role": "user", "content": content}) + + # Prepare version without cache control for conversation history + content_for_saving = [] + if turn == 0: + if is_claude_model: + content_for_saving.append({"type": "text", "text": self.sample_text}) + else: + # For Nova models, no "type" field + content_for_saving.append({"text": self.sample_text}) + + # Add question with format based on model type + if is_claude_model: + content_for_saving.append({"type": "text", "text": question + " "}) + else: + content_for_saving.append({"text": question + " "}) + + # Print request info + print("\n" + "="*60) + print(f"🔄 PROCESSING TURN {turn+1}") + print("="*60) + print(f"Question: \"{question}\"") + print(f"Cache key: {cache_key}") + print(f"Model: {model_id}") + + # Check if this is a repeated question + if cache_key in [data.get("cache_key") for data in self.all_experiments_data]: + print("⚠️ This question was asked before - potential cache hit!") + + # Show what's being sent + if turn == 0: + print(f"📄 Including context text ({len(self.sample_text)} characters)") + context_preview = self.sample_text[:100] + "..." if len(self.sample_text) > 100 else self.sample_text + print(f"Context preview: \"{context_preview}\"") + else: + print("📝 Using conversation history from previous turns") + + # Record the start time for performance measurement + start_time = time.time() + + try: + response = self.invoke_model(messages, model_id, self.model_params) + except Exception as e: + raise RuntimeError(f"Model invocation failed: {str(e)}") + + # Record the end time + end_time = time.time() + invocation_latency = end_time - start_time + + # Validate response format + if not isinstance(response, dict) or "content" not in response or "usage" not in response: + raise RuntimeError("Invalid response format from model") + + # Ensure content has the expected structure + if not response["content"] or not isinstance(response["content"], list): + raise RuntimeError("Invalid content format in response") + + # Update conversation history - reuse the is_claude_model variable from earlier + + # Add user message to conversation history + if is_claude_model: + conversation.append({"role": "user", "content": content_for_saving}) + else: + # For Nova models, no "type" field in content + nova_content = [] + for item in content_for_saving: + if "text" in item: + nova_content.append({"text": item["text"]}) + conversation.append({"role": "user", "content": nova_content}) + + # Add assistant response to conversation history + try: + if is_claude_model: + conversation.append({ + "role": "assistant", + "content": [{"type": "text", "text": response["content"][0]["text"]}] + }) + else: + # For Nova models, no "type" field + conversation.append({ + "role": "assistant", + "content": [{"text": response["content"][0]["text"]}] + }) + except (KeyError, IndexError) as e: + raise RuntimeError(f"Failed to extract response text: {str(e)}") + + # Get metrics - handle different field names for different models + metrics = response["usage"] + + # Normalize metrics field names + input_tokens = metrics.get("input_tokens", metrics.get("inputTokens", 0)) + output_tokens = metrics.get("output_tokens", metrics.get("outputTokens", 0)) + cache_read_tokens = metrics.get("cache_read_input_tokens", metrics.get("cacheReadInputTokens", 0)) + cache_write_tokens = metrics.get("cache_creation_input_tokens", metrics.get("cacheWriteInputTokens", 0)) + + # Store cache information + is_cache_hit = cache_read_tokens > 0 + cache_info = { + "cache_key": cache_key, + "is_cache_hit": is_cache_hit, + "cached_content": self.sample_text if turn == 0 else "", + "question": question, + "cache_creation_tokens": cache_write_tokens, + "cache_read_tokens": cache_read_tokens + } + self.cache_store[cache_key] = cache_info + + # Return data for this turn with normalized field names + return { + "turn": turn + 1, + "question": question, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cache_creation_input_tokens": cache_write_tokens, + "cache_read_input_tokens": cache_read_tokens, + "invocation_latency": invocation_latency, + "cache_key": cache_key, + "is_cache_hit": is_cache_hit + } + + def generate_cache_key(self, context, question): + """Generate a simple cache key for tracking cached content + + Args: + context: The context text (if any) + question: The question text + + Returns: + A hash string to use as cache key + """ + content = (context + question).encode('utf-8') + return hashlib.md5(content).hexdigest()[:8] + + def get_cache_summary(self, turn_data): + """Get cache summary information as a formatted string + + This method analyzes the cache performance for a specific turn + and generates a detailed summary of cache usage, including: + - Cache hit/miss status + - Token savings from cache + - Percentage of prompt that was cached + - Description of what content was cached or retrieved + + Args: + turn_data: Dictionary containing the turn metrics from process_turn + + Returns: + Formatted string with cache summary information ready for display + """ + cache_key = turn_data["cache_key"] + is_cache_hit = turn_data["is_cache_hit"] + + # Calculate performance metrics + input_tokens = turn_data['input_tokens'] + input_tokens_cache_read = turn_data['cache_read_input_tokens'] + input_tokens_cache_create = turn_data['cache_creation_input_tokens'] + total_input_tokens = input_tokens + input_tokens_cache_read + percentage_cached = (input_tokens_cache_read / total_input_tokens * 100) if total_input_tokens > 0 else 0 + + summary = ["\n📊 Cache Summary:"] + summary.append(f" Cache key: {cache_key}") + + if is_cache_hit: + summary.append(f" ✅ CACHE HIT") + summary.append(f" Cache read tokens: {input_tokens_cache_read}") + summary.append(f" Input tokens saved: {input_tokens_cache_read}") + summary.append(f" {percentage_cached:.1f}% of input prompt cached ({total_input_tokens} tokens)") + + # Show what was retrieved from cache + if turn_data["turn"] == 1: + summary.append(" Content retrieved from cache: Context text (first turn)") + cached_content = self.sample_text[:100] + "..." if len(self.sample_text) > 100 else self.sample_text + summary.append(f" Cached content preview: \"{cached_content}\"") + else: + summary.append(" Content retrieved from cache: Previous question context") + + summary.append(" This means the model didn't need to process this content again,") + summary.append(" resulting in faster response time and lower token usage.") + else: + summary.append(f" ❌ CACHE MISS") + summary.append(f" Cache creation tokens: {input_tokens_cache_create}") + + # Show what was written to cache + if turn_data["turn"] == 1: + summary.append(" Content written to cache: Context text (first turn)") + cached_content = self.sample_text[:100] + "..." if len(self.sample_text) > 100 else self.sample_text + summary.append(f" Cached content preview: \"{cached_content}\"") + else: + summary.append(" Content written to cache: Current question context") + + summary.append(" This content will be cached for future similar queries.") + + return "\n".join(summary) + + def print_cache_info(self, turn_data): + """Print information about cache usage for this turn + + Args: + turn_data: Data for the current turn + """ + summary = self.get_cache_summary(turn_data) + print(summary) + + def invoke_model(self, messages, model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", model_params=None): + """Invoke foundation model through Bedrock with appropriate caching strategy + + This method handles the different invocation patterns required for different models: + - Anthropic Claude models: Uses invoke_model API with cache_control + - Amazon Nova models: Uses converse API with cachePoint + + Args: + messages: The conversation messages to send (list of message objects) + model_id: The Bedrock model ID to use (default: Claude 3.7 Sonnet) + + Returns: + Standardized response dictionary with: + - content: List of content blocks with text + - usage: Dictionary with token usage metrics + + Raises: + ValueError: If bedrock_service is not initialized or messages is invalid + boto3.exceptions.Boto3Error: For AWS service-related errors + """ + if not self.bedrock_service: + raise ValueError("BedrockService not initialized. Please provide a valid bedrock_service in the constructor.") + + if not messages or not isinstance(messages, list): + raise ValueError("Messages must be a non-empty list") + + try: + runtime_client = self.bedrock_service.get_runtime_client() + except Exception as e: + raise ValueError(f"Failed to get Bedrock runtime client: {str(e)}") + + # Get the resolved model ID from the model manager + resolved_model_id = self.model_manager.get_model_arn_from_inference_profiles(model_id) + if resolved_model_id != model_id: + print(f"Using resolved model ID: {resolved_model_id}") + model_id = resolved_model_id + + # For Anthropic Claude models + if "anthropic" in model_id.lower() or "claude" in model_id.lower(): + # Use the invoke_model API with the proper format for Claude models + # Prepare user message with cache control + user_message = None + for msg in messages: + if msg["role"] == "user": + user_message = msg + break + + if user_message and len(user_message["content"]) > 1: + # Format the content with cache_control for the second part + content_with_cache = [] + for i, content_item in enumerate(user_message["content"]): + if i == 0: # First item (context) + content_with_cache.append(content_item) + else: # Second item (question) + content_with_cache.append({ + "type": "text", + "text": content_item["text"], + "cache_control": { + "type": "ephemeral" + } + }) + user_message["content"] = content_with_cache + + # Get model parameters or use defaults + params = model_params or self.model_params or {} + max_tokens = params.get("max_tokens", 2048) + temperature = params.get("temperature", 0.5) + top_p = params.get("top_p", 0.8) + top_k = params.get("top_k", 250) + stop_sequences = params.get("stop_sequences") + + # Prepare the request body + request_body = { + "anthropic_version": "bedrock-2023-05-31", + "system": "Reply concisely", + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k + } + + # Add stop sequences if provided + if stop_sequences: + request_body["stop_sequences"] = stop_sequences + + # Print request details + print("\nSending request to Claude model:") + print(f" - Using invoke_model API with model: {model_id}") + print(" - Cache control set to 'ephemeral' for the question") + + response = runtime_client.invoke_model( + modelId=model_id, + body=json.dumps(request_body) + ) + + response_data = json.loads(response['body'].read()) + + # Check for cache metrics + input_tokens = response_data["usage"].get("input_tokens", 0) + output_tokens = response_data["usage"].get("output_tokens", 0) + cache_read = response_data["usage"].get("cache_read_input_tokens", 0) + cache_write = response_data["usage"].get("cache_creation_input_tokens", 0) + + if cache_read > 0: + total_input_tokens = input_tokens + cache_read + print("\n✅ CACHE HIT: Content was retrieved from cache") + print(f" - Cache read tokens: {cache_read}") + print(f" - {(cache_read / total_input_tokens * 100):.1f}% of input prompt cached ({total_input_tokens} tokens)") + elif cache_write > 0: + print("\n📝 CACHE WRITE: Content was written to cache") + print(f" - Cache write tokens: {cache_write}") + else: + print("\n❌ NO CACHE: No caching occurred") + + return response_data + + # For Amazon Nova models and other models that use converse API + else: + # Format messages for Nova models + nova_messages = [] + + # Process each message to create proper Nova format + for msg in messages: + if msg["role"] == "user": + # Create a new content array without "type" field + nova_content = [] + + # Process each content item - ensure document content is preserved + for i, content_item in enumerate(msg["content"]): + if "type" in content_item and content_item["type"] == "text": + # Convert Claude format to Nova format + nova_content.append({ + "text": content_item["text"] + }) + elif "text" in content_item: + # Already in Nova format or simple text + nova_content.append({ + "text": content_item["text"] + }) + + # Add cachePoint between context and question if there are multiple content items + if len(nova_content) > 1: + nova_content.insert(1, { + "cachePoint": { + "type": "default" + } + }) + + # Add the properly formatted user message + nova_messages.append({ + "role": "user", + "content": nova_content + }) + elif msg["role"] == "assistant": + # Create assistant message with proper format + nova_content = [] + for content_item in msg["content"]: + if "type" in content_item and content_item["type"] == "text": + # Convert Claude format to Nova format + nova_content.append({ + "text": content_item["text"] + }) + elif "text" in content_item: + # Already in Nova format or simple text + nova_content.append({ + "text": content_item["text"] + }) + + nova_messages.append({ + "role": "assistant", + "content": nova_content + }) + + # Print request details + print("\nSending request to Amazon Nova model:") + print(f" - Using converse API with model: {model_id}") + print(" - cachePoint inserted between context and question") + + # Get model parameters or use defaults + params = model_params or self.model_params or {} + max_tokens = params.get("max_tokens", 300) + temperature = params.get("temperature", 0.3) + top_p = params.get("top_p", 0.1) + + # Create system message for Nova models + system_message = [{ + "text": "Reply Concisely" + }] + + # Call Bedrock with converse API for Nova models + response = runtime_client.converse( + modelId=model_id, + messages=nova_messages, + system=system_message, + inferenceConfig={ + "maxTokens": max_tokens, + "temperature": temperature, + "topP": top_p + } + ) + + # Process response + output_message = response["output"]["message"] + response_text = output_message["content"][0]["text"] + + # Create response_data in the same format as invoke_model would return + response_data = { + "content": [{"text": response_text}], + "usage": response["usage"] + } + + # Check for cache metrics + input_tokens = response_data["usage"].get("inputTokens", 0) + output_tokens = response_data["usage"].get("outputTokens", 0) + cache_read = response_data["usage"].get("cacheReadInputTokens", 0) + cache_write = response_data["usage"].get("cacheWriteInputTokens", 0) + + if cache_read > 0: + total_input_tokens = input_tokens + cache_read + print("\n✅ CACHE HIT: Content was retrieved from cache") + print(f" - Cache read tokens: {cache_read}") + print(f" - {(cache_read / total_input_tokens * 100):.1f}% of input prompt cached ({total_input_tokens} tokens)") + elif cache_write > 0: + print("\n📝 CACHE WRITE: Content was written to cache") + print(f" - Cache write tokens: {cache_write}") + else: + print("\n❌ NO CACHE: No caching occurred") + + return response_data + + def save_results(self, filename="cache_experiment_results.csv"): + """Save experiment results to CSV file + + Args: + filename: Name of the CSV file to save results to + + Returns: + True if successful, False otherwise + """ + try: + if not self.all_experiments_data: + print("No experiment data to save.") + return False + + pd.DataFrame(self.all_experiments_data).to_csv(filename, index=False) + print(f"Results saved to {filename}") + return True + except Exception as e: + print(f"Error saving results: {str(e)}") + return False + + def get_experiment_summary(self): + """Get experiment summary statistics as a formatted string + + This method analyzes all collected experiment data and generates + a comprehensive summary of the results, including: + - Statistical analysis of token usage and latency + - Cache hit rate calculation + - Overall experiment metrics + + Returns: + Formatted string with experiment summary statistics + that can be displayed to users or logged + """ + if not self.all_experiments_data: + return "No experiment data available." + + df = pd.DataFrame(self.all_experiments_data) + + # Calculate cache hit information + cache_hits = sum(1 for data in self.all_experiments_data if data.get("is_cache_hit", False)) + total_turns = len(self.all_experiments_data) + hit_rate = (cache_hits / total_turns * 100) if total_turns > 0 else 0 + + # Calculate timing information + avg_latency = df["invocation_latency"].mean() if "invocation_latency" in df.columns else 0 + min_latency = df["invocation_latency"].min() if "invocation_latency" in df.columns else 0 + max_latency = df["invocation_latency"].max() if "invocation_latency" in df.columns else 0 + + # Calculate average times for cache hits vs misses + if cache_hits > 0 and (total_turns - cache_hits) > 0: + cache_hit_times = [data["invocation_latency"] for data in self.all_experiments_data if data.get("is_cache_hit", False)] + cache_miss_times = [data["invocation_latency"] for data in self.all_experiments_data if not data.get("is_cache_hit", False)] + + avg_hit_time = sum(cache_hit_times) / len(cache_hit_times) if cache_hit_times else 0 + avg_miss_time = sum(cache_miss_times) / len(cache_miss_times) if cache_miss_times else 0 + + if avg_miss_time > 0: + speedup = (avg_miss_time - avg_hit_time) / avg_miss_time * 100 if avg_miss_time > 0 else 0 + + # Format the output with Markdown tables + result = [] + result.append("## Experiment Results Summary") + + # Summary Statistics Table + result.append("\n### Summary Statistics") + result.append("| Metric | Mean | Median | Min | Max |") + result.append("|--------|------|--------|-----|-----|") + + metrics = ["input_tokens", "output_tokens", "cache_creation_input_tokens", + "cache_read_input_tokens", "invocation_latency"] + + for metric in metrics: + if metric in df.columns: + stats = df[metric].describe() + result.append(f"| {metric.replace('_', ' ').title()} | {stats['mean']:.2f} | {stats['50%']:.2f} | {stats['min']:.2f} | {stats['max']:.2f} |") + + # Cache Performance Table + result.append("\n### Cache Performance") + result.append("| Metric | Value |") + result.append("|--------|-------|") + result.append(f"| Cache Hit Rate | {hit_rate:.1f}% ({cache_hits}/{total_turns} turns) |") + result.append(f"| Total Turns | {total_turns} |") + result.append(f"| Cache Hits | {cache_hits} |") + result.append(f"| Cache Misses | {total_turns - cache_hits} |") + + # Timing Information Table + result.append("\n### Timing Information") + result.append("| Metric | Value (seconds) |") + result.append("|--------|----------------|") + result.append(f"| Average Response Time | {avg_latency:.2f} |") + result.append(f"| Minimum Response Time | {min_latency:.2f} |") + result.append(f"| Maximum Response Time | {max_latency:.2f} |") + + if cache_hits > 0 and (total_turns - cache_hits) > 0: + result.append(f"| Average Time with Cache Hit | {avg_hit_time:.2f} |") + result.append(f"| Average Time with Cache Miss | {avg_miss_time:.2f} |") + if avg_miss_time > 0: + result.append(f"| Cache Speedup | {speedup:.1f}% |") + + # Individual Turn Data Table + result.append("\n### Individual Turn Data") + result.append("| Turn | Question | Cache Hit | Input Tokens | Cache Read Tokens | Response Time (s) |") + result.append("|------|----------|-----------|--------------|-------------------|-------------------|") + + for data in self.all_experiments_data: + turn = data.get("turn", "N/A") + question = data.get("question", "N/A") + is_cache_hit = "✅" if data.get("is_cache_hit", False) else "❌" + input_tokens = data.get("input_tokens", 0) + cache_read = data.get("cache_read_input_tokens", 0) + latency = data.get("invocation_latency", 0) + + # Truncate question if too long + if len(question) > 30: + question = question[:27] + "..." + + result.append(f"| {turn} | {question} | {is_cache_hit} | {input_tokens} | {cache_read} | {latency:.2f} |") + + return "\n".join(result) + + def get_turn_metrics(self, turn_data): + """Get metrics for a specific turn as a formatted string + + This method formats the metrics for a single conversation turn + into a human-readable string that can be displayed to users. + It includes token usage, cache performance, and timing information. + + Args: + turn_data: Dictionary containing the turn metrics from process_turn + + Returns: + Formatted string with turn metrics ready for display + """ + input_tokens = turn_data['input_tokens'] + output_tokens = turn_data['output_tokens'] + input_tokens_cache_read = turn_data['cache_read_input_tokens'] + input_tokens_cache_create = turn_data['cache_creation_input_tokens'] + elapsed_time = turn_data['invocation_latency'] + is_cache_hit = turn_data['is_cache_hit'] + turn_number = turn_data['turn'] + + # Calculate the percentage of input prompt cached + total_input_tokens = input_tokens + input_tokens_cache_read + percentage_cached = (input_tokens_cache_read / total_input_tokens * 100) if total_input_tokens > 0 else 0 + + # Format as markdown table + metrics = [] + metrics.append(f"## Turn {turn_number} Metrics") + + # Cache status with emoji + if is_cache_hit: + metrics.append(f"### ✅ CACHE HIT") + else: + metrics.append(f"### ❌ CACHE MISS") + + # Timing table + metrics.append("\n#### Timing Information") + metrics.append("| Metric | Value |") + metrics.append("|--------|-------|") + metrics.append(f"| Start time | {time.strftime('%H:%M:%S', time.localtime(time.time() - elapsed_time))} |") + metrics.append(f"| End time | {time.strftime('%H:%M:%S', time.localtime(time.time()))} |") + metrics.append(f"| Response time | {elapsed_time:.2f} seconds |") + + # Token usage table + metrics.append("\n#### Token Usage") + metrics.append("| Metric | Value |") + metrics.append("|--------|-------|") + metrics.append(f"| User input tokens | {input_tokens} |") + metrics.append(f"| Output tokens | {output_tokens} |") + + if is_cache_hit: + metrics.append(f"| Cache read tokens | {input_tokens_cache_read} |") + metrics.append(f"| Percentage cached | {percentage_cached:.1f}% of input prompt |") + metrics.append(f"| Total input tokens | {total_input_tokens} |") + else: + metrics.append(f"| Cache write tokens | {input_tokens_cache_create} |") + + return "\n".join(metrics) + + def display_metrics(self): + """Display all metrics including summary statistics and individual turn data + + Returns: + True if metrics were displayed, False if no data available + """ + if not self.all_experiments_data: + print("No experiment data available.") + return False + + try: + # Print summary statistics + print("\n===== Summary Statistics =====") + df = pd.DataFrame(self.all_experiments_data) + + # Check for required columns + required_columns = ["input_tokens", "output_tokens", "cache_creation_input_tokens", + "cache_read_input_tokens", "invocation_latency"] + missing_columns = [col for col in required_columns if col not in df.columns] + + if missing_columns: + print(f"Warning: Missing columns in experiment data: {', '.join(missing_columns)}") + # Use only available columns + available_columns = [col for col in required_columns if col in df.columns] + if available_columns: + print(df[available_columns].describe()) + else: + print("No metric columns available for summary statistics.") + else: + print(df[required_columns].describe()) + + # Print cache hit information + cache_hits = sum(1 for data in self.all_experiments_data if data.get("is_cache_hit", False)) + total_turns = len(self.all_experiments_data) + hit_rate = (cache_hits / total_turns * 100) if total_turns > 0 else 0 + + print(f"\n===== Cache Performance =====") + print(f"Cache Hit Rate: {hit_rate:.1f}% ({cache_hits}/{total_turns} turns)") + print(f"Total turns: {total_turns}") + print(f"Cache hits: {cache_hits}") + print(f"Cache misses: {total_turns - cache_hits}") + + # Print timing information + print(f"\n===== Timing Information =====") + avg_latency = df["invocation_latency"].mean() if "invocation_latency" in df.columns else "N/A" + min_latency = df["invocation_latency"].min() if "invocation_latency" in df.columns else "N/A" + max_latency = df["invocation_latency"].max() if "invocation_latency" in df.columns else "N/A" + + print(f"Average response time: {avg_latency:.2f} seconds") + print(f"Minimum response time: {min_latency:.2f} seconds") + print(f"Maximum response time: {max_latency:.2f} seconds") + + # Calculate average times for cache hits vs misses + if cache_hits > 0 and (total_turns - cache_hits) > 0: + cache_hit_times = [data["invocation_latency"] for data in self.all_experiments_data if data.get("is_cache_hit", False)] + cache_miss_times = [data["invocation_latency"] for data in self.all_experiments_data if not data.get("is_cache_hit", False)] + + avg_hit_time = sum(cache_hit_times) / len(cache_hit_times) if cache_hit_times else 0 + avg_miss_time = sum(cache_miss_times) / len(cache_miss_times) if cache_miss_times else 0 + + print(f"Average time with cache hit: {avg_hit_time:.2f} seconds") + print(f"Average time with cache miss: {avg_miss_time:.2f} seconds") + if avg_miss_time > 0: + print(f"Cache speedup: {(avg_miss_time - avg_hit_time) / avg_miss_time * 100:.1f}%") + + # Print individual turn data + print("\n===== Individual Turn Data =====") + print(df) + + return True + except Exception as e: + print(f"Error displaying metrics: {str(e)}") + return False + + # Removed redundant print_results method that just called self.display_metrics() + + +class ExperimentManager: + """Manages the creation and display of prompt caching experiments + + This class provides utility methods for creating properly configured + experiments and displaying their results in a structured way. + """ + + @staticmethod + def create_experiment(): + """Create and configure a PromptCachingExperiment with shared services + + This method creates the necessary services and experiment + instance with proper dependency injection, ensuring that all components + share the same service instances. + + Returns: + Configured PromptCachingExperiment instance ready to use + + Raises: + ImportError: If required modules are not available + RuntimeError: If service initialization fails + """ + try: + # Create shared services + bedrock_service = BedrockService() + + # Create model manager with the bedrock service + model_manager = ModelManager(bedrock_service=bedrock_service) + + # Create experiment with shared services + experiment = PromptCachingExperiment( + bedrock_service=bedrock_service, + model_manager=model_manager + ) + + return experiment + except ImportError as e: + raise ImportError(f"Required module not available: {str(e)}") + except Exception as e: + raise RuntimeError(f"Failed to create experiment: {str(e)}") + + # Removed redundant display_experiment_results method that just called experiment.display_metrics() + +if __name__ == "__main__": + # Create experiment using the ExperimentManager + experiment = ExperimentManager.create_experiment() + + # Default model + default_model_name = "Claude 3.7 Sonnet" + default_model = experiment.model_manager.get_model_arn_from_inference_profiles("us.anthropic.claude-3-7-sonnet-20250219-v1:0") + + # Load context text + print("Select context source:") + print("1. Load from file (RomeoAndJuliet.txt)") + print("2. Enter file path") + print("3. Enter URL") + + context_choice = input("Enter choice (1-3): ") + + if context_choice == "1": + experiment.load_context_from_file("RomeoAndJuliet.txt") + elif context_choice == "2": + file_path = input("Enter file path: ") + experiment.load_context_from_file(file_path) + elif context_choice == "3": + url = input("Enter URL: ") + experiment.load_context_from_url(url) + else: + print("Invalid choice. Using default Romeo and Juliet text.") + experiment.load_context_from_file("RomeoAndJuliet.txt") + + # Ask user which mode to run + print("\nSelect mode:") + print("1. Run predefined experiment") + print("2. Interactive chat mode") + print("3. Interactive chat mode (with metrics on demand)") + + choice = input("Enter choice (1-3): ") + + # Model selection + print("\nSelect model:") + print(f"1. Use default model [{default_model_name}]") + print("2. Select from available models") + + model_choice = input("Enter choice (1-2): ") + + if model_choice == "1": + model_id = default_model + print(f"Using default model: {default_model} [{default_model_name}]") + elif model_choice == "2": + # Add default model to the model list for selection + if "Anthropic Claude Models" in experiment.model_manager.models: + if default_model not in experiment.model_manager.models["Anthropic Claude Models"]: + experiment.model_manager.models["Anthropic Claude Models"].append(default_model) + model_id = experiment.model_manager.select_model() + else: + print(f"Invalid choice. Using default model: {default_model} [{default_model_name}]") + model_id = default_model + + if choice == "1": + # Run predefined experiment + experiment.run_experiments(n_experiments=1, n_turns=6, model_id=model_id) + # Display experiment results + experiment.display_metrics() + elif choice == "2": + # Interactive chat with metrics after each turn + experiment.interactive_chat(model_id=model_id) + # Display final experiment summary + experiment.display_metrics() + elif choice == "3": + # Interactive chat with metrics on demand (using 'stats' command) + print("\nType 'stats' during chat to see current metrics") + experiment.interactive_chat(model_id=model_id) + # Display final experiment summary + experiment.display_metrics() + else: + print("Invalid choice. Exiting.") \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/prompt_router_app.py b/BedrockPromptCachingRoutingDemo/src/prompt_router_app.py new file mode 100644 index 000000000..7cb9bb5ff --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/prompt_router_app.py @@ -0,0 +1,204 @@ +import streamlit as st +import os +from bedrock_prompt_routing import PromptRouterManager, ChatSession +from file_processor import FileProcessor +import time + +class BedrockChatUI: + def __init__(self): + self.setup_streamlit() + self.init_session_state() + self.region = os.getenv('AWS_REGION', 'us-east-1') + self.router_manager = PromptRouterManager(region=self.region) + + def setup_streamlit(self): + st.set_page_config( + page_title="Amazon Bedrock Prompt Router Demo", + layout="wide", + initial_sidebar_state="expanded" + ) + + st.markdown(""" + + """, unsafe_allow_html=True) + + def init_session_state(self): + if 'chat_session' not in st.session_state: + st.session_state.chat_session = None + if 'messages' not in st.session_state: + st.session_state.messages = [] + if 'current_router' not in st.session_state: + st.session_state.current_router = None + + def render_router_sidebar(self): + with st.sidebar: + st.title("🤖 Prompt Router Selection") + routers = self.router_manager.get_prompt_routers() + + router_options = [(r['name'], r['arn'], r['provider']) for r in routers] + selected_router = st.selectbox( + "Select a Prompt-Router", + options=router_options, + format_func=lambda x: f"{x[0]} ({x[2]})", + key="router_select" + ) + + if selected_router: + if st.session_state.current_router != selected_router[1]: + st.session_state.current_router = selected_router[1] + st.session_state.chat_session = ChatSession( + model_id=selected_router[1], + region=self.region + ) + + router_details = self.router_manager.get_router_details(selected_router[1]) + + st.divider() + st.subheader("📊 Prompt-Router Details") + st.markdown(f"**Provider:** {selected_router[2]}") + st.markdown(f"**Type:** {router_details.get('type', 'N/A')}") + + if router_details['supported_models']: + st.divider() + st.subheader("🔧 Available Models") + for model in router_details['supported_models']: + st.markdown(f"- `{model}`") + + def render_chat_area(self): + chat_col, stats_col = st.columns([2, 1]) + + with chat_col: + st.title("💬 Amazon Bedrock Prompt Router Demo") + + with st.container(): + st.markdown("### 📎 File Upload") + uploaded_file = st.file_uploader( + "Upload PDF, DOCX, or TXT file", + type=['pdf', 'docx', 'txt'] + ) + + if uploaded_file and st.session_state.chat_session: + if FileProcessor.is_supported_file(uploaded_file.name): + with st.spinner("Processing file..."): + extracted_text = FileProcessor.process_uploaded_file(uploaded_file) + if extracted_text: + trace, model_used = st.session_state.chat_session.send_message(extracted_text) + st.session_state.messages.append({ + "role": "user", + "content": f"📄 *Content from {uploaded_file.name}*" + }) + response = st.session_state.chat_session.messages[-1]["content"][0]["text"] + st.session_state.messages.append({ + "role": "assistant", + "content": response + }) + st.rerun() + + st.divider() + + chat_container = st.container() + with chat_container: + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + if prompt := st.chat_input("Type your message..."): + if not st.session_state.chat_session: + st.error("⚠️ Please select a router first") + return + + st.session_state.messages.append({"role": "user", "content": prompt}) + + with st.chat_message("assistant"): + with st.spinner("Thinking..."): + trace, model_used = st.session_state.chat_session.send_message(prompt) + response = st.session_state.chat_session.messages[-1]["content"][0]["text"] + st.markdown(response) + st.session_state.messages.append({"role": "assistant", "content": response}) + + with stats_col: + if st.session_state.chat_session: + st.title("📈 Usage Stats") + stats = st.session_state.chat_session.usage_stats + + col1, col2 = st.columns(2) + with col1: + st.metric("Chats", stats.total_chats) + with col2: + st.metric("Total Tokens", stats.total_input_tokens + stats.total_output_tokens) + + st.divider() + + st.subheader("Token Usage") + st.metric("Input Tokens", stats.total_input_tokens) + st.metric("Output Tokens", stats.total_output_tokens) + + st.divider() + + st.subheader("⚡ Model Usage") + for model, count in stats.model_invocations.items(): + st.metric(f"{model}", count, "calls") + + st.divider() + elapsed_minutes = max(0.1, (time.time() - stats.start_time) / 60) + st.metric("Session Time", f"{elapsed_minutes:.2f} min") + + def run(self): + self.render_router_sidebar() + self.render_chat_area() + +if __name__ == "__main__": + app = BedrockChatUI() + app.run() \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/rag_evaluator.py b/BedrockPromptCachingRoutingDemo/src/rag_evaluator.py new file mode 100644 index 000000000..1e16ac07d --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/rag_evaluator.py @@ -0,0 +1,213 @@ +# rag_evaluator.py +import time +import boto3 +import pandas as pd +from typing import List, Dict, Any, Optional + +# Import RAGAS components +from ragas import SingleTurnSample, EvaluationDataset +from ragas import evaluate +from ragas.metrics import ( + context_recall, + context_precision, + answer_correctness +) + +class RAGEvaluator: + """ + A class to evaluate RAG (Retrieval-Augmented Generation) systems using the RAGAS framework. + """ + + def __init__(self, + bedrock_runtime_client, + bedrock_agent_runtime_client, + text_generation_model_id: str = "anthropic.claude-3-haiku-20240307-v1:0", + evaluation_model_id: str = "anthropic.claude-3-sonnet-20240229-v1:0", + embedding_model_id: str = "amazon.titan-embed-text-v2:0"): + """ + Initialize the RAG evaluator with AWS clients and model IDs. + + Args: + bedrock_runtime_client: Boto3 client for Bedrock runtime + bedrock_agent_runtime_client: Boto3 client for Bedrock agent runtime + text_generation_model_id: Model ID for text generation + evaluation_model_id: Model ID for evaluation + embedding_model_id: Model ID for embeddings + """ + self.bedrock_runtime_client = bedrock_runtime_client + self.bedrock_agent_runtime_client = bedrock_agent_runtime_client + self.text_generation_model_id = text_generation_model_id + self.evaluation_model_id = evaluation_model_id + self.embedding_model_id = embedding_model_id + + # Import LangChain components here to avoid circular imports + from langchain_aws import ChatBedrock + from langchain_aws import BedrockEmbeddings + + # Initialize LangChain components + self.llm_for_evaluation = ChatBedrock( + model_id=evaluation_model_id, + client=bedrock_runtime_client + ) + self.bedrock_embeddings = BedrockEmbeddings( + model_id=embedding_model_id, + client=bedrock_runtime_client + ) + + # Define metrics for evaluation + self.metrics = [ + context_recall, + context_precision, + answer_correctness + ] + + def retrieve_and_generate(self, query: str, kb_id: str) -> Dict[str, Any]: + """ + Perform a retrieve and generate operation using the knowledge base. + + Args: + query: Query text + kb_id: Knowledge base ID + + Returns: + Response from the retrieve and generate operation + """ + start = time.time() + response = self.bedrock_agent_runtime_client.retrieve_and_generate( + input={ + 'text': query + }, + retrieveAndGenerateConfiguration={ + 'type': 'KNOWLEDGE_BASE', + 'knowledgeBaseConfiguration': { + 'knowledgeBaseId': kb_id, + 'modelArn': self.text_generation_model_id + } + } + ) + time_spent = time.time() - start + print(f"[Response]\n{response['output']['text']}\n") + print(f"[Invocation time]\n{time_spent}\n") + + return response + + def prepare_eval_dataset(self, kb_id: str, questions: List[str], ground_truths: List[str]) -> EvaluationDataset: + """ + Prepare an evaluation dataset for RAGAS. + + Args: + kb_id: Knowledge base ID + questions: List of questions + ground_truths: List of ground truth answers + + Returns: + RAGAS evaluation dataset + """ + # Lists to store SingleTurnSample objects + samples = [] + + for question, ground_truth in zip(questions, ground_truths): + # Get response and context from your retrieval system + response = self.retrieve_and_generate(question, kb_id) + answer = response["output"]["text"] + + # Process contexts + contexts = [] + for citation in response["citations"]: + context_texts = [ + ref["content"]["text"] + for ref in citation["retrievedReferences"] + if "content" in ref and "text" in ref["content"] + ] + contexts.extend(context_texts) + + # Create a SingleTurnSample + sample = SingleTurnSample( + user_input=question, + retrieved_contexts=contexts, + response=answer, + reference=ground_truth + ) + + # Add the sample to our list + samples.append(sample) + + # Rate limiting if needed + # time.sleep(10) + + # Create EvaluationDataset from samples + eval_dataset = EvaluationDataset(samples=samples) + + return eval_dataset + + def evaluate_kb(self, kb_id: str, questions: List[str], ground_truths: List[str]) -> pd.DataFrame: + """ + Evaluate a knowledge base using RAGAS. + + Args: + kb_id: Knowledge base ID + questions: List of questions + ground_truths: List of ground truth answers + + Returns: + DataFrame with evaluation results + """ + # Prepare evaluation dataset + eval_dataset = self.prepare_eval_dataset(kb_id, questions, ground_truths) + + # Evaluate using RAGAS + result = evaluate( + dataset=eval_dataset, + metrics=self.metrics, + llm=self.llm_for_evaluation, + embeddings=self.bedrock_embeddings, + ) + + # Convert to DataFrame + result_df = result.to_pandas() + + return result_df + + def compare_kb_strategies(self, kb_ids: Dict[str, str], questions: List[str], ground_truths: List[str]) -> pd.DataFrame: + """ + Compare multiple knowledge base strategies. + + Args: + kb_ids: Dictionary mapping strategy names to knowledge base IDs + questions: List of questions + ground_truths: List of ground truth answers + + Returns: + DataFrame comparing the strategies + """ + results = {} + + # Evaluate each knowledge base + for strategy_name, kb_id in kb_ids.items(): + print(f"\n=== Evaluating {strategy_name} strategy ===") + result_df = self.evaluate_kb(kb_id, questions, ground_truths) + + # Calculate average metrics + avg_metrics = result_df[['context_recall', 'context_precision', 'answer_correctness']].mean() + results[strategy_name] = avg_metrics + + # Create comparison DataFrame + comparison_df = pd.DataFrame(results) + + return comparison_df + + def format_comparison(self, comparison_df: pd.DataFrame) -> pd.DataFrame: + """ + Format the comparison DataFrame with highlighting. + + Args: + comparison_df: DataFrame comparing strategies + + Returns: + Styled DataFrame with highlighting + """ + def highlight_max(s): + is_max = s == s.max() + return ['background-color: #90EE90' if v else '' for v in is_max] + + return comparison_df.style.apply(highlight_max, axis=1) diff --git a/BedrockPromptCachingRoutingDemo/src/requirements.txt b/BedrockPromptCachingRoutingDemo/src/requirements.txt new file mode 100644 index 000000000..eb2116c0f --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/requirements.txt @@ -0,0 +1,12 @@ +boto3>=1.34.0 +requests>=2.31.0 +pandas>=2.0.0 +matplotlib>=3.7.0 +seaborn>=0.12.0 +numpy>=1.24.0 +PyPDF2>=3.0.0 +python-docx>=1.0.0 +opensearch-py +retrying +langchain_aws +langchain \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/sec-10-k/2019-Annual-Report.pdf b/BedrockPromptCachingRoutingDemo/src/sec-10-k/2019-Annual-Report.pdf new file mode 100644 index 000000000..5fa3a5080 Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/sec-10-k/2019-Annual-Report.pdf differ diff --git a/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2020-Annual-Report.pdf b/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2020-Annual-Report.pdf new file mode 100644 index 000000000..8fa8d7e31 Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2020-Annual-Report.pdf differ diff --git a/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2021-Annual-Report.pdf b/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2021-Annual-Report.pdf new file mode 100644 index 000000000..c59128aac Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2021-Annual-Report.pdf differ diff --git a/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2022-Annual-Report.pdf b/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2022-Annual-Report.pdf new file mode 100644 index 000000000..00a6ea759 Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-2022-Annual-Report.pdf differ diff --git a/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-com-Inc-2023-Annual-Report.pdf b/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-com-Inc-2023-Annual-Report.pdf new file mode 100644 index 000000000..12e0f1285 Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/sec-10-k/Amazon-com-Inc-2023-Annual-Report.pdf differ diff --git a/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/LICENSE b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/LICENSE new file mode 100644 index 000000000..1625c1793 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/LICENSE @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/NOTICE b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/NOTICE new file mode 100644 index 000000000..629a77538 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/NOTICE @@ -0,0 +1,9 @@ +Synthetic data for 10k reports by Amazon.com, Inc. or its affiliates. +SPDX-License-Identifier: CC0-1.0 + +@article{Mixtral-8x7B modelcard, +title={Mixtral-8x7B-Instruct-v0.1 Model Card}, +author={Mistral AI Team}, +year={2024}, +url = {https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1} +} \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/README.md b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/README.md new file mode 100644 index 000000000..0f16e2357 --- /dev/null +++ b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/README.md @@ -0,0 +1,22 @@ +# Synthetic data + +All the data in this folder is generated synthetically using `mistralai/Mixtral-8x7B-v0.1`from [Hugging Face](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1). + + +## Citation +Synthetic data for 10k reports by Amazon.com, Inc. or its affiliates. +SPDX-License-Identifier: CC0-1.0 + +``` +@article{Mixtral-8x7B modelcard, +title={Mixtral-8x7B-Instruct-v0.1 Model Card}, +author={Mistral AI Team}, +year={2024}, +url = {https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1} +} +``` +For more details about Mistral AI Team, please refer this [link](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1). + +## Contributing + +We welcome community contributions! Please ensure your sample aligns with [AWS best practices](_!https://aws.amazon.com/architecture/well-architected/_), and please update the Contents section of this README file with a link to your sample, along with a description.. \ No newline at end of file diff --git a/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/bda.m4v b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/bda.m4v new file mode 100644 index 000000000..86ee37e62 Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/bda.m4v differ diff --git a/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/octank_financial_10K.pdf b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/octank_financial_10K.pdf new file mode 100644 index 000000000..827fd9aeb Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/octank_financial_10K.pdf differ diff --git a/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/podcastdemo.mp3 b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/podcastdemo.mp3 new file mode 100644 index 000000000..29dfa67ee Binary files /dev/null and b/BedrockPromptCachingRoutingDemo/src/synthetic_dataset/podcastdemo.mp3 differ diff --git a/docs/bedrock_prompt_caching_routing.md b/docs/bedrock_prompt_caching_routing.md new file mode 100644 index 000000000..78c4e22e4 --- /dev/null +++ b/docs/bedrock_prompt_caching_routing.md @@ -0,0 +1,312 @@ +# Amazon Bedrock Prompt Caching and Routing Workshop + +[Open in GitHub](https://github.com/aws-samples/amazon-bedrock-samples/tree/main/BedrockPromptCachingRoutingDemo) + +**Tags:** bedrock, prompt-caching, prompt-routing, claude, optimization, cost-reduction, performance + +

Overview

+ +This workshop demonstrates Amazon Bedrock's prompt caching and routing capabilities using the latest Claude models. You'll learn how to reduce latency and costs through intelligent prompt caching and how to route requests to optimal models based on your specific needs. + +**Key Learning Outcomes:** +- Implement prompt caching to reduce costs and latency by up to 90% +- Use prompt routing for intelligent model selection based on query complexity +- Understand best practices for Bedrock API integration +- Monitor performance and usage statistics for optimization + +

Context or Details about feature/use case

+ +### Prompt Caching +Prompt caching allows you to cache frequently used prompts, reducing both latency and costs for subsequent requests. This is particularly beneficial for: +- **Document analysis workflows** - Cache document context for multiple questions +- **Multi-turn conversations** - Maintain conversation history efficiently +- **Repetitive query patterns** - Avoid re-processing similar requests +- **Cost optimization** - Reduce API calls by up to 90% for repeated content + +### Prompt Routing +Prompt routing intelligently directs requests to the most appropriate model based on: +- **Query complexity** - Simple queries to fast models, complex ones to capable models +- **Cost optimization requirements** - Balance performance vs. cost based on business needs +- **Performance needs** - Route time-sensitive queries to fastest available models +- **Model capabilities** - Match query type to model strengths + +### Supported Models +- **Claude Haiku 3**: Fast, cost-effective for simple tasks and quick responses +- **Claude Sonnet 3.5**: Balanced performance and cost for general use cases +- **Claude Opus 3**: Most capable for complex reasoning and analysis tasks +- **Amazon Nova Models**: Latest AWS-native models with optimized performance + +

Prerequisites

+ +Before running this workshop, ensure you have: + +1. **AWS Account** with appropriate permissions for Amazon Bedrock +2. **Amazon Bedrock access** with Claude models enabled in your region +3. **AWS CLI configured** with valid credentials +4. **Python 3.8+** installed on your system +5. **Jupyter Notebook** environment (JupyterLab, VS Code, or similar) + +### Required AWS Permissions +Your AWS credentials need the following IAM permissions: +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModel", + "bedrock:ListFoundationModels", + "bedrock:GetModelInvocationLoggingConfiguration" + ], + "Resource": "*" + } + ] +} +``` + +### Model Access +Ensure you have enabled access to Claude models in the Amazon Bedrock console: +- Navigate to Amazon Bedrock → Model access +- Request access to Anthropic Claude models +- Wait for approval (usually immediate for Claude models) + +

Setup

+ +### 1. Clone the Repository +```bash +git clone https://github.com/aws-samples/amazon-bedrock-samples.git +cd amazon-bedrock-samples/BedrockPromptCachingRoutingDemo +``` + +### 2. Install Dependencies +```bash +pip install -r requirements.txt +``` + +### 3. Configure AWS Credentials +```bash +aws configure +# Enter your AWS Access Key ID, Secret Access Key, and preferred region +``` + +### 4. Verify Bedrock Access +```python +import boto3 +client = boto3.client('bedrock-runtime', region_name='us-east-1') +print("✅ Bedrock client initialized successfully") +``` + +

Your code with comments starts here

+ +The workshop is implemented as an interactive Jupyter notebook that demonstrates: + +### Core Components + +1. **ModelManager Class** - Handles different Claude model configurations and selection +2. **BedrockService Class** - Manages Bedrock API interactions with intelligent caching +3. **PromptRouter Class** - Implements smart routing logic based on query analysis + +### Key Features Demonstrated + +#### Prompt Caching Implementation +```python +class BedrockService: + def __init__(self, client): + self.client = client + self.cache = {} # In-memory cache for demo + self.cache_stats = {'hits': 0, 'misses': 0} + + def invoke_model_with_cache(self, model_id: str, prompt: str, use_cache: bool = True): + cache_key = f"{model_id}:{hash(prompt)}" + + # Check cache first + if use_cache and cache_key in self.cache: + self.cache_stats['hits'] += 1 + return self.cache[cache_key] # Cache hit! + + # Make API call and cache result + response = self._make_api_call(model_id, prompt) + if use_cache: + self.cache[cache_key] = response + + return response +``` + +#### Intelligent Prompt Routing +```python +class PromptRouter: + def route_prompt(self, prompt: str, priority: str = 'balanced'): + complexity = self.analyze_query_complexity(prompt) + + if priority == 'cost': + return 'haiku' # Cheapest option + elif priority == 'performance': + return 'opus' # Most capable + else: # balanced + if complexity == 'simple': + return 'haiku' + elif complexity == 'complex': + return 'opus' + else: + return 'sonnet' +``` + +### Interactive Demonstrations + +The notebook includes three hands-on demonstrations: + +1. **Prompt Caching Demo** - Shows cache hits vs misses with performance metrics +2. **Intelligent Routing Demo** - Demonstrates model selection for different query types +3. **Performance Comparison** - Quantifies the benefits of caching and smart routing + +

Other Considerations or Advanced section or Best Practices

+ +### Production Best Practices + +#### Cache Management +- **Persistent Storage**: Use Redis or DynamoDB for distributed caching in production +- **TTL Policies**: Implement time-to-live for cache entries to ensure data freshness +- **Memory Management**: Monitor cache size and implement LRU eviction policies +- **Cache Warming**: Pre-populate cache with frequently used prompts during deployment + +#### Routing Optimization +- **Machine Learning**: Develop ML models to predict optimal routing based on historical performance +- **Cost Tracking**: Implement real-time cost monitoring across different models +- **A/B Testing**: Continuously test routing strategies to optimize for your specific use cases +- **Fallback Strategies**: Ensure high availability with automatic failover to backup models + +#### Security Considerations +- **Data Privacy**: Ensure sensitive information is not cached inappropriately +- **Access Control**: Implement proper IAM policies for Bedrock access +- **Audit Logging**: Log all API calls and routing decisions for compliance +- **Encryption**: Use encryption at rest and in transit for cached data + +#### Monitoring and Observability +```python +# Example CloudWatch metrics integration +import boto3 +cloudwatch = boto3.client('cloudwatch') + +def publish_cache_metrics(cache_stats): + cloudwatch.put_metric_data( + Namespace='BedrockWorkshop/Cache', + MetricData=[ + { + 'MetricName': 'CacheHitRate', + 'Value': cache_stats['hit_rate_percent'], + 'Unit': 'Percent' + }, + { + 'MetricName': 'CacheSize', + 'Value': cache_stats['cached_items'], + 'Unit': 'Count' + } + ] + ) +``` + +### Advanced Features + +#### Multi-Modal Routing +Extend routing logic to handle different content types: +- Text-only queries → Claude models +- Image analysis → Claude Vision models +- Document processing → Specialized document models + +#### Streaming with Caching +Implement streaming responses while maintaining cache benefits: +```python +def stream_with_cache(self, model_id, prompt): + cache_key = self._generate_cache_key(model_id, prompt) + + if cache_key in self.cache: + # Stream cached response + for chunk in self._stream_cached_response(cache_key): + yield chunk + else: + # Stream live response and cache + full_response = "" + for chunk in self._stream_live_response(model_id, prompt): + full_response += chunk + yield chunk + self.cache[cache_key] = full_response +``` + +

Next Steps

+ +After completing this workshop, consider these advanced implementations: + +### 1. Production Deployment +- **Containerization**: Package the solution using Docker for consistent deployment +- **Serverless Architecture**: Deploy using AWS Lambda for automatic scaling +- **API Gateway Integration**: Create REST APIs for external application integration +- **Infrastructure as Code**: Use CloudFormation or CDK for reproducible deployments + +### 2. Enhanced Monitoring +- **Custom Dashboards**: Build CloudWatch dashboards for real-time monitoring +- **Alerting**: Set up alerts for cache performance degradation or routing failures +- **Cost Analysis**: Implement detailed cost tracking and optimization recommendations +- **Performance Benchmarking**: Establish baseline metrics and track improvements + +### 3. Advanced Features +- **User Personalization**: Learn user preferences to improve routing decisions +- **Multi-Region Deployment**: Implement cross-region caching and routing +- **Custom Model Integration**: Add support for fine-tuned models +- **Batch Processing**: Optimize for high-volume batch processing scenarios + +### 4. Integration Patterns +- **Microservices**: Integrate caching and routing as microservices in larger applications +- **Event-Driven Architecture**: Use EventBridge for asynchronous processing +- **Data Pipeline Integration**: Incorporate into ETL/ELT workflows +- **Real-Time Applications**: Build chat applications with optimized response times + +### 5. Machine Learning Enhancements +- **Predictive Routing**: Use ML to predict optimal model selection +- **Anomaly Detection**: Identify unusual patterns in cache performance +- **Auto-Scaling**: Dynamically adjust cache size based on usage patterns +- **Quality Scoring**: Implement response quality metrics for routing optimization + +

Cleanup

+ +### Resource Cleanup +The workshop uses minimal AWS resources, but follow these steps to ensure clean termination: + +1. **Clear Cache Memory** + ```python + # Clear in-memory cache + bedrock_service.cache.clear() + bedrock_service.cache_stats = {'hits': 0, 'misses': 0} + ``` + +2. **Close Bedrock Client** + ```python + # Properly close the Bedrock client + bedrock_client.close() + ``` + +3. **Review CloudWatch Logs** (if logging was enabled) + - Check CloudWatch Logs for any error messages + - Review API call patterns and costs in AWS Cost Explorer + +### Cost Considerations +- **API Calls**: The workshop makes minimal API calls to Bedrock models +- **Caching Benefits**: Demonstrates significant cost savings through reduced API usage +- **Monitoring**: No additional charges for basic CloudWatch metrics + +### Final Statistics +The notebook displays comprehensive statistics including: +- Total API calls made vs. cached responses +- Cache hit rate percentage and performance improvements +- Model usage distribution across different query types +- Estimated cost savings from caching implementation + +### Verification Steps +1. Confirm all notebook cells executed successfully +2. Review cache performance metrics (should show >50% hit rate) +3. Verify routing decisions match query complexity expectations +4. Check that cleanup completed without errors + +**Workshop Complete!** 🎉 + +You've successfully learned how to implement prompt caching and intelligent routing with Amazon Bedrock, achieving significant performance improvements and cost optimizations for AI-powered applications. \ No newline at end of file