Skip to content

Commit b02bd67

Browse files
authored
anthropic[patch]: cache clients (#31659)
1 parent e3f1ce0 commit b02bd67

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Helpers for creating Anthropic API clients.
2+
3+
This module allows for the caching of httpx clients to avoid creating new instances
4+
for each instance of ChatAnthropic.
5+
6+
Logic is largely replicated from anthropic._base_client.
7+
"""
8+
9+
import asyncio
10+
import os
11+
from functools import lru_cache
12+
from typing import Any, Optional
13+
14+
import anthropic
15+
16+
_NOT_GIVEN: Any = object()
17+
18+
19+
class _SyncHttpxClientWrapper(anthropic.DefaultHttpxClient):
20+
"""Borrowed from anthropic._base_client"""
21+
22+
def __del__(self) -> None:
23+
if self.is_closed:
24+
return
25+
26+
try:
27+
self.close()
28+
except Exception:
29+
pass
30+
31+
32+
class _AsyncHttpxClientWrapper(anthropic.DefaultAsyncHttpxClient):
33+
"""Borrowed from anthropic._base_client"""
34+
35+
def __del__(self) -> None:
36+
if self.is_closed:
37+
return
38+
39+
try:
40+
# TODO(someday): support non asyncio runtimes here
41+
asyncio.get_running_loop().create_task(self.aclose())
42+
except Exception:
43+
pass
44+
45+
46+
@lru_cache
47+
def _get_default_httpx_client(
48+
*,
49+
base_url: Optional[str],
50+
timeout: Any = _NOT_GIVEN,
51+
) -> _SyncHttpxClientWrapper:
52+
kwargs: dict[str, Any] = {
53+
"base_url": base_url
54+
or os.environ.get("ANTHROPIC_BASE_URL")
55+
or "https://api.anthropic.com",
56+
}
57+
if timeout is not _NOT_GIVEN:
58+
kwargs["timeout"] = timeout
59+
return _SyncHttpxClientWrapper(**kwargs)
60+
61+
62+
@lru_cache
63+
def _get_default_async_httpx_client(
64+
*,
65+
base_url: Optional[str],
66+
timeout: Any = _NOT_GIVEN,
67+
) -> _AsyncHttpxClientWrapper:
68+
kwargs: dict[str, Any] = {
69+
"base_url": base_url
70+
or os.environ.get("ANTHROPIC_BASE_URL")
71+
or "https://api.anthropic.com",
72+
}
73+
if timeout is not _NOT_GIVEN:
74+
kwargs["timeout"] = timeout
75+
return _AsyncHttpxClientWrapper(**kwargs)

libs/partners/anthropic/langchain_anthropic/chat_models.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@
6969
)
7070
from typing_extensions import NotRequired, TypedDict
7171

72+
from langchain_anthropic._client_utils import (
73+
_get_default_async_httpx_client,
74+
_get_default_httpx_client,
75+
)
7276
from langchain_anthropic.output_parsers import extract_tool_calls
7377

7478
_message_type_lookups = {
@@ -1300,11 +1304,29 @@ def _client_params(self) -> dict[str, Any]:
13001304

13011305
@cached_property
13021306
def _client(self) -> anthropic.Client:
1303-
return anthropic.Client(**self._client_params)
1307+
client_params = self._client_params
1308+
http_client_params = {"base_url": client_params["base_url"]}
1309+
if "timeout" in client_params:
1310+
http_client_params["timeout"] = client_params["timeout"]
1311+
http_client = _get_default_httpx_client(**http_client_params)
1312+
params = {
1313+
**client_params,
1314+
"http_client": http_client,
1315+
}
1316+
return anthropic.Client(**params)
13041317

13051318
@cached_property
13061319
def _async_client(self) -> anthropic.AsyncClient:
1307-
return anthropic.AsyncClient(**self._client_params)
1320+
client_params = self._client_params
1321+
http_client_params = {"base_url": client_params["base_url"]}
1322+
if "timeout" in client_params:
1323+
http_client_params["timeout"] = client_params["timeout"]
1324+
http_client = _get_default_async_httpx_client(**http_client_params)
1325+
params = {
1326+
**client_params,
1327+
"http_client": http_client,
1328+
}
1329+
return anthropic.AsyncClient(**params)
13081330

13091331
def _get_request_payload(
13101332
self,

libs/partners/anthropic/tests/integration_tests/test_chat_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test ChatAnthropic chat model."""
22

3+
import asyncio
34
import json
45
import os
56
from base64 import b64encode
@@ -1082,3 +1083,10 @@ def test_files_api_pdf(block_format: str) -> None:
10821083
],
10831084
}
10841085
_ = llm.invoke([input_message])
1086+
1087+
1088+
def test_async_shared_client() -> None:
1089+
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
1090+
llm._async_client # Instantiates lazily
1091+
_ = asyncio.run(llm.ainvoke("Hello"))
1092+
_ = asyncio.run(llm.ainvoke("Hello"))

libs/partners/anthropic/tests/unit_tests/test_chat_models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,22 @@ def test_initialization() -> None:
4444
assert model.anthropic_api_url == "https://api.anthropic.com"
4545

4646

47+
def test_anthropic_client_caching() -> None:
48+
"""Test that the OpenAI client is cached."""
49+
llm1 = ChatAnthropic(model="claude-3-5-sonnet-latest")
50+
llm2 = ChatAnthropic(model="claude-3-5-sonnet-latest")
51+
assert llm1._client._client is llm2._client._client
52+
53+
llm3 = ChatAnthropic(model="claude-3-5-sonnet-latest", base_url="foo")
54+
assert llm1._client._client is not llm3._client._client
55+
56+
llm4 = ChatAnthropic(model="claude-3-5-sonnet-latest", timeout=None)
57+
assert llm1._client._client is llm4._client._client
58+
59+
llm5 = ChatAnthropic(model="claude-3-5-sonnet-latest", timeout=3)
60+
assert llm1._client._client is not llm5._client._client
61+
62+
4763
@pytest.mark.requires("anthropic")
4864
def test_anthropic_model_name_param() -> None:
4965
llm = ChatAnthropic(model_name="foo") # type: ignore[call-arg, call-arg]

0 commit comments

Comments
 (0)