3
3
import re
4
4
import warnings
5
5
from collections .abc import AsyncIterator , Iterator , Mapping
6
- from typing import (
7
- Any ,
8
- Callable ,
9
- Optional ,
10
- )
6
+ from typing import Any , Callable , Optional
11
7
12
8
import anthropic
13
9
from langchain_core ._api .deprecation import deprecated
19
15
from langchain_core .language_models .llms import LLM
20
16
from langchain_core .outputs import GenerationChunk
21
17
from langchain_core .prompt_values import PromptValue
22
- from langchain_core .utils import (
23
- get_pydantic_field_names ,
24
- )
25
- from langchain_core .utils .utils import (
26
- _build_model_kwargs ,
27
- from_env ,
28
- secret_from_env ,
29
- )
18
+ from langchain_core .utils import get_pydantic_field_names
19
+ from langchain_core .utils .utils import _build_model_kwargs , from_env , secret_from_env
30
20
from pydantic import ConfigDict , Field , SecretStr , model_validator
31
21
from typing_extensions import Self
32
22
33
23
34
24
class _AnthropicCommon (BaseLanguageModel ):
35
25
client : Any = None #: :meta private:
36
26
async_client : Any = None #: :meta private:
37
- model : str = Field (default = "claude-2 " , alias = "model_name" )
27
+ model : str = Field (default = "claude-3-5-sonnet-latest " , alias = "model_name" )
38
28
"""Model name to use."""
39
29
40
- max_tokens_to_sample : int = Field (default = 1024 , alias = "max_tokens " )
30
+ max_tokens : int = Field (default = 1024 , alias = "max_tokens_to_sample " )
41
31
"""Denotes the number of tokens to predict per generation."""
42
32
43
33
temperature : Optional [float ] = None
@@ -104,15 +94,16 @@ def validate_environment(self) -> Self:
104
94
timeout = self .default_request_timeout ,
105
95
max_retries = self .max_retries ,
106
96
)
107
- self .HUMAN_PROMPT = anthropic .HUMAN_PROMPT
108
- self .AI_PROMPT = anthropic .AI_PROMPT
97
+ # Keep for backward compatibility but not used in Messages API
98
+ self .HUMAN_PROMPT = getattr (anthropic , "HUMAN_PROMPT" , None )
99
+ self .AI_PROMPT = getattr (anthropic , "AI_PROMPT" , None )
109
100
return self
110
101
111
102
@property
112
103
def _default_params (self ) -> Mapping [str , Any ]:
113
104
"""Get the default parameters for calling Anthropic API."""
114
105
d = {
115
- "max_tokens_to_sample " : self .max_tokens_to_sample ,
106
+ "max_tokens " : self .max_tokens ,
116
107
"model" : self .model ,
117
108
}
118
109
if self .temperature is not None :
@@ -129,16 +120,8 @@ def _identifying_params(self) -> Mapping[str, Any]:
129
120
return {** self ._default_params }
130
121
131
122
def _get_anthropic_stop (self , stop : Optional [list [str ]] = None ) -> list [str ]:
132
- if not self .HUMAN_PROMPT or not self .AI_PROMPT :
133
- msg = "Please ensure the anthropic package is loaded"
134
- raise NameError (msg )
135
-
136
123
if stop is None :
137
124
stop = []
138
-
139
- # Never want model to invent new turns of Human / Assistant dialog.
140
- stop .extend ([self .HUMAN_PROMPT ])
141
-
142
125
return stop
143
126
144
127
@@ -192,7 +175,7 @@ def _identifying_params(self) -> dict[str, Any]:
192
175
"""Get the identifying parameters."""
193
176
return {
194
177
"model" : self .model ,
195
- "max_tokens" : self .max_tokens_to_sample ,
178
+ "max_tokens" : self .max_tokens ,
196
179
"temperature" : self .temperature ,
197
180
"top_k" : self .top_k ,
198
181
"top_p" : self .top_p ,
@@ -211,27 +194,51 @@ def _get_ls_params(
211
194
params = super ()._get_ls_params (stop = stop , ** kwargs )
212
195
identifying_params = self ._identifying_params
213
196
if max_tokens := kwargs .get (
214
- "max_tokens_to_sample " ,
197
+ "max_tokens " ,
215
198
identifying_params .get ("max_tokens" ),
216
199
):
217
200
params ["ls_max_tokens" ] = max_tokens
218
201
return params
219
202
220
- def _wrap_prompt (self , prompt : str ) -> str :
221
- if not self .HUMAN_PROMPT or not self .AI_PROMPT :
222
- msg = "Please ensure the anthropic package is loaded"
223
- raise NameError (msg )
224
-
225
- if prompt .startswith (self .HUMAN_PROMPT ):
226
- return prompt # Already wrapped.
227
-
228
- # Guard against common errors in specifying wrong number of newlines.
229
- corrected_prompt , n_subs = re .subn (r"^\n*Human:" , self .HUMAN_PROMPT , prompt )
230
- if n_subs == 1 :
231
- return corrected_prompt
232
-
233
- # As a last resort, wrap the prompt ourselves to emulate instruct-style.
234
- return f"{ self .HUMAN_PROMPT } { prompt } { self .AI_PROMPT } Sure, here you go:\n "
203
+ def _format_messages (self , prompt : str ) -> list [dict [str , str ]]:
204
+ """Convert prompt to Messages API format."""
205
+ messages = []
206
+
207
+ # Handle legacy prompts that might have HUMAN_PROMPT/AI_PROMPT markers
208
+ if self .HUMAN_PROMPT and self .HUMAN_PROMPT in prompt :
209
+ # Split on human/assistant turns
210
+ parts = prompt .split (self .HUMAN_PROMPT )
211
+
212
+ for _ , part in enumerate (parts ):
213
+ if not part .strip ():
214
+ continue
215
+
216
+ if self .AI_PROMPT and self .AI_PROMPT in part :
217
+ # Split human and assistant parts
218
+ human_part , assistant_part = part .split (self .AI_PROMPT , 1 )
219
+ if human_part .strip ():
220
+ messages .append ({"role" : "user" , "content" : human_part .strip ()})
221
+ if assistant_part .strip ():
222
+ messages .append (
223
+ {"role" : "assistant" , "content" : assistant_part .strip ()}
224
+ )
225
+ else :
226
+ # Just human content
227
+ if part .strip ():
228
+ messages .append ({"role" : "user" , "content" : part .strip ()})
229
+ else :
230
+ # Handle modern format or plain text
231
+ # Clean prompt for Messages API
232
+ content = re .sub (r"^\n*Human:\s*" , "" , prompt )
233
+ content = re .sub (r"\n*Assistant:\s*.*$" , "" , content )
234
+ if content .strip ():
235
+ messages .append ({"role" : "user" , "content" : content .strip ()})
236
+
237
+ # Ensure we have at least one message
238
+ if not messages :
239
+ messages = [{"role" : "user" , "content" : prompt .strip () or "Hello" }]
240
+
241
+ return messages
235
242
236
243
def _call (
237
244
self ,
@@ -272,15 +279,19 @@ def _call(
272
279
273
280
stop = self ._get_anthropic_stop (stop )
274
281
params = {** self ._default_params , ** kwargs }
275
- response = self .client .completions .create (
276
- prompt = self ._wrap_prompt (prompt ),
277
- stop_sequences = stop ,
282
+
283
+ # Remove parameters not supported by Messages API
284
+ params = {k : v for k , v in params .items () if k != "max_tokens_to_sample" }
285
+
286
+ response = self .client .messages .create (
287
+ messages = self ._format_messages (prompt ),
288
+ stop_sequences = stop if stop else None ,
278
289
** params ,
279
290
)
280
- return response .completion
291
+ return response .content [ 0 ]. text
281
292
282
293
def convert_prompt (self , prompt : PromptValue ) -> str :
283
- return self . _wrap_prompt ( prompt .to_string () )
294
+ return prompt .to_string ()
284
295
285
296
async def _acall (
286
297
self ,
@@ -304,12 +315,15 @@ async def _acall(
304
315
stop = self ._get_anthropic_stop (stop )
305
316
params = {** self ._default_params , ** kwargs }
306
317
307
- response = await self .async_client .completions .create (
308
- prompt = self ._wrap_prompt (prompt ),
309
- stop_sequences = stop ,
318
+ # Remove parameters not supported by Messages API
319
+ params = {k : v for k , v in params .items () if k != "max_tokens_to_sample" }
320
+
321
+ response = await self .async_client .messages .create (
322
+ messages = self ._format_messages (prompt ),
323
+ stop_sequences = stop if stop else None ,
310
324
** params ,
311
325
)
312
- return response .completion
326
+ return response .content [ 0 ]. text
313
327
314
328
def _stream (
315
329
self ,
@@ -343,17 +357,20 @@ def _stream(
343
357
stop = self ._get_anthropic_stop (stop )
344
358
params = {** self ._default_params , ** kwargs }
345
359
346
- for token in self .client .completions .create (
347
- prompt = self ._wrap_prompt (prompt ),
348
- stop_sequences = stop ,
349
- stream = True ,
350
- ** params ,
351
- ):
352
- chunk = GenerationChunk (text = token .completion )
360
+ # Remove parameters not supported by Messages API
361
+ params = {k : v for k , v in params .items () if k != "max_tokens_to_sample" }
353
362
354
- if run_manager :
355
- run_manager .on_llm_new_token (chunk .text , chunk = chunk )
356
- yield chunk
363
+ with self .client .messages .stream (
364
+ messages = self ._format_messages (prompt ),
365
+ stop_sequences = stop if stop else None ,
366
+ ** params ,
367
+ ) as stream :
368
+ for event in stream :
369
+ if event .type == "content_block_delta" and hasattr (event .delta , "text" ):
370
+ chunk = GenerationChunk (text = event .delta .text )
371
+ if run_manager :
372
+ run_manager .on_llm_new_token (chunk .text , chunk = chunk )
373
+ yield chunk
357
374
358
375
async def _astream (
359
376
self ,
@@ -386,17 +403,20 @@ async def _astream(
386
403
stop = self ._get_anthropic_stop (stop )
387
404
params = {** self ._default_params , ** kwargs }
388
405
389
- async for token in await self .async_client .completions .create (
390
- prompt = self ._wrap_prompt (prompt ),
391
- stop_sequences = stop ,
392
- stream = True ,
393
- ** params ,
394
- ):
395
- chunk = GenerationChunk (text = token .completion )
406
+ # Remove parameters not supported by Messages API
407
+ params = {k : v for k , v in params .items () if k != "max_tokens_to_sample" }
396
408
397
- if run_manager :
398
- await run_manager .on_llm_new_token (chunk .text , chunk = chunk )
399
- yield chunk
409
+ async with self .async_client .messages .stream (
410
+ messages = self ._format_messages (prompt ),
411
+ stop_sequences = stop if stop else None ,
412
+ ** params ,
413
+ ) as stream :
414
+ async for event in stream :
415
+ if event .type == "content_block_delta" and hasattr (event .delta , "text" ):
416
+ chunk = GenerationChunk (text = event .delta .text )
417
+ if run_manager :
418
+ await run_manager .on_llm_new_token (chunk .text , chunk = chunk )
419
+ yield chunk
400
420
401
421
def get_num_tokens (self , text : str ) -> int :
402
422
"""Calculate number of tokens."""
0 commit comments