11from __future__ import annotations
2+ import importlib
3+ import os
24import threading
35import json
4- from typing import Set , List , Dict , Any , Optional
6+ from typing import Set , List , Dict , Any , Optional , Callable
57import httpx , requests
68
79
10+ def _get_adk_version () -> str :
11+ """Get the version from package metadata."""
12+ try :
13+ return importlib .metadata .version ("gradient-adk" )
14+ except importlib .metadata .PackageNotFoundError :
15+ return "unknown"
16+
17+
18+ # Type for request hooks: (url, headers) -> modified_headers
19+ RequestHook = Callable [[str , Dict [str , str ]], Dict [str , str ]]
20+
21+
822class CapturedRequest :
923 """Represents a captured HTTP request/response."""
1024
@@ -32,6 +46,7 @@ def __init__(self):
3246 self ._captured_requests : List [CapturedRequest ] = (
3347 []
3448 ) # Capture request/response pairs
49+ self ._request_hooks : List [RequestHook ] = [] # Hooks to modify outgoing requests
3550 self ._lock = threading .Lock ()
3651 self ._active = False
3752 # originals
@@ -73,6 +88,20 @@ def clear_hits(self) -> None:
7388 self ._hit_count = 0
7489 self ._captured_requests .clear ()
7590
91+ def add_request_hook (self , hook : RequestHook ) -> None :
92+ """Register a hook to modify outgoing request headers."""
93+ self ._request_hooks .append (hook )
94+
95+ def _apply_request_hooks (self , url : str , headers : Dict [str , str ]) -> Dict [str , str ]:
96+ """Apply all registered request hooks to headers."""
97+ headers = dict (headers ) if headers else {}
98+ for hook in self ._request_hooks :
99+ try :
100+ headers = hook (url , headers )
101+ except Exception :
102+ pass # Never break requests due to hook errors
103+ return headers
104+
76105 def start_intercepting (self ) -> None :
77106 if self ._active :
78107 return
@@ -87,6 +116,19 @@ def start_intercepting(self) -> None:
87116 # patch httpx (async)
88117 async def intercepted_httpx_send (self_client , request , ** kwargs ):
89118 url_str = str (request .url )
119+
120+ # Apply request hooks to modify headers
121+ new_headers = _global_interceptor ._apply_request_hooks (
122+ url_str , dict (request .headers )
123+ )
124+ if new_headers != dict (request .headers ):
125+ request = httpx .Request (
126+ request .method ,
127+ request .url ,
128+ headers = new_headers ,
129+ content = request .content ,
130+ )
131+
90132 request_payload = _global_interceptor ._extract_request_payload (request )
91133 _global_interceptor ._record_request (url_str , request_payload )
92134
@@ -97,12 +139,12 @@ async def intercepted_httpx_send(self_client, request, **kwargs):
97139 # Don't read response body for streaming responses - it would buffer the entire stream!
98140 # Check if this is a streaming response by looking at headers or response type
99141 is_streaming = (
100- response .headers .get ("transfer-encoding" ) == "chunked" or
101- "text/event-stream" in response .headers .get ("content-type" , "" ) or
102- hasattr (response , "aiter_bytes" ) or
103- hasattr (response , "aiter_lines" )
142+ response .headers .get ("transfer-encoding" ) == "chunked"
143+ or "text/event-stream" in response .headers .get ("content-type" , "" )
144+ or hasattr (response , "aiter_bytes" )
145+ or hasattr (response , "aiter_lines" )
104146 )
105-
147+
106148 if not is_streaming :
107149 response_payload = await _global_interceptor ._extract_response_payload (
108150 response
@@ -114,6 +156,12 @@ async def intercepted_httpx_send(self_client, request, **kwargs):
114156
115157 def intercepted_httpx_request (self_client , method , url , ** kwargs ):
116158 url_str = str (url )
159+
160+ # Apply request hooks to modify headers
161+ kwargs ["headers" ] = _global_interceptor ._apply_request_hooks (
162+ url_str , kwargs .get ("headers" , {})
163+ )
164+
117165 request_payload = _global_interceptor ._extract_request_payload_from_kwargs (
118166 kwargs
119167 )
@@ -130,6 +178,19 @@ def intercepted_httpx_request(self_client, method, url, **kwargs):
130178 # patch httpx (sync)
131179 def intercepted_httpx_sync_send (self_client , request , ** kwargs ):
132180 url_str = str (request .url )
181+
182+ # Apply request hooks to modify headers
183+ new_headers = _global_interceptor ._apply_request_hooks (
184+ url_str , dict (request .headers )
185+ )
186+ if new_headers != dict (request .headers ):
187+ request = httpx .Request (
188+ request .method ,
189+ request .url ,
190+ headers = new_headers ,
191+ content = request .content ,
192+ )
193+
133194 request_payload = _global_interceptor ._extract_request_payload (request )
134195 _global_interceptor ._record_request (url_str , request_payload )
135196
@@ -146,6 +207,12 @@ def intercepted_httpx_sync_send(self_client, request, **kwargs):
146207
147208 def intercepted_httpx_sync_request (self_client , method , url , ** kwargs ):
148209 url_str = str (url )
210+
211+ # Apply request hooks to modify headers
212+ kwargs ["headers" ] = _global_interceptor ._apply_request_hooks (
213+ url_str , kwargs .get ("headers" , {})
214+ )
215+
149216 request_payload = _global_interceptor ._extract_request_payload_from_kwargs (
150217 kwargs
151218 )
@@ -160,6 +227,12 @@ def intercepted_httpx_sync_request(self_client, method, url, **kwargs):
160227 # patch requests
161228 def intercepted_requests_request (self_session , method , url , ** kwargs ):
162229 url_str = str (url )
230+
231+ # Apply request hooks to modify headers
232+ kwargs ["headers" ] = _global_interceptor ._apply_request_hooks (
233+ url_str , kwargs .get ("headers" , {})
234+ )
235+
163236 request_payload = _global_interceptor ._extract_request_payload_from_kwargs (
164237 kwargs
165238 )
@@ -290,6 +363,44 @@ def _extract_response_payload_from_requests(
290363 return None
291364
292365
366+ def create_adk_user_agent_hook (version : str , url_patterns : List [str ]) -> RequestHook :
367+ """
368+ Factory to create a User-Agent hook for specific URL patterns.
369+
370+ Completely replaces the User-Agent header with the Gradient ADK identifier
371+ for requests matching the specified URL patterns.
372+
373+ Format: Gradient/adk/{version} or Gradient/adk/{version}/{uuid}
374+
375+ Args:
376+ version: The ADK version string (e.g., "0.0.5")
377+ url_patterns: List of URL substrings to match (e.g., ["inference.do-ai.run"])
378+
379+ Returns:
380+ A request hook function that can be registered with NetworkInterceptor
381+ """
382+
383+ def hook (url : str , headers : Dict [str , str ]) -> Dict [str , str ]:
384+ # Check if URL matches any pattern
385+ if not any (pattern in url for pattern in url_patterns ):
386+ return headers
387+
388+ # Remove old User-Agent keys (both cases) to avoid duplicates
389+ headers .pop ("User-Agent" , None )
390+ headers .pop ("user-agent" , None )
391+
392+ # Build new User-Agent: Gradient/adk/{version} or Gradient/adk/{version}/{uuid}
393+ user_agent = f"Gradient/adk/{ version } "
394+ deployment_uuid = os .environ .get ("AGENT_WORKSPACE_DEPLOYMENT_UUID" )
395+ if deployment_uuid :
396+ user_agent += f"/{ deployment_uuid } "
397+
398+ headers ["User-Agent" ] = user_agent
399+ return headers
400+
401+ return hook
402+
403+
293404# Global instance
294405_global_interceptor = NetworkInterceptor ()
295406
@@ -302,4 +413,12 @@ def setup_digitalocean_interception() -> None:
302413 intr = get_network_interceptor ()
303414 intr .add_endpoint_pattern ("inference.do-ai.run" )
304415 intr .add_endpoint_pattern ("inference.do-ai-test.run" )
416+
417+ # Register User-Agent hook for ADK identification
418+ ua_hook = create_adk_user_agent_hook (
419+ version = _get_adk_version (),
420+ url_patterns = ["inference.do-ai.run" , "inference.do-ai-test.run" ],
421+ )
422+ intr .add_request_hook (ua_hook )
423+
305424 intr .start_intercepting ()
0 commit comments