diff --git a/examples/max_budget_usd.py b/examples/max_budget_usd.py new file mode 100644 index 00000000..bb9777e8 --- /dev/null +++ b/examples/max_budget_usd.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +"""Example demonstrating max_budget_usd option for cost control.""" + +import anyio + +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ResultMessage, + TextBlock, + query, +) + + +async def without_budget(): + """Example without budget limit.""" + print("=== Without Budget Limit ===") + + async for message in query(prompt="What is 2 + 2?"): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(message, ResultMessage): + if message.total_cost_usd: + print(f"Total cost: ${message.total_cost_usd:.4f}") + print(f"Status: {message.subtype}") + print() + + +async def with_reasonable_budget(): + """Example with budget that won't be exceeded.""" + print("=== With Reasonable Budget ($0.10) ===") + + options = ClaudeAgentOptions( + max_budget_usd=0.10, # 10 cents - plenty for a simple query + ) + + async for message in query(prompt="What is 2 + 2?", options=options): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(message, ResultMessage): + if message.total_cost_usd: + print(f"Total cost: ${message.total_cost_usd:.4f}") + print(f"Status: {message.subtype}") + print() + + +async def with_tight_budget(): + """Example with very tight budget that will likely be exceeded.""" + print("=== With Tight Budget ($0.0001) ===") + + options = ClaudeAgentOptions( + max_budget_usd=0.0001, # Very small budget - will be exceeded quickly + ) + + async for message in query( + prompt="Read the README.md file and summarize it", options=options + ): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(message, ResultMessage): + if message.total_cost_usd: + print(f"Total cost: ${message.total_cost_usd:.4f}") + print(f"Status: {message.subtype}") + + # Check if budget was exceeded + if message.subtype == "error_max_budget_usd": + print("⚠️ Budget limit exceeded!") + print( + "Note: The cost may exceed the budget by up to one API call's worth" + ) + print() + + +async def main(): + """Run all examples.""" + print("This example demonstrates using max_budget_usd to control API costs.\n") + + await without_budget() + await with_reasonable_budget() + await with_tight_budget() + + print( + "\nNote: Budget checking happens after each API call completes,\n" + "so the final cost may slightly exceed the specified budget.\n" + ) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py index 1ec352f9..d1c53924 100644 --- a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py @@ -115,6 +115,9 @@ def _build_command(self) -> list[str]: if self._options.max_turns: cmd.extend(["--max-turns", str(self._options.max_turns)]) + if self._options.max_budget_usd is not None: + cmd.extend(["--max-budget-usd", str(self._options.max_budget_usd)]) + if self._options.disallowed_tools: cmd.extend(["--disallowedTools", ",".join(self._options.disallowed_tools)]) diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index efeaf70b..035dda44 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -518,6 +518,7 @@ class ClaudeAgentOptions: continue_conversation: bool = False resume: str | None = None max_turns: int | None = None + max_budget_usd: float | None = None disallowed_tools: list[str] = field(default_factory=list) model: str | None = None permission_prompt_tool_name: str | None = None diff --git a/tests/test_integration.py b/tests/test_integration.py index 8531c9e5..1f237dcc 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -212,3 +212,73 @@ async def mock_receive(): assert call_kwargs["options"].continue_conversation is True anyio.run(_test) + + def test_max_budget_usd_option(self): + """Test query with max_budget_usd option.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.client.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock the message stream that exceeds budget + async def mock_receive(): + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [ + {"type": "text", "text": "Starting to read..."} + ], + "model": "claude-opus-4-1-20250805", + }, + } + yield { + "type": "result", + "subtype": "error_max_budget_usd", + "duration_ms": 500, + "duration_api_ms": 400, + "is_error": False, + "num_turns": 1, + "session_id": "test-session-budget", + "total_cost_usd": 0.0002, + "usage": { + "input_tokens": 100, + "output_tokens": 50, + }, + } + + mock_transport.read_messages = mock_receive + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) + + # Run query with very small budget + messages = [] + async for msg in query( + prompt="Read the readme", + options=ClaudeAgentOptions(max_budget_usd=0.0001), + ): + messages.append(msg) + + # Verify results + assert len(messages) == 2 + + # Check result message + assert isinstance(messages[1], ResultMessage) + assert messages[1].subtype == "error_max_budget_usd" + assert messages[1].is_error is False + assert messages[1].total_cost_usd == 0.0002 + assert messages[1].total_cost_usd is not None + assert messages[1].total_cost_usd > 0 + + # Verify transport was created with max_budget_usd option + mock_transport_class.assert_called_once() + call_kwargs = mock_transport_class.call_args.kwargs + assert call_kwargs["options"].max_budget_usd == 0.0001 + + anyio.run(_test)