Skip to content

Commit 3a487bf

Browse files
authored
refactor(anthropic): AnthropicLLM to use Messages API (#32290)
re: #32189
1 parent e5fd670 commit 3a487bf

File tree

4 files changed

+137
-115
lines changed

4 files changed

+137
-115
lines changed

libs/partners/anthropic/langchain_anthropic/llms.py

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
import re
44
import warnings
55
from collections.abc import AsyncIterator, Iterator, Mapping
6-
from typing import (
7-
Any,
8-
Callable,
9-
Optional,
10-
)
6+
from typing import Any, Callable, Optional
117

128
import anthropic
139
from langchain_core._api.deprecation import deprecated
@@ -19,25 +15,19 @@
1915
from langchain_core.language_models.llms import LLM
2016
from langchain_core.outputs import GenerationChunk
2117
from langchain_core.prompt_values import PromptValue
22-
from langchain_core.utils import (
23-
get_pydantic_field_names,
24-
)
25-
from langchain_core.utils.utils import (
26-
_build_model_kwargs,
27-
from_env,
28-
secret_from_env,
29-
)
18+
from langchain_core.utils import get_pydantic_field_names
19+
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
3020
from pydantic import ConfigDict, Field, SecretStr, model_validator
3121
from typing_extensions import Self
3222

3323

3424
class _AnthropicCommon(BaseLanguageModel):
3525
client: Any = None #: :meta private:
3626
async_client: Any = None #: :meta private:
37-
model: str = Field(default="claude-2", alias="model_name")
27+
model: str = Field(default="claude-3-5-sonnet-latest", alias="model_name")
3828
"""Model name to use."""
3929

40-
max_tokens_to_sample: int = Field(default=1024, alias="max_tokens")
30+
max_tokens: int = Field(default=1024, alias="max_tokens_to_sample")
4131
"""Denotes the number of tokens to predict per generation."""
4232

4333
temperature: Optional[float] = None
@@ -104,15 +94,16 @@ def validate_environment(self) -> Self:
10494
timeout=self.default_request_timeout,
10595
max_retries=self.max_retries,
10696
)
107-
self.HUMAN_PROMPT = anthropic.HUMAN_PROMPT
108-
self.AI_PROMPT = anthropic.AI_PROMPT
97+
# Keep for backward compatibility but not used in Messages API
98+
self.HUMAN_PROMPT = getattr(anthropic, "HUMAN_PROMPT", None)
99+
self.AI_PROMPT = getattr(anthropic, "AI_PROMPT", None)
109100
return self
110101

111102
@property
112103
def _default_params(self) -> Mapping[str, Any]:
113104
"""Get the default parameters for calling Anthropic API."""
114105
d = {
115-
"max_tokens_to_sample": self.max_tokens_to_sample,
106+
"max_tokens": self.max_tokens,
116107
"model": self.model,
117108
}
118109
if self.temperature is not None:
@@ -129,16 +120,8 @@ def _identifying_params(self) -> Mapping[str, Any]:
129120
return {**self._default_params}
130121

131122
def _get_anthropic_stop(self, stop: Optional[list[str]] = None) -> list[str]:
132-
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
133-
msg = "Please ensure the anthropic package is loaded"
134-
raise NameError(msg)
135-
136123
if stop is None:
137124
stop = []
138-
139-
# Never want model to invent new turns of Human / Assistant dialog.
140-
stop.extend([self.HUMAN_PROMPT])
141-
142125
return stop
143126

144127

@@ -192,7 +175,7 @@ def _identifying_params(self) -> dict[str, Any]:
192175
"""Get the identifying parameters."""
193176
return {
194177
"model": self.model,
195-
"max_tokens": self.max_tokens_to_sample,
178+
"max_tokens": self.max_tokens,
196179
"temperature": self.temperature,
197180
"top_k": self.top_k,
198181
"top_p": self.top_p,
@@ -211,27 +194,51 @@ def _get_ls_params(
211194
params = super()._get_ls_params(stop=stop, **kwargs)
212195
identifying_params = self._identifying_params
213196
if max_tokens := kwargs.get(
214-
"max_tokens_to_sample",
197+
"max_tokens",
215198
identifying_params.get("max_tokens"),
216199
):
217200
params["ls_max_tokens"] = max_tokens
218201
return params
219202

220-
def _wrap_prompt(self, prompt: str) -> str:
221-
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
222-
msg = "Please ensure the anthropic package is loaded"
223-
raise NameError(msg)
224-
225-
if prompt.startswith(self.HUMAN_PROMPT):
226-
return prompt # Already wrapped.
227-
228-
# Guard against common errors in specifying wrong number of newlines.
229-
corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt)
230-
if n_subs == 1:
231-
return corrected_prompt
232-
233-
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
234-
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
203+
def _format_messages(self, prompt: str) -> list[dict[str, str]]:
204+
"""Convert prompt to Messages API format."""
205+
messages = []
206+
207+
# Handle legacy prompts that might have HUMAN_PROMPT/AI_PROMPT markers
208+
if self.HUMAN_PROMPT and self.HUMAN_PROMPT in prompt:
209+
# Split on human/assistant turns
210+
parts = prompt.split(self.HUMAN_PROMPT)
211+
212+
for _, part in enumerate(parts):
213+
if not part.strip():
214+
continue
215+
216+
if self.AI_PROMPT and self.AI_PROMPT in part:
217+
# Split human and assistant parts
218+
human_part, assistant_part = part.split(self.AI_PROMPT, 1)
219+
if human_part.strip():
220+
messages.append({"role": "user", "content": human_part.strip()})
221+
if assistant_part.strip():
222+
messages.append(
223+
{"role": "assistant", "content": assistant_part.strip()}
224+
)
225+
else:
226+
# Just human content
227+
if part.strip():
228+
messages.append({"role": "user", "content": part.strip()})
229+
else:
230+
# Handle modern format or plain text
231+
# Clean prompt for Messages API
232+
content = re.sub(r"^\n*Human:\s*", "", prompt)
233+
content = re.sub(r"\n*Assistant:\s*.*$", "", content)
234+
if content.strip():
235+
messages.append({"role": "user", "content": content.strip()})
236+
237+
# Ensure we have at least one message
238+
if not messages:
239+
messages = [{"role": "user", "content": prompt.strip() or "Hello"}]
240+
241+
return messages
235242

236243
def _call(
237244
self,
@@ -272,15 +279,19 @@ def _call(
272279

273280
stop = self._get_anthropic_stop(stop)
274281
params = {**self._default_params, **kwargs}
275-
response = self.client.completions.create(
276-
prompt=self._wrap_prompt(prompt),
277-
stop_sequences=stop,
282+
283+
# Remove parameters not supported by Messages API
284+
params = {k: v for k, v in params.items() if k != "max_tokens_to_sample"}
285+
286+
response = self.client.messages.create(
287+
messages=self._format_messages(prompt),
288+
stop_sequences=stop if stop else None,
278289
**params,
279290
)
280-
return response.completion
291+
return response.content[0].text
281292

282293
def convert_prompt(self, prompt: PromptValue) -> str:
283-
return self._wrap_prompt(prompt.to_string())
294+
return prompt.to_string()
284295

285296
async def _acall(
286297
self,
@@ -304,12 +315,15 @@ async def _acall(
304315
stop = self._get_anthropic_stop(stop)
305316
params = {**self._default_params, **kwargs}
306317

307-
response = await self.async_client.completions.create(
308-
prompt=self._wrap_prompt(prompt),
309-
stop_sequences=stop,
318+
# Remove parameters not supported by Messages API
319+
params = {k: v for k, v in params.items() if k != "max_tokens_to_sample"}
320+
321+
response = await self.async_client.messages.create(
322+
messages=self._format_messages(prompt),
323+
stop_sequences=stop if stop else None,
310324
**params,
311325
)
312-
return response.completion
326+
return response.content[0].text
313327

314328
def _stream(
315329
self,
@@ -343,17 +357,20 @@ def _stream(
343357
stop = self._get_anthropic_stop(stop)
344358
params = {**self._default_params, **kwargs}
345359

346-
for token in self.client.completions.create(
347-
prompt=self._wrap_prompt(prompt),
348-
stop_sequences=stop,
349-
stream=True,
350-
**params,
351-
):
352-
chunk = GenerationChunk(text=token.completion)
360+
# Remove parameters not supported by Messages API
361+
params = {k: v for k, v in params.items() if k != "max_tokens_to_sample"}
353362

354-
if run_manager:
355-
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
356-
yield chunk
363+
with self.client.messages.stream(
364+
messages=self._format_messages(prompt),
365+
stop_sequences=stop if stop else None,
366+
**params,
367+
) as stream:
368+
for event in stream:
369+
if event.type == "content_block_delta" and hasattr(event.delta, "text"):
370+
chunk = GenerationChunk(text=event.delta.text)
371+
if run_manager:
372+
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
373+
yield chunk
357374

358375
async def _astream(
359376
self,
@@ -386,17 +403,20 @@ async def _astream(
386403
stop = self._get_anthropic_stop(stop)
387404
params = {**self._default_params, **kwargs}
388405

389-
async for token in await self.async_client.completions.create(
390-
prompt=self._wrap_prompt(prompt),
391-
stop_sequences=stop,
392-
stream=True,
393-
**params,
394-
):
395-
chunk = GenerationChunk(text=token.completion)
406+
# Remove parameters not supported by Messages API
407+
params = {k: v for k, v in params.items() if k != "max_tokens_to_sample"}
396408

397-
if run_manager:
398-
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
399-
yield chunk
409+
async with self.async_client.messages.stream(
410+
messages=self._format_messages(prompt),
411+
stop_sequences=stop if stop else None,
412+
**params,
413+
) as stream:
414+
async for event in stream:
415+
if event.type == "content_block_delta" and hasattr(event.delta, "text"):
416+
chunk = GenerationChunk(text=event.delta.text)
417+
if run_manager:
418+
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
419+
yield chunk
400420

401421
def get_num_tokens(self, text: str) -> int:
402422
"""Calculate number of tokens."""

0 commit comments

Comments
 (0)