Skip to content

Commit e42bd42

Browse files
committed
Use prompt template replace method instead of format
1 parent 96c101f commit e42bd42

File tree

4 files changed

+134
-76
lines changed

4 files changed

+134
-76
lines changed

packages/slackBotFunction/app/services/prompt_loader.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ def load_prompt(prompt_name: str, prompt_version: str = None) -> str:
2727
raise Exception(f"Could not find prompt ID for name '{prompt_name}'")
2828

2929
# Load the prompt with the specified version
30-
if prompt_version == "DRAFT":
31-
logger.info(
32-
f"Loading DRAFT version of prompt '{prompt_name}' (ID: {prompt_id})",
33-
extra={"prompt_name": prompt_name, "prompt_id": prompt_id, "prompt_version": "DRAFT"},
34-
)
35-
response = client.get_prompt(promptIdentifier=prompt_id)
36-
else:
30+
if prompt_version and prompt_version != "DRAFT":
3731
logger.info(
3832
f"Loading version {prompt_version} of prompt '{prompt_name}' (ID: {prompt_id})",
3933
extra={"prompt_name": prompt_name, "prompt_id": prompt_id, "prompt_version": prompt_version},
4034
)
4135
response = client.get_prompt(promptIdentifier=prompt_id, promptVersion=str(prompt_version))
36+
else:
37+
logger.info(
38+
f"Loading DRAFT version of prompt '{prompt_name}' (ID: {prompt_id})",
39+
extra={"prompt_name": prompt_name, "prompt_id": prompt_id, "prompt_version": "DRAFT"},
40+
)
41+
response = client.get_prompt(promptIdentifier=prompt_id)
4242

4343
prompt_text = response["variants"][0]["templateConfiguration"]["text"]["text"]
4444
actual_version = response.get("version", "DRAFT")
@@ -48,9 +48,7 @@ def load_prompt(prompt_name: str, prompt_version: str = None) -> str:
4848
extra={
4949
"prompt_name": prompt_name,
5050
"prompt_id": prompt_id,
51-
"version_requested": prompt_version,
52-
"version_actual": actual_version,
53-
"selection_method": "default" if prompt_version is None else "explicit",
51+
"version_used": actual_version,
5452
},
5553
)
5654
return prompt_text

packages/slackBotFunction/app/services/query_reformulator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def reformulate_query(user_query: str) -> str:
3333
extra={"prompt_name": prompt_name, "version_used": prompt_version},
3434
)
3535

36-
# Format the prompt with the user query
37-
prompt = prompt_template.format(user_query=user_query)
36+
# Format the prompt with the user query (using double braces from Bedrock template)
37+
prompt = prompt_template.replace("{{user_query}}", user_query)
3838

3939
response = client.invoke_model(
4040
modelId=model_id,
Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,106 @@
11
import pytest
2-
from unittest.mock import patch
3-
from app.services.prompt_loader import load_prompt
2+
from unittest.mock import patch, MagicMock
3+
from botocore.exceptions import ClientError
4+
from app.services.prompt_loader import load_prompt, get_prompt_id_from_name
45

56

6-
def test_load_prompt_function_exists():
7-
"""Test that the load_prompt function exists and is callable"""
8-
assert callable(load_prompt)
7+
@patch("app.services.prompt_loader.boto3.client")
8+
@patch.dict("os.environ", {"AWS_REGION": "eu-west-2"})
9+
def test_load_prompt_success_draft(mock_boto_client):
10+
mock_client = MagicMock()
11+
mock_boto_client.return_value = mock_client
912

13+
# Mock list_prompts to return prompt ID
14+
mock_client.list_prompts.return_value = {"promptSummaries": [{"name": "test-prompt", "id": "ABC1234567"}]}
1015

11-
def test_load_prompt_requires_prompt_name():
12-
"""Test that load_prompt requires a prompt name parameter"""
13-
with pytest.raises(TypeError):
14-
load_prompt() # Should fail without prompt_name
16+
# Mock get_prompt for DRAFT version
17+
mock_client.get_prompt.return_value = {
18+
"variants": [{"templateConfiguration": {"text": {"text": "Test prompt"}}}],
19+
"version": "DRAFT",
20+
}
1521

22+
result = load_prompt("test-prompt")
23+
assert result == "Test prompt"
24+
mock_client.get_prompt.assert_called_once_with(promptIdentifier="ABC1234567")
1625

17-
def test_load_prompt_handles_missing_environment():
18-
"""Test that load_prompt handles missing AWS_REGION environment variable"""
26+
27+
@patch("app.services.prompt_loader.boto3.client")
28+
@patch.dict("os.environ", {"AWS_REGION": "eu-west-2"})
29+
def test_load_prompt_success_versioned(mock_boto_client):
30+
mock_client = MagicMock()
31+
mock_boto_client.return_value = mock_client
32+
33+
mock_client.list_prompts.return_value = {"promptSummaries": [{"name": "test-prompt", "id": "ABC1234567"}]}
34+
35+
mock_client.get_prompt.return_value = {
36+
"variants": [{"templateConfiguration": {"text": {"text": "Versioned prompt"}}}],
37+
"version": "1",
38+
}
39+
40+
result = load_prompt("test-prompt", "1")
41+
assert result == "Versioned prompt"
42+
mock_client.get_prompt.assert_called_once_with(promptIdentifier="ABC1234567", promptVersion="1")
43+
44+
45+
@patch("app.services.prompt_loader.boto3.client")
46+
@patch.dict("os.environ", {"AWS_REGION": "eu-west-2"})
47+
def test_load_prompt_not_found(mock_boto_client):
48+
mock_client = MagicMock()
49+
mock_boto_client.return_value = mock_client
50+
51+
mock_client.list_prompts.return_value = {"promptSummaries": []}
52+
53+
with pytest.raises(Exception, match="Could not find prompt ID"):
54+
load_prompt("nonexistent-prompt")
55+
56+
57+
@patch("app.services.prompt_loader.boto3.client")
58+
@patch.dict("os.environ", {"AWS_REGION": "eu-west-2"})
59+
def test_load_prompt_client_error(mock_boto_client):
60+
mock_client = MagicMock()
61+
mock_boto_client.return_value = mock_client
62+
63+
mock_client.list_prompts.return_value = {"promptSummaries": [{"name": "test-prompt", "id": "ABC1234567"}]}
64+
65+
error = ClientError({"Error": {"Code": "ValidationException", "Message": "Invalid prompt"}}, "GetPrompt")
66+
mock_client.get_prompt.side_effect = error
67+
68+
with pytest.raises(Exception, match="ValidationException - Invalid prompt"):
69+
load_prompt("test-prompt")
70+
71+
72+
@patch("app.services.prompt_loader.boto3.client")
73+
@patch.dict("os.environ", {"AWS_REGION": "eu-west-2"})
74+
def test_get_prompt_id_from_name_success(mock_boto_client):
75+
mock_client = MagicMock()
76+
mock_client.list_prompts.return_value = {"promptSummaries": [{"name": "test-prompt", "id": "ABC1234567"}]}
77+
78+
result = get_prompt_id_from_name(mock_client, "test-prompt")
79+
assert result == "ABC1234567"
80+
81+
82+
@patch("app.services.prompt_loader.boto3.client")
83+
@patch.dict("os.environ", {"AWS_REGION": "eu-west-2"})
84+
def test_get_prompt_id_from_name_not_found(mock_boto_client):
85+
mock_client = MagicMock()
86+
mock_client.list_prompts.return_value = {"promptSummaries": []}
87+
88+
result = get_prompt_id_from_name(mock_client, "nonexistent")
89+
assert result is None
90+
91+
92+
@patch("app.services.prompt_loader.boto3.client")
93+
@patch.dict("os.environ", {"AWS_REGION": "eu-west-2"})
94+
def test_get_prompt_id_client_error(mock_boto_client):
95+
mock_client = MagicMock()
96+
error = ClientError({"Error": {"Code": "AccessDenied"}}, "ListPrompts")
97+
mock_client.list_prompts.side_effect = error
98+
99+
result = get_prompt_id_from_name(mock_client, "test-prompt")
100+
assert result is None
101+
102+
103+
def test_load_prompt_missing_environment():
19104
with patch.dict("os.environ", {}, clear=True):
20105
with pytest.raises(Exception):
21106
load_prompt("test-prompt")
Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from unittest.mock import patch, MagicMock
2-
import json
1+
from unittest.mock import patch
32
from app.services.query_reformulator import reformulate_query
43

54

6-
def test_reformulate_query_success():
7-
with patch("app.services.query_reformulator.load_prompt") as mock_load_prompt, patch(
8-
"app.services.query_reformulator.boto3.client"
9-
) as mock_boto_client, patch.dict(
5+
def test_reformulate_query_returns_string():
6+
"""Test that reformulate_query returns a string without crashing"""
7+
with patch.dict(
108
"os.environ",
119
{
1210
"AWS_REGION": "eu-west-2",
@@ -15,58 +13,35 @@ def test_reformulate_query_success():
1513
},
1614
):
1715

18-
# Mock prompt loading
19-
mock_load_prompt.return_value = "Test prompt template with {user_query}"
20-
21-
# Mock Bedrock client with proper response
22-
mock_client = MagicMock()
23-
mock_boto_client.return_value = mock_client
24-
25-
# Create a simple mock that returns the expected JSON
26-
mock_client.invoke_model.return_value = {
27-
"body": type(
28-
"MockBody",
29-
(),
30-
{
31-
"read": lambda: json.dumps(
32-
{
33-
"content": [
34-
{"text": "NHS EPS Electronic Prescription Service API FHIR prescription dispensing"}
35-
]
36-
}
37-
).encode("utf-8")
38-
},
39-
)()
40-
}
41-
4216
result = reformulate_query("How do I use EPS?")
43-
44-
# Test that function doesn't crash and returns a string
17+
# Function should return a string (either reformulated or fallback to original)
4518
assert isinstance(result, str)
4619
assert len(result) > 0
4720

4821

49-
@patch("app.services.query_reformulator.load_prompt")
50-
@patch("app.services.query_reformulator.boto3.client")
51-
@patch.dict(
52-
"os.environ",
53-
{
54-
"AWS_REGION": "eu-west-2",
55-
"QUERY_REFORMULATION_MODEL_ID": "anthropic.claude-3-haiku-20240307-v1:0",
56-
"QUERY_REFORMULATION_PROMPT_NAME": "query-reformulation",
57-
},
58-
)
59-
def test_reformulate_query_fallback_on_error(mock_boto_client, mock_load_prompt):
60-
# Mock prompt loading
61-
mock_load_prompt.return_value = "Test prompt template with {user_query}"
22+
def test_reformulate_query_prompt_load_error():
23+
with patch("app.services.query_reformulator.load_prompt") as mock_load_prompt, patch.dict(
24+
"os.environ",
25+
{
26+
"AWS_REGION": "eu-west-2",
27+
"QUERY_REFORMULATION_MODEL_ID": "anthropic.claude-3-haiku-20240307-v1:0",
28+
"QUERY_REFORMULATION_PROMPT_NAME": "query-reformulation",
29+
},
30+
):
31+
32+
mock_load_prompt.side_effect = Exception("Prompt not found")
6233

63-
# Mock Bedrock client to raise exception
64-
mock_client = MagicMock()
65-
mock_boto_client.return_value = mock_client
66-
mock_client.invoke_model.side_effect = Exception("Bedrock error")
34+
original_query = "How do I use EPS?"
35+
result = reformulate_query(original_query)
36+
assert result == original_query
6737

68-
original_query = "How do I use EPS?"
69-
result = reformulate_query(original_query)
7038

71-
# Should fallback to original query on error
72-
assert result == original_query
39+
def test_reformulate_query_missing_prompt_name():
40+
with patch.dict(
41+
"os.environ",
42+
{"AWS_REGION": "eu-west-2", "QUERY_REFORMULATION_MODEL_ID": "anthropic.claude-3-haiku-20240307-v1:0"},
43+
):
44+
45+
original_query = "test query"
46+
result = reformulate_query(original_query)
47+
assert result == original_query

0 commit comments

Comments
 (0)