Skip to content

Commit cb783d6

Browse files
Resolve the issue of the AWS Bedrock chat model always returning nil (#801)
* Fix issue the AWS Bedrock chat model always `nil` * Add `chat_model` arg for `Langchain::LLM::AwsBedrock#initialize` * Fix the AwsBedrock constructor * CHANGELOG entry --------- Co-authored-by: Andrei Bondarev <[email protected]> Co-authored-by: Andrei Bondarev <[email protected]>
1 parent 1b8338a commit cb783d6

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
## [Unreleased]
2+
- [BREAKING] Modify `Langchain::LLM::AwsBedrock` constructor to pass model options via default_options: {...}
23
- Added support for streaming with Anthropic
34
- Bump anthropic gem
45
- Default Langchain::LLM::Anthropic chat model is "claude-3-5-sonnet-20240620" now

lib/langchain/llm/aws_bedrock.rb

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ module Langchain::LLM
1111
#
1212
class AwsBedrock < Base
1313
DEFAULTS = {
14+
chat_completion_model_name: "anthropic.claude-v2",
1415
completion_model_name: "anthropic.claude-v2",
15-
embedding_model_name: "amazon.titan-embed-text-v1",
16+
embeddings_model_name: "amazon.titan-embed-text-v1",
1617
max_tokens_to_sample: 300,
1718
temperature: 1,
1819
top_k: 250,
@@ -52,13 +53,11 @@ class AwsBedrock < Base
5253
SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
5354
SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon cohere].freeze
5455

55-
def initialize(completion_model: DEFAULTS[:completion_model_name], embedding_model: DEFAULTS[:embedding_model_name], aws_client_options: {}, default_options: {})
56+
def initialize(aws_client_options: {}, default_options: {})
5657
depends_on "aws-sdk-bedrockruntime", req: "aws-sdk-bedrockruntime"
5758

5859
@client = ::Aws::BedrockRuntime::Client.new(**aws_client_options)
5960
@defaults = DEFAULTS.merge(default_options)
60-
.merge(completion_model_name: completion_model)
61-
.merge(embedding_model_name: embedding_model)
6261

6362
chat_parameters.update(
6463
model: {default: @defaults[:chat_completion_model_name]},
@@ -85,7 +84,7 @@ def embed(text:, **params)
8584
parameters = compose_embedding_parameters params.merge(text:)
8685

8786
response = client.invoke_model({
88-
model_id: @defaults[:embedding_model_name],
87+
model_id: @defaults[:embeddings_model_name],
8988
body: parameters.to_json,
9089
content_type: "application/json",
9190
accept: "application/json"
@@ -180,7 +179,7 @@ def completion_provider
180179
end
181180

182181
def embedding_provider
183-
@defaults[:embedding_model_name].split(".").first.to_sym
182+
@defaults[:embeddings_model_name].split(".").first.to_sym
184183
end
185184

186185
def wrap_prompt(prompt)

spec/langchain/llm/aws_bedrock_spec.rb

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
}.to_json
2525
end
2626

27+
let(:model_id) { "anthropic.claude-3-sonnet-20240229-v1:0" }
28+
2729
before do
2830
response_object = double("response_object")
2931
allow(response_object).to receive(:body).and_return(StringIO.new(response))
3032
allow(subject.client).to receive(:invoke_model)
3133
.with(matching(
32-
model_id: "anthropic.claude-3-sonnet-20240229-v1:0",
34+
model_id:,
3335
body: {
3436
messages: [{role: "user", content: "What is the capital of France?"}],
3537
stop_sequences: ["stop"],
@@ -46,12 +48,25 @@
4648
expect(
4749
subject.chat(
4850
messages: [{role: "user", content: "What is the capital of France?"}],
49-
model: "anthropic.claude-3-sonnet-20240229-v1:0",
51+
model: model_id,
5052
stop_sequences: ["stop"]
5153
).chat_completion
5254
).to eq("The capital of France is Paris.")
5355
end
5456

57+
context "without default model" do
58+
let(:model_id) { "anthropic.claude-v2" }
59+
60+
it "returns a completion" do
61+
expect(
62+
subject.chat(
63+
messages: [{role: "user", content: "What is the capital of France?"}],
64+
stop_sequences: ["stop"]
65+
).chat_completion
66+
).to eq("The capital of France is Paris.")
67+
end
68+
end
69+
5570
context "with streaming" do
5671
let(:chunks) do
5772
[
@@ -201,7 +216,7 @@
201216
end
202217

203218
context "with ai21 provider" do
204-
let(:subject) { described_class.new(completion_model: "ai21.j2-ultra-v1") }
219+
let(:subject) { described_class.new(default_options: {completion_model_name: "ai21.j2-ultra-v1"}) }
205220

206221
let(:response) do
207222
StringIO.new("{\"completions\":[{\"data\":{\"text\":\"\\nWhat is the meaning of life? What is the meaning of life?\\nWhat is the meaning\"}}]}")
@@ -308,8 +323,11 @@
308323
context "with custom default_options" do
309324
let(:subject) {
310325
described_class.new(
311-
completion_model: "ai21.j2-ultra-v1",
312-
default_options: {max_tokens_to_sample: 100, temperature: 0.7}
326+
default_options: {
327+
completion_model_name: "ai21.j2-ultra-v1",
328+
max_tokens_to_sample: 100,
329+
temperature: 0.7
330+
}
313331
)
314332
}
315333
let(:response_object) { double("response_object") }
@@ -363,7 +381,7 @@
363381
end
364382

365383
context "with cohere provider" do
366-
let(:subject) { described_class.new(completion_model: "cohere.command-text-v14") }
384+
let(:subject) { described_class.new(default_options: {completion_model_name: "cohere.command-text-v14"}) }
367385

368386
let(:response) do
369387
StringIO.new("{\"generations\":[{\"text\":\"\\nWhat is the meaning of life? What is the meaning of life?\\nWhat is the meaning\"}]}")
@@ -424,8 +442,11 @@
424442
context "with custom default_options" do
425443
let(:subject) {
426444
described_class.new(
427-
completion_model: "cohere.command-text-v14",
428-
default_options: {max_tokens_to_sample: 100, temperature: 0.7}
445+
default_options: {
446+
completion_model_name: "cohere.command-text-v14",
447+
max_tokens_to_sample: 100,
448+
temperature: 0.7
449+
}
429450
)
430451
}
431452
let(:response_object) { double("response_object") }
@@ -456,7 +477,7 @@
456477
end
457478

458479
context "with unsupported provider" do
459-
let(:subject) { described_class.new(completion_model: "unsupported.provider") }
480+
let(:subject) { described_class.new(default_options: {completion_model_name: "unsupported.provider"}) }
460481

461482
it "raises an exception" do
462483
expect { subject.complete(prompt: "Hello World") }.to raise_error("Completion provider unsupported is not supported.")
@@ -492,7 +513,7 @@
492513
end
493514

494515
context "with cohere provider" do
495-
let(:subject) { described_class.new(embedding_model: "cohere.embed-multilingual-v3") }
516+
let(:subject) { described_class.new(default_options: {embeddings_model_name: "cohere.embed-multilingual-v3"}) }
496517

497518
let(:response) do
498519
StringIO.new("{\"embeddings\":[[0.1,0.2,0.3,0.4,0.5]]}")
@@ -522,7 +543,7 @@
522543
end
523544

524545
context "with unsupported provider" do
525-
let(:subject) { described_class.new(embedding_model: "unsupported.provider") }
546+
let(:subject) { described_class.new(default_options: {embeddings_model_name: "unsupported.provider"}) }
526547

527548
it "raises an exception" do
528549
expect { subject.embed(text: "Hello World") }.to raise_error("Completion provider unsupported is not supported.")

0 commit comments

Comments
 (0)