diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..fb9ed473 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,27 @@ +# Workflow + +```bash +# Lint and style +# Check for issues and fix automatically +python -m ruff check src/ tests/ --fix +python -m ruff format src/ tests/ + +# Typecheck (only done for src/) +python -m mypy src/ + +# Run all tests +python -m pytest tests/ + +# Run specific test file +python -m pytest tests/test_client.py +``` + +# Codebase Structure + +- `src/claude_code_sdk/` - Main package + - `client.py` - ClaudeSDKClient for interactive sessions + - `query.py` - One-shot query function + - `types.py` - Type definitions + - `_internal/` - Internal implementation details + - `transport/subprocess_cli.py` - CLI subprocess management + - `message_parser.py` - Message parsing logic diff --git a/examples/streaming_mode.py b/examples/streaming_mode.py new file mode 100755 index 00000000..73eb4109 --- /dev/null +++ b/examples/streaming_mode.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +""" +Comprehensive examples of using ClaudeSDKClient for streaming mode. + +This file demonstrates various patterns for building applications with +the ClaudeSDKClient streaming interface. + +The queries are intentionally simplistic. In reality, a query can be a more +complex task that Claude SDK uses its agentic capabilities and tools (e.g. run +bash commands, edit files, search the web, fetch web content) to accomplish. + +Usage: +./examples/streaming_mode.py - List the examples +./examples/streaming_mode.py all - Run all examples +./examples/streaming_mode.py basic_streaming - Run a specific example +""" + +import asyncio +import contextlib +import sys + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + ClaudeSDKClient, + CLIConnectionError, + ResultMessage, + SystemMessage, + TextBlock, + UserMessage, +) + + +def display_message(msg): + """Standardized message display function. + + - UserMessage: "User: " + - AssistantMessage: "Claude: " + - SystemMessage: ignored + - ResultMessage: "Result ended" + cost if available + """ + if isinstance(msg, UserMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"User: {block.text}") + elif isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(msg, SystemMessage): + # Ignore system messages + pass + elif isinstance(msg, ResultMessage): + print("Result ended") + + +async def example_basic_streaming(): + """Basic streaming with context manager.""" + print("=== Basic Streaming Example ===") + + async with ClaudeSDKClient() as client: + print("User: What is 2+2?") + await client.query("What is 2+2?") + + # Receive complete response using the helper method + async for msg in client.receive_response(): + display_message(msg) + + print("\n") + + +async def example_multi_turn_conversation(): + """Multi-turn conversation using receive_response helper.""" + print("=== Multi-Turn Conversation Example ===") + + async with ClaudeSDKClient() as client: + # First turn + print("User: What's the capital of France?") + await client.query("What's the capital of France?") + + # Extract and print response + async for msg in client.receive_response(): + display_message(msg) + + # Second turn - follow-up + print("\nUser: What's the population of that city?") + await client.query("What's the population of that city?") + + async for msg in client.receive_response(): + display_message(msg) + + print("\n") + + +async def example_concurrent_responses(): + """Handle responses while sending new messages.""" + print("=== Concurrent Send/Receive Example ===") + + async with ClaudeSDKClient() as client: + # Background task to continuously receive messages + async def receive_messages(): + async for message in client.receive_messages(): + display_message(message) + + # Start receiving in background + receive_task = asyncio.create_task(receive_messages()) + + # Send multiple messages with delays + questions = [ + "What is 2 + 2?", + "What is the square root of 144?", + "What is 10% of 80?", + ] + + for question in questions: + print(f"\nUser: {question}") + await client.query(question) + await asyncio.sleep(3) # Wait between messages + + # Give time for final responses + await asyncio.sleep(2) + + # Clean up + receive_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await receive_task + + print("\n") + + +async def example_with_interrupt(): + """Demonstrate interrupt capability.""" + print("=== Interrupt Example ===") + print("IMPORTANT: Interrupts require active message consumption.") + + async with ClaudeSDKClient() as client: + # Start a long-running task + print("\nUser: Count from 1 to 100 slowly") + await client.query( + "Count from 1 to 100 slowly, with a brief pause between each number" + ) + + # Create a background task to consume messages + messages_received = [] + interrupt_sent = False + + async def consume_messages(): + """Consume messages in the background to enable interrupt processing.""" + async for message in client.receive_messages(): + messages_received.append(message) + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + # Print first few numbers + print(f"Claude: {block.text[:50]}...") + elif isinstance(message, ResultMessage): + display_message(message) + if interrupt_sent: + break + + # Start consuming messages in the background + consume_task = asyncio.create_task(consume_messages()) + + # Wait 2 seconds then send interrupt + await asyncio.sleep(2) + print("\n[After 2 seconds, sending interrupt...]") + interrupt_sent = True + await client.interrupt() + + # Wait for the consume task to finish processing the interrupt + await consume_task + + # Send new instruction after interrupt + print("\nUser: Never mind, just tell me a quick joke") + await client.query("Never mind, just tell me a quick joke") + + # Get the joke + async for msg in client.receive_response(): + display_message(msg) + + print("\n") + + +async def example_manual_message_handling(): + """Manually handle message stream for custom logic.""" + print("=== Manual Message Handling Example ===") + + async with ClaudeSDKClient() as client: + await client.query( + "List 5 programming languages and their main use cases" + ) + + # Manually process messages with custom logic + languages_found = [] + + async for message in client.receive_messages(): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + text = block.text + print(f"Claude: {text}") + # Custom logic: extract language names + for lang in [ + "Python", + "JavaScript", + "Java", + "C++", + "Go", + "Rust", + "Ruby", + ]: + if lang in text and lang not in languages_found: + languages_found.append(lang) + print(f"Found language: {lang}") + elif isinstance(message, ResultMessage): + display_message(message) + print(f"Total languages mentioned: {len(languages_found)}") + break + + print("\n") + + +async def example_with_options(): + """Use ClaudeCodeOptions to configure the client.""" + print("=== Custom Options Example ===") + + # Configure options + options = ClaudeCodeOptions( + allowed_tools=["Read", "Write"], # Allow file operations + max_thinking_tokens=10000, + system_prompt="You are a helpful coding assistant.", + ) + + async with ClaudeSDKClient(options=options) as client: + print("User: Create a simple hello.txt file with a greeting message") + await client.query( + "Create a simple hello.txt file with a greeting message" + ) + + tool_uses = [] + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + display_message(msg) + for block in msg.content: + if hasattr(block, "name") and not isinstance( + block, TextBlock + ): # ToolUseBlock + tool_uses.append(getattr(block, "name", "")) + else: + display_message(msg) + + if tool_uses: + print(f"Tools used: {', '.join(tool_uses)}") + + print("\n") + + +async def example_async_iterable_prompt(): + """Demonstrate send_message with async iterable.""" + print("=== Async Iterable Prompt Example ===") + + async def create_message_stream(): + """Generate a stream of messages.""" + print("User: Hello! I have multiple questions.") + yield { + "type": "user", + "message": {"role": "user", "content": "Hello! I have multiple questions."}, + "parent_tool_use_id": None, + "session_id": "qa-session", + } + + print("User: First, what's the capital of Japan?") + yield { + "type": "user", + "message": { + "role": "user", + "content": "First, what's the capital of Japan?", + }, + "parent_tool_use_id": None, + "session_id": "qa-session", + } + + print("User: Second, what's 15% of 200?") + yield { + "type": "user", + "message": {"role": "user", "content": "Second, what's 15% of 200?"}, + "parent_tool_use_id": None, + "session_id": "qa-session", + } + + async with ClaudeSDKClient() as client: + # Send async iterable of messages + await client.query(create_message_stream()) + + # Receive the three responses + async for msg in client.receive_response(): + display_message(msg) + async for msg in client.receive_response(): + display_message(msg) + async for msg in client.receive_response(): + display_message(msg) + + print("\n") + + +async def example_error_handling(): + """Demonstrate proper error handling.""" + print("=== Error Handling Example ===") + + client = ClaudeSDKClient() + + try: + await client.connect() + + # Send a message that will take time to process + print("User: Run a bash sleep command for 60 seconds") + await client.query("Run a bash sleep command for 60 seconds") + + # Try to receive response with a short timeout + try: + messages = [] + async with asyncio.timeout(10.0): + async for msg in client.receive_response(): + messages.append(msg) + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text[:50]}...") + elif isinstance(msg, ResultMessage): + display_message(msg) + break + + except asyncio.TimeoutError: + print( + "\nResponse timeout after 10 seconds - demonstrating graceful handling" + ) + print(f"Received {len(messages)} messages before timeout") + + except CLIConnectionError as e: + print(f"Connection error: {e}") + + except Exception as e: + print(f"Unexpected error: {e}") + + finally: + # Always disconnect + await client.disconnect() + + print("\n") + + +async def main(): + """Run all examples or a specific example based on command line argument.""" + examples = { + "basic_streaming": example_basic_streaming, + "multi_turn_conversation": example_multi_turn_conversation, + "concurrent_responses": example_concurrent_responses, + "with_interrupt": example_with_interrupt, + "manual_message_handling": example_manual_message_handling, + "with_options": example_with_options, + "async_iterable_prompt": example_async_iterable_prompt, + "error_handling": example_error_handling, + } + + if len(sys.argv) < 2: + # List available examples + print("Usage: python streaming_mode.py ") + print("\nAvailable examples:") + print(" all - Run all examples") + for name in examples: + print(f" {name}") + sys.exit(0) + + example_name = sys.argv[1] + + if example_name == "all": + # Run all examples + for example in examples.values(): + await example() + print("-" * 50 + "\n") + elif example_name in examples: + # Run specific example + await examples[example_name]() + else: + print(f"Error: Unknown example '{example_name}'") + print("\nAvailable examples:") + print(" all - Run all examples") + for name in examples: + print(f" {name}") + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/streaming_mode_ipython.py b/examples/streaming_mode_ipython.py new file mode 100644 index 00000000..7265afa4 --- /dev/null +++ b/examples/streaming_mode_ipython.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" +IPython-friendly code snippets for ClaudeSDKClient streaming mode. + +These examples are designed to be copy-pasted directly into IPython. +Each example is self-contained and can be run independently. + +The queries are intentionally simplistic. In reality, a query can be a more +complex task that Claude SDK uses its agentic capabilities and tools (e.g. run +bash commands, edit files, search the web, fetch web content) to accomplish. +""" + +# ============================================================================ +# BASIC STREAMING +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage + +async with ClaudeSDKClient() as client: + print("User: What is 2+2?") + await client.query("What is 2+2?") + + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + +# ============================================================================ +# STREAMING WITH REAL-TIME DISPLAY +# ============================================================================ + +import asyncio +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +async with ClaudeSDKClient() as client: + async def send_and_receive(prompt): + print(f"User: {prompt}") + await client.query(prompt) + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + await send_and_receive("Tell me a short joke") + print("\n---\n") + await send_and_receive("Now tell me a fun fact") + + +# ============================================================================ +# PERSISTENT CLIENT FOR MULTIPLE QUESTIONS +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +# Create client +client = ClaudeSDKClient() +await client.connect() + + +# Helper to get response +async def get_response(): + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + +# Use it multiple times +print("User: What's 2+2?") +await client.query("What's 2+2?") +await get_response() + +print("User: What's 10*10?") +await client.query("What's 10*10?") +await get_response() + +# Don't forget to disconnect when done +await client.disconnect() + + +# ============================================================================ +# WITH INTERRUPT CAPABILITY +# ============================================================================ +# IMPORTANT: Interrupts require active message consumption. You must be +# consuming messages from the client for the interrupt to be processed. + +import asyncio +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage + +async with ClaudeSDKClient() as client: + print("\n--- Sending initial message ---\n") + + # Send a long-running task + print("User: Count from 1 to 100, run bash sleep for 1 second in between") + await client.query("Count from 1 to 100, run bash sleep for 1 second in between") + + # Create a background task to consume messages + messages_received = [] + interrupt_sent = False + + async def consume_messages(): + async for msg in client.receive_messages(): + messages_received.append(msg) + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Check if we got a result after interrupt + if isinstance(msg, ResultMessage) and interrupt_sent: + break + + # Start consuming messages in the background + consume_task = asyncio.create_task(consume_messages()) + + # Wait a bit then send interrupt + await asyncio.sleep(10) + print("\n--- Sending interrupt ---\n") + interrupt_sent = True + await client.interrupt() + + # Wait for the consume task to finish + await consume_task + + # Send a new message after interrupt + print("\n--- After interrupt, sending new message ---\n") + await client.query("Just say 'Hello! I was interrupted.'") + + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + +# ============================================================================ +# ERROR HANDLING PATTERN +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +try: + async with ClaudeSDKClient() as client: + print("User: Run a bash sleep command for 60 seconds") + await client.query("Run a bash sleep command for 60 seconds") + + # Timeout after 20 seconds + messages = [] + async with asyncio.timeout(20.0): + async for msg in client.receive_response(): + messages.append(msg) + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + +except asyncio.TimeoutError: + print("Request timed out after 20 seconds") +except Exception as e: + print(f"Error: {e}") + + +# ============================================================================ +# SENDING ASYNC ITERABLE OF MESSAGES +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +async def message_generator(): + """Generate multiple messages as an async iterable.""" + print("User: I have two math questions.") + yield { + "type": "user", + "message": {"role": "user", "content": "I have two math questions."}, + "parent_tool_use_id": None, + "session_id": "math-session" + } + print("User: What is 25 * 4?") + yield { + "type": "user", + "message": {"role": "user", "content": "What is 25 * 4?"}, + "parent_tool_use_id": None, + "session_id": "math-session" + } + print("User: What is 100 / 5?") + yield { + "type": "user", + "message": {"role": "user", "content": "What is 100 / 5?"}, + "parent_tool_use_id": None, + "session_id": "math-session" + } + +async with ClaudeSDKClient() as client: + # Send async iterable instead of string + await client.query(message_generator()) + + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + +# ============================================================================ +# COLLECTING ALL MESSAGES INTO A LIST +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage + +async with ClaudeSDKClient() as client: + print("User: What are the primary colors?") + await client.query("What are the primary colors?") + + # Collect all messages into a list + messages = [msg async for msg in client.receive_response()] + + # Process them afterwards + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(msg, ResultMessage): + print(f"Total messages: {len(messages)}") diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index b8a11525..1439937a 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -1,8 +1,5 @@ """Claude SDK for Python.""" -import os -from collections.abc import AsyncIterator - from ._errors import ( ClaudeSDKError, CLIConnectionError, @@ -10,7 +7,8 @@ CLINotFoundError, ProcessError, ) -from ._internal.client import InternalClient +from .client import ClaudeSDKClient +from .query import query from .types import ( AssistantMessage, ClaudeCodeOptions, @@ -29,8 +27,9 @@ __version__ = "0.0.14" __all__ = [ - # Main function + # Main exports "query", + "ClaudeSDKClient", # Types "PermissionMode", "McpServerConfig", @@ -51,52 +50,3 @@ "ProcessError", "CLIJSONDecodeError", ] - - -async def query( - *, prompt: str, options: ClaudeCodeOptions | None = None -) -> AsyncIterator[Message]: - """ - Query Claude Code. - - Python SDK for interacting with Claude Code. - - Args: - prompt: The prompt to send to Claude - options: Optional configuration (defaults to ClaudeCodeOptions() if None). - Set options.permission_mode to control tool execution: - - 'default': CLI prompts for dangerous tools - - 'acceptEdits': Auto-accept file edits - - 'bypassPermissions': Allow all tools (use with caution) - Set options.cwd for working directory. - - Yields: - Messages from the conversation - - - Example: - ```python - # Simple usage - async for message in query(prompt="Hello"): - print(message) - - # With options - async for message in query( - prompt="Hello", - options=ClaudeCodeOptions( - system_prompt="You are helpful", - cwd="/home/user" - ) - ): - print(message) - ``` - """ - if options is None: - options = ClaudeCodeOptions() - - os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py" - - client = InternalClient() - - async for message in client.process_query(prompt=prompt, options=options): - yield message diff --git a/src/claude_code_sdk/_errors.py b/src/claude_code_sdk/_errors.py index e8327577..c86bf235 100644 --- a/src/claude_code_sdk/_errors.py +++ b/src/claude_code_sdk/_errors.py @@ -1,5 +1,7 @@ """Error types for Claude SDK.""" +from typing import Any + class ClaudeSDKError(Exception): """Base exception for all Claude SDK errors.""" @@ -44,3 +46,11 @@ def __init__(self, line: str, original_error: Exception): self.line = line self.original_error = original_error super().__init__(f"Failed to decode JSON: {line[:100]}...") + + +class MessageParseError(ClaudeSDKError): + """Raised when unable to parse a message from CLI output.""" + + def __init__(self, message: str, data: dict[str, Any] | None = None): + self.data = data + super().__init__(message) diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index ef1070d0..715dab5e 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -1,20 +1,10 @@ """Internal client implementation.""" -from collections.abc import AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator from typing import Any -from ..types import ( - AssistantMessage, - ClaudeCodeOptions, - ContentBlock, - Message, - ResultMessage, - SystemMessage, - TextBlock, - ToolResultBlock, - ToolUseBlock, - UserMessage, -) +from ..types import ClaudeCodeOptions, Message +from .message_parser import parse_message from .transport.subprocess_cli import SubprocessCLITransport @@ -25,73 +15,19 @@ def __init__(self) -> None: """Initialize the internal client.""" async def process_query( - self, prompt: str, options: ClaudeCodeOptions + self, prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions ) -> AsyncIterator[Message]: """Process a query through transport.""" - transport = SubprocessCLITransport(prompt=prompt, options=options) + transport = SubprocessCLITransport( + prompt=prompt, options=options, close_stdin_after_prompt=True + ) try: await transport.connect() async for data in transport.receive_messages(): - message = self._parse_message(data) - if message: - yield message + yield parse_message(data) finally: await transport.disconnect() - - def _parse_message(self, data: dict[str, Any]) -> Message | None: - """Parse message from CLI output, trusting the structure.""" - - match data["type"]: - case "user": - return UserMessage(content=data["message"]["content"]) - - case "assistant": - content_blocks: list[ContentBlock] = [] - for block in data["message"]["content"]: - match block["type"]: - case "text": - content_blocks.append(TextBlock(text=block["text"])) - case "tool_use": - content_blocks.append( - ToolUseBlock( - id=block["id"], - name=block["name"], - input=block["input"], - ) - ) - case "tool_result": - content_blocks.append( - ToolResultBlock( - tool_use_id=block["tool_use_id"], - content=block.get("content"), - is_error=block.get("is_error"), - ) - ) - - return AssistantMessage(content=content_blocks) - - case "system": - return SystemMessage( - subtype=data["subtype"], - data=data, - ) - - case "result": - return ResultMessage( - subtype=data["subtype"], - duration_ms=data["duration_ms"], - duration_api_ms=data["duration_api_ms"], - is_error=data["is_error"], - num_turns=data["num_turns"], - session_id=data["session_id"], - total_cost_usd=data.get("total_cost_usd"), - usage=data.get("usage"), - result=data.get("result"), - ) - - case _: - return None diff --git a/src/claude_code_sdk/_internal/message_parser.py b/src/claude_code_sdk/_internal/message_parser.py new file mode 100644 index 00000000..858e24fa --- /dev/null +++ b/src/claude_code_sdk/_internal/message_parser.py @@ -0,0 +1,114 @@ +"""Message parser for Claude Code SDK responses.""" + +import logging +from typing import Any + +from .._errors import MessageParseError +from ..types import ( + AssistantMessage, + ContentBlock, + Message, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +logger = logging.getLogger(__name__) + + +def parse_message(data: dict[str, Any]) -> Message: + """ + Parse message from CLI output into typed Message objects. + + Args: + data: Raw message dictionary from CLI output + + Returns: + Parsed Message object + + Raises: + MessageParseError: If parsing fails or message type is unrecognized + """ + if not isinstance(data, dict): + raise MessageParseError( + f"Invalid message data type (expected dict, got {type(data).__name__})", + data, + ) + + message_type = data.get("type") + if not message_type: + raise MessageParseError("Message missing 'type' field", data) + + match message_type: + case "user": + try: + return UserMessage(content=data["message"]["content"]) + except KeyError as e: + raise MessageParseError( + f"Missing required field in user message: {e}", data + ) from e + + case "assistant": + try: + content_blocks: list[ContentBlock] = [] + for block in data["message"]["content"]: + match block["type"]: + case "text": + content_blocks.append(TextBlock(text=block["text"])) + case "tool_use": + content_blocks.append( + ToolUseBlock( + id=block["id"], + name=block["name"], + input=block["input"], + ) + ) + case "tool_result": + content_blocks.append( + ToolResultBlock( + tool_use_id=block["tool_use_id"], + content=block.get("content"), + is_error=block.get("is_error"), + ) + ) + + return AssistantMessage(content=content_blocks) + except KeyError as e: + raise MessageParseError( + f"Missing required field in assistant message: {e}", data + ) from e + + case "system": + try: + return SystemMessage( + subtype=data["subtype"], + data=data, + ) + except KeyError as e: + raise MessageParseError( + f"Missing required field in system message: {e}", data + ) from e + + case "result": + try: + return ResultMessage( + subtype=data["subtype"], + duration_ms=data["duration_ms"], + duration_api_ms=data["duration_api_ms"], + is_error=data["is_error"], + num_turns=data["num_turns"], + session_id=data["session_id"], + total_cost_usd=data.get("total_cost_usd"), + usage=data.get("usage"), + result=data.get("result"), + ) + except KeyError as e: + raise MessageParseError( + f"Missing required field in result message: {e}", data + ) from e + + case _: + raise MessageParseError(f"Unknown message type: {message_type}", data) diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index c283f425..34b70342 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -4,14 +4,14 @@ import logging import os import shutil -from collections.abc import AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator from pathlib import Path from subprocess import PIPE from typing import Any import anyio from anyio.abc import Process -from anyio.streams.text import TextReceiveStream +from anyio.streams.text import TextReceiveStream, TextSendStream from ..._errors import CLIConnectionError, CLINotFoundError, ProcessError from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError @@ -28,17 +28,23 @@ class SubprocessCLITransport(Transport): def __init__( self, - prompt: str, + prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions, cli_path: str | Path | None = None, + close_stdin_after_prompt: bool = False, ): self._prompt = prompt + self._is_streaming = not isinstance(prompt, str) self._options = options self._cli_path = str(cli_path) if cli_path else self._find_cli() self._cwd = str(options.cwd) if options.cwd else None self._process: Process | None = None self._stdout_stream: TextReceiveStream | None = None self._stderr_stream: TextReceiveStream | None = None + self._stdin_stream: TextSendStream | None = None + self._pending_control_responses: dict[str, dict[str, Any]] = {} + self._request_counter = 0 + self._close_stdin_after_prompt = close_stdin_after_prompt def _find_cli(self) -> str: """Find Claude Code CLI binary.""" @@ -116,7 +122,14 @@ def _build_command(self) -> list[str]: ["--mcp-config", json.dumps({"mcpServers": self._options.mcp_servers})] ) - cmd.extend(["--print", self._prompt]) + # Add prompt handling based on mode + if self._is_streaming: + # Streaming mode: use --input-format stream-json + cmd.extend(["--input-format", "stream-json"]) + else: + # String mode: use --print with the prompt + cmd.extend(["--print", str(self._prompt)]) + return cmd async def connect(self) -> None: @@ -126,9 +139,10 @@ async def connect(self) -> None: cmd = self._build_command() try: + # Enable stdin pipe for both modes (but we'll close it for string mode) self._process = await anyio.open_process( cmd, - stdin=None, + stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=self._cwd, @@ -140,6 +154,20 @@ async def connect(self) -> None: if self._process.stderr: self._stderr_stream = TextReceiveStream(self._process.stderr) + # Handle stdin based on mode + if self._is_streaming: + # Streaming mode: keep stdin open and start streaming task + if self._process.stdin: + self._stdin_stream = TextSendStream(self._process.stdin) + # Start streaming messages to stdin in background + import asyncio + + asyncio.create_task(self._stream_to_stdin()) + else: + # String mode: close stdin immediately (backward compatible) + if self._process.stdin: + await self._process.stdin.aclose() + except FileNotFoundError as e: # Check if the error comes from the working directory or the CLI if self._cwd and not Path(self._cwd).exists(): @@ -169,9 +197,50 @@ async def disconnect(self) -> None: self._process = None self._stdout_stream = None self._stderr_stream = None + self._stdin_stream = None async def send_request(self, messages: list[Any], options: dict[str, Any]) -> None: - """Not used for CLI transport - args passed via command line.""" + """Send additional messages in streaming mode.""" + if not self._is_streaming: + raise CLIConnectionError("send_request only works in streaming mode") + + if not self._stdin_stream: + raise CLIConnectionError("stdin not available - stream may have ended") + + # Send each message as a user message + for message in messages: + # Ensure message has required structure + if not isinstance(message, dict): + message = { + "type": "user", + "message": {"role": "user", "content": str(message)}, + "parent_tool_use_id": None, + "session_id": options.get("session_id", "default"), + } + + await self._stdin_stream.send(json.dumps(message) + "\n") + + async def _stream_to_stdin(self) -> None: + """Stream messages to stdin for streaming mode.""" + if not self._stdin_stream or not isinstance(self._prompt, AsyncIterable): + return + + try: + async for message in self._prompt: + if not self._stdin_stream: + break + await self._stdin_stream.send(json.dumps(message) + "\n") + + # Close stdin after prompt if requested (e.g., for query() one-shot mode) + if self._close_stdin_after_prompt and self._stdin_stream: + await self._stdin_stream.aclose() + self._stdin_stream = None + # Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode) + except Exception as e: + logger.debug(f"Error streaming to stdin: {e}") + if self._stdin_stream: + await self._stdin_stream.aclose() + self._stdin_stream = None async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: """Receive messages from CLI.""" @@ -213,11 +282,24 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: try: data = json.loads(json_buffer) json_buffer = "" + + # Handle control responses separately + if data.get("type") == "control_response": + response = data.get("response", {}) + request_id = response.get("request_id") + if request_id: + # Store the response for the pending request + self._pending_control_responses[request_id] = response + continue + try: yield data except GeneratorExit: return except json.JSONDecodeError: + # We are speculatively decoding the buffer until we get + # a full JSON object. If there is an actual issue, we + # raise an error after _MAX_BUFFER_SIZE. continue except anyio.ClosedResourceError: @@ -280,3 +362,45 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: def is_connected(self) -> bool: """Check if subprocess is running.""" return self._process is not None and self._process.returncode is None + + async def interrupt(self) -> None: + """Send interrupt control request (only works in streaming mode).""" + if not self._is_streaming: + raise CLIConnectionError( + "Interrupt requires streaming mode (AsyncIterable prompt)" + ) + + if not self._stdin_stream: + raise CLIConnectionError("Not connected or stdin not available") + + await self._send_control_request({"subtype": "interrupt"}) + + async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Send a control request and wait for response.""" + if not self._stdin_stream: + raise CLIConnectionError("Stdin not available") + + # Generate unique request ID + self._request_counter += 1 + request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}" + + # Build control request + control_request = { + "type": "control_request", + "request_id": request_id, + "request": request, + } + + # Send request + await self._stdin_stream.send(json.dumps(control_request) + "\n") + + # Wait for response + while request_id not in self._pending_control_responses: + await anyio.sleep(0.1) + + response = self._pending_control_responses.pop(request_id) + + if response.get("subtype") == "error": + raise CLIConnectionError(f"Control request failed: {response.get('error')}") + + return response diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py new file mode 100644 index 00000000..8e86ba7c --- /dev/null +++ b/src/claude_code_sdk/client.py @@ -0,0 +1,226 @@ +"""Claude SDK Client for interacting with Claude Code.""" + +import os +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any + +from ._errors import CLIConnectionError +from .types import ClaudeCodeOptions, Message, ResultMessage + + +class ClaudeSDKClient: + """ + Client for bidirectional, interactive conversations with Claude Code. + + This client provides full control over the conversation flow with support + for streaming, interrupts, and dynamic message sending. For simple one-shot + queries, consider using the query() function instead. + + Key features: + - **Bidirectional**: Send and receive messages at any time + - **Stateful**: Maintains conversation context across messages + - **Interactive**: Send follow-ups based on responses + - **Control flow**: Support for interrupts and session management + + When to use ClaudeSDKClient: + - Building chat interfaces or conversational UIs + - Interactive debugging or exploration sessions + - Multi-turn conversations with context + - When you need to react to Claude's responses + - Real-time applications with user input + - When you need interrupt capabilities + + When to use query() instead: + - Simple one-off questions + - Batch processing of prompts + - Fire-and-forget automation scripts + - When all inputs are known upfront + - Stateless operations + + Example - Interactive conversation: + ```python + # Automatically connects with empty stream for interactive use + async with ClaudeSDKClient() as client: + # Send initial message + await client.query("Let's solve a math problem step by step") + + # Receive and process response + async for message in client.receive_messages(): + if "ready" in str(message.content).lower(): + break + + # Send follow-up based on response + await client.query("What's 15% of 80?") + + # Continue conversation... + # Automatically disconnects + ``` + + Example - With interrupt: + ```python + async with ClaudeSDKClient() as client: + # Start a long task + await client.query("Count to 1000") + + # Interrupt after 2 seconds + await asyncio.sleep(2) + await client.interrupt() + + # Send new instruction + await client.query("Never mind, what's 2+2?") + ``` + + Example - Manual connection: + ```python + client = ClaudeSDKClient() + + # Connect with initial message stream + async def message_stream(): + yield {"type": "user", "message": {"role": "user", "content": "Hello"}} + + await client.connect(message_stream()) + + # Send additional messages dynamically + await client.query("What's the weather?") + + async for message in client.receive_messages(): + print(message) + + await client.disconnect() + ``` + """ + + def __init__(self, options: ClaudeCodeOptions | None = None): + """Initialize Claude SDK client.""" + if options is None: + options = ClaudeCodeOptions() + self.options = options + self._transport: Any | None = None + os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client" + + async def connect( + self, prompt: str | AsyncIterable[dict[str, Any]] | None = None + ) -> None: + """Connect to Claude with a prompt or message stream.""" + from ._internal.transport.subprocess_cli import SubprocessCLITransport + + # Auto-connect with empty async iterable if no prompt is provided + async def _empty_stream() -> AsyncIterator[dict[str, Any]]: + # Never yields, but indicates that this function is an iterator and + # keeps the connection open. + # This yield is never reached but makes this an async generator + return + yield {} # type: ignore[unreachable] + + self._transport = SubprocessCLITransport( + prompt=_empty_stream() if prompt is None else prompt, + options=self.options, + ) + await self._transport.connect() + + async def receive_messages(self) -> AsyncIterator[Message]: + """Receive all messages from Claude.""" + if not self._transport: + raise CLIConnectionError("Not connected. Call connect() first.") + + from ._internal.message_parser import parse_message + + async for data in self._transport.receive_messages(): + yield parse_message(data) + + async def query( + self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default" + ) -> None: + """ + Send a new request in streaming mode. + + Args: + prompt: Either a string message or an async iterable of message dictionaries + session_id: Session identifier for the conversation + """ + if not self._transport: + raise CLIConnectionError("Not connected. Call connect() first.") + + # Handle string prompts + if isinstance(prompt, str): + message = { + "type": "user", + "message": {"role": "user", "content": prompt}, + "parent_tool_use_id": None, + "session_id": session_id, + } + await self._transport.send_request([message], {"session_id": session_id}) + else: + # Handle AsyncIterable prompts + messages = [] + async for msg in prompt: + # Ensure session_id is set on each message + if "session_id" not in msg: + msg["session_id"] = session_id + messages.append(msg) + + if messages: + await self._transport.send_request(messages, {"session_id": session_id}) + + async def interrupt(self) -> None: + """Send interrupt signal (only works with streaming mode).""" + if not self._transport: + raise CLIConnectionError("Not connected. Call connect() first.") + await self._transport.interrupt() + + async def receive_response(self) -> AsyncIterator[Message]: + """ + Receive messages from Claude until and including a ResultMessage. + + This async iterator yields all messages in sequence and automatically terminates + after yielding a ResultMessage (which indicates the response is complete). + It's a convenience method over receive_messages() for single-response workflows. + + **Stopping Behavior:** + - Yields each message as it's received + - Terminates immediately after yielding a ResultMessage + - The ResultMessage IS included in the yielded messages + - If no ResultMessage is received, the iterator continues indefinitely + + Yields: + Message: Each message received (UserMessage, AssistantMessage, SystemMessage, ResultMessage) + + Example: + ```python + async with ClaudeSDKClient() as client: + await client.query("What's the capital of France?") + + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(msg, ResultMessage): + print(f"Cost: ${msg.total_cost_usd:.4f}") + # Iterator will terminate after this message + ``` + + Note: + To collect all messages: `messages = [msg async for msg in client.receive_response()]` + The final message in the list will always be a ResultMessage. + """ + async for message in self.receive_messages(): + yield message + if isinstance(message, ResultMessage): + return + + async def disconnect(self) -> None: + """Disconnect from Claude.""" + if self._transport: + await self._transport.disconnect() + self._transport = None + + async def __aenter__(self) -> "ClaudeSDKClient": + """Enter async context - automatically connects with empty stream for interactive use.""" + await self.connect() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + """Exit async context - always disconnects.""" + await self.disconnect() + return False diff --git a/src/claude_code_sdk/query.py b/src/claude_code_sdk/query.py new file mode 100644 index 00000000..ad77a1b1 --- /dev/null +++ b/src/claude_code_sdk/query.py @@ -0,0 +1,102 @@ +"""Query function for one-shot interactions with Claude Code.""" + +import os +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any + +from ._internal.client import InternalClient +from .types import ClaudeCodeOptions, Message + + +async def query( + *, + prompt: str | AsyncIterable[dict[str, Any]], + options: ClaudeCodeOptions | None = None, +) -> AsyncIterator[Message]: + """ + Query Claude Code for one-shot or unidirectional streaming interactions. + + This function is ideal for simple, stateless queries where you don't need + bidirectional communication or conversation management. For interactive, + stateful conversations, use ClaudeSDKClient instead. + + Key differences from ClaudeSDKClient: + - **Unidirectional**: Send all messages upfront, receive all responses + - **Stateless**: Each query is independent, no conversation state + - **Simple**: Fire-and-forget style, no connection management + - **No interrupts**: Cannot interrupt or send follow-up messages + + When to use query(): + - Simple one-off questions ("What is 2+2?") + - Batch processing of independent prompts + - Code generation or analysis tasks + - Automated scripts and CI/CD pipelines + - When you know all inputs upfront + + When to use ClaudeSDKClient: + - Interactive conversations with follow-ups + - Chat applications or REPL-like interfaces + - When you need to send messages based on responses + - When you need interrupt capabilities + - Long-running sessions with state + + Args: + prompt: The prompt to send to Claude. Can be a string for single-shot queries + or an AsyncIterable[dict] for streaming mode with continuous interaction. + In streaming mode, each dict should have the structure: + { + "type": "user", + "message": {"role": "user", "content": "..."}, + "parent_tool_use_id": None, + "session_id": "..." + } + options: Optional configuration (defaults to ClaudeCodeOptions() if None). + Set options.permission_mode to control tool execution: + - 'default': CLI prompts for dangerous tools + - 'acceptEdits': Auto-accept file edits + - 'bypassPermissions': Allow all tools (use with caution) + Set options.cwd for working directory. + + Yields: + Messages from the conversation + + Example - Simple query: + ```python + # One-off question + async for message in query(prompt="What is the capital of France?"): + print(message) + ``` + + Example - With options: + ```python + # Code generation with specific settings + async for message in query( + prompt="Create a Python web server", + options=ClaudeCodeOptions( + system_prompt="You are an expert Python developer", + cwd="/home/user/project" + ) + ): + print(message) + ``` + + Example - Streaming mode (still unidirectional): + ```python + async def prompts(): + yield {"type": "user", "message": {"role": "user", "content": "Hello"}} + yield {"type": "user", "message": {"role": "user", "content": "How are you?"}} + + # All prompts are sent, then all responses received + async for message in query(prompt=prompts()): + print(message) + ``` + """ + if options is None: + options = ClaudeCodeOptions() + + os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py" + + client = InternalClient() + + async for message in client.process_query(prompt=prompt, options=options): + yield message diff --git a/tests/test_message_parser.py b/tests/test_message_parser.py new file mode 100644 index 00000000..0eb43542 --- /dev/null +++ b/tests/test_message_parser.py @@ -0,0 +1,121 @@ +"""Tests for message parser error handling.""" + +import pytest + +from claude_code_sdk._errors import MessageParseError +from claude_code_sdk._internal.message_parser import parse_message +from claude_code_sdk.types import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolUseBlock, + UserMessage, +) + + +class TestMessageParser: + """Test message parsing with the new exception behavior.""" + + def test_parse_valid_user_message(self): + """Test parsing a valid user message.""" + data = { + "type": "user", + "message": {"content": [{"type": "text", "text": "Hello"}]}, + } + message = parse_message(data) + assert isinstance(message, UserMessage) + + def test_parse_valid_assistant_message(self): + """Test parsing a valid assistant message.""" + data = { + "type": "assistant", + "message": { + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "tool_use", + "id": "tool_123", + "name": "Read", + "input": {"file_path": "/test.txt"}, + }, + ] + }, + } + message = parse_message(data) + assert isinstance(message, AssistantMessage) + assert len(message.content) == 2 + assert isinstance(message.content[0], TextBlock) + assert isinstance(message.content[1], ToolUseBlock) + + def test_parse_valid_system_message(self): + """Test parsing a valid system message.""" + data = {"type": "system", "subtype": "start"} + message = parse_message(data) + assert isinstance(message, SystemMessage) + assert message.subtype == "start" + + def test_parse_valid_result_message(self): + """Test parsing a valid result message.""" + data = { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 500, + "is_error": False, + "num_turns": 2, + "session_id": "session_123", + } + message = parse_message(data) + assert isinstance(message, ResultMessage) + assert message.subtype == "success" + + def test_parse_invalid_data_type(self): + """Test that non-dict data raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message("not a dict") # type: ignore + assert "Invalid message data type" in str(exc_info.value) + assert "expected dict, got str" in str(exc_info.value) + + def test_parse_missing_type_field(self): + """Test that missing 'type' field raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"message": {"content": []}}) + assert "Message missing 'type' field" in str(exc_info.value) + + def test_parse_unknown_message_type(self): + """Test that unknown message type raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "unknown_type"}) + assert "Unknown message type: unknown_type" in str(exc_info.value) + + def test_parse_user_message_missing_fields(self): + """Test that user message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "user"}) + assert "Missing required field in user message" in str(exc_info.value) + + def test_parse_assistant_message_missing_fields(self): + """Test that assistant message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "assistant"}) + assert "Missing required field in assistant message" in str(exc_info.value) + + def test_parse_system_message_missing_fields(self): + """Test that system message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "system"}) + assert "Missing required field in system message" in str(exc_info.value) + + def test_parse_result_message_missing_fields(self): + """Test that result message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "result", "subtype": "success"}) + assert "Missing required field in result message" in str(exc_info.value) + + def test_message_parse_error_contains_data(self): + """Test that MessageParseError contains the original data.""" + data = {"type": "unknown", "some": "data"} + with pytest.raises(MessageParseError) as exc_info: + parse_message(data) + assert exc_info.value.data == data diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py new file mode 100644 index 00000000..884d7c4e --- /dev/null +++ b/tests/test_streaming_client.py @@ -0,0 +1,567 @@ +"""Tests for ClaudeSDKClient streaming functionality and query() with async iterables.""" + +import asyncio +import sys +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import anyio +import pytest + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + ClaudeSDKClient, + CLIConnectionError, + ResultMessage, + TextBlock, + UserMessage, + query, +) +from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport + + +class TestClaudeSDKClientStreaming: + """Test ClaudeSDKClient streaming functionality.""" + + def test_auto_connect_with_context_manager(self): + """Test automatic connection when using context manager.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + # Verify connect was called + mock_transport.connect.assert_called_once() + assert client._transport is mock_transport + + # Verify disconnect was called on exit + mock_transport.disconnect.assert_called_once() + + anyio.run(_test) + + def test_manual_connect_disconnect(self): + """Test manual connect and disconnect.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + client = ClaudeSDKClient() + await client.connect() + + # Verify connect was called + mock_transport.connect.assert_called_once() + assert client._transport is mock_transport + + await client.disconnect() + # Verify disconnect was called + mock_transport.disconnect.assert_called_once() + assert client._transport is None + + anyio.run(_test) + + def test_connect_with_string_prompt(self): + """Test connecting with a string prompt.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + client = ClaudeSDKClient() + await client.connect("Hello Claude") + + # Verify transport was created with string prompt + call_kwargs = mock_transport_class.call_args.kwargs + assert call_kwargs["prompt"] == "Hello Claude" + + anyio.run(_test) + + def test_connect_with_async_iterable(self): + """Test connecting with an async iterable.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async def message_stream(): + yield {"type": "user", "message": {"role": "user", "content": "Hi"}} + yield { + "type": "user", + "message": {"role": "user", "content": "Bye"}, + } + + client = ClaudeSDKClient() + stream = message_stream() + await client.connect(stream) + + # Verify transport was created with async iterable + call_kwargs = mock_transport_class.call_args.kwargs + # Should be the same async iterator + assert call_kwargs["prompt"] is stream + + anyio.run(_test) + + def test_query(self): + """Test sending a query.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.query("Test message") + + # Verify send_request was called with correct format + mock_transport.send_request.assert_called_once() + call_args = mock_transport.send_request.call_args + messages, options = call_args[0] + assert len(messages) == 1 + assert messages[0]["type"] == "user" + assert messages[0]["message"]["content"] == "Test message" + assert options["session_id"] == "default" + + anyio.run(_test) + + def test_send_message_with_session_id(self): + """Test sending a message with custom session ID.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.query("Test", session_id="custom-session") + + call_args = mock_transport.send_request.call_args + messages, options = call_args[0] + assert messages[0]["session_id"] == "custom-session" + assert options["session_id"] == "custom-session" + + anyio.run(_test) + + def test_send_message_not_connected(self): + """Test sending message when not connected raises error.""" + + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + await client.query("Test") + + anyio.run(_test) + + def test_receive_messages(self): + """Test receiving messages.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock the message stream + async def mock_receive(): + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Hello!"}], + }, + } + yield { + "type": "user", + "message": {"role": "user", "content": "Hi there"}, + } + + mock_transport.receive_messages = mock_receive + + async with ClaudeSDKClient() as client: + messages = [] + async for msg in client.receive_messages(): + messages.append(msg) + if len(messages) == 2: + break + + assert len(messages) == 2 + assert isinstance(messages[0], AssistantMessage) + assert isinstance(messages[0].content[0], TextBlock) + assert messages[0].content[0].text == "Hello!" + assert isinstance(messages[1], UserMessage) + assert messages[1].content == "Hi there" + + anyio.run(_test) + + def test_receive_response(self): + """Test receive_response stops at ResultMessage.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock the message stream + async def mock_receive(): + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Answer"}], + }, + } + yield { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 800, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + # This should not be yielded + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [ + {"type": "text", "text": "Should not see this"} + ], + }, + } + + mock_transport.receive_messages = mock_receive + + async with ClaudeSDKClient() as client: + messages = [] + async for msg in client.receive_response(): + messages.append(msg) + + # Should only get 2 messages (assistant + result) + assert len(messages) == 2 + assert isinstance(messages[0], AssistantMessage) + assert isinstance(messages[1], ResultMessage) + + anyio.run(_test) + + def test_interrupt(self): + """Test interrupt functionality.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.interrupt() + mock_transport.interrupt.assert_called_once() + + anyio.run(_test) + + def test_interrupt_not_connected(self): + """Test interrupt when not connected raises error.""" + + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + await client.interrupt() + + anyio.run(_test) + + def test_client_with_options(self): + """Test client initialization with options.""" + + async def _test(): + options = ClaudeCodeOptions( + cwd="/custom/path", + allowed_tools=["Read", "Write"], + system_prompt="Be helpful", + ) + + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + client = ClaudeSDKClient(options=options) + await client.connect() + + # Verify options were passed to transport + call_kwargs = mock_transport_class.call_args.kwargs + assert call_kwargs["options"] is options + + anyio.run(_test) + + def test_concurrent_send_receive(self): + """Test concurrent sending and receiving messages.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock receive to wait then yield messages + async def mock_receive(): + await asyncio.sleep(0.1) + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Response 1"}], + }, + } + await asyncio.sleep(0.1) + yield { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 800, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + + mock_transport.receive_messages = mock_receive + + async with ClaudeSDKClient() as client: + # Helper to get next message + async def get_next_message(): + return await client.receive_response().__anext__() + + # Start receiving in background + receive_task = asyncio.create_task(get_next_message()) + + # Send message while receiving + await client.query("Question 1") + + # Wait for first message + first_msg = await receive_task + assert isinstance(first_msg, AssistantMessage) + + anyio.run(_test) + + +class TestQueryWithAsyncIterable: + """Test query() function with async iterable inputs.""" + + def test_query_with_async_iterable(self): + """Test query with async iterable of messages.""" + + async def _test(): + async def message_stream(): + yield {"type": "user", "message": {"role": "user", "content": "First"}} + yield {"type": "user", "message": {"role": "user", "content": "Second"}} + + # Create a simple test script that validates stdin and outputs a result + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + test_script = f.name + f.write("""#!/usr/bin/env python3 +import sys +import json + +# Read stdin messages +stdin_messages = [] +while True: + line = sys.stdin.readline() + if not line: + break + stdin_messages.append(line.strip()) + +# Verify we got 2 messages +assert len(stdin_messages) == 2 +assert '"First"' in stdin_messages[0] +assert '"Second"' in stdin_messages[1] + +# Output a valid result +print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}') +""") + + Path(test_script).chmod(0o755) + + try: + # Mock _find_cli to return python executing our test script + with patch.object( + SubprocessCLITransport, "_find_cli", return_value=sys.executable + ): + # Mock _build_command to add our test script as first argument + original_build_command = SubprocessCLITransport._build_command + + def mock_build_command(self): + # Get original command + cmd = original_build_command(self) + # Replace the CLI path with python + script + cmd[0] = test_script + return cmd + + with patch.object( + SubprocessCLITransport, "_build_command", mock_build_command + ): + # Run query with async iterable + messages = [] + async for msg in query(prompt=message_stream()): + messages.append(msg) + + # Should get the result message + assert len(messages) == 1 + assert isinstance(messages[0], ResultMessage) + assert messages[0].subtype == "success" + finally: + # Clean up + Path(test_script).unlink() + + anyio.run(_test) + + +class TestClaudeSDKClientEdgeCases: + """Test edge cases and error scenarios.""" + + def test_receive_messages_not_connected(self): + """Test receiving messages when not connected.""" + + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + async for _ in client.receive_messages(): + pass + + anyio.run(_test) + + def test_receive_response_not_connected(self): + """Test receive_response when not connected.""" + + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + async for _ in client.receive_response(): + pass + + anyio.run(_test) + + def test_double_connect(self): + """Test connecting twice.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + client = ClaudeSDKClient() + await client.connect() + # Second connect should create new transport + await client.connect() + + # Should have been called twice + assert mock_transport_class.call_count == 2 + + anyio.run(_test) + + def test_disconnect_without_connect(self): + """Test disconnecting without connecting first.""" + + async def _test(): + client = ClaudeSDKClient() + # Should not raise error + await client.disconnect() + + anyio.run(_test) + + def test_context_manager_with_exception(self): + """Test context manager cleans up on exception.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + with pytest.raises(ValueError): + async with ClaudeSDKClient(): + raise ValueError("Test error") + + # Disconnect should still be called + mock_transport.disconnect.assert_called_once() + + anyio.run(_test) + + def test_receive_response_list_comprehension(self): + """Test collecting messages with list comprehension as shown in examples.""" + + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock the message stream + async def mock_receive(): + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Hello"}], + }, + } + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "World"}], + }, + } + yield { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 800, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + + mock_transport.receive_messages = mock_receive + + async with ClaudeSDKClient() as client: + # Test list comprehension pattern from docstring + messages = [msg async for msg in client.receive_response()] + + assert len(messages) == 3 + assert all( + isinstance(msg, AssistantMessage | ResultMessage) + for msg in messages + ) + assert isinstance(messages[-1], ResultMessage) + + anyio.run(_test) diff --git a/tests/test_transport.py b/tests/test_transport.py index c8d8e51f..aa9e4328 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -103,6 +103,12 @@ async def _test(): mock_process.wait = AsyncMock() mock_process.stdout = MagicMock() mock_process.stderr = MagicMock() + + # Mock stdin with aclose method + mock_stdin = MagicMock() + mock_stdin.aclose = AsyncMock() + mock_process.stdin = mock_stdin + mock_exec.return_value = mock_process transport = SubprocessCLITransport(