33SPDX-License-Identifier: MIT-0
44"""
55import json
6- from langchain .llms import Bedrock
76from langchain .prompts .prompt import PromptTemplate
7+ from langchain_community .llms import Bedrock
8+ from langchain_community .chat_models import BedrockChat
89import logging
910import os
1011from timeit import default_timer as timer
@@ -67,6 +68,8 @@ def get_max_output_length(model_id):
6768 'anthropic.claude-instant-v1' : 4000 ,
6869 'anthropic.claude-v2' : 4000 ,
6970 'anthropic.claude-v2:1' : 4000 ,
71+ 'anthropic.claude-3-sonnet-20240229-v1:0' : 4000 ,
72+ 'anthropic.claude-3-haiku-20240307-v1:0' : 4000 ,
7073 'ai21.j2-mid-v1' : 8191 ,
7174 'ai21.j2-ultra-v1' : 8191 ,
7275 'cohere.command-light-text-v14' : 4096 ,
@@ -96,6 +99,11 @@ def get_model_kwargs(model_id, temperature, max_response_token_length):
9699 "temperature" : temperature ,
97100 "max_tokens_to_sample" : max_response_token_length
98101 }
102+ case 'anthropic.claude-3-sonnet-20240229-v1:0' | 'anthropic.claude-3-haiku-20240307-v1:0' :
103+ model_kwargs = {
104+ "temperature" : temperature ,
105+ "max_tokens" : max_response_token_length
106+ }
99107 case 'ai21.j2-mid-v1' | 'ai21.j2-ultra-v1' :
100108 model_kwargs = {
101109 "temperature" : temperature ,
@@ -183,17 +191,18 @@ def invoke_llm_with_bedrock_rt(model_id, bedrock_rt_client, temperature, max_res
183191 return prompt_response
184192
185193
186- # Function to invoke the specified LLM through the LangChain client and
194+ # Function to invoke the specified LLM through the LangChain LLM client and
187195# using the specified prompt
188- def invoke_llm_with_lc (model_id , bedrock_rt_client , temperature , max_response_token_length , prompt ):
196+ def invoke_llm_with_lc_llm (model_id , bedrock_rt_client , temperature , max_response_token_length , prompt ):
189197 # Create the LangChain LLM client
190- logging .info ('Creating LangChain client for LLM "{}"...' .format (model_id ))
198+ logging .info ('Creating LangChain LLM client for LLM "{}"...' .format (model_id ))
191199 llm = Bedrock (
192- model_id = model_id ,
193- model_kwargs = get_model_kwargs (model_id , temperature , max_response_token_length ),
194- client = bedrock_rt_client
200+ model_id = model_id ,
201+ model_kwargs = get_model_kwargs (model_id , temperature , max_response_token_length ),
202+ client = bedrock_rt_client ,
203+ streaming = False
195204 )
196- logging .info ('Completed creating LangChain client for LLM.' )
205+ logging .info ('Completed creating LangChain LLM client for LLM.' )
197206 logging .info ('Invoking LLM "{}" with specified inference parameters "{}"...' .
198207 format (llm .model_id , llm .model_kwargs ))
199208 start = timer ()
@@ -205,6 +214,29 @@ def invoke_llm_with_lc(model_id, bedrock_rt_client, temperature, max_response_to
205214 return prompt_response
206215
207216
217+ # Function to invoke the specified LLM through the LangChain ChatModel client and
218+ # using the specified prompt
219+ def invoke_llm_with_lc_cm (model_id , bedrock_rt_client , temperature , max_response_token_length , prompt ):
220+ # Create the LangChain ChatModel client
221+ logging .info ('Creating LangChain ChatModel client for LLM "{}"...' .format (model_id ))
222+ llm = BedrockChat (
223+ model_id = model_id ,
224+ model_kwargs = get_model_kwargs (model_id , temperature , max_response_token_length ),
225+ client = bedrock_rt_client ,
226+ streaming = False
227+ )
228+ logging .info ('Completed creating LangChain ChatModel client for LLM.' )
229+ logging .info ('Invoking LLM "{}" with specified inference parameters "{}"...' .
230+ format (llm .model_id , llm .model_kwargs ))
231+ start = timer ()
232+ prompt_response = llm .invoke (prompt ).content
233+ end = timer ()
234+ logging .info (prompt + prompt_response )
235+ logging .info ('Completed invoking LLM.' )
236+ logging .info ('Prompt processing duration = {} second(s)' .format (end - start ))
237+ return prompt_response
238+
239+
208240# Function to process the steps required in the example prompt 1
209241def process_prompt_1 (model_id , bedrock_rt_client , temperature , max_response_token_length ,
210242 prompt_templates_dir , prompt_template_file , prompt_data , call_to_action ):
@@ -213,9 +245,9 @@ def process_prompt_1(model_id, bedrock_rt_client, temperature, max_response_toke
213245 DATA = prompt_data , CALL_TO_ACTION = call_to_action )
214246 # Invoke the LLM and print the response
215247 match model_id :
216- case 'mistral.mistral-7b-instruct-v0:2 ' | 'mistral.mixtral-8x7b-instruct-v0:1 ' :
217- return invoke_llm_with_bedrock_rt (model_id , bedrock_rt_client , temperature ,
218- max_response_token_length , prompt )
248+ case 'anthropic.claude-3-sonnet-20240229-v1:0 ' | 'anthropic.claude-3-haiku-20240307-v1:0 ' :
249+ return invoke_llm_with_lc_cm (model_id , bedrock_rt_client , temperature ,
250+ max_response_token_length , prompt )
219251 case _:
220- return invoke_llm_with_lc (model_id , bedrock_rt_client , temperature ,
221- max_response_token_length , prompt )
252+ return invoke_llm_with_lc_llm (model_id , bedrock_rt_client , temperature ,
253+ max_response_token_length , prompt )
0 commit comments