Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions lib/langchain/llm/aws_bedrock.rb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AwsBedrock < Base
].freeze

SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[
amazon
anthropic
ai21
mistral
Expand Down Expand Up @@ -216,6 +217,8 @@ def compose_parameters(params, model_id)
params
elsif provider_name(model_id) == :mistral
params
elsif provider_name(model_id) == :amazon
compose_parameters_amazon(params)
end
end

Expand All @@ -238,6 +241,8 @@ def parse_response(response, model_id)
Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string))
elsif provider_name(model_id) == :mistral
Langchain::LLM::MistralAIResponse.new(JSON.parse(response.body.string))
elsif provider_name(model_id) == :amazon
Langchain::LLM::AwsBedrockAmazonResponse.new(JSON.parse(response.body.string))
end
end

Expand Down Expand Up @@ -288,6 +293,18 @@ def compose_parameters_anthropic(params)
params.merge(anthropic_version: "bedrock-2023-05-31")
end

def compose_parameters_amazon(params)
params = params.merge(inferenceConfig: {
maxTokens: params[:max_tokens],
temperature: params[:temperature],
topP: params[:top_p],
topK: params[:top_k],
stopSequences: params[:stop_sequences]
}.compact)

params.reject { |k, _| k == :max_tokens || k == :temperature }
end

def response_from_chunks(chunks)
raw_response = {}

Expand Down
37 changes: 37 additions & 0 deletions lib/langchain/llm/response/aws_bedrock_amazon_response.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# frozen_string_literal: true

module Langchain::LLM
class AwsBedrockAmazonResponse < BaseResponse
def completion
raw_response.dig("output", "message", "content", 0, "text")
end

def chat_completion
completion
end

def chat_completions
completions
end

def completions
nil
end

def stop_reason
raw_response.dig("stopReason")
end

def prompt_tokens
raw_response.dig("usage", "inputTokens").to_i
end

def completion_tokens
raw_response.dig("usage", "outputTokens").to_i
end

def total_tokens
raw_response.dig("usage", "totalTokens").to_i
end
end
end
44 changes: 44 additions & 0 deletions spec/langchain/llm/aws_bedrock_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,50 @@
end
end
end

context "with amazon provider" do
let(:response) do
{
output: {
message: {
content: [
{text: "The capital of France is Paris."}
]
}
},
usage: {inputTokens: 14, outputTokens: 10}
}.to_json
end

let(:model_id) { "amazon.nova-pro-v1:0" }

before do
response_object = double("response_object")
allow(response_object).to receive(:body).and_return(StringIO.new(response))
allow(subject.client).to receive(:invoke_model)
.with(matching(
model_id:,
body: {
messages: [{role: "user", content: [{text: "What is the capital of France?"}]}],
inferenceConfig: {
maxTokens: 300
}
}.to_json,
content_type: "application/json",
accept: "application/json"
))
.and_return(response_object)
end

it "returns a completion" do
expect(
subject.chat(
messages: [{role: "user", content: [{text: "What is the capital of France?"}]}],
model: model_id
).chat_completion
).to eq("The capital of France is Paris.")
end
end
end

describe "#complete" do
Expand Down