Skip to content

Commit 883cc68

Browse files
authored
Merge pull request #10 from digitalocean/add-user-agent
Add ADK User agent to Serverless inference requests
2 parents 773b63c + 21e9dec commit 883cc68

File tree

2 files changed

+367
-8
lines changed

2 files changed

+367
-8
lines changed

gradient_adk/runtime/network_interceptor.py

Lines changed: 125 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
from __future__ import annotations
2+
import importlib
3+
import os
24
import threading
35
import json
4-
from typing import Set, List, Dict, Any, Optional
6+
from typing import Set, List, Dict, Any, Optional, Callable
57
import httpx, requests
68

79

10+
def _get_adk_version() -> str:
11+
"""Get the version from package metadata."""
12+
try:
13+
return importlib.metadata.version("gradient-adk")
14+
except importlib.metadata.PackageNotFoundError:
15+
return "unknown"
16+
17+
18+
# Type for request hooks: (url, headers) -> modified_headers
19+
RequestHook = Callable[[str, Dict[str, str]], Dict[str, str]]
20+
21+
822
class CapturedRequest:
923
"""Represents a captured HTTP request/response."""
1024

@@ -32,6 +46,7 @@ def __init__(self):
3246
self._captured_requests: List[CapturedRequest] = (
3347
[]
3448
) # Capture request/response pairs
49+
self._request_hooks: List[RequestHook] = [] # Hooks to modify outgoing requests
3550
self._lock = threading.Lock()
3651
self._active = False
3752
# originals
@@ -73,6 +88,20 @@ def clear_hits(self) -> None:
7388
self._hit_count = 0
7489
self._captured_requests.clear()
7590

91+
def add_request_hook(self, hook: RequestHook) -> None:
92+
"""Register a hook to modify outgoing request headers."""
93+
self._request_hooks.append(hook)
94+
95+
def _apply_request_hooks(self, url: str, headers: Dict[str, str]) -> Dict[str, str]:
96+
"""Apply all registered request hooks to headers."""
97+
headers = dict(headers) if headers else {}
98+
for hook in self._request_hooks:
99+
try:
100+
headers = hook(url, headers)
101+
except Exception:
102+
pass # Never break requests due to hook errors
103+
return headers
104+
76105
def start_intercepting(self) -> None:
77106
if self._active:
78107
return
@@ -87,6 +116,19 @@ def start_intercepting(self) -> None:
87116
# patch httpx (async)
88117
async def intercepted_httpx_send(self_client, request, **kwargs):
89118
url_str = str(request.url)
119+
120+
# Apply request hooks to modify headers
121+
new_headers = _global_interceptor._apply_request_hooks(
122+
url_str, dict(request.headers)
123+
)
124+
if new_headers != dict(request.headers):
125+
request = httpx.Request(
126+
request.method,
127+
request.url,
128+
headers=new_headers,
129+
content=request.content,
130+
)
131+
90132
request_payload = _global_interceptor._extract_request_payload(request)
91133
_global_interceptor._record_request(url_str, request_payload)
92134

@@ -97,12 +139,12 @@ async def intercepted_httpx_send(self_client, request, **kwargs):
97139
# Don't read response body for streaming responses - it would buffer the entire stream!
98140
# Check if this is a streaming response by looking at headers or response type
99141
is_streaming = (
100-
response.headers.get("transfer-encoding") == "chunked" or
101-
"text/event-stream" in response.headers.get("content-type", "") or
102-
hasattr(response, "aiter_bytes") or
103-
hasattr(response, "aiter_lines")
142+
response.headers.get("transfer-encoding") == "chunked"
143+
or "text/event-stream" in response.headers.get("content-type", "")
144+
or hasattr(response, "aiter_bytes")
145+
or hasattr(response, "aiter_lines")
104146
)
105-
147+
106148
if not is_streaming:
107149
response_payload = await _global_interceptor._extract_response_payload(
108150
response
@@ -114,6 +156,12 @@ async def intercepted_httpx_send(self_client, request, **kwargs):
114156

115157
def intercepted_httpx_request(self_client, method, url, **kwargs):
116158
url_str = str(url)
159+
160+
# Apply request hooks to modify headers
161+
kwargs["headers"] = _global_interceptor._apply_request_hooks(
162+
url_str, kwargs.get("headers", {})
163+
)
164+
117165
request_payload = _global_interceptor._extract_request_payload_from_kwargs(
118166
kwargs
119167
)
@@ -130,6 +178,19 @@ def intercepted_httpx_request(self_client, method, url, **kwargs):
130178
# patch httpx (sync)
131179
def intercepted_httpx_sync_send(self_client, request, **kwargs):
132180
url_str = str(request.url)
181+
182+
# Apply request hooks to modify headers
183+
new_headers = _global_interceptor._apply_request_hooks(
184+
url_str, dict(request.headers)
185+
)
186+
if new_headers != dict(request.headers):
187+
request = httpx.Request(
188+
request.method,
189+
request.url,
190+
headers=new_headers,
191+
content=request.content,
192+
)
193+
133194
request_payload = _global_interceptor._extract_request_payload(request)
134195
_global_interceptor._record_request(url_str, request_payload)
135196

@@ -146,6 +207,12 @@ def intercepted_httpx_sync_send(self_client, request, **kwargs):
146207

147208
def intercepted_httpx_sync_request(self_client, method, url, **kwargs):
148209
url_str = str(url)
210+
211+
# Apply request hooks to modify headers
212+
kwargs["headers"] = _global_interceptor._apply_request_hooks(
213+
url_str, kwargs.get("headers", {})
214+
)
215+
149216
request_payload = _global_interceptor._extract_request_payload_from_kwargs(
150217
kwargs
151218
)
@@ -160,6 +227,12 @@ def intercepted_httpx_sync_request(self_client, method, url, **kwargs):
160227
# patch requests
161228
def intercepted_requests_request(self_session, method, url, **kwargs):
162229
url_str = str(url)
230+
231+
# Apply request hooks to modify headers
232+
kwargs["headers"] = _global_interceptor._apply_request_hooks(
233+
url_str, kwargs.get("headers", {})
234+
)
235+
163236
request_payload = _global_interceptor._extract_request_payload_from_kwargs(
164237
kwargs
165238
)
@@ -290,6 +363,44 @@ def _extract_response_payload_from_requests(
290363
return None
291364

292365

366+
def create_adk_user_agent_hook(version: str, url_patterns: List[str]) -> RequestHook:
367+
"""
368+
Factory to create a User-Agent hook for specific URL patterns.
369+
370+
Completely replaces the User-Agent header with the Gradient ADK identifier
371+
for requests matching the specified URL patterns.
372+
373+
Format: Gradient/adk/{version} or Gradient/adk/{version}/{uuid}
374+
375+
Args:
376+
version: The ADK version string (e.g., "0.0.5")
377+
url_patterns: List of URL substrings to match (e.g., ["inference.do-ai.run"])
378+
379+
Returns:
380+
A request hook function that can be registered with NetworkInterceptor
381+
"""
382+
383+
def hook(url: str, headers: Dict[str, str]) -> Dict[str, str]:
384+
# Check if URL matches any pattern
385+
if not any(pattern in url for pattern in url_patterns):
386+
return headers
387+
388+
# Remove old User-Agent keys (both cases) to avoid duplicates
389+
headers.pop("User-Agent", None)
390+
headers.pop("user-agent", None)
391+
392+
# Build new User-Agent: Gradient/adk/{version} or Gradient/adk/{version}/{uuid}
393+
user_agent = f"Gradient/adk/{version}"
394+
deployment_uuid = os.environ.get("AGENT_WORKSPACE_DEPLOYMENT_UUID")
395+
if deployment_uuid:
396+
user_agent += f"/{deployment_uuid}"
397+
398+
headers["User-Agent"] = user_agent
399+
return headers
400+
401+
return hook
402+
403+
293404
# Global instance
294405
_global_interceptor = NetworkInterceptor()
295406

@@ -302,4 +413,12 @@ def setup_digitalocean_interception() -> None:
302413
intr = get_network_interceptor()
303414
intr.add_endpoint_pattern("inference.do-ai.run")
304415
intr.add_endpoint_pattern("inference.do-ai-test.run")
416+
417+
# Register User-Agent hook for ADK identification
418+
ua_hook = create_adk_user_agent_hook(
419+
version=_get_adk_version(),
420+
url_patterns=["inference.do-ai.run", "inference.do-ai-test.run"],
421+
)
422+
intr.add_request_hook(ua_hook)
423+
305424
intr.start_intercepting()

0 commit comments

Comments
 (0)