33import typing as t
44from abc import ABC , abstractmethod
55
6- from langchain .chat_models import AzureChatOpenAI , ChatOpenAI
6+ from langchain .chat_models import AzureChatOpenAI , ChatOpenAI , BedrockChat
77from langchain .chat_models .base import BaseChatModel
8- from langchain .llms import AzureOpenAI , OpenAI
8+ from langchain .llms import AzureOpenAI , OpenAI , Bedrock
99from langchain .llms .base import BaseLLM
1010from langchain .prompts import ChatPromptTemplate
1111from langchain .schema import LLMResult
1919def isOpenAI (llm : BaseLLM | BaseChatModel ) -> bool :
2020 return isinstance (llm , OpenAI ) or isinstance (llm , ChatOpenAI )
2121
22+ def isBedrock (llm : BaseLLM | BaseChatModel ) -> bool :
23+ return isinstance (llm , Bedrock ) or isinstance (llm , BedrockChat )
2224
2325# have to specify it twice for runtime and static checks
2426MULTIPLE_COMPLETION_SUPPORTED = [OpenAI , ChatOpenAI , AzureOpenAI , AzureChatOpenAI ]
@@ -115,7 +117,10 @@ def generate(
115117 ) -> LLMResult :
116118 # set temperature to 0.2 for multiple completions
117119 temperature = 0.2 if n > 1 else 0
118- self .llm .temperature = temperature
120+ if isBedrock (self .llm ) and ("model_kwargs" in self .llm .__dict__ ):
121+ self .llm .model_kwargs = {"temperature" : temperature }
122+ else :
123+ self .llm .temperature = temperature
119124
120125 if self .llm_supports_completions (self .llm ):
121126 return self .generate_multiple_completions (prompts , n , callbacks )
@@ -134,7 +139,7 @@ def generate(
134139
135140 # compute total token usage by adding individual token usage
136141 llm_output = list_llmresults [0 ].llm_output
137- if "token_usage" in llm_output :
142+ if ( llm_output is not None ) and ( "token_usage" in llm_output ) :
138143 sum_prompt_tokens = 0
139144 sum_completion_tokens = 0
140145 sum_total_tokens = 0
0 commit comments