|
1 | 1 | import hashlib |
2 | 2 | import json |
| 3 | +import logging |
3 | 4 | import os |
4 | 5 | import re |
| 6 | +from typing import Any |
| 7 | +from typing import Dict |
5 | 8 | from typing import Optional |
6 | 9 | from urllib.parse import urljoin |
7 | 10 |
|
8 | 11 | from aiohttp.web import Request |
9 | 12 | from aiohttp.web import Response |
10 | 13 | import requests |
| 14 | +from requests_aws4auth import AWS4Auth |
11 | 15 | import vcr |
12 | 16 |
|
13 | 17 |
|
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | + |
| 21 | +# Used for AWS signature recalculation for aws services initial proxying |
| 22 | +AWS_REGION = os.environ.get("AWS_REGION", "us-east-1") |
| 23 | +AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY") |
| 24 | + |
| 25 | + |
14 | 26 | def url_path_join(base_url: str, path: str) -> str: |
15 | 27 | """Join a base URL with a path, handling slashes automatically.""" |
16 | 28 | return urljoin(base_url.rstrip("/") + "/", path.lstrip("/")) |
17 | 29 |
|
18 | 30 |
|
| 31 | +AWS_SERVICES = { |
| 32 | + "bedrock-runtime": "bedrock", |
| 33 | +} |
| 34 | + |
| 35 | + |
19 | 36 | PROVIDER_BASE_URLS = { |
20 | 37 | "openai": "https://api.openai.com/v1", |
21 | 38 | "azure-openai": "https://dd.openai.azure.com/", |
22 | 39 | "deepseek": "https://api.deepseek.com/", |
23 | 40 | "anthropic": "https://api.anthropic.com/", |
24 | 41 | "datadog": "https://api.datadoghq.com/", |
25 | 42 | "genai": "https://generativelanguage.googleapis.com/", |
| 43 | + "bedrock-runtime": f"https://bedrock-runtime.{AWS_REGION}.amazonaws.com", |
26 | 44 | } |
27 | 45 |
|
| 46 | +CASSETTE_FILTER_HEADERS = [ |
| 47 | + "authorization", |
| 48 | + "OpenAI-Organization", |
| 49 | + "api-key", |
| 50 | + "x-api-key", |
| 51 | + "dd-api-key", |
| 52 | + "dd-application-key", |
| 53 | + "x-goog-api-key", |
| 54 | + "x-amz-security-token", |
| 55 | + "x-amz-content-sha256", |
| 56 | + "x-amz-date", |
| 57 | + "x-amz-user-agent", |
| 58 | + "amz-sdk-invocation-id", |
| 59 | + "amz-sdk-request", |
| 60 | +] |
| 61 | + |
28 | 62 | NORMALIZERS = [ |
29 | 63 | ( |
30 | 64 | r"--form-data-boundary-[^\r\n]+", |
@@ -65,22 +99,29 @@ def normalize_multipart_body(body: bytes) -> str: |
65 | 99 | return f"[binary_data_{hex_digest}]" |
66 | 100 |
|
67 | 101 |
|
| 102 | +def parse_authorization_header(auth_header: str) -> Dict[str, str]: |
| 103 | + """Parse AWS Authorization header to extract components""" |
| 104 | + if not auth_header.startswith("AWS4-HMAC-SHA256 "): |
| 105 | + return {} |
| 106 | + |
| 107 | + auth_parts = auth_header[len("AWS4-HMAC-SHA256 ") :].split(",") |
| 108 | + parsed = {} |
| 109 | + |
| 110 | + for part in auth_parts: |
| 111 | + key, value = part.split("=", 1) |
| 112 | + parsed[key.strip()] = value.strip() |
| 113 | + |
| 114 | + return parsed |
| 115 | + |
| 116 | + |
68 | 117 | def get_vcr(subdirectory: str, vcr_cassettes_directory: str) -> vcr.VCR: |
69 | 118 | cassette_dir = os.path.join(vcr_cassettes_directory, subdirectory) |
70 | 119 |
|
71 | 120 | return vcr.VCR( |
72 | 121 | cassette_library_dir=cassette_dir, |
73 | 122 | record_mode="once", |
74 | 123 | match_on=["path", "method"], |
75 | | - filter_headers=[ |
76 | | - "authorization", |
77 | | - "OpenAI-Organization", |
78 | | - "api-key", |
79 | | - "x-api-key", |
80 | | - "dd-api-key", |
81 | | - "dd-application-key", |
82 | | - "x-goog-api-key", |
83 | | - ], |
| 124 | + filter_headers=CASSETTE_FILTER_HEADERS, |
84 | 125 | ) |
85 | 126 |
|
86 | 127 |
|
@@ -125,31 +166,47 @@ async def proxy_request(request: Request, vcr_cassettes_directory: str) -> Respo |
125 | 166 | body_bytes = await request.read() |
126 | 167 |
|
127 | 168 | vcr_cassette_prefix = request.pop("vcr_cassette_prefix", None) |
128 | | - |
129 | 169 | cassette_name = generate_cassette_name(path, request.method, body_bytes, vcr_cassette_prefix) |
| 170 | + |
| 171 | + request_kwargs: Dict[str, Any] = { |
| 172 | + "method": request.method, |
| 173 | + "url": target_url, |
| 174 | + "headers": headers, |
| 175 | + "data": body_bytes, |
| 176 | + "cookies": dict(request.cookies), |
| 177 | + "allow_redirects": False, |
| 178 | + "stream": True, |
| 179 | + } |
| 180 | + |
| 181 | + if provider in AWS_SERVICES and not os.path.exists(os.path.join(vcr_cassettes_directory, provider, cassette_name)): |
| 182 | + if not AWS_SECRET_ACCESS_KEY: |
| 183 | + return Response( |
| 184 | + body="AWS_SECRET_ACCESS_KEY environment variable not set for aws signature recalculation", |
| 185 | + status=400, |
| 186 | + ) |
| 187 | + |
| 188 | + auth_header = request.headers.get("Authorization", "") |
| 189 | + auth_parts = parse_authorization_header(auth_header) |
| 190 | + aws_access_key = auth_parts.get("Credential", "").split("/")[0] |
| 191 | + |
| 192 | + auth = AWS4Auth(aws_access_key, AWS_SECRET_ACCESS_KEY, AWS_REGION, AWS_SERVICES[provider]) |
| 193 | + request_kwargs["auth"] = auth |
| 194 | + |
130 | 195 | with get_vcr(provider, vcr_cassettes_directory).use_cassette(f"{cassette_name}.yaml"): |
131 | | - oai_response = requests.request( |
132 | | - method=request.method, |
133 | | - url=target_url, |
134 | | - headers=headers, |
135 | | - data=body_bytes, |
136 | | - cookies=dict(request.cookies), |
137 | | - allow_redirects=False, |
138 | | - stream=True, |
139 | | - ) |
| 196 | + provider_response = requests.request(**request_kwargs) |
140 | 197 |
|
141 | 198 | # Extract content type without charset |
142 | | - content_type = oai_response.headers.get("content-type", "") |
| 199 | + content_type = provider_response.headers.get("content-type", "") |
143 | 200 | if ";" in content_type: |
144 | 201 | content_type = content_type.split(";")[0].strip() |
145 | 202 |
|
146 | 203 | response = Response( |
147 | | - body=oai_response.content, |
148 | | - status=oai_response.status_code, |
| 204 | + body=provider_response.content, |
| 205 | + status=provider_response.status_code, |
149 | 206 | content_type=content_type, |
150 | 207 | ) |
151 | 208 |
|
152 | | - for key, value in oai_response.headers.items(): |
| 209 | + for key, value in provider_response.headers.items(): |
153 | 210 | if key.lower() not in ( |
154 | 211 | "content-length", |
155 | 212 | "transfer-encoding", |
|
0 commit comments