-
Notifications
You must be signed in to change notification settings - Fork 394
Accumulate extra pydantic fields from the sample event #1070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
1617027
4d4227f
efb4ae1
55501ef
b780db7
8c5b3fa
728433e
2f51ae8
8294861
4a6c317
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -401,6 +401,31 @@ def build_events( | |
| ) | ||
|
|
||
|
|
||
| def _deep_merge_extra_fields(existing: object, new: object) -> object: | ||
| """Deep merge new data into existing data, mutating containers in place. | ||
|
|
||
| - Dicts: recursively merge keys (mutates existing dict) | ||
| - Lists: extend existing with new items (mutates existing list) | ||
| - Other: replace with new value | ||
| """ | ||
| if isinstance(existing, dict) and isinstance(new, dict): | ||
| existing_dict = cast("dict[str, object]", existing) | ||
| new_dict = cast("dict[str, object]", new) | ||
| for key, value in new_dict.items(): | ||
| if key in existing_dict: | ||
| existing_dict[key] = _deep_merge_extra_fields(existing_dict[key], value) | ||
| else: | ||
| existing_dict[key] = value | ||
| return existing_dict # Return mutated dict | ||
| elif isinstance(existing, list) and isinstance(new, list): | ||
| existing_list = cast("list[object]", existing) | ||
| new_list = cast("list[object]", new) | ||
evanmiller-anthropic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| existing_list.extend(new_list) | ||
| return existing_list # Return mutated list | ||
| else: | ||
| return new | ||
|
|
||
|
|
||
| def accumulate_event( | ||
| *, | ||
| event: RawMessageStreamEvent, | ||
|
|
@@ -481,4 +506,19 @@ def accumulate_event( | |
| if event.usage.server_tool_use is not None: | ||
| current_snapshot.usage.server_tool_use = event.usage.server_tool_use | ||
|
|
||
| # Accumulate any extra fields from the event into the snapshot | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: should we just accumulate all the extra fields from events into the message?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe non-streaming requests already copy over these extra fields (the code I'm working with broke when we moved from non-streaming to streaming). I'm not sure exactly where that happens, though.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the confusion! Let me explain what I'm seeing. Also, could you share the SDK version and the requests you made for both streaming and non-streaming? That would be super helpful for figuring this out!
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be enough to gate this logic on the event.type? I will have to dig up the version numbers and test cases later.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'd be great to understand which particular events are responsible for transferring message extra fields. From the snapshots you've created for the test, I see that these are
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P.S. I see that |
||
| if hasattr(event, '__pydantic_extra__') and event.__pydantic_extra__: | ||
| if not hasattr(current_snapshot, '__pydantic_extra__') or current_snapshot.__pydantic_extra__ is None: | ||
| current_snapshot.__pydantic_extra__ = {} | ||
|
|
||
| snapshot_extra = current_snapshot.__pydantic_extra__ | ||
| for key, value in event.__pydantic_extra__.items(): | ||
| if key in snapshot_extra: | ||
| snapshot_extra[key] = _deep_merge_extra_fields( | ||
| snapshot_extra[key], | ||
| value | ||
| ) | ||
| else: | ||
| snapshot_extra[key] = value | ||
|
|
||
| return current_snapshot | ||
evanmiller-anthropic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| event: message_start | ||
| data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-3-opus-latest","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":1},"private_field":{"nested":{"values":[1,2]}}}} | ||
|
|
||
| event: content_block_start | ||
| data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} | ||
|
|
||
| event: content_block_delta | ||
| data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"},"private_field":{"nested":{"values":[3],"metadata":"chunk1"}}} | ||
|
|
||
| event: content_block_delta | ||
| data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"!"},"private_field":{"nested":{"values":[4,5],"metadata":"chunk2"}}} | ||
|
|
||
| event: content_block_stop | ||
| data: {"type":"content_block_stop","index":0} | ||
|
|
||
| event: message_delta | ||
| data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":3},"private_field":{"nested":{"values":[6]}}} | ||
|
|
||
| event: message_stop | ||
| data: {"type":"message_stop"} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| """Tests for accumulating extra fields in streaming responses. | ||
| This tests that pydantic extra fields (fields not in the schema) are properly | ||
| accumulated during streaming, without exposing specific field names in the SDK. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| import asyncio | ||
| from typing import Any, cast | ||
|
|
||
| import httpx | ||
| import respx | ||
|
|
||
| from anthropic import Anthropic, AsyncAnthropic | ||
|
|
||
| from .helpers import get_response, to_async_iter | ||
|
|
||
| base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") | ||
| api_key = "my-anthropic-api-key" | ||
|
|
||
| sync_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) | ||
| async_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) | ||
evanmiller-anthropic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def assert_extra_fields_accumulated(message: Any) -> None: | ||
evanmiller-anthropic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Verify that extra fields are properly accumulated from streaming events. | ||
| This test is intentionally generic - it doesn't know the specific field names, | ||
| just that extra fields should be deep-merged correctly. | ||
| """ | ||
| # Extra fields should be accessible via attribute access (pydantic's extra="allow") | ||
| assert hasattr(message, '__pydantic_extra__'), "Message should have __pydantic_extra__" | ||
|
|
||
| extra = message.__pydantic_extra__ | ||
| assert 'private_field' in extra, "Extra fields should be accumulated" | ||
|
|
||
| # Verify deep merging: nested dicts should be merged, lists should be extended | ||
| private_field_value = extra['private_field'] | ||
| assert isinstance(private_field_value, dict), "Extra field should be a dict" | ||
| private_field = cast(dict[str, object], private_field_value) | ||
| assert 'nested' in private_field, "Nested structure should be present" | ||
|
|
||
| nested_value = private_field['nested'] | ||
| assert isinstance(nested_value, dict), "Nested field should be a dict" | ||
| nested = cast(dict[str, object], nested_value) | ||
| assert 'values' in nested, "Nested values should be present" | ||
|
|
||
| # The 'values' list should have been extended across all streaming events: | ||
| # message_start: [1, 2] | ||
| # content_block_delta 1: [3] | ||
| # content_block_delta 2: [4, 5] | ||
| # message_delta: [6] | ||
| # Expected: [1, 2, 3, 4, 5, 6] | ||
| values_value = nested['values'] | ||
| assert isinstance(values_value, list), "Nested values should be a list" | ||
| values = cast(list[int], values_value) | ||
| assert values == [1, 2, 3, 4, 5, 6], "Lists should be extended, not replaced" | ||
|
|
||
| # Last value from dict merge should be present | ||
| assert nested.get('metadata') == 'chunk2', "Dict values should be merged" | ||
|
|
||
|
|
||
| class TestSyncExtraFields: | ||
| def test_extra_fields_accumulation(self) -> None: | ||
| """Test that extra fields are accumulated during streaming.""" | ||
| with respx.mock(base_url=base_url) as respx_mock: | ||
| respx_mock.post("/v1/messages").mock( | ||
| return_value=httpx.Response(200, content=get_response("extra_fields_response.txt")) | ||
| ) | ||
|
|
||
| with sync_client.messages.stream( | ||
| max_tokens=1024, | ||
| messages=[ | ||
| { | ||
| "role": "user", | ||
| "content": "Say hello!", | ||
| } | ||
| ], | ||
| model="claude-3-opus-latest", | ||
| ) as stream: | ||
| # Consume the stream | ||
| for _ in stream: | ||
| pass | ||
|
|
||
| message = stream.get_final_message() | ||
| assert_extra_fields_accumulated(message) | ||
|
|
||
|
|
||
| class TestAsyncExtraFields: | ||
| def test_extra_fields_accumulation(self) -> None: | ||
| """Test that extra fields are accumulated during async streaming.""" | ||
|
|
||
| async def run_test() -> None: | ||
| with respx.mock(base_url=base_url) as respx_mock: | ||
| respx_mock.post("/v1/messages").mock( | ||
| return_value=httpx.Response(200, content=to_async_iter(get_response("extra_fields_response.txt"))) | ||
| ) | ||
|
|
||
| async with async_client.messages.stream( | ||
| max_tokens=1024, | ||
| messages=[ | ||
| { | ||
| "role": "user", | ||
| "content": "Say hello!", | ||
| } | ||
| ], | ||
| model="claude-3-opus-latest", | ||
| ) as stream: | ||
| # Consume the stream | ||
| async for _ in stream: | ||
| pass | ||
|
|
||
| message = await stream.get_final_message() | ||
| assert_extra_fields_accumulated(message) | ||
|
|
||
| asyncio.run(run_test()) | ||
|
|
||
|
|
||
| def test_deep_merge_extra_fields_function() -> None: | ||
| """Test the _deep_merge_extra_fields helper function directly.""" | ||
| from anthropic.lib.streaming._messages import _deep_merge_extra_fields | ||
|
|
||
| # Test dict merging | ||
| existing = {"a": 1, "b": {"c": 2}} | ||
| new = {"b": {"d": 3}, "e": 4} | ||
| result = _deep_merge_extra_fields(existing, new) | ||
| assert result == {"a": 1, "b": {"c": 2, "d": 3}, "e": 4} | ||
| assert result is existing, "Should mutate in place" | ||
|
|
||
| # Test list extending | ||
| existing_list = [1, 2, 3] | ||
| new_list = [4, 5] | ||
| result_list = _deep_merge_extra_fields(existing_list, new_list) | ||
| assert result_list == [1, 2, 3, 4, 5] | ||
| assert result_list is existing_list, "Should mutate in place" | ||
|
|
||
| # Test nested dict with lists | ||
| existing_nested = {"data": {"values": [1, 2]}} | ||
| new_nested = {"data": {"values": [3, 4], "count": 4}} | ||
| result_nested = _deep_merge_extra_fields(existing_nested, new_nested) | ||
| assert result_nested == {"data": {"values": [1, 2, 3, 4], "count": 4}} | ||
| assert result_nested is existing_nested, "Should mutate in place" | ||
|
|
||
| # Test scalar replacement | ||
| assert _deep_merge_extra_fields(1, 2) == 2 | ||
| assert _deep_merge_extra_fields("old", "new") == "new" | ||
| assert _deep_merge_extra_fields(None, {"a": 1}) == {"a": 1} | ||
Uh oh!
There was an error while loading. Please reload this page.