Skip to content

Commit e665f2f

Browse files
fix(streaming): accumulate citations (#844)
1 parent fb10a7d commit e665f2f

File tree

4 files changed

+150
-64
lines changed

4 files changed

+150
-64
lines changed

src/anthropic/lib/streaming/_beta_messages.py

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never
66

77
import httpx
8+
from pydantic import BaseModel
89

910
from ..._utils import consume_sync_iterator, consume_async_iterator
1011
from ..._models import build, construct_type
1112
from ._beta_types import (
1213
BetaTextEvent,
14+
BetaCitationEvent,
1315
BetaInputJsonEvent,
1416
BetaMessageStopEvent,
1517
BetaMessageStreamEvent,
@@ -314,24 +316,40 @@ def build_events(
314316
events_to_fire.append(event)
315317

316318
content_block = message_snapshot.content[event.index]
317-
if event.delta.type == "text_delta" and content_block.type == "text":
318-
events_to_fire.append(
319-
build(
320-
BetaTextEvent,
321-
type="text",
322-
text=event.delta.text,
323-
snapshot=content_block.text,
319+
if event.delta.type == "text_delta":
320+
if content_block.type == "text":
321+
events_to_fire.append(
322+
build(
323+
BetaTextEvent,
324+
type="text",
325+
text=event.delta.text,
326+
snapshot=content_block.text,
327+
)
324328
)
325-
)
326-
elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
327-
events_to_fire.append(
328-
build(
329-
BetaInputJsonEvent,
330-
type="input_json",
331-
partial_json=event.delta.partial_json,
332-
snapshot=content_block.input,
329+
elif event.delta.type == "input_json_delta":
330+
if content_block.type == "tool_use":
331+
events_to_fire.append(
332+
build(
333+
BetaInputJsonEvent,
334+
type="input_json",
335+
partial_json=event.delta.partial_json,
336+
snapshot=content_block.input,
337+
)
333338
)
334-
)
339+
elif event.delta.type == "citations_delta":
340+
if content_block.type == "text":
341+
events_to_fire.append(
342+
build(
343+
BetaCitationEvent,
344+
type="citation",
345+
citation=event.delta.citation,
346+
snapshot=content_block.citations or [],
347+
)
348+
)
349+
else:
350+
# we only want exhaustive checking for linters, not at runtime
351+
if TYPE_CHECKING: # type: ignore[unreachable]
352+
assert_never(event.delta)
335353
elif event.type == "content_block_stop":
336354
content_block = message_snapshot.content[event.index]
337355

@@ -354,6 +372,9 @@ def accumulate_event(
354372
event: BetaRawMessageStreamEvent,
355373
current_snapshot: BetaMessage | None,
356374
) -> BetaMessage:
375+
if not isinstance(event, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
376+
raise TypeError(f"Unexpected event runtime type - {event}")
377+
357378
if current_snapshot is None:
358379
if event.type == "message_start":
359380
return BetaMessage.construct(**cast(Any, event.message.to_dict()))
@@ -370,21 +391,33 @@ def accumulate_event(
370391
)
371392
elif event.type == "content_block_delta":
372393
content = current_snapshot.content[event.index]
373-
if content.type == "text" and event.delta.type == "text_delta":
374-
content.text += event.delta.text
375-
elif content.type == "tool_use" and event.delta.type == "input_json_delta":
376-
from jiter import from_json
377-
378-
# we need to keep track of the raw JSON string as well so that we can
379-
# re-parse it for each delta, for now we just store it as an untyped
380-
# property on the snapshot
381-
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
382-
json_buf += bytes(event.delta.partial_json, "utf-8")
383-
384-
if json_buf:
385-
content.input = from_json(json_buf, partial_mode=True)
386-
387-
setattr(content, JSON_BUF_PROPERTY, json_buf)
394+
if event.delta.type == "text_delta":
395+
if content.type == "text":
396+
content.text += event.delta.text
397+
elif event.delta.type == "input_json_delta":
398+
if content.type == "tool_use":
399+
from jiter import from_json
400+
401+
# we need to keep track of the raw JSON string as well so that we can
402+
# re-parse it for each delta, for now we just store it as an untyped
403+
# property on the snapshot
404+
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
405+
json_buf += bytes(event.delta.partial_json, "utf-8")
406+
407+
if json_buf:
408+
content.input = from_json(json_buf, partial_mode=True)
409+
410+
setattr(content, JSON_BUF_PROPERTY, json_buf)
411+
elif event.delta.type == "citations_delta":
412+
if content.type == "text":
413+
if not content.citations:
414+
content.citations = [event.delta.citation]
415+
else:
416+
content.citations.append(event.delta.citation)
417+
else:
418+
# we only want exhaustive checking for linters, not at runtime
419+
if TYPE_CHECKING: # type: ignore[unreachable]
420+
assert_never(event.delta)
388421
elif event.type == "message_delta":
389422
current_snapshot.stop_reason = event.delta.stop_reason
390423
current_snapshot.stop_sequence = event.delta.stop_sequence

src/anthropic/lib/streaming/_beta_types.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Union
2-
from typing_extensions import Literal, Annotated
2+
from typing_extensions import List, Literal, Annotated
33

44
from ..._models import BaseModel
55
from ...types.beta import (
@@ -13,6 +13,7 @@
1313
BetaRawContentBlockStartEvent,
1414
)
1515
from ..._utils._transform import PropertyInfo
16+
from ...types.beta.beta_citations_delta import Citation
1617

1718

1819
class BetaTextEvent(BaseModel):
@@ -25,6 +26,16 @@ class BetaTextEvent(BaseModel):
2526
"""The entire accumulated text"""
2627

2728

29+
class BetaCitationEvent(BaseModel):
30+
type: Literal["citation"]
31+
32+
citation: Citation
33+
"""The new citation"""
34+
35+
snapshot: List[Citation]
36+
"""All of the accumulated citations"""
37+
38+
2839
class BetaInputJsonEvent(BaseModel):
2940
type: Literal["input_json"]
3041

@@ -57,6 +68,7 @@ class BetaContentBlockStopEvent(BetaRawContentBlockStopEvent):
5768
BetaMessageStreamEvent = Annotated[
5869
Union[
5970
BetaTextEvent,
71+
BetaCitationEvent,
6072
BetaInputJsonEvent,
6173
BetaRawMessageStartEvent,
6274
BetaRawMessageDeltaEvent,

src/anthropic/lib/streaming/_messages.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ._types import (
1111
TextEvent,
12+
CitationEvent,
1213
InputJsonEvent,
1314
MessageStopEvent,
1415
MessageStreamEvent,
@@ -315,24 +316,40 @@ def build_events(
315316
events_to_fire.append(event)
316317

317318
content_block = message_snapshot.content[event.index]
318-
if event.delta.type == "text_delta" and content_block.type == "text":
319-
events_to_fire.append(
320-
build(
321-
TextEvent,
322-
type="text",
323-
text=event.delta.text,
324-
snapshot=content_block.text,
319+
if event.delta.type == "text_delta":
320+
if content_block.type == "text":
321+
events_to_fire.append(
322+
build(
323+
TextEvent,
324+
type="text",
325+
text=event.delta.text,
326+
snapshot=content_block.text,
327+
)
325328
)
326-
)
327-
elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
328-
events_to_fire.append(
329-
build(
330-
InputJsonEvent,
331-
type="input_json",
332-
partial_json=event.delta.partial_json,
333-
snapshot=content_block.input,
329+
elif event.delta.type == "input_json_delta":
330+
if content_block.type == "tool_use":
331+
events_to_fire.append(
332+
build(
333+
InputJsonEvent,
334+
type="input_json",
335+
partial_json=event.delta.partial_json,
336+
snapshot=content_block.input,
337+
)
334338
)
335-
)
339+
elif event.delta.type == "citations_delta":
340+
if content_block.type == "text":
341+
events_to_fire.append(
342+
build(
343+
CitationEvent,
344+
type="citation",
345+
citation=event.delta.citation,
346+
snapshot=content_block.citations or [],
347+
)
348+
)
349+
else:
350+
# we only want exhaustive checking for linters, not at runtime
351+
if TYPE_CHECKING: # type: ignore[unreachable]
352+
assert_never(event.delta)
336353
elif event.type == "content_block_stop":
337354
content_block = message_snapshot.content[event.index]
338355

@@ -374,21 +391,33 @@ def accumulate_event(
374391
)
375392
elif event.type == "content_block_delta":
376393
content = current_snapshot.content[event.index]
377-
if content.type == "text" and event.delta.type == "text_delta":
378-
content.text += event.delta.text
379-
elif content.type == "tool_use" and event.delta.type == "input_json_delta":
380-
from jiter import from_json
381-
382-
# we need to keep track of the raw JSON string as well so that we can
383-
# re-parse it for each delta, for now we just store it as an untyped
384-
# property on the snapshot
385-
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
386-
json_buf += bytes(event.delta.partial_json, "utf-8")
387-
388-
if json_buf:
389-
content.input = from_json(json_buf, partial_mode=True)
390-
391-
setattr(content, JSON_BUF_PROPERTY, json_buf)
394+
if event.delta.type == "text_delta":
395+
if content.type == "text":
396+
content.text += event.delta.text
397+
elif event.delta.type == "input_json_delta":
398+
if content.type == "tool_use":
399+
from jiter import from_json
400+
401+
# we need to keep track of the raw JSON string as well so that we can
402+
# re-parse it for each delta, for now we just store it as an untyped
403+
# property on the snapshot
404+
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
405+
json_buf += bytes(event.delta.partial_json, "utf-8")
406+
407+
if json_buf:
408+
content.input = from_json(json_buf, partial_mode=True)
409+
410+
setattr(content, JSON_BUF_PROPERTY, json_buf)
411+
elif event.delta.type == "citations_delta":
412+
if content.type == "text":
413+
if not content.citations:
414+
content.citations = [event.delta.citation]
415+
else:
416+
content.citations.append(event.delta.citation)
417+
else:
418+
# we only want exhaustive checking for linters, not at runtime
419+
if TYPE_CHECKING: # type: ignore[unreachable]
420+
assert_never(event.delta)
392421
elif event.type == "message_delta":
393422
current_snapshot.stop_reason = event.delta.stop_reason
394423
current_snapshot.stop_sequence = event.delta.stop_sequence

src/anthropic/lib/streaming/_types.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Union
2-
from typing_extensions import Literal, Annotated
2+
from typing_extensions import List, Literal, Annotated
33

44
from ...types import (
55
Message,
@@ -13,6 +13,7 @@
1313
)
1414
from ..._models import BaseModel
1515
from ..._utils._transform import PropertyInfo
16+
from ...types.citations_delta import Citation
1617

1718

1819
class TextEvent(BaseModel):
@@ -25,6 +26,16 @@ class TextEvent(BaseModel):
2526
"""The entire accumulated text"""
2627

2728

29+
class CitationEvent(BaseModel):
30+
type: Literal["citation"]
31+
32+
citation: Citation
33+
"""The new citation"""
34+
35+
snapshot: List[Citation]
36+
"""All of the accumulated citations"""
37+
38+
2839
class InputJsonEvent(BaseModel):
2940
type: Literal["input_json"]
3041

@@ -57,6 +68,7 @@ class ContentBlockStopEvent(RawContentBlockStopEvent):
5768
MessageStreamEvent = Annotated[
5869
Union[
5970
TextEvent,
71+
CitationEvent,
6072
InputJsonEvent,
6173
RawMessageStartEvent,
6274
RawMessageDeltaEvent,

0 commit comments

Comments
 (0)