diff --git a/README.md b/README.md index 89c76a15c..166e243e0 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ Zen instruments the following AI SDKs to track which models are used and how man * ✅ [`openai`](https://pypi.org/project/openai) ^1.0 * ✅ [`anthropic`](https://pypi.org/project/anthropic/) * ✅ [`mistralai`](https://pypi.org/project/mistralai) ^1.0.0 +* ✅ [`boto3`](https://pypi.org/project/boto3) (AWS Bedrock) Zen is compatible with Python 3.8-3.13 and can run on Windows, Linux, and Mac OS X. diff --git a/aikido_zen/__init__.py b/aikido_zen/__init__.py index 7835a37a9..37b5ef16e 100644 --- a/aikido_zen/__init__.py +++ b/aikido_zen/__init__.py @@ -80,5 +80,6 @@ def protect(mode="daemon", token=""): import aikido_zen.sinks.openai import aikido_zen.sinks.anthropic import aikido_zen.sinks.mistralai + import aikido_zen.sinks.botocore logger.info("Zen by Aikido v%s starting.", PKG_VERSION) diff --git a/aikido_zen/sinks/botocore.py b/aikido_zen/sinks/botocore.py new file mode 100644 index 000000000..165195c75 --- /dev/null +++ b/aikido_zen/sinks/botocore.py @@ -0,0 +1,49 @@ +from aikido_zen.helpers.get_argument import get_argument +from aikido_zen.helpers.on_ai_call import on_ai_call +from aikido_zen.helpers.register_call import register_call +from aikido_zen.sinks import after, on_import, patch_function, before + + +def get_tokens_from_converse(api_response): + usage = api_response.get("usage", {}) + input_tokens = usage.get("inputTokens", 0) + output_tokens = usage.get("outputTokens", 0) + return int(input_tokens), int(output_tokens) + + +def get_tokens_from_invoke_model(api_response): + headers = api_response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + input_tokens_str = headers.get("x-amzn-bedrock-input-token-count", "0") + output_tokens_str = headers.get("x-amzn-bedrock-output-token-count", "0") + return int(input_tokens_str), int(output_tokens_str) + + +@after +def make_api_call_after(func, instance, args, kwargs, return_value): + # Extract arguments to validate later + operation_name = get_argument(args, kwargs, 0, "operation_name") + api_params = get_argument(args, kwargs, 1, "api_params") + if not operation_name or not api_params or not return_value: + return + + # Validate arguments, we only want to check operations related to AI. + if operation_name not in ["Converse", "InvokeModel"]: + return + register_call(f"botocore.client.{operation_name}", "ai_op") + + model_id = str(api_params.get("modelId", "")) + if not model_id: + return None + + input_tokens, output_tokens = (0, 0) + if operation_name == "Converse": + input_tokens, output_tokens = get_tokens_from_converse(return_value) + elif operation_name == "InvokeModel": + input_tokens, output_tokens = get_tokens_from_invoke_model(return_value) + + on_ai_call("bedrock", model_id, input_tokens, output_tokens) + + +@on_import("botocore.client") +def patch(m): + patch_function(m, "BaseClient._make_api_call", make_api_call_after) diff --git a/aikido_zen/sinks/tests/aws_bedrock_test.py b/aikido_zen/sinks/tests/aws_bedrock_test.py new file mode 100644 index 000000000..b4c8a8356 --- /dev/null +++ b/aikido_zen/sinks/tests/aws_bedrock_test.py @@ -0,0 +1,81 @@ +import json +import os +import aikido_zen.sinks.botocore +import pytest + +from aikido_zen.thread.thread_cache import get_cache + +skip_no_api_key = pytest.mark.skipif( + "AWS_BEDROCK_TEST" not in os.environ, + reason="AWS_BEDROCK_TEST environment variable not set, run `export AWS_BEDROCK_TEST=1`", +) + + +@pytest.fixture(autouse=True) +def setup(): + get_cache().reset() + yield + get_cache().reset() + + +@pytest.fixture +def client(): + import boto3 + + client = boto3.client(service_name="bedrock-runtime", region_name="us-east-1") + return client + + +def get_ai_stats(): + return get_cache().ai_stats.get_stats() + + +@skip_no_api_key +def test_boto3_converse(client): + metadata = { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "prompt": "Are tomatoes a fruit?", + "max_tokens": 20, + } + + response = client.converse( + modelId="anthropic.claude-3-sonnet-20240229-v1:0", + messages=[{"role": "user", "content": [{"text": metadata["prompt"]}]}], + inferenceConfig={ + "temperature": 0.7, + "topP": 0.9, + "maxTokens": metadata["max_tokens"], + }, + ) + output = response["output"]["message"]["content"][0]["text"] + + assert get_ai_stats()[0]["model"] == "anthropic.claude-3-sonnet-20240229-v1:0" + assert get_ai_stats()[0]["calls"] == 1 + assert get_ai_stats()[0]["provider"] == "bedrock" + assert get_ai_stats()[0]["tokens"]["input"] == 13 + assert get_ai_stats()[0]["tokens"]["output"] == 20 + assert get_ai_stats()[0]["tokens"]["total"] == 33 + + +@skip_no_api_key +def test_boto3_invoke_model_claude_3_sonnet(client): + model_id = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" # Example model ID for Amazon Bedrock + input_payload = { + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Are tomatoes a vegetable?"}], + } + ], + "max_tokens": 20, + "anthropic_version": "bedrock-2023-05-31", + } + response = client.invoke_model(modelId=model_id, body=json.dumps(input_payload)) + print(response) + stats = get_ai_stats()[0] + assert stats["model"] == "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + assert stats["calls"] == 1 + assert stats["provider"] == "bedrock" + assert stats["tokens"]["input"] == 14 + assert stats["tokens"]["output"] == 20 + assert stats["tokens"]["total"] == 34 diff --git a/poetry.lock b/poetry.lock index 1948a3b50..ac21ed6bb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -103,7 +103,7 @@ description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "python_version < \"3.11.0\"" +markers = "python_version < \"3.11\"" files = [ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, @@ -265,6 +265,51 @@ files = [ {file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"}, ] +[[package]] +name = "boto3" +version = "1.40.17" +description = "The AWS SDK for Python" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.9\"" +files = [ + {file = "boto3-1.40.17-py3-none-any.whl", hash = "sha256:2cacecd689cb51d81fbf54f84b64d0e6e922fbc18ee513c568b9f61caf4221e0"}, + {file = "boto3-1.40.17.tar.gz", hash = "sha256:e115dc87d5975d32dfa0ebaf19c39e360665317a350004fa94b03200fe853f2e"}, +] + +[package.dependencies] +botocore = ">=1.40.17,<1.41.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.13.0,<0.14.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.40.17" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.9\"" +files = [ + {file = "botocore-1.40.17-py3-none-any.whl", hash = "sha256:603951935c1a741ae70236bf15725c5293074f28503e7029ad0e24ece476a342"}, + {file = "botocore-1.40.17.tar.gz", hash = "sha256:769cd04a6a612f2d48b5f456c676fd81733fab682870952f7e2887260ea6a2bc"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.27.6)"] + [[package]] name = "certifi" version = "2024.8.30" @@ -1113,6 +1158,19 @@ files = [ {file = "jiter-0.9.1.tar.gz", hash = "sha256:7852990068b6e06102ecdc44c1619855a2af63347bfb5e7e009928dcacf04fdd"}, ] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version >= \"3.9\"" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "lxml" version = "5.4.0" @@ -2139,6 +2197,25 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "s3transfer" +version = "0.13.1" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.9\"" +files = [ + {file = "s3transfer-0.13.1-py3-none-any.whl", hash = "sha256:a981aa7429be23fe6dfc13e80e4020057cbab622b08c0315288758d67cabc724"}, + {file = "s3transfer-0.13.1.tar.gz", hash = "sha256:c3fdba22ba1bd367922f27ec8032d6a1cf5f10c934fb5d68cf60fd5a23d936cf"}, +] + +[package.dependencies] +botocore = ">=1.37.4,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.37.4,<2.0a.0)"] + [[package]] name = "six" version = "1.17.0" @@ -2187,7 +2264,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "python_full_version <= \"3.11.0a6\"" +markers = "python_version < \"3.11\"" files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, @@ -2287,13 +2364,32 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""} [package.extras] devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"] +[[package]] +name = "urllib3" +version = "1.26.20" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["main", "dev"] +files = [ + {file = "urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e"}, + {file = "urllib3-1.26.20.tar.gz", hash = "sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32"}, +] +markers = {main = "python_version < \"3.10\"", dev = "python_version >= \"3.9\" and python_version < \"3.10\""} + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] +markers = "python_version >= \"3.10\"" files = [ {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, @@ -2436,4 +2532,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.8" -content-hash = "900189063d5d4428f548bff101057097fa80402112d772d36cfe93a4109f1997" +content-hash = "b615a51e940f3cf790a337226fb39d4b3bd509bc29127d18c54ac5abd6ed042a" diff --git a/pyproject.toml b/pyproject.toml index c9210d60a..c78de41fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,8 @@ clickhouse-driver = "^0.2.9" openai = "^1.85.0" anthropic = "^0.54.0" mistralai = { version = "^1.8.2", python = ">=3.9,<4.0" } +boto3 = { version = "^1.40.17", python = ">=3.9,<4.0" } + django = "4" [build-system]