Skip to content

Commit a9f5cca

Browse files
committed
more typing
1 parent 97ea714 commit a9f5cca

File tree

6 files changed

+37
-33
lines changed

6 files changed

+37
-33
lines changed

packages/slackBotFunction/app/services/dynamo.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from typing import Any
12
from app.core.config import get_logger, get_slack_bot_state_table
23
from time import time
4+
from mypy_boto3_dynamodb.type_defs import GetItemOutputTableTypeDef
35

46
logger = get_logger()
57

68

7-
def get_state_information(key):
9+
def get_state_information(key: str) -> GetItemOutputTableTypeDef:
810
start_time = time()
911
table = get_slack_bot_state_table()
1012
is_success = True
@@ -27,7 +29,7 @@ def get_state_information(key):
2729
return results
2830

2931

30-
def store_state_information(item, condition=None):
32+
def store_state_information(item: dict[str, Any], condition: str = None):
3133
start_time = time()
3234
table = get_slack_bot_state_table()
3335
is_success = True
@@ -52,7 +54,7 @@ def store_state_information(item, condition=None):
5254
)
5355

5456

55-
def update_state_information(key, update_expression, expression_attribute_values):
57+
def update_state_information(key: str, update_expression: str, expression_attribute_values: dict[str, Any]):
5658
start_time = time()
5759
table = get_slack_bot_state_table()
5860
is_success = True
@@ -76,7 +78,7 @@ def update_state_information(key, update_expression, expression_attribute_values
7678
)
7779

7880

79-
def delete_state_information(pk, sk, condition):
81+
def delete_state_information(pk: str, sk: str, condition: str):
8082
start_time = time()
8183
table = get_slack_bot_state_table()
8284
is_success = True

packages/slackBotFunction/app/services/prompt_loader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import traceback
33
import boto3
44
from botocore.exceptions import ClientError
5+
from app.core.config import get_logger
56
from app.services.exceptions import PromptNotFoundError, PromptLoadError
67
from mypy_boto3_bedrock_agent import AgentsforBedrockClient
78

9+
logger = get_logger()
810

9-
def load_prompt(logger, prompt_name: str, prompt_version: str = None) -> str:
11+
12+
def load_prompt(prompt_name: str, prompt_version: str = None) -> str:
1013
"""
1114
Load a prompt template from Amazon Bedrock Prompt Management.
1215
@@ -17,7 +20,7 @@ def load_prompt(logger, prompt_name: str, prompt_version: str = None) -> str:
1720
client: AgentsforBedrockClient = boto3.client("bedrock-agent", region_name=os.environ["AWS_REGION"])
1821

1922
# Get the prompt ID from the name
20-
prompt_id = get_prompt_id_from_name(logger, client, prompt_name)
23+
prompt_id = get_prompt_id_from_name(client, prompt_name)
2124
if not prompt_id:
2225
raise PromptNotFoundError(f"Could not find prompt ID for name '{prompt_name}'")
2326

@@ -73,7 +76,7 @@ def load_prompt(logger, prompt_name: str, prompt_version: str = None) -> str:
7376
raise PromptLoadError(f"Unexpected error loading prompt '{prompt_name}': {e}")
7477

7578

76-
def get_prompt_id_from_name(logger, client: AgentsforBedrockClient, prompt_name: str) -> str | None:
79+
def get_prompt_id_from_name(client: AgentsforBedrockClient, prompt_name: str) -> str | None:
7780
"""
7881
Get the 10-character prompt ID from the prompt name using ListPrompts.
7982
"""

packages/slackBotFunction/app/services/query_reformulator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import traceback
33
import boto3
44

5+
from app.core.config import get_logger
56
from app.services.bedrock import invoke_model
67
from .prompt_loader import load_prompt
78
from .exceptions import ConfigurationError
89
from mypy_boto3_bedrock_runtime.client import BedrockRuntimeClient
910

11+
logger = get_logger()
1012

11-
def reformulate_query(logger, user_query: str) -> str:
13+
14+
def reformulate_query(user_query: str) -> str:
1215
"""
1316
Reformulate user query using Claude Haiku for better RAG retrieval.
1417
@@ -27,7 +30,7 @@ def reformulate_query(logger, user_query: str) -> str:
2730
raise ConfigurationError("QUERY_REFORMULATION_PROMPT_NAME environment variable not set")
2831

2932
# Load prompt with specified version (DRAFT by default)
30-
prompt_template = load_prompt(logger, prompt_name, prompt_version)
33+
prompt_template = load_prompt(prompt_name, prompt_version)
3134

3235
logger.info(
3336
"Prompt loaded successfully from Bedrock",

packages/slackBotFunction/app/slack/slack_events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def process_async_slack_event(slack_event_data: Dict[str, Any]):
202202
return
203203

204204
# Reformulate query for better RAG retrieval
205-
reformulated_query = reformulate_query(logger, user_query)
205+
reformulated_query = reformulate_query(user_query)
206206

207207
# Check if we have an existing Bedrock conversation session
208208
session_data = get_conversation_session_data(conversation_key)

packages/slackBotFunction/tests/test_prompt_loader.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def mock_logger():
1010

1111

1212
@patch("boto3.client")
13-
def test_load_prompt_success_draft(mock_boto_client: Mock, mock_logger: Mock, mock_env: Mock):
13+
def test_load_prompt_success_draft(mock_boto_client: Mock, mock_env: Mock):
1414
# set up mocks
1515
mock_client = MagicMock()
1616
mock_boto_client.return_value = mock_client
@@ -30,15 +30,15 @@ def test_load_prompt_success_draft(mock_boto_client: Mock, mock_logger: Mock, mo
3030
from app.services.prompt_loader import load_prompt
3131

3232
# perform operation
33-
result = load_prompt(mock_logger, "test-prompt")
33+
result = load_prompt("test-prompt")
3434

3535
# assertions
3636
assert result == "Test prompt"
3737
mock_client.get_prompt.assert_called_once_with(promptIdentifier="ABC1234567")
3838

3939

4040
@patch("boto3.client")
41-
def test_load_prompt_success_versioned(mock_boto_client: Mock, mock_logger: Mock, mock_env: Mock):
41+
def test_load_prompt_success_versioned(mock_boto_client: Mock, mock_env: Mock):
4242
# set up mocks
4343
mock_client = MagicMock()
4444
mock_boto_client.return_value = mock_client
@@ -56,15 +56,15 @@ def test_load_prompt_success_versioned(mock_boto_client: Mock, mock_logger: Mock
5656
from app.services.prompt_loader import load_prompt
5757

5858
# perform operation
59-
result = load_prompt(mock_logger, "test-prompt", "1")
59+
result = load_prompt("test-prompt", "1")
6060

6161
# assertions
6262
assert result == "Versioned prompt"
6363
mock_client.get_prompt.assert_called_once_with(promptIdentifier="ABC1234567", promptVersion="1")
6464

6565

6666
@patch("boto3.client")
67-
def test_load_prompt_not_found(mock_boto_client: Mock, mock_logger: Mock, mock_env: Mock):
67+
def test_load_prompt_not_found(mock_boto_client: Mock, mock_env: Mock):
6868
# set up mocks
6969
mock_client = MagicMock()
7070
mock_boto_client.return_value = mock_client
@@ -78,11 +78,11 @@ def test_load_prompt_not_found(mock_boto_client: Mock, mock_logger: Mock, mock_e
7878

7979
# perform operation
8080
with pytest.raises(Exception, match="Could not find prompt ID"):
81-
load_prompt(mock_logger, "nonexistent-prompt")
81+
load_prompt("nonexistent-prompt")
8282

8383

8484
@patch("boto3.client")
85-
def test_load_prompt_client_error(mock_boto_client: Mock, mock_logger: Mock, mock_env: Mock):
85+
def test_load_prompt_client_error(mock_boto_client: Mock, mock_env: Mock):
8686
# set up mocks
8787
mock_client = MagicMock()
8888
mock_boto_client.return_value = mock_client
@@ -99,10 +99,10 @@ def test_load_prompt_client_error(mock_boto_client: Mock, mock_logger: Mock, moc
9999

100100
# perform operation
101101
with pytest.raises(Exception, match="ValidationException - Invalid prompt"):
102-
load_prompt(mock_logger, "test-prompt")
102+
load_prompt("test-prompt")
103103

104104

105-
def test_get_prompt_id_from_name_success(mock_logger: Mock, mock_env: Mock):
105+
def test_get_prompt_id_from_name_success(mock_env: Mock):
106106
# set up mocks
107107
mock_client = MagicMock()
108108
mock_client.list_prompts.return_value = {"promptSummaries": [{"name": "test-prompt", "id": "ABC1234567"}]}
@@ -113,13 +113,13 @@ def test_get_prompt_id_from_name_success(mock_logger: Mock, mock_env: Mock):
113113
from app.services.prompt_loader import get_prompt_id_from_name
114114

115115
# perform operation
116-
result = get_prompt_id_from_name(mock_logger, mock_client, "test-prompt")
116+
result = get_prompt_id_from_name(mock_client, "test-prompt")
117117

118118
# assertions
119119
assert result == "ABC1234567"
120120

121121

122-
def test_get_prompt_id_from_name_not_found(mock_logger: Mock, mock_env: Mock):
122+
def test_get_prompt_id_from_name_not_found(mock_env: Mock):
123123
# set up mocks
124124
mock_client = MagicMock()
125125
mock_client.list_prompts.return_value = {"promptSummaries": []}
@@ -130,7 +130,7 @@ def test_get_prompt_id_from_name_not_found(mock_logger: Mock, mock_env: Mock):
130130
from app.services.prompt_loader import get_prompt_id_from_name
131131

132132
# perform operation
133-
result = get_prompt_id_from_name(mock_logger, mock_client, "nonexistent")
133+
result = get_prompt_id_from_name(mock_client, "nonexistent")
134134

135135
# assertions
136136
assert result is None
@@ -147,7 +147,7 @@ def test_get_prompt_id_client_error(mock_logger: Mock, mock_env: Mock):
147147
del sys.modules["app.services.prompt_loader"]
148148
from app.services.prompt_loader import get_prompt_id_from_name
149149

150-
result = get_prompt_id_from_name(mock_logger, mock_client, "test-prompt")
150+
result = get_prompt_id_from_name(mock_client, "test-prompt")
151151

152152
# assertions
153153
assert result is None

packages/slackBotFunction/tests/test_query_reformulator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ def mock_logger():
1111

1212
@patch("app.services.prompt_loader.load_prompt")
1313
@patch("app.services.bedrock.invoke_model")
14-
def test_reformulate_query_returns_string(
15-
mock_invoke_model: Mock, mock_load_prompt: Mock, mock_logger: Mock, mock_env: Mock
16-
):
14+
def test_reformulate_query_returns_string(mock_invoke_model: Mock, mock_load_prompt: Mock, mock_env: Mock):
1715
"""Test that reformulate_query returns a string without crashing"""
1816
# set up mocks
1917
mock_load_prompt.return_value = "Test reformat. {{user_query}}"
@@ -25,7 +23,7 @@ def test_reformulate_query_returns_string(
2523
from app.services.query_reformulator import reformulate_query
2624

2725
# perform operation
28-
result = reformulate_query(mock_logger, "How do I use EPS?")
26+
result = reformulate_query("How do I use EPS?")
2927

3028
# assertions
3129
# Function should return a string (either reformulated or fallback to original)
@@ -39,7 +37,7 @@ def test_reformulate_query_returns_string(
3937

4038

4139
@patch("app.services.prompt_loader.load_prompt")
42-
def test_reformulate_query_prompt_load_error(mock_load_prompt: Mock, mock_logger: Mock, mock_env: Mock):
40+
def test_reformulate_query_prompt_load_error(mock_load_prompt: Mock, mock_env: Mock):
4341
# set up mocks
4442
mock_load_prompt.side_effect = Exception("Prompt not found")
4543

@@ -50,17 +48,15 @@ def test_reformulate_query_prompt_load_error(mock_load_prompt: Mock, mock_logger
5048

5149
# perform operation
5250
original_query = "How do I use EPS?"
53-
result = reformulate_query(mock_logger, original_query)
51+
result = reformulate_query(original_query)
5452

5553
# assertions
5654
assert result == original_query
5755

5856

5957
@patch("app.services.prompt_loader.load_prompt")
6058
@patch("app.services.bedrock.invoke_model")
61-
def test_reformulate_query_bedrock_error(
62-
mock_invoke_model: Mock, mock_load_prompt: Mock, mock_logger: Mock, mock_env: Mock
63-
):
59+
def test_reformulate_query_bedrock_error(mock_invoke_model: Mock, mock_load_prompt: Mock, mock_env: Mock):
6460
"""Test query reformulation with Bedrock API error"""
6561
# set up mocks
6662
mock_load_prompt.return_value = "Reformulate this query: {{user_query}}"
@@ -72,7 +68,7 @@ def test_reformulate_query_bedrock_error(
7268
from app.services.query_reformulator import reformulate_query
7369

7470
# perform operation
75-
result = reformulate_query(mock_logger, "original query")
71+
result = reformulate_query("original query")
7672

7773
# assertions
7874
assert result == "original query"

0 commit comments

Comments
 (0)