Skip to content

Commit ab93324

Browse files
committed
Create Bedrock Prompt Management resource for query reformulation
1 parent f465368 commit ab93324

File tree

6 files changed

+156
-1
lines changed

6 files changed

+156
-1
lines changed

packages/cdk/resources/IamResources.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ export class IamResources extends Construct {
150150
resources: [`arn:aws:bedrock:${props.region}:${props.account}:guardrail/*`]
151151
})
152152

153+
const slackBotPromptPolicy = new PolicyStatement({
154+
actions: ["bedrock:InvokeModel"],
155+
resources: [`arn:aws:bedrock:${props.region}:${props.account}:prompt/*`]
156+
})
157+
153158
const slackBotDynamoDbPolicy = new PolicyStatement({
154159
actions: [
155160
"dynamodb:GetItem",
@@ -183,6 +188,7 @@ export class IamResources extends Construct {
183188
slackBotSSMPolicy,
184189
slackBotLambdaPolicy,
185190
slackBotGuardrailPolicy,
191+
slackBotPromptPolicy,
186192
slackBotDynamoDbPolicy,
187193
slackBotKmsPolicy
188194
]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import {Construct} from "constructs"
2+
import {CfnPrompt} from "aws-cdk-lib/aws-bedrock"
3+
4+
export interface PromptManagementProps {
5+
readonly stackName: string
6+
}
7+
8+
export class PromptManagement extends Construct {
9+
public readonly queryReformulationPrompt: CfnPrompt
10+
11+
constructor(scope: Construct, id: string, props: PromptManagementProps) {
12+
super(scope, id)
13+
14+
this.queryReformulationPrompt = new CfnPrompt(this, "QueryReformulationPrompt", {
15+
name: `${props.stackName}-query-reformulation`,
16+
description: "Reformulates user queries for better RAG retrieval",
17+
defaultVariant: "default",
18+
variants: [{
19+
name: "default",
20+
templateType: "TEXT",
21+
templateConfiguration: {
22+
text: {
23+
text: `You are a query reformulation assistant for the NHS EPS (Electronic Prescription Service) API ` +
24+
`documentation system.
25+
26+
Your task is to reformulate user queries to improve retrieval from a knowledge base containing FHIR NHS EPS API
27+
documentation, onboarding guides, and technical specifications.
28+
29+
Guidelines:
30+
- Expand abbreviations (EPS = Electronic Prescription Service, FHIR = Fast Healthcare Interoperability Resources)
31+
- Add relevant technical context (API, prescription, dispensing, healthcare)
32+
- Convert casual language to technical terminology
33+
- Include synonyms for better matching
34+
- Keep the core intent intact
35+
- Focus on NHS, healthcare, prescription, and API-related terms
36+
37+
User Query: {{query}}
38+
39+
Reformulated Query:`
40+
}
41+
},
42+
modelId: "anthropic.claude-3-haiku-20240307-v1:0"
43+
}]
44+
})
45+
}
46+
}

packages/cdk/stacks/EpsAssistMeStack.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {Storage} from "../resources/Storage"
1111
import {Secrets} from "../resources/Secrets"
1212
import {OpenSearchResources} from "../resources/OpenSearchResources"
1313
import {VectorKnowledgeBaseResources} from "../resources/VectorKnowledgeBaseResources"
14+
import {PromptManagement} from "../resources/PromptManagement"
1415
import {IamResources} from "../resources/IamResources"
1516
import {VectorIndex} from "../resources/VectorIndex"
1617
import {DatabaseTables} from "../resources/DatabaseTables"
@@ -111,6 +112,11 @@ export class EpsAssistMeStack extends Stack {
111112
endpoint
112113
})
113114

115+
// Create Prompt Management resources
116+
const promptManagement = new PromptManagement(this, "PromptManagement", {
117+
stackName: props.stackName
118+
})
119+
114120
// Create VectorKnowledgeBase construct after vector index
115121
const vectorKB = new VectorKnowledgeBaseResources(this, "VectorKB", {
116122
stackName: props.stackName,
@@ -127,6 +133,10 @@ export class EpsAssistMeStack extends Stack {
127133
functions.functions.slackBot.function.addEnvironment("GUARD_RAIL_ID", vectorKB.guardrail.attrGuardrailId)
128134
functions.functions.slackBot.function.addEnvironment("GUARD_RAIL_VERSION", vectorKB.guardrail.attrVersion)
129135
functions.functions.slackBot.function.addEnvironment("KNOWLEDGEBASE_ID", vectorKB.knowledgeBase.attrKnowledgeBaseId)
136+
functions.functions.slackBot.function.addEnvironment(
137+
"QUERY_REFORMULATION_PROMPT_ARN",
138+
promptManagement.queryReformulationPrompt.attrArn
139+
)
130140

131141
// Create Apis and pass the Lambda function
132142
const apis = new Apis(this, "Apis", {

packages/slackBotFunction/app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aws_lambda_powertools.utilities.parameters import get_parameter
1111
from aws_lambda_powertools.utilities.typing import LambdaContext
1212
from botocore.exceptions import ClientError
13+
from query_reformulator import reformulate_query
1314

1415

1516
# Initialize Powertools Logger
@@ -145,7 +146,9 @@ def process_async_slack_event(slack_event_data):
145146
)
146147
return
147148

148-
kb_response = get_bedrock_knowledgebase_response(user_query)
149+
# Reformulate query for better RAG retrieval
150+
reformulated_query = reformulate_query(user_query)
151+
kb_response = get_bedrock_knowledgebase_response(reformulated_query)
149152
response_text = kb_response["output"]["text"]
150153

151154
client.chat_postMessage(channel=channel, text=response_text, thread_ts=thread_ts)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import json
3+
import boto3
4+
from aws_lambda_powertools import Logger
5+
6+
logger = Logger(service="queryReformulator")
7+
8+
9+
def reformulate_query(user_query: str) -> str:
10+
"""
11+
Reformulate user query using Bedrock Prompt Management for better RAG retrieval.
12+
"""
13+
try:
14+
client = boto3.client("bedrock-runtime", region_name=os.environ["AWS_REGION"])
15+
prompt_arn = os.environ["QUERY_REFORMULATION_PROMPT_ARN"]
16+
17+
response = client.invoke_model(
18+
modelId="anthropic.claude-3-haiku-20240307-v1:0",
19+
body=json.dumps(
20+
{
21+
"anthropic_version": "bedrock-2023-05-31",
22+
"max_tokens": 200,
23+
"prompt": prompt_arn,
24+
"variables": {"query": user_query},
25+
}
26+
),
27+
)
28+
29+
result = json.loads(response["body"].read())
30+
reformulated_query = result["content"][0]["text"].strip()
31+
32+
logger.info(
33+
"Query reformulated", extra={"original_query": user_query, "reformulated_query": reformulated_query}
34+
)
35+
36+
return reformulated_query
37+
38+
except Exception as e:
39+
logger.error(f"Error reformulating query: {e}")
40+
return user_query # Fallback to original query
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from unittest.mock import patch, MagicMock
2+
import json
3+
from query_reformulator import reformulate_query
4+
5+
6+
@patch("query_reformulator.boto3.client")
7+
@patch.dict(
8+
"os.environ",
9+
{
10+
"AWS_REGION": "eu-west-2",
11+
"QUERY_REFORMULATION_PROMPT_ARN": "arn:aws:bedrock:eu-west-2:123456789012:prompt/test-prompt",
12+
},
13+
)
14+
def test_reformulate_query_success(mock_boto_client):
15+
# Mock Bedrock response
16+
mock_client = MagicMock()
17+
mock_boto_client.return_value = mock_client
18+
19+
mock_response = {"body": MagicMock()}
20+
mock_response["body"].read.return_value = json.dumps(
21+
{"content": [{"text": "NHS EPS Electronic Prescription Service API FHIR prescription dispensing"}]}
22+
).encode()
23+
24+
mock_client.invoke_model.return_value = mock_response
25+
26+
result = reformulate_query("How do I use EPS?")
27+
28+
assert result == "NHS EPS Electronic Prescription Service API FHIR prescription dispensing"
29+
mock_client.invoke_model.assert_called_once()
30+
31+
32+
@patch("query_reformulator.boto3.client")
33+
@patch.dict(
34+
"os.environ",
35+
{
36+
"AWS_REGION": "eu-west-2",
37+
"QUERY_REFORMULATION_PROMPT_ARN": "arn:aws:bedrock:eu-west-2:123456789012:prompt/test-prompt",
38+
},
39+
)
40+
def test_reformulate_query_fallback_on_error(mock_boto_client):
41+
# Mock Bedrock client to raise exception
42+
mock_client = MagicMock()
43+
mock_boto_client.return_value = mock_client
44+
mock_client.invoke_model.side_effect = Exception("Bedrock error")
45+
46+
original_query = "How do I use EPS?"
47+
result = reformulate_query(original_query)
48+
49+
# Should fallback to original query on error
50+
assert result == original_query

0 commit comments

Comments
 (0)