Skip to content

Commit 5fbb8fc

Browse files
authored
Merge pull request #15 from digitalocean/fix-stremaing-evaluations
Add stream chunk collection if evaluation-id is passed
2 parents 63cf65d + 9d8c8f9 commit 5fbb8fc

File tree

4 files changed

+335
-3
lines changed

4 files changed

+335
-3
lines changed

gradient_adk/decorator.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,49 @@ async def run(req: Request):
166166
logger.error("Error creating generator", error=str(e), exc_info=True)
167167
raise HTTPException(status_code=500, detail="Internal server error")
168168

169-
# Wrap in tracking iterator
169+
# If evaluation mode, collect all chunks and return as single response with trace ID
170+
if is_evaluation:
171+
from fastapi.responses import JSONResponse
172+
173+
collected_chunks: List[str] = []
174+
try:
175+
async for chunk in user_gen:
176+
if isinstance(chunk, bytes):
177+
chunk_str = chunk.decode("utf-8", errors="replace")
178+
elif isinstance(chunk, dict):
179+
chunk_str = json.dumps(chunk)
180+
elif chunk is None:
181+
continue
182+
else:
183+
chunk_str = str(chunk)
184+
collected_chunks.append(chunk_str)
185+
186+
result = "".join(collected_chunks)
187+
188+
# Submit tracking and get trace ID
189+
trace_id = None
190+
if tr:
191+
try:
192+
tr._req["outputs"] = result
193+
trace_id = await tr.submit_and_get_trace_id()
194+
except Exception:
195+
pass
196+
197+
headers = {"X-Gradient-Trace-Id": trace_id} if trace_id else {}
198+
return JSONResponse(content=result, headers=headers)
199+
200+
except Exception as e:
201+
if tr:
202+
try:
203+
tr._req["outputs"] = "".join(collected_chunks)
204+
tr._req["error"] = str(e)
205+
await tr._submit()
206+
except Exception:
207+
pass
208+
logger.error("Error in streaming evaluation", error=str(e), exc_info=True)
209+
raise HTTPException(status_code=500, detail="Internal server error")
210+
211+
# Normal streaming case - wrap in tracking iterator
170212
streaming_iter = _StreamingIteratorWithTracking(user_gen, tr, func.__name__)
171213

172214
return FastAPIStreamingResponse(
@@ -234,4 +276,4 @@ async def health():
234276

235277
def run_server(fastapi_app: FastAPI, host: str = "0.0.0.0", port: int = 8080, **kwargs):
236278
"""Run the FastAPI server with uvicorn."""
237-
uvicorn.run(fastapi_app, host=host, port=port, **kwargs)
279+
uvicorn.run(fastapi_app, host=host, port=port, **kwargs)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Streaming echo agent for integration testing.
3+
Does not make any external API calls - just echoes back the input in chunks.
4+
Used to test streaming vs non-streaming behavior with evaluation-id header.
5+
"""
6+
7+
from gradient_adk import entrypoint
8+
9+
10+
@entrypoint
11+
async def main(query, context):
12+
"""Streaming echo agent - yields the response in chunks."""
13+
prompt = query.get("prompt", "no prompt provided")
14+
# Stream the response in multiple chunks
15+
yield "Echo: "
16+
yield prompt
17+
yield " [DONE]"

integration_tests/run/test_adk_agents_run.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def echo_agent_dir(self):
5959
"""Get the path to the echo agent directory."""
6060
return Path(__file__).parent.parent / "example_agents" / "echo_agent"
6161

62+
@pytest.fixture
63+
def streaming_echo_agent_dir(self):
64+
"""Get the path to the streaming echo agent directory."""
65+
return Path(__file__).parent.parent / "example_agents" / "streaming_echo_agent"
66+
6267
@pytest.fixture
6368
def setup_agent_in_temp(self, echo_agent_dir):
6469
"""
@@ -86,6 +91,33 @@ def setup_agent_in_temp(self, echo_agent_dir):
8691

8792
yield temp_path
8893

94+
@pytest.fixture
95+
def setup_streaming_agent_in_temp(self, streaming_echo_agent_dir):
96+
"""
97+
Setup a temporary directory with the streaming echo agent and proper configuration.
98+
Yields the temp directory path and cleans up after.
99+
"""
100+
with tempfile.TemporaryDirectory() as temp_dir:
101+
temp_path = Path(temp_dir)
102+
103+
# Copy the streaming echo agent main.py
104+
shutil.copy(streaming_echo_agent_dir / "main.py", temp_path / "main.py")
105+
106+
# Create .gradient directory and config
107+
gradient_dir = temp_path / ".gradient"
108+
gradient_dir.mkdir()
109+
110+
config = {
111+
"agent_name": "test-streaming-echo-agent",
112+
"agent_environment": "main",
113+
"entrypoint_file": "main.py",
114+
}
115+
116+
with open(gradient_dir / "agent.yml", "w") as f:
117+
yaml.safe_dump(config, f)
118+
119+
yield temp_path
120+
89121
@pytest.mark.cli
90122
def test_agent_run_happy_path(self, setup_agent_in_temp):
91123
"""
@@ -386,5 +418,135 @@ def test_agent_run_run_endpoint_with_various_inputs(self, setup_agent_in_temp):
386418
assert data["echo"] == "Hello `} E1-('"
387419
logger.info("Unicode test passed")
388420

421+
finally:
422+
cleanup_process(process)
423+
424+
@pytest.mark.cli
425+
def test_streaming_agent_without_evaluation_id_streams_response(
426+
self, setup_streaming_agent_in_temp
427+
):
428+
"""
429+
Test that a streaming agent returns a streamed response when no evaluation-id header is sent.
430+
Verifies:
431+
- Response is streamed (text/event-stream content type)
432+
- Response contains the expected content
433+
"""
434+
logger = logging.getLogger(__name__)
435+
temp_dir = setup_streaming_agent_in_temp
436+
port = find_free_port()
437+
process = None
438+
439+
try:
440+
logger.info(f"Starting streaming agent on port {port} in {temp_dir}")
441+
442+
# Start the agent server
443+
process = subprocess.Popen(
444+
[
445+
"gradient",
446+
"agent",
447+
"run",
448+
"--port",
449+
str(port),
450+
"--no-dev",
451+
],
452+
cwd=temp_dir,
453+
start_new_session=True,
454+
)
455+
456+
# Wait for server to be ready
457+
server_ready = wait_for_server(port, timeout=30)
458+
assert server_ready, "Server did not start within timeout"
459+
460+
# Make a streaming request WITHOUT evaluation-id header
461+
with requests.post(
462+
f"http://localhost:{port}/run",
463+
json={"prompt": "Hello, World!"},
464+
stream=True,
465+
timeout=30,
466+
) as response:
467+
assert response.status_code == 200
468+
469+
# Verify it's a streaming response (text/event-stream)
470+
content_type = response.headers.get("content-type", "")
471+
assert "text/event-stream" in content_type, (
472+
f"Expected text/event-stream content type for streaming, got: {content_type}"
473+
)
474+
475+
# Collect chunks to verify content
476+
chunks = list(response.iter_content(decode_unicode=True))
477+
full_content = "".join(c for c in chunks if c)
478+
479+
# Verify the content contains the expected streamed output
480+
assert "Echo:" in full_content or "Hello, World!" in full_content, (
481+
f"Expected streamed content to contain prompt, got: {full_content}"
482+
)
483+
484+
logger.info(f"Streaming response received with {len(chunks)} chunks")
485+
logger.info(f"Full content: {full_content}")
486+
487+
finally:
488+
cleanup_process(process)
489+
490+
@pytest.mark.cli
491+
def test_streaming_agent_with_evaluation_id_returns_single_response(
492+
self, setup_streaming_agent_in_temp
493+
):
494+
"""
495+
Test that a streaming agent returns a single JSON response (not streamed)
496+
when the evaluation-id header is present.
497+
Verifies:
498+
- Response is NOT streamed (application/json content type)
499+
- Response contains the complete collected content
500+
"""
501+
logger = logging.getLogger(__name__)
502+
temp_dir = setup_streaming_agent_in_temp
503+
port = find_free_port()
504+
process = None
505+
506+
try:
507+
logger.info(f"Starting streaming agent on port {port} in {temp_dir}")
508+
509+
# Start the agent server
510+
process = subprocess.Popen(
511+
[
512+
"gradient",
513+
"agent",
514+
"run",
515+
"--port",
516+
str(port),
517+
"--no-dev",
518+
],
519+
cwd=temp_dir,
520+
start_new_session=True,
521+
)
522+
523+
# Wait for server to be ready
524+
server_ready = wait_for_server(port, timeout=30)
525+
assert server_ready, "Server did not start within timeout"
526+
527+
# Make a request WITH evaluation-id header
528+
response = requests.post(
529+
f"http://localhost:{port}/run",
530+
json={"prompt": "Hello, World!"},
531+
headers={"evaluation-id": "test-eval-123"},
532+
timeout=30,
533+
)
534+
assert response.status_code == 200
535+
536+
# Verify it's NOT a streaming response (should be application/json)
537+
content_type = response.headers.get("content-type", "")
538+
assert "application/json" in content_type, (
539+
f"Expected application/json content type for evaluation mode, got: {content_type}"
540+
)
541+
542+
# Verify the response contains the complete content
543+
result = response.json()
544+
expected_content = "Echo: Hello, World! [DONE]"
545+
assert result == expected_content, (
546+
f"Expected complete collected content '{expected_content}', got: {result}"
547+
)
548+
549+
logger.info(f"Single JSON response received: {result}")
550+
389551
finally:
390552
cleanup_process(process)

tests/decorator_test.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self):
3131
self.ended = []
3232
self.closed = False
3333
self._req = {}
34+
self._submitted_trace_id = None
3435

3536
def on_request_start(self, name, inputs, is_evaluation=False):
3637
self.started.append((name, inputs, is_evaluation))
@@ -45,6 +46,12 @@ async def _submit(self):
4546
"""Simulate async submission."""
4647
await asyncio.sleep(0)
4748

49+
async def submit_and_get_trace_id(self):
50+
"""Simulate async submission and return trace ID."""
51+
await asyncio.sleep(0)
52+
self._submitted_trace_id = "test-trace-id-12345"
53+
return self._submitted_trace_id
54+
4855
async def aclose(self):
4956
"""Simulate async close."""
5057
await asyncio.sleep(0)
@@ -372,6 +379,110 @@ def handler(data, context):
372379
assert tracker.started[-1][2] is True # is_evaluation flag
373380

374381

382+
def test_streaming_with_evaluation_id_collects_and_returns_complete_response(
383+
patch_helpers,
384+
):
385+
"""Test that streaming with evaluation-id collects all chunks and returns complete response."""
386+
tracker = patch_helpers
387+
388+
@entrypoint
389+
async def handler(data):
390+
yield "hello"
391+
yield " "
392+
yield "world"
393+
394+
fastapi_app = globals()["fastapi_app"]
395+
with TestClient(fastapi_app) as client:
396+
# With evaluation-id header, response should NOT be streamed
397+
r = client.post(
398+
"/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}
399+
)
400+
assert r.status_code == 200
401+
# Response should be the complete collected output
402+
assert r.json() == "hello world"
403+
# Trace ID should be in response headers
404+
assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345"
405+
406+
# Check that is_evaluation was passed correctly
407+
assert tracker.started
408+
assert tracker.started[-1][2] is True # is_evaluation flag
409+
# Tracker should have the collected output
410+
assert tracker._req.get("outputs") == "hello world"
411+
# submit_and_get_trace_id should have been called
412+
assert tracker._submitted_trace_id == "test-trace-id-12345"
413+
414+
415+
def test_streaming_without_evaluation_id_streams_normally(patch_helpers):
416+
"""Test that streaming without evaluation-id continues to stream normally."""
417+
tracker = patch_helpers
418+
419+
@entrypoint
420+
async def handler(data):
421+
yield "hello"
422+
yield " "
423+
yield "world"
424+
425+
fastapi_app = globals()["fastapi_app"]
426+
with TestClient(fastapi_app) as client:
427+
# Without evaluation-id header, response should be streamed
428+
with client.stream("POST", "/run", json={"test": 1}) as resp:
429+
assert resp.status_code == 200
430+
# Read the full stream by iterating
431+
body = "".join(chunk for chunk in resp.iter_text())
432+
assert body == "hello world"
433+
434+
# Check that is_evaluation was passed as False
435+
assert tracker.started
436+
assert tracker.started[-1][2] is False # is_evaluation flag
437+
# submit_and_get_trace_id should NOT have been called (normal streaming)
438+
assert tracker._submitted_trace_id is None
439+
440+
441+
def test_streaming_with_evaluation_id_handles_dict_chunks(patch_helpers):
442+
"""Test that streaming with evaluation-id properly handles dict chunks."""
443+
tracker = patch_helpers
444+
445+
@entrypoint
446+
async def handler(data):
447+
yield {"type": "start"}
448+
yield {"type": "data", "value": 42}
449+
yield {"type": "end"}
450+
451+
fastapi_app = globals()["fastapi_app"]
452+
with TestClient(fastapi_app) as client:
453+
r = client.post(
454+
"/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}
455+
)
456+
assert r.status_code == 200
457+
# Dict chunks should be JSON serialized and concatenated
458+
result = r.json()
459+
assert '{"type": "start"}' in result
460+
assert '{"type": "data", "value": 42}' in result
461+
assert '{"type": "end"}' in result
462+
# Trace ID should be in response headers
463+
assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345"
464+
465+
466+
def test_streaming_with_evaluation_id_skips_none_chunks(patch_helpers):
467+
"""Test that streaming with evaluation-id properly skips None chunks."""
468+
tracker = patch_helpers
469+
470+
@entrypoint
471+
async def handler(data):
472+
yield "a"
473+
yield None # Should be skipped
474+
yield "b"
475+
476+
fastapi_app = globals()["fastapi_app"]
477+
with TestClient(fastapi_app) as client:
478+
r = client.post(
479+
"/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}
480+
)
481+
assert r.status_code == 200
482+
assert r.json() == "ab" # None skipped
483+
assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345"
484+
485+
375486
def test_shutdown_event_calls_tracker_aclose(patch_helpers):
376487
"""Test that shutdown event calls tracker aclose."""
377488
tracker = patch_helpers
@@ -414,4 +525,4 @@ def handler(data, context):
414525
assert calls["host"] == "127.0.0.1"
415526
assert calls["port"] == 9999
416527
assert calls["kwargs"]["reload"] is True
417-
assert calls["kwargs"]["log_level"] == "debug"
528+
assert calls["kwargs"]["log_level"] == "debug"

0 commit comments

Comments
 (0)