diff --git a/backend/config.py b/backend/config.py index f742e70f5..e9bcc89b3 100644 --- a/backend/config.py +++ b/backend/config.py @@ -2,6 +2,15 @@ # Setting to True will stream a mock response instead of calling the OpenAI API # TODO: Should only be set to true when value is 'True', not any abitrary truthy value import os +import boto3 + +def has_valid_aws_credentials(): + sts = boto3.client('sts') + try: + sts.get_caller_identity() + return True + except boto3.exceptions.ClientError: + return False NUM_VARIANTS = 2 @@ -10,6 +19,7 @@ ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", None) OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None) +AWS_CREDENTIALS = has_valid_aws_credentials() # Image generation (optional) REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None) diff --git a/backend/evals/core.py b/backend/evals/core.py index 0fab5948f..451ad530d 100644 --- a/backend/evals/core.py +++ b/backend/evals/core.py @@ -1,9 +1,10 @@ -from config import ANTHROPIC_API_KEY, GEMINI_API_KEY, OPENAI_API_KEY +from config import ANTHROPIC_API_KEY, GEMINI_API_KEY, OPENAI_API_KEY, AWS_CREDENTIALS from llm import ( Llm, stream_claude_response, stream_gemini_response, stream_openai_response, + stream_bedrock_response ) from prompts import assemble_prompt from prompts.types import Stack @@ -46,6 +47,15 @@ async def process_chunk(_: str): callback=lambda x: process_chunk(x), model=model, ) + elif model == Llm.BEDROCK_CLAUDE_3_5_SONNET_2024_06_20: + if not AWS_CREDENTIALS: + raise Exception("AWS credentials not found") + + completion = await stream_bedrock_response( + prompt_messages, + callback=lambda x: process_chunk(x), + model=model, + ) else: if not OPENAI_API_KEY: raise Exception("OpenAI API key not found") diff --git a/backend/llm.py b/backend/llm.py index 244f09446..613c18e18 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -11,6 +11,8 @@ from image_processing.utils import process_image from google import genai from google.genai import types +from boto3 import Session +import json from utils import pprint_prompt @@ -29,6 +31,7 @@ class Llm(Enum): CLAUDE_3_5_SONNET_2024_10_22 = "claude-3-5-sonnet-20241022" GEMINI_2_0_FLASH_EXP = "gemini-2.0-flash-exp" O1_2024_12_17 = "o1-2024-12-17" + BEDROCK_CLAUDE_3_5_SONNET_2024_06_20 = "anthropic.claude-3-5-sonnet-20240620-v1:0" class Completion(TypedDict): @@ -91,24 +94,13 @@ async def stream_openai_response( completion_time = time.time() - start_time return {"duration": completion_time, "code": full_response} + - -# TODO: Have a seperate function that translates OpenAI messages to Claude messages -async def stream_claude_response( - messages: List[ChatCompletionMessageParam], - api_key: str, - callback: Callable[[str], Awaitable[None]], - model: Llm, -) -> Completion: - start_time = time.time() - client = AsyncAnthropic(api_key=api_key) - - # Base parameters - max_tokens = 8192 - temperature = 0.0 - - # Translate OpenAI messages to Claude messages - +def process_claude_messages(messages: List[ChatCompletionMessageParam]) -> tuple[str, List[dict]]: + """ + Process messages for Claude by converting image URLs to base64 data + and removing the image URL parameter from the message. + """ # Deep copy messages to avoid modifying the original list cloned_messages = copy.deepcopy(messages) @@ -139,6 +131,27 @@ async def stream_claude_response( "data": base64_data, } + return system_prompt, claude_messages + + +# TODO: Have a seperate function that translates OpenAI messages to Claude messages +async def stream_claude_response( + messages: List[ChatCompletionMessageParam], + api_key: str, + callback: Callable[[str], Awaitable[None]], + model: Llm, +) -> Completion: + start_time = time.time() + client = AsyncAnthropic(api_key=api_key) + + # Base parameters + max_tokens = 8192 + temperature = 0.0 + + # Translate OpenAI messages to Claude messages + + system_prompt, claude_messages = process_claude_messages(messages) + # Stream Claude response async with client.messages.stream( model=model.value, @@ -300,3 +313,69 @@ async def stream_gemini_response( await callback(response.text) # type: ignore completion_time = time.time() - start_time return {"duration": completion_time, "code": full_response} + + +async def stream_bedrock_response( + messages: List[ChatCompletionMessageParam], + callback: Callable[[str], Awaitable[None]], + model: Llm, +) -> Completion: + print(f"Invoking {model} on AWS Bedrock") + start_time = time.time() + + # Initialize Bedrock runtime client + session = Session() + + # Expect configuration from environment variables or /.aws/credentials + bedrock_client = session.client( + service_name='bedrock-runtime', + ) + + full_response = "" + + system_prompt, claude_messages = process_claude_messages(messages) + + body = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 8192, + "messages": claude_messages, + "temperature":0.0, + "system":system_prompt, + } + + # Convert the payload to bytes + body_bytes = json.dumps(body).encode('utf-8') + + response = bedrock_client.invoke_model_with_response_stream( + body=body_bytes, + contentType='application/json', + accept='application/json', + modelId=model.value, + trace='DISABLED', + ) + + for event in response['body']: + if 'chunk' in event: + chunk = event['chunk']['bytes'].decode('utf-8') + if chunk: + chunk_obj = json.loads(chunk) + if chunk_obj.get('delta', {}).get('type') == 'text_delta': + response_text = chunk_obj['delta']['text'] + full_response += response_text + await callback(response_text) + elif 'internalServerException' in event: + raise Exception(event['internalServerException']['message']) + elif 'modelStreamErrorException' in event: + raise Exception(event['modelStreamErrorException']['message']) + elif 'validationException' in event: + raise Exception(event['validationException']['message']) + elif 'throttlingException' in event: + raise Exception(event['throttlingException']['message']) + elif 'modelTimeoutException' in event: + raise Exception(event['modelTimeoutException']['message']) + elif 'serviceUnavailableException' in event: + raise Exception(event['serviceUnavailableException']['message']) + + completion_time = time.time() - start_time + + return {"duration": completion_time, "code": full_response} diff --git a/backend/poetry.lock b/backend/poetry.lock index 7f7e39fbe..aa5dd8b62 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -232,6 +232,44 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "boto3" +version = "1.36.11" +description = "The AWS SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "boto3-1.36.11-py3-none-any.whl", hash = "sha256:641dd772eac111d9443258f0f5491c57c2af47bddae94a8d32de19edb5bf7b1c"}, + {file = "boto3-1.36.11.tar.gz", hash = "sha256:b40fbf2c0f22e55b67df95475a68bb72be5169097180a875726b6b884339ac8b"}, +] + +[package.dependencies] +botocore = ">=1.36.11,<1.37.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.11.0,<0.12.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.36.11" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">=3.8" +files = [ + {file = "botocore-1.36.11-py3-none-any.whl", hash = "sha256:82c5660027f696608d0e55feb08c146c11c7ebeba7615961c7765dcf6009a00d"}, + {file = "botocore-1.36.11.tar.gz", hash = "sha256:c919be883f95b9e0c3021429a365d40cd7944b8345a07af30dc8d891ceefe07a"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} + +[package.extras] +crt = ["awscrt (==0.23.8)"] + [[package]] name = "cachetools" version = "5.5.0" @@ -934,6 +972,17 @@ files = [ {file = "jiter-0.8.2.tar.gz", hash = "sha256:cd73d3e740666d0e639f678adb176fad25c1bcbdae88d8d7b857e1783bb4212d"}, ] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "moviepy" version = "1.0.3" @@ -1623,6 +1672,20 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "python-dotenv" version = "1.0.1" @@ -1734,6 +1797,23 @@ files = [ [package.dependencies] pyasn1 = ">=0.1.3" +[[package]] +name = "s3transfer" +version = "0.11.2" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">=3.8" +files = [ + {file = "s3transfer-0.11.2-py3-none-any.whl", hash = "sha256:be6ecb39fadd986ef1701097771f87e4d2f821f27f6071c872143884d2950fbc"}, + {file = "s3transfer-0.11.2.tar.gz", hash = "sha256:3b39185cb72f5acc77db1a58b6e25b977f28d20496b6e58d6813d75f464d632f"}, +] + +[package.dependencies] +botocore = ">=1.36.0,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.36.0,<2.0a.0)"] + [[package]] name = "setuptools" version = "75.6.0" @@ -1754,6 +1834,17 @@ enabler = ["pytest-enabler (>=2.2)"] test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] +[[package]] +name = "six" +version = "1.17.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, + {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -2142,4 +2233,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "93868f80563b845bf5aacd5a752a1dcdb957c59357b2b31b7a95685bcfefb5c2" +content-hash = "5d0538309fe40ebb2e3e661019faa83b83de18dccd490d32cde30016bb4b2966" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 0c636452a..e02c3b3d6 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -22,6 +22,7 @@ types-pillow = "^10.2.0.20240520" aiohttp = "^3.9.5" pydantic = "^2.10" google-genai = "^0.3.0" +boto3 = "^1.36.11" [tool.poetry.group.dev.dependencies] pytest = "^7.4.3" diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 4afb413a5..f4677455c 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -13,6 +13,7 @@ OPENAI_BASE_URL, REPLICATE_API_KEY, SHOULD_MOCK_AI_RESPONSE, + AWS_CREDENTIALS ) from custom_types import InputMode from llm import ( @@ -22,6 +23,7 @@ stream_claude_response_native, stream_gemini_response, stream_openai_response, + stream_bedrock_response, ) from fs_logging.core import write_logs from mock_llm import mock_completion @@ -80,6 +82,7 @@ class ExtractedParams: should_generate_images: bool openai_api_key: str | None anthropic_api_key: str | None + aws_credentials: bool openai_base_url: str | None generation_type: Literal["create", "update"] @@ -110,6 +113,11 @@ async def extract_params( params, "anthropicApiKey", ANTHROPIC_API_KEY ) + # If neither is provided, we throw an error later only if Claude is used. + aws_credentials = get_from_settings_dialog_or_env( + params, "awsCredentials", AWS_CREDENTIALS + ) + # Base URL for OpenAI API openai_base_url: str | None = None # Disable user-specified OpenAI Base URL in prod @@ -136,13 +144,14 @@ async def extract_params( should_generate_images=should_generate_images, openai_api_key=openai_api_key, anthropic_api_key=anthropic_api_key, + aws_credentials=aws_credentials, openai_base_url=openai_base_url, generation_type=generation_type, ) def get_from_settings_dialog_or_env( - params: dict[str, str], key: str, env_var: str | None + params: dict[str, str], key: str, env_var: str | bool | None ) -> str | None: value = params.get(key) if value: @@ -196,6 +205,7 @@ async def send_message( openai_api_key = extracted_params.openai_api_key openai_base_url = extracted_params.openai_base_url anthropic_api_key = extracted_params.anthropic_api_key + aws_credentials = extracted_params.aws_credentials should_generate_images = extracted_params.should_generate_images generation_type = extracted_params.generation_type @@ -277,11 +287,16 @@ async def process_chunk(content: str, variantIndex: int): claude_model, Llm.CLAUDE_3_5_SONNET_2024_06_20, ] + elif aws_credentials: + variant_models = [ + Llm.BEDROCK_CLAUDE_3_5_SONNET_2024_06_20, + Llm.BEDROCK_CLAUDE_3_5_SONNET_2024_06_20, + ] else: await throw_error( - "No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog. If you add it to .env, make sure to restart the backend server." + "No OpenAI and Anthropic API key found nor AWS credentials. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog or configure your AWS credentials following their documentation. If you add it to .env, make sure to restart the backend server." ) - raise Exception("No OpenAI or Anthropic key") + raise Exception("No OpenAI, Anthropic or AWS credentials") tasks: List[Coroutine[Any, Any, Completion]] = [] for index, model in enumerate(variant_models): @@ -331,6 +346,18 @@ async def process_chunk(content: str, variantIndex: int): model=claude_model, ) ) + elif model == Llm.BEDROCK_CLAUDE_3_5_SONNET_2024_06_20: + if not aws_credentials: + await throw_error("AWS credentials are missing.") + raise Exception("AWS credentials are missing.") + + tasks.append( + stream_bedrock_response( + prompt_messages, + callback=lambda x, i=index: process_chunk(x, i), + model=Llm.BEDROCK_CLAUDE_3_5_SONNET_2024_06_20, + ) + ) # Run the models in parallel and capture exceptions if any completions = await asyncio.gather(*tasks, return_exceptions=True)