Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ContentBlockStopEvent,
)
from ...types import Message, ContentBlock, RawMessageStreamEvent
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator
from ..._models import build, construct_type, construct_type_unchecked
from ..._streaming import Stream, AsyncStream

Expand Down Expand Up @@ -401,6 +401,27 @@ def build_events(
)


def _deep_merge_extra_fields(existing: object, new: object) -> object:
"""Deep merge new data into existing data, mutating containers in place.

- Dicts: recursively merge keys (mutates existing dict)
- Lists: extend existing with new items (mutates existing list)
- Other: replace with new value
"""
if is_dict(existing) and is_dict(new):
for key, value in new.items():
if key in existing:
existing[key] = _deep_merge_extra_fields(existing[key], value)
else:
existing[key] = value
return existing # Return mutated dict
elif is_list(existing) and is_list(new):
existing.extend(new)
return existing # Return mutated list
else:
return new


def accumulate_event(
*,
event: RawMessageStreamEvent,
Expand Down Expand Up @@ -481,4 +502,19 @@ def accumulate_event(
if event.usage.server_tool_use is not None:
current_snapshot.usage.server_tool_use = event.usage.server_tool_use

# Accumulate any extra fields from the event into the snapshot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: should we just accumulate all the extra fields from events into the message?
TBH, what I really expect here is to see the same extra fields I would get if I made a non-streaming request. Maybe we shouldn't accumulate all the extra kwargs from events, but only the ones that make sense (which are extra kwargs for the non-streaming case too).

  • Does that make sense?
  • Do we know which events are responsible for transferring the extra fields we would see in the non-streaming case?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe non-streaming requests already copy over these extra fields (the code I'm working with broke when we moved from non-streaming to streaming). I'm not sure exactly where that happens, though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion! Let me explain what I'm seeing.
So in the non-streaming case, we have extra fields A and B, right? I'd expect the streaming case to have those same fields – no more, no less. But right now it looks like it might be getting more.
It seems like we're accumulating extra fields from all the events into the final message, including stuff from the message-stop event and other events that don't actually hold message data. I don't think that's what we want here – or am I missing something?

Also, could you share the SDK version and the requests you made for both streaming and non-streaming? That would be super helpful for figuring this out!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be enough to gate this logic on the event.type?

I will have to dig up the version numbers and test cases later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be great to understand which particular events are responsible for transferring message extra fields. From the snapshots you've created for the test, I see that these are message_start, content_block_delta, and message_delta. We can potentially update kwargs only if the event is one of these as you mentioned. It would help us avoid filling up the extra kwargs of the final message with unnecessary data.
Also, I noticed that the same kwargs could be sent through different event types. For example, in your snapshot, message_start, content_block_delta, and message_delta are all sending the private_field extra field. Is that intentional?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P.S. I see that message_delta can send some fields which have already been sent by message_start, such as usage or context management fields, so we need to accumulate at least these two events. But I'm not sure about content_block_delta.

if hasattr(event, '__pydantic_extra__') and event.__pydantic_extra__:
if not hasattr(current_snapshot, '__pydantic_extra__') or current_snapshot.__pydantic_extra__ is None:
current_snapshot.__pydantic_extra__ = {}

snapshot_extra = current_snapshot.__pydantic_extra__
for key, value in event.__pydantic_extra__.items():
if key in snapshot_extra:
snapshot_extra[key] = _deep_merge_extra_fields(
snapshot_extra[key],
value
)
else:
snapshot_extra[key] = value

return current_snapshot
130 changes: 130 additions & 0 deletions tests/lib/streaming/test_extra_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Tests for accumulating extra fields in streaming responses.

This tests that pydantic extra fields (fields not in the schema) are properly
accumulated during streaming, without exposing specific field names in the SDK.
"""

from __future__ import annotations

from typing import Any, cast

import pytest

from anthropic.types import Usage, Message, TextBlock, TextDelta
from anthropic._compat import PYDANTIC_V1
from anthropic.lib.streaming._messages import accumulate_event
from anthropic.types.message_delta_usage import MessageDeltaUsage
from anthropic.types.raw_message_delta_event import Delta, RawMessageDeltaEvent
from anthropic.types.raw_message_start_event import RawMessageStartEvent
from anthropic.types.raw_content_block_delta_event import RawContentBlockDeltaEvent
from anthropic.types.raw_content_block_start_event import RawContentBlockStartEvent


@pytest.mark.skipif(PYDANTIC_V1, reason="Extra fields accumulation not supported in Pydantic v1")
def test_extra_fields_accumulation():
"""Test that extra fields are accumulated across streaming events."""
# Build message with extra field via message_start
message_start = RawMessageStartEvent(
type="message_start",
message=Message(
id="msg_123",
type="message",
role="assistant",
content=[],
model="claude-3-opus-latest",
stop_reason=None,
stop_sequence=None,
usage=Usage(input_tokens=11, output_tokens=1),
# Extra field with nested structure
private_field={"nested": {"values": [1, 2]}}, # type: ignore[call-arg]
),
)
snapshot = accumulate_event(event=message_start, current_snapshot=None)

# content_block_start
content_block_start = RawContentBlockStartEvent(
type="content_block_start",
index=0,
content_block=TextBlock(type="text", text=""),
)
snapshot = accumulate_event(event=content_block_start, current_snapshot=snapshot)

# First content_block_delta with extra field
delta1 = RawContentBlockDeltaEvent(
type="content_block_delta",
index=0,
delta=TextDelta(type="text_delta", text="Hello"),
private_field={"nested": {"values": [3], "metadata": "chunk1"}}, # type: ignore[call-arg]
)
snapshot = accumulate_event(event=delta1, current_snapshot=snapshot)

# Second content_block_delta with extra field
delta2 = RawContentBlockDeltaEvent(
type="content_block_delta",
index=0,
delta=TextDelta(type="text_delta", text="!"),
private_field={"nested": {"values": [4, 5], "metadata": "chunk2"}}, # type: ignore[call-arg]
)
snapshot = accumulate_event(event=delta2, current_snapshot=snapshot)

# message_delta with extra field
message_delta = RawMessageDeltaEvent(
type="message_delta",
delta=Delta(stop_reason="end_turn", stop_sequence=None),
usage=MessageDeltaUsage(output_tokens=3),
private_field={"nested": {"values": [6]}}, # type: ignore[call-arg]
)
snapshot = accumulate_event(event=message_delta, current_snapshot=snapshot)

# This feature requires Pydantic v2
if PYDANTIC_V1:
return

# Verify extra fields were accumulated
assert hasattr(snapshot, "__pydantic_extra__"), "Message should have __pydantic_extra__"
extra = snapshot.__pydantic_extra__
assert extra is not None
assert "private_field" in extra, "Extra fields should be accumulated"

private_field = cast(dict[str, Any], extra["private_field"])
assert "nested" in private_field

nested = cast(dict[str, Any], private_field["nested"])
assert "values" in nested

# Lists should be extended across all events: [1,2] + [3] + [4,5] + [6]
assert nested["values"] == [1, 2, 3, 4, 5, 6], "Lists should be extended, not replaced"

# Dict values should use the last value
assert nested.get("metadata") == "chunk2", "Dict values should be merged"


def test_deep_merge_extra_fields_function() -> None:
"""Test the _deep_merge_extra_fields helper function directly."""
from anthropic.lib.streaming._messages import _deep_merge_extra_fields

# Test dict merging
existing = {"a": 1, "b": {"c": 2}}
new = {"b": {"d": 3}, "e": 4}
result = _deep_merge_extra_fields(existing, new)
assert result == {"a": 1, "b": {"c": 2, "d": 3}, "e": 4}
assert result is existing, "Should mutate in place"

# Test list extending
existing_list = [1, 2, 3]
new_list = [4, 5]
result_list = _deep_merge_extra_fields(existing_list, new_list)
assert result_list == [1, 2, 3, 4, 5]
assert result_list is existing_list, "Should mutate in place"

# Test nested dict with lists
existing_nested = {"data": {"values": [1, 2]}}
new_nested = {"data": {"values": [3, 4], "count": 4}}
result_nested = _deep_merge_extra_fields(existing_nested, new_nested)
assert result_nested == {"data": {"values": [1, 2, 3, 4], "count": 4}}
assert result_nested is existing_nested, "Should mutate in place"

# Test scalar replacement
assert _deep_merge_extra_fields(1, 2) == 2
assert _deep_merge_extra_fields("old", "new") == "new"
assert _deep_merge_extra_fields(None, {"a": 1}) == {"a": 1}