Skip to content

Commit e7f0f19

Browse files
authored
Add files via upload
1 parent 8663517 commit e7f0f19

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

notebooks/scripts/helper_functions.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
SPDX-License-Identifier: MIT-0
44
"""
55
import json
6-
from langchain.llms import Bedrock
76
from langchain.prompts.prompt import PromptTemplate
7+
from langchain_community.llms import Bedrock
8+
from langchain_community.chat_models import BedrockChat
89
import logging
910
import os
1011
from timeit import default_timer as timer
@@ -67,6 +68,8 @@ def get_max_output_length(model_id):
6768
'anthropic.claude-instant-v1': 4000,
6869
'anthropic.claude-v2': 4000,
6970
'anthropic.claude-v2:1': 4000,
71+
'anthropic.claude-3-sonnet-20240229-v1:0': 4000,
72+
'anthropic.claude-3-haiku-20240307-v1:0': 4000,
7073
'ai21.j2-mid-v1': 8191,
7174
'ai21.j2-ultra-v1': 8191,
7275
'cohere.command-light-text-v14': 4096,
@@ -96,6 +99,11 @@ def get_model_kwargs(model_id, temperature, max_response_token_length):
9699
"temperature": temperature,
97100
"max_tokens_to_sample": max_response_token_length
98101
}
102+
case 'anthropic.claude-3-sonnet-20240229-v1:0' | 'anthropic.claude-3-haiku-20240307-v1:0':
103+
model_kwargs = {
104+
"temperature": temperature,
105+
"max_tokens": max_response_token_length
106+
}
99107
case 'ai21.j2-mid-v1' | 'ai21.j2-ultra-v1':
100108
model_kwargs = {
101109
"temperature": temperature,
@@ -183,17 +191,18 @@ def invoke_llm_with_bedrock_rt(model_id, bedrock_rt_client, temperature, max_res
183191
return prompt_response
184192

185193

186-
# Function to invoke the specified LLM through the LangChain client and
194+
# Function to invoke the specified LLM through the LangChain LLM client and
187195
# using the specified prompt
188-
def invoke_llm_with_lc(model_id, bedrock_rt_client, temperature, max_response_token_length, prompt):
196+
def invoke_llm_with_lc_llm(model_id, bedrock_rt_client, temperature, max_response_token_length, prompt):
189197
# Create the LangChain LLM client
190-
logging.info('Creating LangChain client for LLM "{}"...'.format(model_id))
198+
logging.info('Creating LangChain LLM client for LLM "{}"...'.format(model_id))
191199
llm = Bedrock(
192-
model_id = model_id,
193-
model_kwargs = get_model_kwargs(model_id, temperature, max_response_token_length),
194-
client = bedrock_rt_client
200+
model_id=model_id,
201+
model_kwargs=get_model_kwargs(model_id, temperature, max_response_token_length),
202+
client=bedrock_rt_client,
203+
streaming=False
195204
)
196-
logging.info('Completed creating LangChain client for LLM.')
205+
logging.info('Completed creating LangChain LLM client for LLM.')
197206
logging.info('Invoking LLM "{}" with specified inference parameters "{}"...'.
198207
format(llm.model_id, llm.model_kwargs))
199208
start = timer()
@@ -205,6 +214,29 @@ def invoke_llm_with_lc(model_id, bedrock_rt_client, temperature, max_response_to
205214
return prompt_response
206215

207216

217+
# Function to invoke the specified LLM through the LangChain ChatModel client and
218+
# using the specified prompt
219+
def invoke_llm_with_lc_cm(model_id, bedrock_rt_client, temperature, max_response_token_length, prompt):
220+
# Create the LangChain ChatModel client
221+
logging.info('Creating LangChain ChatModel client for LLM "{}"...'.format(model_id))
222+
llm = BedrockChat(
223+
model_id=model_id,
224+
model_kwargs=get_model_kwargs(model_id, temperature, max_response_token_length),
225+
client=bedrock_rt_client,
226+
streaming=False
227+
)
228+
logging.info('Completed creating LangChain ChatModel client for LLM.')
229+
logging.info('Invoking LLM "{}" with specified inference parameters "{}"...'.
230+
format(llm.model_id, llm.model_kwargs))
231+
start = timer()
232+
prompt_response = llm.invoke(prompt).content
233+
end = timer()
234+
logging.info(prompt + prompt_response)
235+
logging.info('Completed invoking LLM.')
236+
logging.info('Prompt processing duration = {} second(s)'.format(end - start))
237+
return prompt_response
238+
239+
208240
# Function to process the steps required in the example prompt 1
209241
def process_prompt_1(model_id, bedrock_rt_client, temperature, max_response_token_length,
210242
prompt_templates_dir, prompt_template_file, prompt_data, call_to_action):
@@ -213,9 +245,9 @@ def process_prompt_1(model_id, bedrock_rt_client, temperature, max_response_toke
213245
DATA=prompt_data, CALL_TO_ACTION=call_to_action)
214246
# Invoke the LLM and print the response
215247
match model_id:
216-
case 'mistral.mistral-7b-instruct-v0:2' | 'mistral.mixtral-8x7b-instruct-v0:1':
217-
return invoke_llm_with_bedrock_rt(model_id, bedrock_rt_client, temperature,
218-
max_response_token_length, prompt)
248+
case 'anthropic.claude-3-sonnet-20240229-v1:0' | 'anthropic.claude-3-haiku-20240307-v1:0':
249+
return invoke_llm_with_lc_cm(model_id, bedrock_rt_client, temperature,
250+
max_response_token_length, prompt)
219251
case _:
220-
return invoke_llm_with_lc(model_id, bedrock_rt_client, temperature,
221-
max_response_token_length, prompt)
252+
return invoke_llm_with_lc_llm(model_id, bedrock_rt_client, temperature,
253+
max_response_token_length, prompt)

0 commit comments

Comments
 (0)