Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
# Setting to True will stream a mock response instead of calling the OpenAI API
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
import os
import boto3

def has_valid_aws_credentials():
sts = boto3.client('sts')
try:
sts.get_caller_identity()
return True
except boto3.exceptions.ClientError:
return False

NUM_VARIANTS = 2

Expand All @@ -10,6 +19,7 @@
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", None)
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None)
AWS_CREDENTIALS = has_valid_aws_credentials()

# Image generation (optional)
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
Expand Down
12 changes: 11 additions & 1 deletion backend/evals/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from config import ANTHROPIC_API_KEY, GEMINI_API_KEY, OPENAI_API_KEY
from config import ANTHROPIC_API_KEY, GEMINI_API_KEY, OPENAI_API_KEY, AWS_CREDENTIALS
from llm import (
Llm,
stream_claude_response,
stream_gemini_response,
stream_openai_response,
stream_bedrock_response
)
from prompts import assemble_prompt
from prompts.types import Stack
Expand Down Expand Up @@ -46,6 +47,15 @@ async def process_chunk(_: str):
callback=lambda x: process_chunk(x),
model=model,
)
elif model == Llm.BEDROCK_CLAUDE_3_5_SONNET_2024_06_20:
if not AWS_CREDENTIALS:
raise Exception("AWS credentials not found")

completion = await stream_bedrock_response(
prompt_messages,
callback=lambda x: process_chunk(x),
model=model,
)
else:
if not OPENAI_API_KEY:
raise Exception("OpenAI API key not found")
Expand Down
113 changes: 96 additions & 17 deletions backend/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from image_processing.utils import process_image
from google import genai
from google.genai import types
from boto3 import Session
import json

from utils import pprint_prompt

Expand All @@ -29,6 +31,7 @@ class Llm(Enum):
CLAUDE_3_5_SONNET_2024_10_22 = "claude-3-5-sonnet-20241022"
GEMINI_2_0_FLASH_EXP = "gemini-2.0-flash-exp"
O1_2024_12_17 = "o1-2024-12-17"
BEDROCK_CLAUDE_3_5_SONNET_2024_06_20 = "anthropic.claude-3-5-sonnet-20240620-v1:0"
Copy link
Author

@marcelovicentegc marcelovicentegc Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using the latest Claude 3.5 sonnet here (v2, 2024-10-22) because I still don't have access to it on AWS to test it - but should work either way.

@abi we could work on creating a dropdown menu or something on the front-end to allow users to select the model they want to use based on what's available on their environment. Let me know if something like this is already on the radar. It occurs to me that this would be specially useful for Bedrock since equivalent models such as DeepSeek R1 will be available there any time soon.



class Completion(TypedDict):
Expand Down Expand Up @@ -91,24 +94,13 @@ async def stream_openai_response(

completion_time = time.time() - start_time
return {"duration": completion_time, "code": full_response}



# TODO: Have a seperate function that translates OpenAI messages to Claude messages
async def stream_claude_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
callback: Callable[[str], Awaitable[None]],
model: Llm,
) -> Completion:
start_time = time.time()
client = AsyncAnthropic(api_key=api_key)

# Base parameters
max_tokens = 8192
temperature = 0.0

# Translate OpenAI messages to Claude messages

def process_claude_messages(messages: List[ChatCompletionMessageParam]) -> tuple[str, List[dict]]:
"""
Process messages for Claude by converting image URLs to base64 data
and removing the image URL parameter from the message.
"""
# Deep copy messages to avoid modifying the original list
cloned_messages = copy.deepcopy(messages)

Expand Down Expand Up @@ -139,6 +131,27 @@ async def stream_claude_response(
"data": base64_data,
}

return system_prompt, claude_messages


# TODO: Have a seperate function that translates OpenAI messages to Claude messages
async def stream_claude_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
callback: Callable[[str], Awaitable[None]],
model: Llm,
) -> Completion:
start_time = time.time()
client = AsyncAnthropic(api_key=api_key)

# Base parameters
max_tokens = 8192
temperature = 0.0

# Translate OpenAI messages to Claude messages

system_prompt, claude_messages = process_claude_messages(messages)

# Stream Claude response
async with client.messages.stream(
model=model.value,
Expand Down Expand Up @@ -300,3 +313,69 @@ async def stream_gemini_response(
await callback(response.text) # type: ignore
completion_time = time.time() - start_time
return {"duration": completion_time, "code": full_response}


async def stream_bedrock_response(
messages: List[ChatCompletionMessageParam],
callback: Callable[[str], Awaitable[None]],
model: Llm,
) -> Completion:
print(f"Invoking {model} on AWS Bedrock")
start_time = time.time()

# Initialize Bedrock runtime client
session = Session()

# Expect configuration from environment variables or /.aws/credentials
bedrock_client = session.client(
service_name='bedrock-runtime',
)

full_response = ""

system_prompt, claude_messages = process_claude_messages(messages)

body = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 8192,
"messages": claude_messages,
"temperature":0.0,
"system":system_prompt,
}

# Convert the payload to bytes
body_bytes = json.dumps(body).encode('utf-8')

response = bedrock_client.invoke_model_with_response_stream(
body=body_bytes,
contentType='application/json',
accept='application/json',
modelId=model.value,
trace='DISABLED',
)

for event in response['body']:
if 'chunk' in event:
chunk = event['chunk']['bytes'].decode('utf-8')
if chunk:
chunk_obj = json.loads(chunk)
if chunk_obj.get('delta', {}).get('type') == 'text_delta':
response_text = chunk_obj['delta']['text']
full_response += response_text
await callback(response_text)
elif 'internalServerException' in event:
raise Exception(event['internalServerException']['message'])
elif 'modelStreamErrorException' in event:
raise Exception(event['modelStreamErrorException']['message'])
elif 'validationException' in event:
raise Exception(event['validationException']['message'])
elif 'throttlingException' in event:
raise Exception(event['throttlingException']['message'])
elif 'modelTimeoutException' in event:
raise Exception(event['modelTimeoutException']['message'])
elif 'serviceUnavailableException' in event:
raise Exception(event['serviceUnavailableException']['message'])

completion_time = time.time() - start_time

return {"duration": completion_time, "code": full_response}
95 changes: 93 additions & 2 deletions backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ types-pillow = "^10.2.0.20240520"
aiohttp = "^3.9.5"
pydantic = "^2.10"
google-genai = "^0.3.0"
boto3 = "^1.36.11"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
Expand Down
Loading