55from typing_extensions import Self , Iterator , Awaitable , AsyncIterator , assert_never
66
77import httpx
8+ from pydantic import BaseModel
89
910from ..._utils import consume_sync_iterator , consume_async_iterator
1011from ..._models import build , construct_type
1112from ._beta_types import (
1213 BetaTextEvent ,
14+ BetaCitationEvent ,
1315 BetaInputJsonEvent ,
1416 BetaMessageStopEvent ,
1517 BetaMessageStreamEvent ,
@@ -314,24 +316,40 @@ def build_events(
314316 events_to_fire .append (event )
315317
316318 content_block = message_snapshot .content [event .index ]
317- if event .delta .type == "text_delta" and content_block .type == "text" :
318- events_to_fire .append (
319- build (
320- BetaTextEvent ,
321- type = "text" ,
322- text = event .delta .text ,
323- snapshot = content_block .text ,
319+ if event .delta .type == "text_delta" :
320+ if content_block .type == "text" :
321+ events_to_fire .append (
322+ build (
323+ BetaTextEvent ,
324+ type = "text" ,
325+ text = event .delta .text ,
326+ snapshot = content_block .text ,
327+ )
324328 )
325- )
326- elif event .delta .type == "input_json_delta" and content_block .type == "tool_use" :
327- events_to_fire .append (
328- build (
329- BetaInputJsonEvent ,
330- type = "input_json" ,
331- partial_json = event .delta .partial_json ,
332- snapshot = content_block .input ,
329+ elif event .delta .type == "input_json_delta" :
330+ if content_block .type == "tool_use" :
331+ events_to_fire .append (
332+ build (
333+ BetaInputJsonEvent ,
334+ type = "input_json" ,
335+ partial_json = event .delta .partial_json ,
336+ snapshot = content_block .input ,
337+ )
333338 )
334- )
339+ elif event .delta .type == "citations_delta" :
340+ if content_block .type == "text" :
341+ events_to_fire .append (
342+ build (
343+ BetaCitationEvent ,
344+ type = "citation" ,
345+ citation = event .delta .citation ,
346+ snapshot = content_block .citations or [],
347+ )
348+ )
349+ else :
350+ # we only want exhaustive checking for linters, not at runtime
351+ if TYPE_CHECKING : # type: ignore[unreachable]
352+ assert_never (event .delta )
335353 elif event .type == "content_block_stop" :
336354 content_block = message_snapshot .content [event .index ]
337355
@@ -354,6 +372,9 @@ def accumulate_event(
354372 event : BetaRawMessageStreamEvent ,
355373 current_snapshot : BetaMessage | None ,
356374) -> BetaMessage :
375+ if not isinstance (event , BaseModel ): # pyright: ignore[reportUnnecessaryIsInstance]
376+ raise TypeError (f"Unexpected event runtime type - { event } " )
377+
357378 if current_snapshot is None :
358379 if event .type == "message_start" :
359380 return BetaMessage .construct (** cast (Any , event .message .to_dict ()))
@@ -370,21 +391,33 @@ def accumulate_event(
370391 )
371392 elif event .type == "content_block_delta" :
372393 content = current_snapshot .content [event .index ]
373- if content .type == "text" and event .delta .type == "text_delta" :
374- content .text += event .delta .text
375- elif content .type == "tool_use" and event .delta .type == "input_json_delta" :
376- from jiter import from_json
377-
378- # we need to keep track of the raw JSON string as well so that we can
379- # re-parse it for each delta, for now we just store it as an untyped
380- # property on the snapshot
381- json_buf = cast (bytes , getattr (content , JSON_BUF_PROPERTY , b"" ))
382- json_buf += bytes (event .delta .partial_json , "utf-8" )
383-
384- if json_buf :
385- content .input = from_json (json_buf , partial_mode = True )
386-
387- setattr (content , JSON_BUF_PROPERTY , json_buf )
394+ if event .delta .type == "text_delta" :
395+ if content .type == "text" :
396+ content .text += event .delta .text
397+ elif event .delta .type == "input_json_delta" :
398+ if content .type == "tool_use" :
399+ from jiter import from_json
400+
401+ # we need to keep track of the raw JSON string as well so that we can
402+ # re-parse it for each delta, for now we just store it as an untyped
403+ # property on the snapshot
404+ json_buf = cast (bytes , getattr (content , JSON_BUF_PROPERTY , b"" ))
405+ json_buf += bytes (event .delta .partial_json , "utf-8" )
406+
407+ if json_buf :
408+ content .input = from_json (json_buf , partial_mode = True )
409+
410+ setattr (content , JSON_BUF_PROPERTY , json_buf )
411+ elif event .delta .type == "citations_delta" :
412+ if content .type == "text" :
413+ if not content .citations :
414+ content .citations = [event .delta .citation ]
415+ else :
416+ content .citations .append (event .delta .citation )
417+ else :
418+ # we only want exhaustive checking for linters, not at runtime
419+ if TYPE_CHECKING : # type: ignore[unreachable]
420+ assert_never (event .delta )
388421 elif event .type == "message_delta" :
389422 current_snapshot .stop_reason = event .delta .stop_reason
390423 current_snapshot .stop_sequence = event .delta .stop_sequence
0 commit comments