Skip to content

Commit d1794c4

Browse files
committed
chore: apply anthropics#527
1 parent b6701d2 commit d1794c4

File tree

3 files changed

+297
-13
lines changed

3 files changed

+297
-13
lines changed

src/clawd_code_sdk/_internal/client.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Internal client implementation."""
22

3+
import logging
34
from collections.abc import AsyncIterable, AsyncIterator
5+
from contextlib import aclosing
46
from dataclasses import asdict, replace
57
from typing import Any
68

@@ -27,6 +29,8 @@
2729
from .transport import Transport
2830
from .transport.subprocess_cli import SubprocessCLITransport
2931

32+
logger = logging.getLogger(__name__)
33+
3034
# Map error types to exception classes
3135
_ERROR_TYPE_TO_EXCEPTION: dict[str, type[APIError]] = {
3236
"authentication_failed": AuthenticationError,
@@ -190,11 +194,17 @@ async def process_query(
190194
query._tg.start_soon(query.stream_input, prompt)
191195

192196
# Yield parsed messages
193-
async for data in query.receive_messages():
194-
message = parse_message(data)
195-
# Check for API errors and raise appropriate exceptions
196-
_raise_if_api_error(message)
197-
yield message
198-
197+
# Use aclosing() for proper async generator cleanup
198+
async with aclosing(query.receive_messages()) as messages:
199+
async for data in messages:
200+
message = parse_message(data)
201+
# Check for API errors and raise appropriate exceptions
202+
_raise_if_api_error(message)
203+
yield message
204+
205+
except GeneratorExit:
206+
# Handle early termination of the async generator gracefully
207+
# This occurs when the caller breaks out of the async for loop
208+
logger.debug("process_query generator closed early by caller")
199209
finally:
200210
await query.close()

src/clawd_code_sdk/_internal/query.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ def __init__(
125125
float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0
126126
) # Convert ms to seconds
127127

128+
# Cancel scope for the reader task - can be cancelled from any task context
129+
# This fixes the RuntimeError when async generator cleanup happens in a different task
130+
self._reader_cancel_scope: CancelScope | None = None
131+
self._reader_task_started = anyio.Event()
132+
133+
# Track whether we entered the task group in this task
134+
# Used to determine if we can safely call __aexit__()
135+
self._tg_entered_in_current_task = False
136+
128137
async def initialize(self) -> dict[str, Any] | None:
129138
"""Initialize control protocol if in streaming mode.
130139
@@ -172,11 +181,33 @@ async def initialize(self) -> dict[str, Any] | None:
172181
return response
173182

174183
async def start(self) -> None:
175-
"""Start reading messages from transport."""
184+
"""Start reading messages from transport.
185+
186+
This method starts background tasks for reading messages. The task lifecycle
187+
is managed using a CancelScope that can be safely cancelled from any async
188+
task context, avoiding the RuntimeError that occurs when task group
189+
__aexit__() is called from a different task than __aenter__().
190+
"""
176191
if self._tg is None:
192+
# Create a task group for spawning background tasks
177193
self._tg = anyio.create_task_group()
178194
await self._tg.__aenter__()
179-
self._tg.start_soon(self._read_messages)
195+
self._tg_entered_in_current_task = True
196+
197+
# Start the reader with its own cancel scope that can be cancelled safely
198+
self._tg.start_soon(self._read_messages_with_cancel_scope)
199+
200+
async def _read_messages_with_cancel_scope(self) -> None:
201+
"""Wrapper for _read_messages that sets up a cancellable scope.
202+
203+
This wrapper creates a CancelScope that can be cancelled from any task
204+
context, solving the issue where async generator cleanup happens in a
205+
different task than where the task group was entered.
206+
"""
207+
self._reader_cancel_scope = anyio.CancelScope()
208+
self._reader_task_started.set()
209+
with self._reader_cancel_scope:
210+
await self._read_messages()
180211

181212
async def _read_messages(self) -> None:
182213
"""Read messages from transport and route them."""
@@ -667,15 +698,69 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
667698
yield message
668699

669700
async def close(self) -> None:
670-
"""Close the query and transport."""
701+
"""Close the query and transport.
702+
703+
This method safely cleans up resources, handling the case where cleanup
704+
happens in a different async task context than where start() was called.
705+
This commonly occurs during async generator cleanup (e.g., when breaking
706+
out of an `async for` loop or when asyncio.run() shuts down).
707+
708+
The fix uses two mechanisms:
709+
1. A CancelScope for the reader task that can be cancelled from any context
710+
2. Suppressing the RuntimeError that occurs when task group __aexit__()
711+
is called from a different task than __aenter__()
712+
"""
713+
if self._closed:
714+
return
671715
self._closed = True
672-
if self._tg:
716+
717+
# Cancel the reader task via its cancel scope (safe from any task context)
718+
if self._reader_cancel_scope is not None:
719+
self._reader_cancel_scope.cancel()
720+
721+
# Handle task group cleanup
722+
if self._tg is not None:
723+
# Always cancel the task group's scope to stop any running tasks
673724
self._tg.cancel_scope.cancel()
674-
# Wait for task group to complete cancellation
675-
with suppress(anyio.get_cancelled_exc_class()):
676-
await self._tg.__aexit__(None, None, None)
725+
726+
# Try to properly exit the task group, but handle the case where
727+
# we're in a different task context than where __aenter__() was called
728+
try:
729+
with suppress(anyio.get_cancelled_exc_class()):
730+
await self._tg.__aexit__(None, None, None)
731+
except RuntimeError as e:
732+
# Handle "Attempted to exit cancel scope in a different task"
733+
# This happens during async generator cleanup when Python's GC
734+
# runs the finally block in a different task context.
735+
if "different task" in str(e):
736+
logger.debug(
737+
"Task group cleanup skipped due to cross-task context "
738+
"(this is expected during async generator cleanup)"
739+
)
740+
else:
741+
raise
742+
finally:
743+
self._tg = None
744+
self._tg_entered_in_current_task = False
745+
677746
await self.transport.close()
678747

748+
# Make Query an async context manager
749+
async def __aenter__(self) -> "Query":
750+
"""Enter async context - starts reading messages."""
751+
await self.start()
752+
return self
753+
754+
async def __aexit__(
755+
self,
756+
exc_type: type[BaseException] | None,
757+
exc_val: BaseException | None,
758+
exc_tb: Any,
759+
) -> bool:
760+
"""Exit async context - closes the query."""
761+
await self.close()
762+
return False
763+
679764
# Make Query an async iterator
680765
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
681766
"""Return async iterator for messages."""

tests/test_streaming_client.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,3 +833,192 @@ async def mock_receive():
833833
assert isinstance(messages[-1], ResultMessage)
834834

835835
anyio.run(_test)
836+
837+
838+
class TestAsyncGeneratorCleanup:
839+
"""Tests for async generator cleanup behavior (issue #454).
840+
841+
These tests verify that the RuntimeError "Attempted to exit cancel scope
842+
in a different task" does not occur during async generator cleanup.
843+
844+
The key behavior we're testing is that cleanup doesn't raise RuntimeError,
845+
not that specific mock methods are called (which depends on mock setup).
846+
"""
847+
848+
def test_streaming_client_early_disconnect(self):
849+
"""Test ClaudeSDKClient early disconnect doesn't raise RuntimeError.
850+
851+
This is the primary test case from issue #454 - breaking out of an
852+
async for loop should not cause RuntimeError during cleanup.
853+
"""
854+
855+
async def _test():
856+
with patch(
857+
"clawd_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
858+
) as mock_transport_class:
859+
mock_transport = create_mock_transport()
860+
mock_transport_class.return_value = mock_transport
861+
862+
async def mock_receive():
863+
# Send init response
864+
await asyncio.sleep(0.01)
865+
written = mock_transport.write.call_args_list
866+
for call in written:
867+
if call:
868+
data = call[0][0]
869+
try:
870+
msg = json.loads(data.strip())
871+
if (
872+
msg.get("type") == "control_request"
873+
and msg.get("request", {}).get("subtype")
874+
== "initialize"
875+
):
876+
yield {
877+
"type": "control_response",
878+
"response": {
879+
"request_id": msg.get("request_id"),
880+
"subtype": "success",
881+
"commands": [],
882+
},
883+
}
884+
break
885+
except (json.JSONDecodeError, KeyError, AttributeError):
886+
pass
887+
888+
# Yield some messages
889+
for i in range(5):
890+
yield {
891+
"type": "assistant",
892+
"message": {
893+
"role": "assistant",
894+
"content": [{"type": "text", "text": f"Message {i}"}],
895+
"model": "claude-opus-4-1-20250805",
896+
},
897+
}
898+
899+
mock_transport.read_messages = mock_receive
900+
901+
# Connect, get one message, then disconnect early
902+
client = ClaudeSDKClient()
903+
await client.connect()
904+
905+
count = 0
906+
async for msg in client.receive_messages():
907+
count += 1
908+
if count >= 2:
909+
break # Early exit - this should NOT raise RuntimeError
910+
911+
# Early disconnect should not raise RuntimeError
912+
# The key assertion is that we reach this point without exception
913+
await client.disconnect()
914+
915+
assert count == 2
916+
# Transport close is called by disconnect
917+
mock_transport.close.assert_called()
918+
919+
anyio.run(_test)
920+
921+
def test_query_cancel_scope_can_be_cancelled(self):
922+
"""Test that Query's cancel scope can be safely cancelled from any context.
923+
924+
This verifies the fix for issue #454 where the cancel scope mechanism
925+
allows cleanup without RuntimeError.
926+
"""
927+
928+
async def _test():
929+
from clawd_code_sdk._internal.query import Query
930+
from clawd_code_sdk._internal.transport import Transport
931+
932+
# Create a mock transport
933+
mock_transport = AsyncMock(spec=Transport)
934+
mock_transport.connect = AsyncMock()
935+
mock_transport.close = AsyncMock()
936+
mock_transport.write = AsyncMock()
937+
938+
messages_to_yield = [
939+
{"type": "system", "subtype": "init"},
940+
{
941+
"type": "assistant",
942+
"message": {
943+
"content": [{"type": "text", "text": "Hello"}],
944+
"model": "test",
945+
},
946+
},
947+
]
948+
message_index = 0
949+
950+
async def mock_read():
951+
nonlocal message_index
952+
while message_index < len(messages_to_yield):
953+
yield messages_to_yield[message_index]
954+
message_index += 1
955+
await asyncio.sleep(0.01)
956+
957+
mock_transport.read_messages = mock_read
958+
959+
# Create Query
960+
q = Query(
961+
transport=mock_transport,
962+
is_streaming_mode=False,
963+
)
964+
965+
# Start the query
966+
await q.start()
967+
968+
# Give reader time to start
969+
await asyncio.sleep(0.05)
970+
971+
# Cancel scope should exist
972+
assert q._reader_cancel_scope is not None
973+
974+
# Close should work without RuntimeError
975+
# This is the key test - close() used to raise RuntimeError
976+
await q.close()
977+
978+
# Verify closed state
979+
assert q._closed is True
980+
mock_transport.close.assert_called()
981+
982+
anyio.run(_test)
983+
984+
def test_query_as_async_context_manager(self):
985+
"""Test using Query as an async context manager for proper cleanup."""
986+
987+
async def _test():
988+
from clawd_code_sdk._internal.query import Query
989+
from clawd_code_sdk._internal.transport import Transport
990+
991+
mock_transport = AsyncMock(spec=Transport)
992+
mock_transport.connect = AsyncMock()
993+
mock_transport.close = AsyncMock()
994+
mock_transport.write = AsyncMock()
995+
996+
async def mock_read():
997+
yield {"type": "system", "subtype": "init"}
998+
yield {
999+
"type": "assistant",
1000+
"message": {
1001+
"content": [{"type": "text", "text": "Hello"}],
1002+
"model": "test",
1003+
},
1004+
}
1005+
1006+
mock_transport.read_messages = mock_read
1007+
1008+
# Use Query as async context manager
1009+
q = Query(
1010+
transport=mock_transport,
1011+
is_streaming_mode=False,
1012+
)
1013+
1014+
async with q:
1015+
# Query should be started
1016+
assert q._tg is not None
1017+
# Get one message
1018+
msg = await q.__anext__()
1019+
assert msg["type"] == "system"
1020+
1021+
# After context exit, should be closed
1022+
assert q._closed is True
1023+
1024+
anyio.run(_test)

0 commit comments

Comments
 (0)