Skip to content

Commit 3408cb6

Browse files
dicksontsairushilpatel0
authored andcommitted
PR feedback
Signed-off-by: Rushil Patel <[email protected]>
1 parent 8796e2d commit 3408cb6

File tree

3 files changed

+115
-59
lines changed

3 files changed

+115
-59
lines changed
Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Message parser for Claude Code SDK responses."""
22

3+
import logging
34
from typing import Any
45

56
from ..types import (
@@ -14,6 +15,8 @@
1415
UserMessage,
1516
)
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
def parse_message(data: dict[str, Any]) -> Message | None:
1922
"""
@@ -23,55 +26,82 @@ def parse_message(data: dict[str, Any]) -> Message | None:
2326
data: Raw message dictionary from CLI output
2427
2528
Returns:
26-
Parsed Message object or None if type is unrecognized
29+
Parsed Message object or None if type is unrecognized or parsing fails
2730
"""
28-
match data["type"]:
31+
try:
32+
message_type = data.get("type")
33+
if not message_type:
34+
logger.warning("Message missing 'type' field: %s", data)
35+
return None
36+
37+
except AttributeError:
38+
logger.error("Invalid message data type (expected dict): %s", type(data))
39+
return None
40+
41+
match message_type:
2942
case "user":
30-
return UserMessage(content=data["message"]["content"])
43+
try:
44+
return UserMessage(content=data["message"]["content"])
45+
except KeyError as e:
46+
logger.error("Missing required field in user message: %s", e)
47+
return None
3148

3249
case "assistant":
33-
content_blocks: list[ContentBlock] = []
34-
for block in data["message"]["content"]:
35-
match block["type"]:
36-
case "text":
37-
content_blocks.append(TextBlock(text=block["text"]))
38-
case "tool_use":
39-
content_blocks.append(
40-
ToolUseBlock(
41-
id=block["id"],
42-
name=block["name"],
43-
input=block["input"],
50+
try:
51+
content_blocks: list[ContentBlock] = []
52+
for block in data["message"]["content"]:
53+
match block["type"]:
54+
case "text":
55+
content_blocks.append(TextBlock(text=block["text"]))
56+
case "tool_use":
57+
content_blocks.append(
58+
ToolUseBlock(
59+
id=block["id"],
60+
name=block["name"],
61+
input=block["input"],
62+
)
4463
)
45-
)
46-
case "tool_result":
47-
content_blocks.append(
48-
ToolResultBlock(
49-
tool_use_id=block["tool_use_id"],
50-
content=block.get("content"),
51-
is_error=block.get("is_error"),
64+
case "tool_result":
65+
content_blocks.append(
66+
ToolResultBlock(
67+
tool_use_id=block["tool_use_id"],
68+
content=block.get("content"),
69+
is_error=block.get("is_error"),
70+
)
5271
)
53-
)
5472

55-
return AssistantMessage(content=content_blocks)
73+
return AssistantMessage(content=content_blocks)
74+
except KeyError as e:
75+
logger.error("Missing required field in assistant message: %s", e)
76+
return None
5677

5778
case "system":
58-
return SystemMessage(
59-
subtype=data["subtype"],
60-
data=data,
61-
)
79+
try:
80+
return SystemMessage(
81+
subtype=data["subtype"],
82+
data=data,
83+
)
84+
except KeyError as e:
85+
logger.error("Missing required field in system message: %s", e)
86+
return None
6287

6388
case "result":
64-
return ResultMessage(
65-
subtype=data["subtype"],
66-
duration_ms=data["duration_ms"],
67-
duration_api_ms=data["duration_api_ms"],
68-
is_error=data["is_error"],
69-
num_turns=data["num_turns"],
70-
session_id=data["session_id"],
71-
total_cost_usd=data.get("total_cost_usd"),
72-
usage=data.get("usage"),
73-
result=data.get("result"),
74-
)
89+
try:
90+
return ResultMessage(
91+
subtype=data["subtype"],
92+
duration_ms=data["duration_ms"],
93+
duration_api_ms=data["duration_api_ms"],
94+
is_error=data["is_error"],
95+
num_turns=data["num_turns"],
96+
session_id=data["session_id"],
97+
total_cost_usd=data.get("total_cost_usd"),
98+
usage=data.get("usage"),
99+
result=data.get("result"),
100+
)
101+
except KeyError as e:
102+
logger.error("Missing required field in result message: %s", e)
103+
return None
75104

76105
case _:
106+
logger.debug("Unknown message type: %s", message_type)
77107
return None

src/claude_code_sdk/_internal/transport/subprocess_cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,12 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
306306
yield data
307307
except GeneratorExit:
308308
return
309-
except json.JSONDecodeError:
309+
except json.JSONDecodeError as e:
310+
logger.warning(
311+
f"Failed to parse JSON from CLI output: {e}. Buffer content: {json_buffer[:200]}..."
312+
)
313+
# Clear buffer to avoid repeated parse attempts on malformed data
314+
json_buffer = ""
310315
continue
311316

312317
except anyio.ClosedResourceError:

src/claude_code_sdk/client.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,37 @@ async def receive_messages(self) -> AsyncIterator[Message]:
128128
if message:
129129
yield message
130130

131-
async def send_message(self, content: str, session_id: str = "default") -> None:
132-
"""Send a new message in streaming mode."""
131+
async def send_message(self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default") -> None:
132+
"""
133+
Send a new message in streaming mode.
134+
135+
Args:
136+
prompt: Either a string message or an async iterable of message dictionaries
137+
session_id: Session identifier for the conversation
138+
"""
133139
if not self._transport:
134140
raise CLIConnectionError("Not connected. Call connect() first.")
135141

136-
message = {
137-
"type": "user",
138-
"message": {"role": "user", "content": content},
139-
"parent_tool_use_id": None,
140-
"session_id": session_id,
141-
}
142-
143-
await self._transport.send_request([message], {"session_id": session_id})
142+
# Handle string prompts
143+
if isinstance(prompt, str):
144+
message = {
145+
"type": "user",
146+
"message": {"role": "user", "content": prompt},
147+
"parent_tool_use_id": None,
148+
"session_id": session_id,
149+
}
150+
await self._transport.send_request([message], {"session_id": session_id})
151+
else:
152+
# Handle AsyncIterable prompts
153+
messages = []
154+
async for msg in prompt:
155+
# Ensure session_id is set on each message
156+
if "session_id" not in msg:
157+
msg["session_id"] = session_id
158+
messages.append(msg)
159+
160+
if messages:
161+
await self._transport.send_request(messages, {"session_id": session_id})
144162

145163
async def interrupt(self) -> None:
146164
"""Send interrupt signal (only works with streaming mode)."""
@@ -150,19 +168,24 @@ async def interrupt(self) -> None:
150168

151169
async def receive_response(self) -> AsyncIterator[Message]:
152170
"""
153-
Receive messages from Claude until a ResultMessage is received.
171+
Receive messages from Claude until and including a ResultMessage.
154172
155-
This is an async iterator that yields all messages including the final ResultMessage.
156-
It's a convenience method over receive_messages() that automatically stops iteration
157-
after receiving a ResultMessage.
173+
This async iterator yields all messages in sequence and automatically terminates
174+
after yielding a ResultMessage (which indicates the response is complete).
175+
It's a convenience method over receive_messages() for single-response workflows.
176+
177+
**Stopping Behavior:**
178+
- Yields each message as it's received
179+
- Terminates immediately after yielding a ResultMessage
180+
- The ResultMessage IS included in the yielded messages
181+
- If no ResultMessage is received, the iterator continues indefinitely
158182
159183
Yields:
160184
Message: Each message received (UserMessage, AssistantMessage, SystemMessage, ResultMessage)
161185
162186
Example:
163187
```python
164188
async with ClaudeSDKClient() as client:
165-
# Send message and process response
166189
await client.send_message("What's the capital of France?")
167190
168191
async for msg in client.receive_response():
@@ -172,14 +195,12 @@ async def receive_response(self) -> AsyncIterator[Message]:
172195
print(f"Claude: {block.text}")
173196
elif isinstance(msg, ResultMessage):
174197
print(f"Cost: ${msg.total_cost_usd:.4f}")
198+
# Iterator will terminate after this message
175199
```
176200
177201
Note:
178-
The iterator will automatically stop after yielding a ResultMessage.
179-
If you need to collect all messages into a list, use:
180-
```python
181-
messages = [msg async for msg in client.receive_response()]
182-
```
202+
To collect all messages: `messages = [msg async for msg in client.receive_response()]`
203+
The final message in the list will always be a ResultMessage.
183204
"""
184205
async for message in self.receive_messages():
185206
yield message

0 commit comments

Comments
 (0)