@@ -7,51 +7,40 @@ module Langchain::LLM
77 # gem 'aws-sdk-bedrockruntime', '~> 1.1'
88 #
99 # Usage:
10- # llm = Langchain::LLM::AwsBedrock.new(llm_options : {})
10+ # llm = Langchain::LLM::AwsBedrock.new(default_options : {})
1111 #
1212 class AwsBedrock < Base
1313 DEFAULTS = {
14- chat_model : "anthropic.claude-v2 " ,
15- completion_model : "anthropic.claude-v2" ,
14+ chat_model : "anthropic.claude-3-5-sonnet-20240620-v1:0 " ,
15+ completion_model : "anthropic.claude-v2:1 " ,
1616 embedding_model : "amazon.titan-embed-text-v1" ,
1717 max_tokens_to_sample : 300 ,
1818 temperature : 1 ,
1919 top_k : 250 ,
2020 top_p : 0.999 ,
2121 stop_sequences : [ "\n \n Human:" ] ,
22- anthropic_version : "bedrock-2023-05-31" ,
23- return_likelihoods : "NONE" ,
24- count_penalty : {
25- scale : 0 ,
26- apply_to_whitespaces : false ,
27- apply_to_punctuations : false ,
28- apply_to_numbers : false ,
29- apply_to_stopwords : false ,
30- apply_to_emojis : false
31- } ,
32- presence_penalty : {
33- scale : 0 ,
34- apply_to_whitespaces : false ,
35- apply_to_punctuations : false ,
36- apply_to_numbers : false ,
37- apply_to_stopwords : false ,
38- apply_to_emojis : false
39- } ,
40- frequency_penalty : {
41- scale : 0 ,
42- apply_to_whitespaces : false ,
43- apply_to_punctuations : false ,
44- apply_to_numbers : false ,
45- apply_to_stopwords : false ,
46- apply_to_emojis : false
47- }
22+ return_likelihoods : "NONE"
4823 } . freeze
4924
5025 attr_reader :client , :defaults
5126
52- SUPPORTED_COMPLETION_PROVIDERS = %i[ anthropic ai21 cohere meta ] . freeze
53- SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[ anthropic ] . freeze
54- SUPPORTED_EMBEDDING_PROVIDERS = %i[ amazon cohere ] . freeze
27+ SUPPORTED_COMPLETION_PROVIDERS = %i[
28+ anthropic
29+ ai21
30+ cohere
31+ meta
32+ ] . freeze
33+
34+ SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[
35+ anthropic
36+ ai21
37+ mistral
38+ ] . freeze
39+
40+ SUPPORTED_EMBEDDING_PROVIDERS = %i[
41+ amazon
42+ cohere
43+ ] . freeze
5544
5645 def initialize ( aws_client_options : { } , default_options : { } )
5746 depends_on "aws-sdk-bedrockruntime" , req : "aws-sdk-bedrockruntime"
@@ -64,8 +53,7 @@ def initialize(aws_client_options: {}, default_options: {})
6453 temperature : { } ,
6554 max_tokens : { default : @defaults [ :max_tokens_to_sample ] } ,
6655 metadata : { } ,
67- system : { } ,
68- anthropic_version : { default : "bedrock-2023-05-31" }
56+ system : { }
6957 )
7058 chat_parameters . ignore ( :n , :user )
7159 chat_parameters . remap ( stop : :stop_sequences )
@@ -100,23 +88,25 @@ def embed(text:, **params)
10088 # @param params extra parameters passed to Aws::BedrockRuntime::Client#invoke_model
10189 # @return [Langchain::LLM::AnthropicResponse], [Langchain::LLM::CohereResponse] or [Langchain::LLM::AI21Response] Response object
10290 #
103- def complete ( prompt :, **params )
104- raise "Completion provider #{ completion_provider } is not supported." unless SUPPORTED_COMPLETION_PROVIDERS . include? ( completion_provider )
91+ def complete (
92+ prompt :,
93+ model : @defaults [ :completion_model ] ,
94+ **params
95+ )
96+ raise "Completion provider #{ model } is not supported." unless SUPPORTED_COMPLETION_PROVIDERS . include? ( provider_name ( model ) )
10597
106- raise "Model #{ @defaults [ :completion_model ] } only supports #chat." if @defaults [ :completion_model ] . include? ( "claude-3" )
107-
108- parameters = compose_parameters params
98+ parameters = compose_parameters ( params , model )
10999
110100 parameters [ :prompt ] = wrap_prompt prompt
111101
112102 response = client . invoke_model ( {
113- model_id : @defaults [ :completion_model ] ,
103+ model_id : model ,
114104 body : parameters . to_json ,
115105 content_type : "application/json" ,
116106 accept : "application/json"
117107 } )
118108
119- parse_response response
109+ parse_response ( response , model )
120110 end
121111
122112 # Generate a chat completion for a given prompt
@@ -137,10 +127,11 @@ def complete(prompt:, **params)
137127 # @return [Langchain::LLM::AnthropicResponse] Response object
138128 def chat ( params = { } , &block )
139129 parameters = chat_parameters . to_params ( params )
130+ parameters = compose_parameters ( parameters , parameters [ :model ] )
140131
141- raise ArgumentError . new ( "messages argument is required" ) if Array ( parameters [ :messages ] ) . empty?
142-
143- raise "Model #{ parameters [ :model ] } does not support chat completions." unless Langchain :: LLM :: AwsBedrock :: SUPPORTED_CHAT_COMPLETION_PROVIDERS . include? ( completion_provider )
132+ unless SUPPORTED_CHAT_COMPLETION_PROVIDERS . include? ( provider_name ( parameters [ :model ] ) )
133+ raise "Chat provider #{ parameters [ :model ] } is not supported."
134+ end
144135
145136 if block
146137 response_chunks = [ ]
@@ -168,12 +159,26 @@ def chat(params = {}, &block)
168159 accept : "application/json"
169160 } )
170161
171- parse_response response
162+ parse_response ( response , parameters [ :model ] )
172163 end
173164 end
174165
175166 private
176167
168+ def parse_model_id ( model_id )
169+ model_id
170+ . gsub ( "us." , "" ) # Meta append "us." to their model ids
171+ . split ( "." )
172+ end
173+
174+ def provider_name ( model_id )
175+ parse_model_id ( model_id ) . first . to_sym
176+ end
177+
178+ def model_name ( model_id )
179+ parse_model_id ( model_id ) . last
180+ end
181+
177182 def completion_provider
178183 @defaults [ :completion_model ] . split ( "." ) . first . to_sym
179184 end
@@ -200,15 +205,17 @@ def max_tokens_key
200205 end
201206 end
202207
203- def compose_parameters ( params )
204- if completion_provider == :anthropic
205- compose_parameters_anthropic params
206- elsif completion_provider == :cohere
207- compose_parameters_cohere params
208- elsif completion_provider == :ai21
209- compose_parameters_ai21 params
210- elsif completion_provider == :meta
211- compose_parameters_meta params
208+ def compose_parameters ( params , model_id )
209+ if provider_name ( model_id ) == :anthropic
210+ compose_parameters_anthropic ( params )
211+ elsif provider_name ( model_id ) == :cohere
212+ compose_parameters_cohere ( params )
213+ elsif provider_name ( model_id ) == :ai21
214+ params
215+ elsif provider_name ( model_id ) == :meta
216+ params
217+ elsif provider_name ( model_id ) == :mistral
218+ params
212219 end
213220 end
214221
@@ -220,15 +227,17 @@ def compose_embedding_parameters(params)
220227 end
221228 end
222229
223- def parse_response ( response )
224- if completion_provider == :anthropic
230+ def parse_response ( response , model_id )
231+ if provider_name ( model_id ) == :anthropic
225232 Langchain ::LLM ::AnthropicResponse . new ( JSON . parse ( response . body . string ) )
226- elsif completion_provider == :cohere
233+ elsif provider_name ( model_id ) == :cohere
227234 Langchain ::LLM ::CohereResponse . new ( JSON . parse ( response . body . string ) )
228- elsif completion_provider == :ai21
235+ elsif provider_name ( model_id ) == :ai21
229236 Langchain ::LLM ::AI21Response . new ( JSON . parse ( response . body . string , symbolize_names : true ) )
230- elsif completion_provider == :meta
237+ elsif provider_name ( model_id ) == :meta
231238 Langchain ::LLM ::AwsBedrockMetaResponse . new ( JSON . parse ( response . body . string ) )
239+ elsif provider_name ( model_id ) == :mistral
240+ Langchain ::LLM ::MistralAIResponse . new ( JSON . parse ( response . body . string ) )
232241 end
233242 end
234243
@@ -276,61 +285,7 @@ def compose_parameters_cohere(params)
276285 end
277286
278287 def compose_parameters_anthropic ( params )
279- default_params = @defaults . merge ( params )
280-
281- {
282- max_tokens_to_sample : default_params [ :max_tokens_to_sample ] ,
283- temperature : default_params [ :temperature ] ,
284- top_k : default_params [ :top_k ] ,
285- top_p : default_params [ :top_p ] ,
286- stop_sequences : default_params [ :stop_sequences ] ,
287- anthropic_version : default_params [ :anthropic_version ]
288- }
289- end
290-
291- def compose_parameters_ai21 ( params )
292- default_params = @defaults . merge ( params )
293-
294- {
295- maxTokens : default_params [ :max_tokens_to_sample ] ,
296- temperature : default_params [ :temperature ] ,
297- topP : default_params [ :top_p ] ,
298- stopSequences : default_params [ :stop_sequences ] ,
299- countPenalty : {
300- scale : default_params [ :count_penalty ] [ :scale ] ,
301- applyToWhitespaces : default_params [ :count_penalty ] [ :apply_to_whitespaces ] ,
302- applyToPunctuations : default_params [ :count_penalty ] [ :apply_to_punctuations ] ,
303- applyToNumbers : default_params [ :count_penalty ] [ :apply_to_numbers ] ,
304- applyToStopwords : default_params [ :count_penalty ] [ :apply_to_stopwords ] ,
305- applyToEmojis : default_params [ :count_penalty ] [ :apply_to_emojis ]
306- } ,
307- presencePenalty : {
308- scale : default_params [ :presence_penalty ] [ :scale ] ,
309- applyToWhitespaces : default_params [ :presence_penalty ] [ :apply_to_whitespaces ] ,
310- applyToPunctuations : default_params [ :presence_penalty ] [ :apply_to_punctuations ] ,
311- applyToNumbers : default_params [ :presence_penalty ] [ :apply_to_numbers ] ,
312- applyToStopwords : default_params [ :presence_penalty ] [ :apply_to_stopwords ] ,
313- applyToEmojis : default_params [ :presence_penalty ] [ :apply_to_emojis ]
314- } ,
315- frequencyPenalty : {
316- scale : default_params [ :frequency_penalty ] [ :scale ] ,
317- applyToWhitespaces : default_params [ :frequency_penalty ] [ :apply_to_whitespaces ] ,
318- applyToPunctuations : default_params [ :frequency_penalty ] [ :apply_to_punctuations ] ,
319- applyToNumbers : default_params [ :frequency_penalty ] [ :apply_to_numbers ] ,
320- applyToStopwords : default_params [ :frequency_penalty ] [ :apply_to_stopwords ] ,
321- applyToEmojis : default_params [ :frequency_penalty ] [ :apply_to_emojis ]
322- }
323- }
324- end
325-
326- def compose_parameters_meta ( params )
327- default_params = @defaults . merge ( params )
328-
329- {
330- temperature : default_params [ :temperature ] ,
331- top_p : default_params [ :top_p ] ,
332- max_gen_len : default_params [ :max_tokens_to_sample ]
333- }
288+ params . merge ( anthropic_version : "bedrock-2023-05-31" )
334289 end
335290
336291 def response_from_chunks ( chunks )
0 commit comments