|
9 | 9 | import base64 |
10 | 10 | import hashlib |
11 | 11 | import os |
| 12 | +import re |
12 | 13 | from urllib.parse import urlparse |
13 | 14 | from botocore.exceptions import ClientError |
14 | 15 | from idp_common.bedrock.client import BedrockClient |
15 | 16 |
|
16 | 17 | # Set up logging |
17 | 18 | logger = logging.getLogger() |
18 | 19 | logger.setLevel(os.environ.get("LOG_LEVEL", "INFO")) |
19 | | -# Get LOG_LEVEL from environment variable with INFO as default |
20 | | - |
21 | | -def get_guardrails(): |
22 | | - client = boto3.client("bedrock", region_name=os.environ['AWS_REGION']) |
23 | | - guardrails = client.list_guardrails( |
24 | | - maxResults=10 |
25 | | - ) |
26 | 20 |
|
27 | | - grs = [] |
28 | | - for gr in guardrails['guardrails']: |
29 | | - if gr['status'] == 'READY': |
30 | | - grs.append({'id': gr['id'], 'version': gr['version'], 'name': gr['name']}) |
| 21 | +def remove_text_between_brackets(text): |
| 22 | + # Find position of first opening bracket |
| 23 | + start = text.find('{') |
| 24 | + # Find position of last closing bracket |
| 25 | + end = text.rfind('}') |
31 | 26 |
|
32 | | - return grs |
33 | | - |
34 | | -def apply_guardrails_to_prompt(prompt): |
35 | | - br_client = boto3.client("bedrock-runtime", region_name=os.environ['AWS_REGION']) |
36 | | - |
37 | | - guardrails = get_guardrails() |
38 | | - |
39 | | - try: |
40 | | - for gr in guardrails: |
41 | | - response = br_client.apply_guardrail( |
42 | | - guardrailIdentifier=gr['id'], |
43 | | - guardrailVersion=gr['version'], |
44 | | - source="INPUT", |
45 | | - content=[ |
46 | | - { |
47 | | - "text": { |
48 | | - "text": prompt |
49 | | - } |
50 | | - } |
51 | | - ] |
52 | | - ) |
53 | | - |
54 | | - error_details = "" |
55 | | - if (response['action'] == 'GUARDRAIL_INTERVENED'): |
56 | | - for output in response['outputs']: |
57 | | - error_details += output['text'] |
58 | | - |
59 | | - if len(error_details): |
60 | | - raise Exception(f"{error_details}") |
61 | | - |
62 | | - return True |
63 | | - |
64 | | - except Exception as e: |
65 | | - print(f"Caught by Guardrail: {e}") |
66 | | - raise |
67 | | - |
68 | | - return False |
| 27 | + # If both brackets exist, remove text between them including brackets |
| 28 | + if start != -1 and end != -1: |
| 29 | + return text[:start] + text[end+1:] |
| 30 | + # If brackets not found, return original string |
| 31 | + return text |
69 | 32 |
|
| 33 | +# Get LOG_LEVEL from environment variable with INFO as default |
70 | 34 | def s3_object_exists(bucket, key): |
71 | 35 | try: |
72 | 36 | s3 = boto3.client('s3') |
@@ -147,77 +111,81 @@ def handler(event, context): |
147 | 111 |
|
148 | 112 | try: |
149 | 113 | # logger.info(f"Received event: {json.dumps(event)}") |
150 | | - |
151 | 114 | objectKey = event['arguments']['s3Uri'] |
152 | 115 | prompt = event['arguments']['prompt'] |
153 | 116 | history = event['arguments']['history'] |
154 | 117 |
|
155 | 118 | full_prompt = "The history JSON object is: " + json.dumps(history) + ".\n\n" |
156 | 119 | full_prompt += "The user's question is: " + prompt + "\n\n" |
157 | 120 |
|
158 | | - gr_test = apply_guardrails_to_prompt(full_prompt) |
159 | | - |
160 | | - if gr_test: |
161 | | - # this feature is not enabled until the model can be selected on the chat screen |
162 | | - # selectedModelId = event['arguments']['modelId'] |
163 | | - selectedModelId = get_summarization_model() |
164 | | - |
165 | | - logger.info(f"Processing S3 URI: {objectKey}") |
166 | | - logger.info(f"Region: {os.environ['AWS_REGION']}") |
167 | | - |
168 | | - output_bucket = os.environ['OUTPUT_BUCKET'] |
169 | | - |
170 | | - bedrock_runtime = boto3.client('bedrock-runtime', region_name=os.environ['AWS_REGION']) |
171 | | - |
172 | | - if (len(objectKey)): |
173 | | - fulltext_key = objectKey + '/summary/fulltext.txt' |
174 | | - content_str = "" |
175 | | - s3 = boto3.client('s3') |
176 | | - |
177 | | - if not s3_object_exists(output_bucket, fulltext_key): |
178 | | - logger.info(f"Creating full text file: {fulltext_key}") |
179 | | - content_str = get_full_text(output_bucket, objectKey) |
180 | | - |
181 | | - s3.put_object( |
182 | | - Bucket=output_bucket, |
183 | | - Key=fulltext_key, |
184 | | - Body=content_str.encode('utf-8') |
185 | | - ) |
186 | | - else: |
187 | | - # read full contents of the object as text |
188 | | - response = s3.get_object(Bucket=output_bucket, Key=fulltext_key) |
189 | | - content_str = response['Body'].read().decode('utf-8') |
190 | | - |
191 | | - logger.info(f"Model: {selectedModelId}") |
192 | | - logger.info(f"Output Bucket: {output_bucket}") |
193 | | - logger.info(f"Full Text Key: {fulltext_key}") |
194 | | - |
195 | | - client = BedrockClient() |
196 | | - # Content with cachepoint tags |
197 | | - content = [ |
198 | | - { |
199 | | - "text": content_str + """ |
200 | | - <<CACHEPOINT>> |
201 | | - """ + full_prompt |
202 | | - } |
203 | | - ] |
204 | | - |
205 | | - model_response = client.invoke_model( |
206 | | - model_id="us.amazon.nova-pro-v1:0", |
207 | | - system_prompt="You are an assistant that's responsible for getting details from document text attached here based on questions from the user.\n\nIf you don't know the answer, just say that you don't know. Don't try to make up an answer.\n\nAdditionally, use the user and assistant responses in the following JSON object to see what's been asked and what the resposes were in the past.\n\n", |
208 | | - content=content, |
209 | | - temperature=0.0 |
| 121 | + # this feature is not enabled until the model can be selected on the chat screen |
| 122 | + # selectedModelId = event['arguments']['modelId'] |
| 123 | + selectedModelId = get_summarization_model() |
| 124 | + |
| 125 | + logger.info(f"Processing S3 URI: {objectKey}") |
| 126 | + logger.info(f"Region: {os.environ['AWS_REGION']}") |
| 127 | + |
| 128 | + output_bucket = os.environ['OUTPUT_BUCKET'] |
| 129 | + |
| 130 | + if (len(objectKey)): |
| 131 | + fulltext_key = objectKey + '/summary/fulltext.txt' |
| 132 | + content_str = "" |
| 133 | + s3 = boto3.client('s3') |
| 134 | + |
| 135 | + if not s3_object_exists(output_bucket, fulltext_key): |
| 136 | + logger.info(f"Creating full text file: {fulltext_key}") |
| 137 | + content_str = get_full_text(output_bucket, objectKey) |
| 138 | + |
| 139 | + s3.put_object( |
| 140 | + Bucket=output_bucket, |
| 141 | + Key=fulltext_key, |
| 142 | + Body=content_str.encode('utf-8') |
210 | 143 | ) |
| 144 | + else: |
| 145 | + # read full contents of the object as text |
| 146 | + response = s3.get_object(Bucket=output_bucket, Key=fulltext_key) |
| 147 | + content_str = response['Body'].read().decode('utf-8') |
| 148 | + |
| 149 | + logger.info(f"Model: {selectedModelId}") |
| 150 | + logger.info(f"Output Bucket: {output_bucket}") |
| 151 | + logger.info(f"Full Text Key: {fulltext_key}") |
| 152 | + |
| 153 | + # Content with cachepoint tags |
| 154 | + content = [ |
| 155 | + { |
| 156 | + "text": content_str + """ |
| 157 | + <<CACHEPOINT>> |
| 158 | + """ + full_prompt |
| 159 | + } |
| 160 | + ] |
| 161 | + |
| 162 | + client = BedrockClient( |
| 163 | + region=os.environ['AWS_REGION'], |
| 164 | + max_retries=5, |
| 165 | + initial_backoff=1.5, |
| 166 | + max_backoff=300, |
| 167 | + metrics_enabled=True |
| 168 | + ) |
211 | 169 |
|
212 | | - text = client.extract_text_from_response(model_response) |
213 | | - logger.info(f"Response before guardrails check: {text}") |
214 | | - |
215 | | - chat_response = {"cr": {"content": [{"text": "I can't answer that right now"}]}} |
| 170 | + # Invoke a model |
| 171 | + response = client.invoke_model( |
| 172 | + model_id=selectedModelId, |
| 173 | + system_prompt="You are an assistant that's responsible for getting details from document text attached here based on questions from the user.\n\nIf you don't know the answer, just say that you don't know. Don't try to make up an answer.\n\nAdditionally, use the user and assistant responses in the following JSON object to see what's been asked and what the resposes were in the past.\n\n", |
| 174 | + content=content, |
| 175 | + temperature=0.0 |
| 176 | + ) |
216 | 177 |
|
217 | | - if apply_guardrails_to_prompt(text): |
218 | | - chat_response = {"cr": {"content": [{"text": text}]}} |
219 | | - |
220 | | - return json.dumps(chat_response) |
| 178 | + text = client.extract_text_from_response(response) |
| 179 | + logger.info(f"Full response: {text}") |
| 180 | + |
| 181 | + # right now, there is a JSON object before the full response when a guardrail is tripped |
| 182 | + # need to remove that JSON object first |
| 183 | + logger.info(f"New response: {remove_text_between_brackets(text).strip("\n")}") |
| 184 | + cleaned_up_text = remove_text_between_brackets(text).strip("\n") |
| 185 | + |
| 186 | + chat_response = {"cr": {"content": [{"text": cleaned_up_text}]}} |
| 187 | + |
| 188 | + return json.dumps(chat_response) |
221 | 189 |
|
222 | 190 | except ClientError as e: |
223 | 191 | error_code = e.response['Error']['Code'] |
|
0 commit comments