Skip to content

Commit 13fe3b1

Browse files
feat: Add RAG environment variables and update test cases accordingly
1 parent f4a79ef commit 13fe3b1

File tree

4 files changed

+73
-7
lines changed

4 files changed

+73
-7
lines changed

packages/slackBotFunction/app/core/config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,28 @@ def get_bot_token() -> str:
7171

7272

7373
@lru_cache
74-
def get_guardrail_config() -> Tuple[str, str, str, str, str]:
74+
def get_retrieve_generate_config() -> Tuple[str, str, str, str, str, str, str]:
7575
# Bedrock configuration from environment
7676
KNOWLEDGEBASE_ID = os.environ["KNOWLEDGEBASE_ID"]
7777
RAG_MODEL_ID = os.environ["RAG_MODEL_ID"]
7878
AWS_REGION = os.environ["AWS_REGION"]
7979
GUARD_RAIL_ID = os.environ["GUARD_RAIL_ID"]
8080
GUARD_VERSION = os.environ["GUARD_RAIL_VERSION"]
81+
RAG_RESPONSE_PROMPT_NAME = os.environ["RAG_RESPONSE_PROMPT_NAME"]
82+
RAG_RESPONSE_PROMPT_VERSION = os.environ["RAG_RESPONSE_PROMPT_VERSION"]
8183

8284
logger.info(
8385
"Guardrail configuration loaded", extra={"guardrail_id": GUARD_RAIL_ID, "guardrail_version": GUARD_VERSION}
8486
)
85-
return KNOWLEDGEBASE_ID, RAG_MODEL_ID, AWS_REGION, GUARD_RAIL_ID, GUARD_VERSION
87+
return (
88+
KNOWLEDGEBASE_ID,
89+
RAG_MODEL_ID,
90+
AWS_REGION,
91+
GUARD_RAIL_ID,
92+
GUARD_VERSION,
93+
RAG_RESPONSE_PROMPT_NAME,
94+
RAG_RESPONSE_PROMPT_VERSION,
95+
)
8696

8797

8898
@dataclass

packages/slackBotFunction/app/services/bedrock.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from mypy_boto3_bedrock_runtime.client import BedrockRuntimeClient
66
from mypy_boto3_bedrock_agent_runtime.type_defs import RetrieveAndGenerateResponseTypeDef
77

8-
from app.core.config import get_guardrail_config, get_logger
8+
from app.core.config import get_retrieve_generate_config, get_logger
9+
from app.services.prompt_loader import load_prompt
910

1011

1112
logger = get_logger()
@@ -19,7 +20,18 @@ def query_bedrock(user_query: str, session_id: str = None) -> RetrieveAndGenerat
1920
a response using the configured LLM model with guardrails for safety.
2021
"""
2122

22-
KNOWLEDGEBASE_ID, RAG_MODEL_ID, AWS_REGION, GUARD_RAIL_ID, GUARD_VERSION = get_guardrail_config()
23+
(
24+
KNOWLEDGEBASE_ID,
25+
RAG_MODEL_ID,
26+
AWS_REGION,
27+
GUARD_RAIL_ID,
28+
GUARD_VERSION,
29+
RAG_RESPONSE_PROMPT_NAME,
30+
RAG_RESPONSE_PROMPT_VERSION,
31+
) = get_retrieve_generate_config()
32+
33+
prompt_template = load_prompt(RAG_RESPONSE_PROMPT_NAME, RAG_RESPONSE_PROMPT_VERSION)
34+
2335
client: AgentsforBedrockRuntimeClient = boto3.client(
2436
service_name="bedrock-agent-runtime",
2537
region_name=AWS_REGION,
@@ -41,6 +53,14 @@ def query_bedrock(user_query: str, session_id: str = None) -> RetrieveAndGenerat
4153
},
4254
}
4355

56+
if prompt_template:
57+
request_params["retrieveAndGenerateConfiguration"]["knowledgeBaseConfiguration"]["generationConfiguration"][
58+
"promptTemplate"
59+
] = {"textPromptTemplate": prompt_template}
60+
logger.info(
61+
"Using prompt template for RAG response generation", extra={"prompt_name": RAG_RESPONSE_PROMPT_NAME}
62+
)
63+
4464
# Include session ID for conversation continuity across messages
4565
if session_id:
4666
request_params["sessionId"] = session_id

packages/slackBotFunction/tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def mock_env():
2424
"QUERY_REFORMULATION_MODEL_ID": "test-model",
2525
"QUERY_REFORMULATION_PROMPT_NAME": "test-prompt",
2626
"QUERY_REFORMULATION_PROMPT_VERSION": "DRAFT",
27+
"RAG_RESPONSE_PROMPT_NAME": "test-rag-prompt",
28+
"RAG_RESPONSE_PROMPT_VERSION": "DRAFT",
2729
}
2830
env_vars["AWS_DEFAULT_REGION"] = env_vars["AWS_REGION"]
2931
with patch.dict(os.environ, env_vars, clear=False):

packages/slackBotFunction/tests/test_bedrock_integration.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from unittest.mock import Mock, patch
33

44

5+
@patch("app.services.prompt_loader.load_prompt")
56
@patch("boto3.client")
6-
def test_get_bedrock_knowledgebase_response(mock_boto_client: Mock, mock_env: Mock):
7+
def test_get_bedrock_knowledgebase_response(mock_boto_client: Mock, mock_load_prompt: Mock, mock_env: Mock):
78
"""Test Bedrock knowledge base integration"""
89
# set up mocks
910
mock_client = Mock()
@@ -19,13 +20,15 @@ def test_get_bedrock_knowledgebase_response(mock_boto_client: Mock, mock_env: Mo
1920
result = query_bedrock("test query")
2021

2122
# assertions
23+
mock_load_prompt.assert_called_once_with("test-rag-prompt", "DRAFT")
2224
mock_boto_client.assert_called_once_with(service_name="bedrock-agent-runtime", region_name="eu-west-2")
2325
mock_client.retrieve_and_generate.assert_called_once()
2426
assert result["output"]["text"] == "bedrock response"
2527

2628

29+
@patch("app.services.prompt_loader.load_prompt")
2730
@patch("boto3.client")
28-
def test_query_bedrock_with_session(mock_boto_client: Mock, mock_env: Mock):
31+
def test_query_bedrock_with_session(mock_boto_client: Mock, mock_load_prompt: Mock, mock_env: Mock):
2932
"""Test query_bedrock with existing session"""
3033
# set up mocks
3134
mock_client = Mock()
@@ -42,13 +45,15 @@ def test_query_bedrock_with_session(mock_boto_client: Mock, mock_env: Mock):
4245
result = query_bedrock("test query", session_id="existing_session")
4346

4447
# assertions
48+
mock_load_prompt.assert_called_once_with("test-rag-prompt", "DRAFT")
4549
assert result == mock_response
4650
call_args = mock_client.retrieve_and_generate.call_args[1]
4751
assert call_args["sessionId"] == "existing_session"
4852

4953

54+
@patch("app.services.prompt_loader.load_prompt")
5055
@patch("boto3.client")
51-
def test_query_bedrock_without_session(mock_boto_client: Mock, mock_env: Mock):
56+
def test_query_bedrock_without_session(mock_boto_client: Mock, mock_load_prompt: Mock, mock_env: Mock):
5257
"""Test query_bedrock without session"""
5358
# set up mocks
5459
mock_client = Mock()
@@ -65,6 +70,35 @@ def test_query_bedrock_without_session(mock_boto_client: Mock, mock_env: Mock):
6570
result = query_bedrock("test query")
6671

6772
# assertions
73+
mock_load_prompt.assert_called_once_with("test-rag-prompt", "DRAFT")
6874
assert result == mock_response
6975
call_args = mock_client.retrieve_and_generate.call_args[1]
7076
assert "sessionId" not in call_args
77+
78+
79+
@patch("app.services.prompt_loader.load_prompt")
80+
@patch("boto3.client")
81+
def test_query_bedrock_check_prompt(mock_boto_client: Mock, mock_load_prompt: Mock, mock_env: Mock):
82+
"""Test query_bedrock prompt loading"""
83+
# set up mocks
84+
mock_client = Mock()
85+
mock_boto_client.return_value = mock_client
86+
mock_client.retrieve_and_generate.return_value = {"output": {"text": "response"}}
87+
mock_load_prompt.return_value = "Test prompt template"
88+
89+
# delete and import module to test
90+
if "app.services.bedrock" in sys.modules:
91+
del sys.modules["app.services.bedrock"]
92+
from app.services.bedrock import query_bedrock
93+
94+
# perform operation
95+
result = query_bedrock("test query")
96+
97+
# assertions
98+
mock_load_prompt.assert_called_once_with("test-rag-prompt", "DRAFT")
99+
call_args = mock_client.retrieve_and_generate.call_args[1]
100+
prompt_template = call_args["retrieveAndGenerateConfiguration"]["knowledgeBaseConfiguration"][
101+
"generationConfiguration"
102+
]["promptTemplate"]["textPromptTemplate"]
103+
assert prompt_template == "Test prompt template"
104+
assert result["output"]["text"] == "response"

0 commit comments

Comments
 (0)