Skip to content

Commit 59acc53

Browse files
SkylarKeltyclaude
andcommitted
Cancel research tasks on client disconnect
Streaming: wrap the generator in try/finally so when the client disconnects (GeneratorExit/CancelledError), the research task is cancelled immediately instead of running to completion. Non-streaming: poll request.is_disconnected() while awaiting the research task, cancel and return 499 if the client goes away. Both paths prevent wasted LLM calls and Playwright fetches when nginx/LiteLLM closes the upstream connection. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1e4209c commit 59acc53

File tree

1 file changed

+94
-58
lines changed

1 file changed

+94
-58
lines changed

artemis/main.py

Lines changed: 94 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -357,64 +357,77 @@ async def _stream_responses(request: ResponsesRequest) -> StreamingResponse:
357357
358358
Uses an asyncio.Queue so progress events from deep_research are
359359
yielded to the client in real-time rather than collected and replayed.
360+
361+
If the client disconnects mid-stream, the finally block cancels any
362+
in-flight research task to avoid wasting LLM/Playwright resources.
360363
"""
361364
async def event_generator():
362365
settings = get_settings()
363-
364-
yield f"[Starting research on: {request.input}]\n\n"
365-
366-
if request.preset in {"deep-research", "shallow-research"}:
367-
queue: asyncio.Queue[str] = asyncio.Queue()
368-
preset_config = _research_preset_config(settings, request.preset)
369-
370-
def progress_cb(stage: str, msg: str):
371-
queue.put_nowait(f"[{stage.upper()}] {msg}")
372-
373-
research_task = asyncio.create_task(deep_research(
374-
request.input,
375-
stages=request.max_steps or preset_config.stages,
376-
passes=preset_config.passes,
377-
sub_queries_per_stage=preset_config.subqueries,
378-
results_per_query=preset_config.results_per_query,
379-
max_tokens=preset_config.max_tokens,
380-
outline=request.outline,
381-
content_extraction=preset_config.content_extraction,
382-
pages_per_section=preset_config.pages_per_section,
383-
content_max_chars=preset_config.content_max_chars,
384-
progress_callback=progress_cb,
385-
))
386-
387-
# Yield progress events as they arrive
388-
while not research_task.done():
366+
research_task: asyncio.Task | None = None
367+
368+
try:
369+
yield f"[Starting research on: {request.input}]\n\n"
370+
371+
if request.preset in {"deep-research", "shallow-research"}:
372+
queue: asyncio.Queue[str] = asyncio.Queue()
373+
preset_config = _research_preset_config(settings, request.preset)
374+
375+
def progress_cb(stage: str, msg: str):
376+
queue.put_nowait(f"[{stage.upper()}] {msg}")
377+
378+
research_task = asyncio.create_task(deep_research(
379+
request.input,
380+
stages=request.max_steps or preset_config.stages,
381+
passes=preset_config.passes,
382+
sub_queries_per_stage=preset_config.subqueries,
383+
results_per_query=preset_config.results_per_query,
384+
max_tokens=preset_config.max_tokens,
385+
outline=request.outline,
386+
content_extraction=preset_config.content_extraction,
387+
pages_per_section=preset_config.pages_per_section,
388+
content_max_chars=preset_config.content_max_chars,
389+
progress_callback=progress_cb,
390+
))
391+
392+
# Yield progress events as they arrive
393+
while not research_task.done():
394+
try:
395+
msg = await asyncio.wait_for(queue.get(), timeout=0.5)
396+
yield msg + "\n\n"
397+
except asyncio.TimeoutError:
398+
continue
399+
400+
# Drain any remaining queued messages
401+
while not queue.empty():
402+
yield queue.get_nowait() + "\n\n"
403+
404+
# Propagate any exception from the research task
405+
research_run = research_task.result()
406+
407+
yield research_run.essay
408+
yield f"\n\n[Found {len(research_run.results)} sources]"
409+
yield f"\n[USAGE] {research_run.usage.model_dump_json()}"
410+
else:
411+
yield "[Searching...]\n"
412+
results = await search_searxng(query=request.input, max_results=10)
413+
yield f"[Found {len(results)} results]\n\n"
414+
415+
summary, usage, warnings = await _build_summary(request.input, results)
416+
yield summary or _fallback_text(results)
417+
418+
response_usage = usage or TokenUsage()
419+
response_usage.search_requests = 1
420+
citation_chars = sum(len(r.snippet or "") + len(r.title) for r in results)
421+
response_usage.citation_tokens = citation_chars // 4
422+
yield f"\n[USAGE] {response_usage.model_dump_json()}"
423+
finally:
424+
if research_task is not None and not research_task.done():
425+
research_task.cancel()
389426
try:
390-
msg = await asyncio.wait_for(queue.get(), timeout=0.5)
391-
yield msg + "\n\n"
392-
except asyncio.TimeoutError:
393-
continue
394-
395-
# Drain any remaining queued messages
396-
while not queue.empty():
397-
yield queue.get_nowait() + "\n\n"
398-
399-
# Propagate any exception from the research task
400-
research_run = research_task.result()
401-
402-
yield research_run.essay
403-
yield f"\n\n[Found {len(research_run.results)} sources]"
404-
yield f"\n[USAGE] {research_run.usage.model_dump_json()}"
405-
else:
406-
yield "[Searching...]\n"
407-
results = await search_searxng(query=request.input, max_results=10)
408-
yield f"[Found {len(results)} results]\n\n"
409-
410-
summary, usage, warnings = await _build_summary(request.input, results)
411-
yield summary or _fallback_text(results)
412-
413-
response_usage = usage or TokenUsage()
414-
response_usage.search_requests = 1
415-
citation_chars = sum(len(r.snippet or "") + len(r.title) for r in results)
416-
response_usage.citation_tokens = citation_chars // 4
417-
yield f"\n[USAGE] {response_usage.model_dump_json()}"
427+
await research_task
428+
except asyncio.CancelledError:
429+
pass
430+
logger.info("Research task cancelled due to client disconnect")
418431

419432
return StreamingResponse(event_generator(), media_type="text/plain")
420433

@@ -544,18 +557,22 @@ async def search(request: SearchRequest) -> SearchResponse:
544557
response_model=None,
545558
dependencies=[Depends(verify_api_key)],
546559
)
547-
async def responses(request: ResponsesRequest) -> ResponsesAPIResponse | StreamingResponse:
560+
async def responses(
561+
request: ResponsesRequest, http_request: Request
562+
) -> ResponsesAPIResponse | StreamingResponse:
548563
"""Perplexity-compatible responses endpoint.
549564
550565
Supports streaming via SSE when streaming=true.
566+
For non-streaming research requests, polls for client disconnect
567+
and cancels the research task if the client goes away.
551568
"""
552569
if request.streaming:
553570
return await _stream_responses(request)
554-
571+
555572
if request.preset in {"deep-research", "shallow-research"}:
556573
settings = get_settings()
557574
preset_config = _research_preset_config(settings, request.preset)
558-
research_run = await deep_research(
575+
research_task = asyncio.create_task(deep_research(
559576
request.input,
560577
stages=request.max_steps or preset_config.stages,
561578
passes=preset_config.passes,
@@ -566,7 +583,26 @@ async def responses(request: ResponsesRequest) -> ResponsesAPIResponse | Streami
566583
content_extraction=preset_config.content_extraction,
567584
pages_per_section=preset_config.pages_per_section,
568585
content_max_chars=preset_config.content_max_chars,
569-
)
586+
))
587+
try:
588+
while not research_task.done():
589+
if await http_request.is_disconnected():
590+
research_task.cancel()
591+
try:
592+
await research_task
593+
except asyncio.CancelledError:
594+
pass
595+
logger.info("Research task cancelled: client disconnected")
596+
return JSONResponse(
597+
status_code=499,
598+
content={"detail": "Client disconnected"},
599+
)
600+
await asyncio.sleep(0.5)
601+
research_run = research_task.result()
602+
except asyncio.CancelledError:
603+
research_task.cancel()
604+
raise
605+
570606
return ResponsesAPIResponse(
571607
id=str(uuid.uuid4()),
572608
created=_created_timestamp(),

0 commit comments

Comments
 (0)