Skip to content

Commit 13f3f8e

Browse files
committed
updates chat with document to use BedrockClient and take guardrails into account
1 parent e364005 commit 13f3f8e

File tree

2 files changed

+78
-109
lines changed

2 files changed

+78
-109
lines changed

src/lambda/chat_with_document_resolver/index.py

Lines changed: 77 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -9,64 +9,28 @@
99
import base64
1010
import hashlib
1111
import os
12+
import re
1213
from urllib.parse import urlparse
1314
from botocore.exceptions import ClientError
1415
from idp_common.bedrock.client import BedrockClient
1516

1617
# Set up logging
1718
logger = logging.getLogger()
1819
logger.setLevel(os.environ.get("LOG_LEVEL", "INFO"))
19-
# Get LOG_LEVEL from environment variable with INFO as default
20-
21-
def get_guardrails():
22-
client = boto3.client("bedrock", region_name=os.environ['AWS_REGION'])
23-
guardrails = client.list_guardrails(
24-
maxResults=10
25-
)
2620

27-
grs = []
28-
for gr in guardrails['guardrails']:
29-
if gr['status'] == 'READY':
30-
grs.append({'id': gr['id'], 'version': gr['version'], 'name': gr['name']})
21+
def remove_text_between_brackets(text):
22+
# Find position of first opening bracket
23+
start = text.find('{')
24+
# Find position of last closing bracket
25+
end = text.rfind('}')
3126

32-
return grs
33-
34-
def apply_guardrails_to_prompt(prompt):
35-
br_client = boto3.client("bedrock-runtime", region_name=os.environ['AWS_REGION'])
36-
37-
guardrails = get_guardrails()
38-
39-
try:
40-
for gr in guardrails:
41-
response = br_client.apply_guardrail(
42-
guardrailIdentifier=gr['id'],
43-
guardrailVersion=gr['version'],
44-
source="INPUT",
45-
content=[
46-
{
47-
"text": {
48-
"text": prompt
49-
}
50-
}
51-
]
52-
)
53-
54-
error_details = ""
55-
if (response['action'] == 'GUARDRAIL_INTERVENED'):
56-
for output in response['outputs']:
57-
error_details += output['text']
58-
59-
if len(error_details):
60-
raise Exception(f"{error_details}")
61-
62-
return True
63-
64-
except Exception as e:
65-
print(f"Caught by Guardrail: {e}")
66-
raise
67-
68-
return False
27+
# If both brackets exist, remove text between them including brackets
28+
if start != -1 and end != -1:
29+
return text[:start] + text[end+1:]
30+
# If brackets not found, return original string
31+
return text
6932

33+
# Get LOG_LEVEL from environment variable with INFO as default
7034
def s3_object_exists(bucket, key):
7135
try:
7236
s3 = boto3.client('s3')
@@ -147,77 +111,81 @@ def handler(event, context):
147111

148112
try:
149113
# logger.info(f"Received event: {json.dumps(event)}")
150-
151114
objectKey = event['arguments']['s3Uri']
152115
prompt = event['arguments']['prompt']
153116
history = event['arguments']['history']
154117

155118
full_prompt = "The history JSON object is: " + json.dumps(history) + ".\n\n"
156119
full_prompt += "The user's question is: " + prompt + "\n\n"
157120

158-
gr_test = apply_guardrails_to_prompt(full_prompt)
159-
160-
if gr_test:
161-
# this feature is not enabled until the model can be selected on the chat screen
162-
# selectedModelId = event['arguments']['modelId']
163-
selectedModelId = get_summarization_model()
164-
165-
logger.info(f"Processing S3 URI: {objectKey}")
166-
logger.info(f"Region: {os.environ['AWS_REGION']}")
167-
168-
output_bucket = os.environ['OUTPUT_BUCKET']
169-
170-
bedrock_runtime = boto3.client('bedrock-runtime', region_name=os.environ['AWS_REGION'])
171-
172-
if (len(objectKey)):
173-
fulltext_key = objectKey + '/summary/fulltext.txt'
174-
content_str = ""
175-
s3 = boto3.client('s3')
176-
177-
if not s3_object_exists(output_bucket, fulltext_key):
178-
logger.info(f"Creating full text file: {fulltext_key}")
179-
content_str = get_full_text(output_bucket, objectKey)
180-
181-
s3.put_object(
182-
Bucket=output_bucket,
183-
Key=fulltext_key,
184-
Body=content_str.encode('utf-8')
185-
)
186-
else:
187-
# read full contents of the object as text
188-
response = s3.get_object(Bucket=output_bucket, Key=fulltext_key)
189-
content_str = response['Body'].read().decode('utf-8')
190-
191-
logger.info(f"Model: {selectedModelId}")
192-
logger.info(f"Output Bucket: {output_bucket}")
193-
logger.info(f"Full Text Key: {fulltext_key}")
194-
195-
client = BedrockClient()
196-
# Content with cachepoint tags
197-
content = [
198-
{
199-
"text": content_str + """
200-
<<CACHEPOINT>>
201-
""" + full_prompt
202-
}
203-
]
204-
205-
model_response = client.invoke_model(
206-
model_id="us.amazon.nova-pro-v1:0",
207-
system_prompt="You are an assistant that's responsible for getting details from document text attached here based on questions from the user.\n\nIf you don't know the answer, just say that you don't know. Don't try to make up an answer.\n\nAdditionally, use the user and assistant responses in the following JSON object to see what's been asked and what the resposes were in the past.\n\n",
208-
content=content,
209-
temperature=0.0
121+
# this feature is not enabled until the model can be selected on the chat screen
122+
# selectedModelId = event['arguments']['modelId']
123+
selectedModelId = get_summarization_model()
124+
125+
logger.info(f"Processing S3 URI: {objectKey}")
126+
logger.info(f"Region: {os.environ['AWS_REGION']}")
127+
128+
output_bucket = os.environ['OUTPUT_BUCKET']
129+
130+
if (len(objectKey)):
131+
fulltext_key = objectKey + '/summary/fulltext.txt'
132+
content_str = ""
133+
s3 = boto3.client('s3')
134+
135+
if not s3_object_exists(output_bucket, fulltext_key):
136+
logger.info(f"Creating full text file: {fulltext_key}")
137+
content_str = get_full_text(output_bucket, objectKey)
138+
139+
s3.put_object(
140+
Bucket=output_bucket,
141+
Key=fulltext_key,
142+
Body=content_str.encode('utf-8')
210143
)
144+
else:
145+
# read full contents of the object as text
146+
response = s3.get_object(Bucket=output_bucket, Key=fulltext_key)
147+
content_str = response['Body'].read().decode('utf-8')
148+
149+
logger.info(f"Model: {selectedModelId}")
150+
logger.info(f"Output Bucket: {output_bucket}")
151+
logger.info(f"Full Text Key: {fulltext_key}")
152+
153+
# Content with cachepoint tags
154+
content = [
155+
{
156+
"text": content_str + """
157+
<<CACHEPOINT>>
158+
""" + full_prompt
159+
}
160+
]
161+
162+
client = BedrockClient(
163+
region=os.environ['AWS_REGION'],
164+
max_retries=5,
165+
initial_backoff=1.5,
166+
max_backoff=300,
167+
metrics_enabled=True
168+
)
211169

212-
text = client.extract_text_from_response(model_response)
213-
logger.info(f"Response before guardrails check: {text}")
214-
215-
chat_response = {"cr": {"content": [{"text": "I can't answer that right now"}]}}
170+
# Invoke a model
171+
response = client.invoke_model(
172+
model_id=selectedModelId,
173+
system_prompt="You are an assistant that's responsible for getting details from document text attached here based on questions from the user.\n\nIf you don't know the answer, just say that you don't know. Don't try to make up an answer.\n\nAdditionally, use the user and assistant responses in the following JSON object to see what's been asked and what the resposes were in the past.\n\n",
174+
content=content,
175+
temperature=0.0
176+
)
216177

217-
if apply_guardrails_to_prompt(text):
218-
chat_response = {"cr": {"content": [{"text": text}]}}
219-
220-
return json.dumps(chat_response)
178+
text = client.extract_text_from_response(response)
179+
logger.info(f"Full response: {text}")
180+
181+
# right now, there is a JSON object before the full response when a guardrail is tripped
182+
# need to remove that JSON object first
183+
logger.info(f"New response: {remove_text_between_brackets(text).strip("\n")}")
184+
cleaned_up_text = remove_text_between_brackets(text).strip("\n")
185+
186+
chat_response = {"cr": {"content": [{"text": cleaned_up_text}]}}
187+
188+
return json.dumps(chat_response)
221189

222190
except ClientError as e:
223191
error_code = e.response['Error']['Code']

template.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6107,6 +6107,7 @@ Resources:
61076107
OUTPUT_BUCKET: !Ref OutputBucket
61086108
CONFIGURATION_TABLE_NAME: !Ref ConfigurationTable
61096109
TRACKING_TABLE_NAME: !Ref TrackingTable
6110+
GUARDRAIL_ID_AND_VERSION: !If [HasGuardrailConfig, !Sub "${BedrockGuardrailId}:${BedrockGuardrailVersion}", ""]
61106111
LoggingConfig:
61116112
LogGroup: !Ref ChatWithDocumentResolverFunctionLogGroup
61126113
Policies:

0 commit comments

Comments
 (0)