Skip to content

Commit f3faef7

Browse files
authored
Add headers to context (#26)
* Add headers to context * Remove traces tests * Add RequestContext type to examples
1 parent 2c6a401 commit f3faef7

File tree

7 files changed

+148
-35
lines changed

7 files changed

+148
-35
lines changed

README.md

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ The Gradient™ Agent Development Kit (ADK) is a comprehensive toolkit for build
1111
## Features
1212

1313
### 🛠️ CLI (Command Line Interface)
14+
1415
- **Local Development**: Run and test your agents locally with hot-reload support
1516
- **Seamless Deployment**: Deploy agents to DigitalOcean with a single command
1617
- **Evaluation Framework**: Run comprehensive evaluations with custom metrics and datasets
1718
- **Observability**: View traces and runtime logs directly from the CLI
1819

1920
### 🚀 Runtime Environment
21+
2022
- **Framework Agnostic**: Works with any Python framework for building AI agents
2123
- **Automatic LangGraph Integration**: Built-in trace capture for LangGraph nodes and state transitions
2224
- **Custom Decorators**: Capture traces from any framework using `@trace` decorators
@@ -40,6 +42,7 @@ gradient agent init
4042
```
4143

4244
This creates a new agent project with:
45+
4346
- `main.py` - Agent entrypoint with example code
4447
- `agents/` - Directory for agent implementations
4548
- `tools/` - Directory for custom tools
@@ -77,7 +80,7 @@ gradient agent evaluate \
7780
LangGraph agents automatically capture traces for all nodes and state transitions:
7881

7982
```python
80-
from gradient_adk import entrypoint
83+
from gradient_adk import entrypoint, RequestContext
8184
from langgraph.graph import StateGraph
8285
from typing import TypedDict
8386

@@ -92,11 +95,11 @@ async def llm_call(state: State) -> State:
9295
return state
9396

9497
@entrypoint
95-
async def main(input: dict, context: dict):
98+
async def main(input: dict, context: RequestContext):
9699
graph = StateGraph(State)
97100
graph.add_node("llm_call", llm_call)
98101
graph.set_entry_point("llm_call")
99-
102+
100103
graph = graph.compile()
101104
result = await graph.ainvoke({"input": input.get("query")})
102105
return result["output"]
@@ -107,7 +110,7 @@ async def main(input: dict, context: dict):
107110
For frameworks beyond LangGraph, use trace decorators to capture custom spans:
108111

109112
```python
110-
from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever
113+
from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever, RequestContext
111114

112115
@trace_retriever("vector_search")
113116
async def search_knowledge_base(query: str):
@@ -127,7 +130,7 @@ async def calculate(x: int, y: int):
127130
return x + y
128131

129132
@entrypoint
130-
async def main(input: dict, context: dict):
133+
async def main(input: dict, context: RequestContext):
131134
docs = await search_knowledge_base(input["query"])
132135
result = await calculate(5, 10)
133136
response = await generate_response(f"Context: {docs}")
@@ -139,10 +142,10 @@ async def main(input: dict, context: dict):
139142
The runtime supports streaming responses with automatic trace capture:
140143

141144
```python
142-
from gradient_adk import entrypoint
145+
from gradient_adk import entrypoint, RequestContext
143146

144147
@entrypoint
145-
async def main(input: dict, context: dict):
148+
async def main(input: dict, context: RequestContext):
146149
# Stream text chunks
147150
async def generate_chunks():
148151
async for chunk in llm.stream(input["query"]):
@@ -190,12 +193,12 @@ gradient agent evaluate \
190193
--success-threshold 80.0
191194
```
192195

193-
194196
## Tracing
195197

196198
The ADK provides comprehensive tracing capabilities to capture and analyze your agent's execution. You can use **decorators** for wrapping functions or **programmatic functions** for manual span creation.
197199

198200
### What Gets Traced Automatically
201+
199202
- **LangGraph Nodes**: All node executions, state transitions, and edges (including LLM calls, tool calls, and DigitalOcean Knowledge Base calls)
200203
- **HTTP Requests**: Request/response payloads for LLM API calls
201204
- **Errors**: Full exception details and stack traces
@@ -206,7 +209,7 @@ The ADK provides comprehensive tracing capabilities to capture and analyze your
206209
Use decorators to automatically trace function executions:
207210

208211
```python
209-
from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever
212+
from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever, RequestContext
210213

211214
@trace_llm("model_call")
212215
async def call_model(prompt: str):
@@ -226,7 +229,7 @@ async def search_docs(query: str):
226229
return results
227230

228231
@entrypoint
229-
async def main(input: dict, context: dict):
232+
async def main(input: dict, context: RequestContext):
230233
docs = await search_docs(input["query"])
231234
result = await calculate(5, 10)
232235
response = await call_model(f"Context: {docs}")
@@ -238,10 +241,10 @@ async def main(input: dict, context: dict):
238241
For more control over span creation, use the programmatic functions. These are useful when you can't use decorators or need to add spans for code you don't control:
239242

240243
```python
241-
from gradient_adk import entrypoint, add_llm_span, add_tool_span, add_agent_span
244+
from gradient_adk import entrypoint, add_llm_span, add_tool_span, add_agent_span, RequestContext
242245

243246
@entrypoint
244-
async def main(input: dict, context: dict):
247+
async def main(input: dict, context: RequestContext):
245248
# Add an LLM span with detailed metadata
246249
response = await external_llm_call(input["query"])
247250
add_llm_span(
@@ -279,17 +282,18 @@ async def main(input: dict, context: dict):
279282

280283
#### Available Span Functions
281284

282-
| Function | Description | Key Optional Fields |
283-
|----------|-------------|---------------------|
284-
| `add_llm_span()` | Record LLM/model calls | `model`, `temperature`, `num_input_tokens`, `num_output_tokens`, `total_tokens`, `tools`, `time_to_first_token_ns` |
285-
| `add_tool_span()` | Record tool/function executions | `tool_call_id` |
286-
| `add_agent_span()` | Record agent/sub-agent executions ||
285+
| Function | Description | Key Optional Fields |
286+
| ------------------ | --------------------------------- | ------------------------------------------------------------------------------------------------------------------ |
287+
| `add_llm_span()` | Record LLM/model calls | `model`, `temperature`, `num_input_tokens`, `num_output_tokens`, `total_tokens`, `tools`, `time_to_first_token_ns` |
288+
| `add_tool_span()` | Record tool/function executions | `tool_call_id` |
289+
| `add_agent_span()` | Record agent/sub-agent executions | |
287290

288291
**Common optional fields for all span functions:** `duration_ns`, `metadata`, `tags`, `status_code`
289292

290293
### Viewing Traces
291294

292295
Traces are:
296+
293297
- Automatically sent to DigitalOcean's Gradient Platform
294298
- Available in real-time through the web console
295299
- Accessible via `gradient agent traces` command

gradient_adk/cli/templates/main.py.template

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import os
66
from typing import Dict, TypedDict
77

88
from gradient import AsyncGradient
9-
from gradient_adk import entrypoint
9+
from gradient_adk import entrypoint, RequestContext
1010
from langgraph.graph import StateGraph
1111

1212

@@ -44,7 +44,7 @@ async def llm_call(state: State) -> State:
4444

4545

4646
@entrypoint
47-
async def main(input: Dict, context: Dict):
47+
async def main(input: Dict, context: RequestContext):
4848
"""Entrypoint"""
4949

5050
# Setup the graph
@@ -61,4 +61,4 @@ async def main(input: Dict, context: Dict):
6161

6262
# Invoke the app
6363
result = await app.ainvoke(initial_state)
64-
return result["output"]
64+
return result["output"]

gradient_adk/decorator.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99
import inspect
1010
import json
11-
from dataclasses import dataclass
11+
from dataclasses import dataclass, field
1212
from typing import Callable, Optional, Any, Dict, List
1313

1414

@@ -18,9 +18,18 @@ class RequestContext:
1818
1919
Attributes:
2020
session_id: The session ID for the request, if provided.
21+
headers: Raw request headers as a dictionary.
2122
"""
2223

2324
session_id: Optional[str] = None
25+
headers: Dict[str, str] = field(default_factory=dict)
26+
27+
28+
def _build_request_context(req: Request) -> RequestContext:
29+
return RequestContext(
30+
session_id=req.headers.get("session-id"),
31+
headers=dict(req.headers.items()),
32+
)
2433

2534

2635
from fastapi import FastAPI, HTTPException, Request
@@ -157,9 +166,8 @@ async def run(req: Request):
157166

158167
is_evaluation = "evaluation-id" in req.headers
159168

160-
# Extract session ID from headers
161-
session_id = req.headers.get("session-id")
162-
context = RequestContext(session_id=session_id)
169+
context = _build_request_context(req)
170+
session_id = context.session_id
163171

164172
# Initialize tracker
165173
tr = None
@@ -230,7 +238,9 @@ async def run(req: Request):
230238
await tr._submit()
231239
except Exception:
232240
pass
233-
logger.error("Error in streaming evaluation", error=str(e), exc_info=True)
241+
logger.error(
242+
"Error in streaming evaluation", error=str(e), exc_info=True
243+
)
234244
raise HTTPException(status_code=500, detail="Internal server error")
235245

236246
# Normal streaming case - wrap in tracking iterator
@@ -301,4 +311,4 @@ async def health():
301311

302312
def run_server(fastapi_app: FastAPI, host: str = "0.0.0.0", port: int = 8080, **kwargs):
303313
"""Run the FastAPI server with uvicorn."""
304-
uvicorn.run(fastapi_app, host=host, port=port, **kwargs)
314+
uvicorn.run(fastapi_app, host=host, port=port, **kwargs)

gradient_adk/tracing.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
with the same kind of tracing automatically provided for some other frameworks.
55
66
Example usage:
7-
from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever
7+
from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever, RequestContext
88
99
@trace_retriever("fetch_data")
1010
async def fetch_data(query: str) -> dict:
@@ -22,7 +22,7 @@ async def calculate(x: int, y: int) -> int:
2222
return x + y
2323
2424
@entrypoint
25-
async def my_agent(input: dict, context: dict):
25+
async def my_agent(input: dict, context: RequestContext):
2626
data = await fetch_data(input["query"])
2727
result = await calculate(5, 10)
2828
response = await call_model(data["prompt"])
@@ -240,7 +240,9 @@ async def async_gen_wrapper(*args, **kwargs):
240240
if span_type is None and has_network_hits:
241241
meta["is_llm_call"] = True
242242
# Get captured request/response payloads for LLM metadata extraction
243-
captured = interceptor.get_captured_requests_since(network_token)
243+
captured = interceptor.get_captured_requests_since(
244+
network_token
245+
)
244246
if captured:
245247
call = captured[0]
246248
if call.request_payload:
@@ -302,7 +304,9 @@ async def async_wrapper(*args, **kwargs):
302304
if span_type is None and has_network_hits:
303305
meta["is_llm_call"] = True
304306
# Get captured request/response payloads for LLM metadata extraction
305-
captured = interceptor.get_captured_requests_since(network_token)
307+
captured = interceptor.get_captured_requests_since(
308+
network_token
309+
)
306310
if captured:
307311
call = captured[0]
308312
if call.request_payload:
@@ -401,7 +405,9 @@ def sync_wrapper(*args, **kwargs):
401405
if span_type is None and has_network_hits:
402406
meta["is_llm_call"] = True
403407
# Get captured request/response payloads for LLM metadata extraction
404-
captured = interceptor.get_captured_requests_since(network_token)
408+
captured = interceptor.get_captured_requests_since(
409+
network_token
410+
)
405411
if captured:
406412
call = captured[0]
407413
if call.request_payload:
@@ -536,7 +542,9 @@ def add_llm_span(
536542
span = _create_span(name, _freeze(input))
537543
meta = _ensure_meta(span)
538544
meta["is_llm_call"] = True
539-
meta["is_programmatic"] = True # Mark as programmatic to skip auto-duration calculation
545+
meta["is_programmatic"] = (
546+
True # Mark as programmatic to skip auto-duration calculation
547+
)
540548

541549
if model is not None:
542550
meta["model_name"] = model
@@ -548,7 +556,11 @@ def add_llm_span(
548556
meta["llm_request_payload"]["temperature"] = temperature
549557
if time_to_first_token_ns is not None:
550558
meta["time_to_first_token_ns"] = time_to_first_token_ns
551-
if num_input_tokens is not None or num_output_tokens is not None or total_tokens is not None:
559+
if (
560+
num_input_tokens is not None
561+
or num_output_tokens is not None
562+
or total_tokens is not None
563+
):
552564
if "llm_response_payload" not in meta:
553565
meta["llm_response_payload"] = {}
554566
meta["llm_response_payload"]["usage"] = {
@@ -608,7 +620,9 @@ def add_tool_span(
608620
span = _create_span(name, _freeze(input))
609621
meta = _ensure_meta(span)
610622
meta["is_tool_call"] = True
611-
meta["is_programmatic"] = True # Mark as programmatic to skip auto-duration calculation
623+
meta["is_programmatic"] = (
624+
True # Mark as programmatic to skip auto-duration calculation
625+
)
612626

613627
if tool_call_id is not None:
614628
meta["tool_call_id"] = tool_call_id
@@ -662,7 +676,9 @@ def add_agent_span(
662676
span = _create_span(name, _freeze(input))
663677
meta = _ensure_meta(span)
664678
meta["is_agent_call"] = True
665-
meta["is_programmatic"] = True # Mark as programmatic to skip auto-duration calculation
679+
meta["is_programmatic"] = (
680+
True # Mark as programmatic to skip auto-duration calculation
681+
)
666682

667683
if tags is not None:
668684
meta["tags"] = tags

integration_tests/example_agents/echo_agent/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ async def main(query, context: RequestContext):
1414
"echo": prompt,
1515
"received": query,
1616
"session_id": context.session_id if context else None,
17+
"headers": context.headers if context else {},
1718
}

integration_tests/run/test_adk_agents_run.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,55 @@ def test_agent_run_session_id_header_passthrough(self, setup_agent_in_temp):
526526
finally:
527527
cleanup_process(process)
528528

529+
@pytest.mark.cli
530+
def test_agent_run_headers_passthrough(self, setup_agent_in_temp):
531+
"""
532+
Test that arbitrary headers are passed through to RequestContext.headers.
533+
"""
534+
logger = logging.getLogger(__name__)
535+
temp_dir = setup_agent_in_temp
536+
port = find_free_port()
537+
process = None
538+
539+
try:
540+
logger.info(f"Starting agent on port {port} in {temp_dir}")
541+
542+
process = subprocess.Popen(
543+
[
544+
"gradient",
545+
"agent",
546+
"run",
547+
"--port",
548+
str(port),
549+
"--no-dev",
550+
],
551+
cwd=temp_dir,
552+
start_new_session=True,
553+
)
554+
555+
server_ready = wait_for_server(port, timeout=30)
556+
assert server_ready, "Server did not start within timeout"
557+
558+
headers = {
559+
"Session-Id": "session-headers-123",
560+
"X-Request-Id": "req-789",
561+
"X-Custom": "custom-value",
562+
}
563+
response = requests.post(
564+
f"http://localhost:{port}/run",
565+
json={"prompt": "Hello headers"},
566+
headers=headers,
567+
timeout=10,
568+
)
569+
assert response.status_code == 200
570+
data = response.json()
571+
lowered = {k.lower(): v for k, v in data["headers"].items()}
572+
assert lowered["session-id"] == "session-headers-123"
573+
assert lowered["x-request-id"] == "req-789"
574+
assert lowered["x-custom"] == "custom-value"
575+
finally:
576+
cleanup_process(process)
577+
529578
@pytest.mark.cli
530579
def test_streaming_agent_without_evaluation_id_streams_response(
531580
self, setup_streaming_agent_in_temp

0 commit comments

Comments
 (0)