Skip to content

Commit 43569a4

Browse files
authored
feat(vcr): add support for proxying aws bedrock-runtime requests (#241)
* merge * more formatting * modified solution * comment * fmt * just use an open source library instead * readme * release note * update comment * try fixing mypy issues
1 parent 3e921af commit 43569a4

File tree

5 files changed

+93
-23
lines changed

5 files changed

+93
-23
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ The cassettes are matched based on the path, method, and body of the request. To
146146

147147
Optionally specifying whatever mounted path is used for the cassettes directory. The test agent comes with a default set of cassettes for OpenAI, Azure OpenAI, and DeepSeek.
148148

149+
#### AWS Services
150+
AWS service proxying, specifically recording cassettes for the first time, requires a `AWS_SECRET_ACCESS_KEY` environment variable to be set for the container running the test agent. This is used to recalculate the AWS signature for the request, as the one generated client-side likely used `{test-agent-host}:{test-agent-port}/vcr/{aws-service}` as the host, and the signature will mismatch that on the actual AWS service.
151+
152+
Additionally, the `AWS_REGION` environment variable can be set, defaulting to `us-east-1`.
153+
154+
To add a new AWS service to proxy, add an entry in the `PROVIDER_BASE_URLS` for its provider url, and an entry in the `AWS_SERVICES` dictionary for the service name, since they are not always a one-to-one mapping with the implied provider url (e.g, `https://bedrock-runtime.{AWS_REGION}.amazonaws.com` is the provider url, but the service name is `bedrock`, as `bedrock` also has multiple sub services, like `converse`).
155+
149156
#### Usage in clients
150157

151158
To use this feature in your client, you can use the `/vcr/{provider}` endpoint to proxy requests to the provider API.

ddapm_test_agent/vcr_proxy.py

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,64 @@
11
import hashlib
22
import json
3+
import logging
34
import os
45
import re
6+
from typing import Any
7+
from typing import Dict
58
from typing import Optional
69
from urllib.parse import urljoin
710

811
from aiohttp.web import Request
912
from aiohttp.web import Response
1013
import requests
14+
from requests_aws4auth import AWS4Auth
1115
import vcr
1216

1317

18+
logger = logging.getLogger(__name__)
19+
20+
21+
# Used for AWS signature recalculation for aws services initial proxying
22+
AWS_REGION = os.environ.get("AWS_REGION", "us-east-1")
23+
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")
24+
25+
1426
def url_path_join(base_url: str, path: str) -> str:
1527
"""Join a base URL with a path, handling slashes automatically."""
1628
return urljoin(base_url.rstrip("/") + "/", path.lstrip("/"))
1729

1830

31+
AWS_SERVICES = {
32+
"bedrock-runtime": "bedrock",
33+
}
34+
35+
1936
PROVIDER_BASE_URLS = {
2037
"openai": "https://api.openai.com/v1",
2138
"azure-openai": "https://dd.openai.azure.com/",
2239
"deepseek": "https://api.deepseek.com/",
2340
"anthropic": "https://api.anthropic.com/",
2441
"datadog": "https://api.datadoghq.com/",
2542
"genai": "https://generativelanguage.googleapis.com/",
43+
"bedrock-runtime": f"https://bedrock-runtime.{AWS_REGION}.amazonaws.com",
2644
}
2745

46+
CASSETTE_FILTER_HEADERS = [
47+
"authorization",
48+
"OpenAI-Organization",
49+
"api-key",
50+
"x-api-key",
51+
"dd-api-key",
52+
"dd-application-key",
53+
"x-goog-api-key",
54+
"x-amz-security-token",
55+
"x-amz-content-sha256",
56+
"x-amz-date",
57+
"x-amz-user-agent",
58+
"amz-sdk-invocation-id",
59+
"amz-sdk-request",
60+
]
61+
2862
NORMALIZERS = [
2963
(
3064
r"--form-data-boundary-[^\r\n]+",
@@ -65,22 +99,29 @@ def normalize_multipart_body(body: bytes) -> str:
6599
return f"[binary_data_{hex_digest}]"
66100

67101

102+
def parse_authorization_header(auth_header: str) -> Dict[str, str]:
103+
"""Parse AWS Authorization header to extract components"""
104+
if not auth_header.startswith("AWS4-HMAC-SHA256 "):
105+
return {}
106+
107+
auth_parts = auth_header[len("AWS4-HMAC-SHA256 ") :].split(",")
108+
parsed = {}
109+
110+
for part in auth_parts:
111+
key, value = part.split("=", 1)
112+
parsed[key.strip()] = value.strip()
113+
114+
return parsed
115+
116+
68117
def get_vcr(subdirectory: str, vcr_cassettes_directory: str) -> vcr.VCR:
69118
cassette_dir = os.path.join(vcr_cassettes_directory, subdirectory)
70119

71120
return vcr.VCR(
72121
cassette_library_dir=cassette_dir,
73122
record_mode="once",
74123
match_on=["path", "method"],
75-
filter_headers=[
76-
"authorization",
77-
"OpenAI-Organization",
78-
"api-key",
79-
"x-api-key",
80-
"dd-api-key",
81-
"dd-application-key",
82-
"x-goog-api-key",
83-
],
124+
filter_headers=CASSETTE_FILTER_HEADERS,
84125
)
85126

86127

@@ -125,31 +166,47 @@ async def proxy_request(request: Request, vcr_cassettes_directory: str) -> Respo
125166
body_bytes = await request.read()
126167

127168
vcr_cassette_prefix = request.pop("vcr_cassette_prefix", None)
128-
129169
cassette_name = generate_cassette_name(path, request.method, body_bytes, vcr_cassette_prefix)
170+
171+
request_kwargs: Dict[str, Any] = {
172+
"method": request.method,
173+
"url": target_url,
174+
"headers": headers,
175+
"data": body_bytes,
176+
"cookies": dict(request.cookies),
177+
"allow_redirects": False,
178+
"stream": True,
179+
}
180+
181+
if provider in AWS_SERVICES and not os.path.exists(os.path.join(vcr_cassettes_directory, provider, cassette_name)):
182+
if not AWS_SECRET_ACCESS_KEY:
183+
return Response(
184+
body="AWS_SECRET_ACCESS_KEY environment variable not set for aws signature recalculation",
185+
status=400,
186+
)
187+
188+
auth_header = request.headers.get("Authorization", "")
189+
auth_parts = parse_authorization_header(auth_header)
190+
aws_access_key = auth_parts.get("Credential", "").split("/")[0]
191+
192+
auth = AWS4Auth(aws_access_key, AWS_SECRET_ACCESS_KEY, AWS_REGION, AWS_SERVICES[provider])
193+
request_kwargs["auth"] = auth
194+
130195
with get_vcr(provider, vcr_cassettes_directory).use_cassette(f"{cassette_name}.yaml"):
131-
oai_response = requests.request(
132-
method=request.method,
133-
url=target_url,
134-
headers=headers,
135-
data=body_bytes,
136-
cookies=dict(request.cookies),
137-
allow_redirects=False,
138-
stream=True,
139-
)
196+
provider_response = requests.request(**request_kwargs)
140197

141198
# Extract content type without charset
142-
content_type = oai_response.headers.get("content-type", "")
199+
content_type = provider_response.headers.get("content-type", "")
143200
if ";" in content_type:
144201
content_type = content_type.split(";")[0].strip()
145202

146203
response = Response(
147-
body=oai_response.content,
148-
status=oai_response.status_code,
204+
body=provider_response.content,
205+
status=provider_response.status_code,
149206
content_type=content_type,
150207
)
151208

152-
for key, value in oai_response.headers.items():
209+
for key, value in provider_response.headers.items():
153210
if key.lower() not in (
154211
"content-length",
155212
"transfer-encoding",

flake.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
requests
5959
yarl
6060
vcrpy
61+
requests-aws4auth
6162
protobuf
6263
opentelemetry-proto
6364
grpcio
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
vcr: adds support for proxying aws bedrock runtime. to record cassettes for the first time, the `AWS_SECRET_ACCESS_KEY` environment variable must be set for the container running the test agent, for request signature recalculation. Additionally, `AWS_REGION` can be set, defaulting to `us-east-1`.

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"typing_extensions",
3838
"yarl",
3939
"vcrpy",
40+
"requests-aws4auth",
4041
# ddtrace libraries officially support opentelemetry-proto 1.33.1
4142
# which implements the v1.7.0 spec
4243
"opentelemetry-proto>1.33.0,<1.37.0",

0 commit comments

Comments
 (0)