Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aikido_zen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,6 @@ def protect(mode="daemon", token=""):
import aikido_zen.sinks.openai
import aikido_zen.sinks.anthropic
import aikido_zen.sinks.mistralai
import aikido_zen.sinks.botocore

logger.info("Zen by Aikido v%s starting.", PKG_VERSION)
49 changes: 49 additions & 0 deletions aikido_zen/sinks/botocore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from aikido_zen.helpers.get_argument import get_argument
from aikido_zen.helpers.on_ai_call import on_ai_call
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import after, on_import, patch_function, before


def get_tokens_from_converse(api_response):
usage = api_response.get("usage", {})
input_tokens = usage.get("inputTokens", 0)
output_tokens = usage.get("outputTokens", 0)
return int(input_tokens), int(output_tokens)


def get_tokens_from_invoke_model(api_response):
headers = api_response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
input_tokens_str = headers.get("x-amzn-bedrock-input-token-count", "0")
output_tokens_str = headers.get("x-amzn-bedrock-output-token-count", "0")
return int(input_tokens_str), int(output_tokens_str)


@after
def make_api_call_after(func, instance, args, kwargs, return_value):
# Extract arguments to validate later
operation_name = get_argument(args, kwargs, 0, "operation_name")
api_params = get_argument(args, kwargs, 1, "api_params")
if not operation_name or not api_params or not return_value:
return

# Validate arguments, we only want to check operations related to AI.
if operation_name not in ["Converse", "InvokeModel"]:
return
register_call(f"botocore.client.{operation_name}", "ai_op")

model_id = str(api_params.get("modelId", ""))
if not model_id:
return None

input_tokens, output_tokens = (0, 0)
if operation_name == "Converse":
input_tokens, output_tokens = get_tokens_from_converse(return_value)
elif operation_name == "InvokeModel":
input_tokens, output_tokens = get_tokens_from_invoke_model(return_value)

on_ai_call("bedrock", model_id, input_tokens, output_tokens)


@on_import("botocore.client")
def patch(m):
patch_function(m, "BaseClient._make_api_call", make_api_call_after)
80 changes: 80 additions & 0 deletions aikido_zen/sinks/tests/aws_bedrock_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import json
import os
import aikido_zen.sinks.botocore
import pytest
import boto3

from aikido_zen.thread.thread_cache import get_cache

skip_no_api_key = pytest.mark.skipif(
"AWS_BEDROCK_TEST" not in os.environ,
reason="AWS_BEDROCK_TEST environment variable not set, run `export AWS_BEDROCK_TEST=1`",
)


@pytest.fixture(autouse=True)
def setup():
get_cache().reset()
yield
get_cache().reset()


@pytest.fixture
def client():
client = boto3.client(service_name="bedrock-runtime", region_name="us-east-1")
return client


def get_ai_stats():
return get_cache().ai_stats.get_stats()


@skip_no_api_key
def test_boto3_converse(client):
metadata = {
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"prompt": "Are tomatoes a fruit?",
"max_tokens": 20,
}

response = client.converse(
modelId="anthropic.claude-3-sonnet-20240229-v1:0",
messages=[{"role": "user", "content": [{"text": metadata["prompt"]}]}],
inferenceConfig={
"temperature": 0.7,
"topP": 0.9,
"maxTokens": metadata["max_tokens"],
},
)
output = response["output"]["message"]["content"][0]["text"]

assert get_ai_stats()[0]["model"] == "anthropic.claude-3-sonnet-20240229-v1:0"
assert get_ai_stats()[0]["calls"] == 1
assert get_ai_stats()[0]["provider"] == "bedrock"
assert get_ai_stats()[0]["tokens"]["input"] == 13
assert get_ai_stats()[0]["tokens"]["output"] == 20
assert get_ai_stats()[0]["tokens"]["total"] == 33


@skip_no_api_key
def test_boto3_invoke_model_claude_3_sonnet(client):
model_id = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" # Example model ID for Amazon Bedrock
input_payload = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Are tomatoes a vegetable?"}],
}
],
"max_tokens": 20,
"anthropic_version": "bedrock-2023-05-31",
}
response = client.invoke_model(modelId=model_id, body=json.dumps(input_payload))
print(response)
stats = get_ai_stats()[0]
assert stats["model"] == "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
assert stats["calls"] == 1
assert stats["provider"] == "bedrock"
assert stats["tokens"]["input"] == 14
assert stats["tokens"]["output"] == 20
assert stats["tokens"]["total"] == 34
104 changes: 100 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ clickhouse-driver = "^0.2.9"
openai = "^1.85.0"
anthropic = "^0.54.0"
mistralai = { version = "^1.8.2", python = ">=3.9,<4.0" }
boto3 = { version = "^1.40.17", python = ">=3.9,<4.0" }

django = "4"

[build-system]
Expand Down
Loading