Skip to content

Commit 5cbe6ab

Browse files
authored
anthropic[patch]: support citations in streaming (#29591)
1 parent 5ae4ed7 commit 5cbe6ab

File tree

4 files changed

+59
-17
lines changed

4 files changed

+59
-17
lines changed

libs/partners/anthropic/langchain_anthropic/chat_models.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,9 @@ def _stream(
718718
kwargs["stream"] = True
719719
payload = self._get_request_payload(messages, stop=stop, **kwargs)
720720
stream = self._client.messages.create(**payload)
721-
coerce_content_to_string = not _tools_in_params(payload)
721+
coerce_content_to_string = not _tools_in_params(
722+
payload
723+
) and not _documents_in_params(payload)
722724
for event in stream:
723725
msg = _make_message_chunk_from_anthropic_event(
724726
event,
@@ -745,7 +747,9 @@ async def _astream(
745747
kwargs["stream"] = True
746748
payload = self._get_request_payload(messages, stop=stop, **kwargs)
747749
stream = await self._async_client.messages.create(**payload)
748-
coerce_content_to_string = not _tools_in_params(payload)
750+
coerce_content_to_string = not _tools_in_params(
751+
payload
752+
) and not _documents_in_params(payload)
749753
async for event in stream:
750754
msg = _make_message_chunk_from_anthropic_event(
751755
event,
@@ -761,6 +765,16 @@ async def _astream(
761765
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
762766
data_dict = data.model_dump()
763767
content = data_dict["content"]
768+
769+
# Remove citations if they are None - introduced in anthropic sdk 0.45
770+
for block in content:
771+
if (
772+
isinstance(block, dict)
773+
and "citations" in block
774+
and block["citations"] is None
775+
):
776+
block.pop("citations")
777+
764778
llm_output = {
765779
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
766780
}
@@ -1254,6 +1268,19 @@ def _tools_in_params(params: dict) -> bool:
12541268
)
12551269

12561270

1271+
def _documents_in_params(params: dict) -> bool:
1272+
for message in params.get("messages", []):
1273+
if isinstance(message.get("content"), list):
1274+
for block in message["content"]:
1275+
if (
1276+
isinstance(block, dict)
1277+
and block.get("type") == "document"
1278+
and block.get("citations", {}).get("enabled")
1279+
):
1280+
return True
1281+
return False
1282+
1283+
12571284
class _AnthropicToolUse(TypedDict):
12581285
type: Literal["tool_use"]
12591286
name: str
@@ -1299,31 +1326,37 @@ def _make_message_chunk_from_anthropic_event(
12991326
elif (
13001327
event.type == "content_block_start"
13011328
and event.content_block is not None
1302-
and event.content_block.type == "tool_use"
1329+
and event.content_block.type in ("tool_use", "document")
13031330
):
13041331
if coerce_content_to_string:
13051332
warnings.warn("Received unexpected tool content block.")
13061333
content_block = event.content_block.model_dump()
13071334
content_block["index"] = event.index
1308-
tool_call_chunk = create_tool_call_chunk(
1309-
index=event.index,
1310-
id=event.content_block.id,
1311-
name=event.content_block.name,
1312-
args="",
1313-
)
1335+
if event.content_block.type == "tool_use":
1336+
tool_call_chunk = create_tool_call_chunk(
1337+
index=event.index,
1338+
id=event.content_block.id,
1339+
name=event.content_block.name,
1340+
args="",
1341+
)
1342+
tool_call_chunks = [tool_call_chunk]
1343+
else:
1344+
tool_call_chunks = []
13141345
message_chunk = AIMessageChunk(
13151346
content=[content_block],
1316-
tool_call_chunks=[tool_call_chunk], # type: ignore
1347+
tool_call_chunks=tool_call_chunks, # type: ignore
13171348
)
13181349
elif event.type == "content_block_delta":
1319-
if event.delta.type == "text_delta":
1320-
if coerce_content_to_string:
1350+
if event.delta.type in ("text_delta", "citations_delta"):
1351+
if coerce_content_to_string and hasattr(event.delta, "text"):
13211352
text = event.delta.text
13221353
message_chunk = AIMessageChunk(content=text)
13231354
else:
13241355
content_block = event.delta.model_dump()
13251356
content_block["index"] = event.index
13261357
content_block["type"] = "text"
1358+
if "citation" in content_block:
1359+
content_block["citations"] = [content_block.pop("citation")]
13271360
message_chunk = AIMessageChunk(content=[content_block])
13281361
elif event.delta.type == "input_json_delta":
13291362
content_block = event.delta.model_dump()

libs/partners/anthropic/poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/partners/anthropic/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ plugins = ['pydantic.mypy']
2121

2222
[tool.poetry.dependencies]
2323
python = ">=3.9,<4.0"
24-
anthropic = ">=0.41.0,<1"
24+
anthropic = ">=0.45.0,<1"
2525
langchain-core = "^0.3.33"
2626
pydantic = "^2.7.4"
2727

libs/partners/anthropic/tests/integration_tests/test_chat_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,3 +649,12 @@ def test_citations() -> None:
649649
assert isinstance(response, AIMessage)
650650
assert isinstance(response.content, list)
651651
assert any("citations" in block for block in response.content)
652+
653+
# Test streaming
654+
full: Optional[BaseMessageChunk] = None
655+
for chunk in llm.stream(messages):
656+
full = chunk if full is None else full + chunk
657+
assert isinstance(full, AIMessageChunk)
658+
assert isinstance(full.content, list)
659+
assert any("citations" in block for block in full.content)
660+
assert not any("citation" in block for block in full.content)

0 commit comments

Comments
 (0)