Skip to content

Commit 1b2c79b

Browse files
authored
feat: add support for llms on AWS Bedrock (#226) (#227)
Hi, I added support for models on bedrock. Please check the changes. I'll add examples in a different issue to help users who are using AWS. Thanks!
1 parent 405dae2 commit 1b2c79b

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

src/ragas/llms/base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import typing as t
55
from abc import ABC, abstractmethod
66

7-
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
7+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI, BedrockChat
88
from langchain.chat_models.base import BaseChatModel
9-
from langchain.llms import AzureOpenAI, OpenAI
9+
from langchain.llms import AzureOpenAI, OpenAI, Bedrock
1010
from langchain.llms.base import BaseLLM
1111
from langchain.schema import LLMResult
1212

@@ -20,6 +20,8 @@
2020
def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool:
2121
return isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)
2222

23+
def isBedrock(llm: BaseLLM | BaseChatModel) -> bool:
24+
return isinstance(llm, Bedrock) or isinstance(llm, BedrockChat)
2325

2426
# have to specify it twice for runtime and static checks
2527
MULTIPLE_COMPLETION_SUPPORTED = [OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI]
@@ -116,7 +118,10 @@ def generate(
116118
) -> LLMResult:
117119
# set temperature to 0.2 for multiple completions
118120
temperature = 0.2 if n > 1 else 0
119-
self.llm.temperature = temperature
121+
if isBedrock(self.llm) and ("model_kwargs" in self.llm.__dict__):
122+
self.llm.model_kwargs = {"temperature": temperature}
123+
else:
124+
self.llm.temperature = temperature
120125

121126
if self.llm_supports_completions(self.llm):
122127
return self.generate_multiple_completions(prompts, n, callbacks)
@@ -135,7 +140,7 @@ def generate(
135140

136141
# compute total token usage by adding individual token usage
137142
llm_output = list_llmresults[0].llm_output
138-
if "token_usage" in llm_output:
143+
if (llm_output is not None) and ("token_usage" in llm_output):
139144
sum_prompt_tokens = 0
140145
sum_completion_tokens = 0
141146
sum_total_tokens = 0

src/ragas/metrics/llms/base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import typing as t
44
from abc import ABC, abstractmethod
55

6-
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
6+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI, BedrockChat
77
from langchain.chat_models.base import BaseChatModel
8-
from langchain.llms import AzureOpenAI, OpenAI
8+
from langchain.llms import AzureOpenAI, OpenAI, Bedrock
99
from langchain.llms.base import BaseLLM
1010
from langchain.prompts import ChatPromptTemplate
1111
from langchain.schema import LLMResult
@@ -19,6 +19,8 @@
1919
def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool:
2020
return isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)
2121

22+
def isBedrock(llm: BaseLLM | BaseChatModel) -> bool:
23+
return isinstance(llm, Bedrock) or isinstance(llm, BedrockChat)
2224

2325
# have to specify it twice for runtime and static checks
2426
MULTIPLE_COMPLETION_SUPPORTED = [OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI]
@@ -115,7 +117,10 @@ def generate(
115117
) -> LLMResult:
116118
# set temperature to 0.2 for multiple completions
117119
temperature = 0.2 if n > 1 else 0
118-
self.llm.temperature = temperature
120+
if isBedrock(self.llm) and ("model_kwargs" in self.llm.__dict__):
121+
self.llm.model_kwargs = {"temperature": temperature}
122+
else:
123+
self.llm.temperature = temperature
119124

120125
if self.llm_supports_completions(self.llm):
121126
return self.generate_multiple_completions(prompts, n, callbacks)
@@ -134,7 +139,7 @@ def generate(
134139

135140
# compute total token usage by adding individual token usage
136141
llm_output = list_llmresults[0].llm_output
137-
if "token_usage" in llm_output:
142+
if (llm_output is not None) and ("token_usage" in llm_output):
138143
sum_prompt_tokens = 0
139144
sum_completion_tokens = 0
140145
sum_total_tokens = 0

0 commit comments

Comments
 (0)