Skip to content

Commit cc52c9a

Browse files
Update: [AEA-5873] - direct lambda invocation refactor of the slackbot (#131)
## Summary :sparkles: New Feature ### Details added direct invocation to the slac bot lambda function to support upcoming regression testing comes with the following work: - direct invocation handler - moves shared logic - created `ai_processor.py` service to avoid code duplicating - types - there's already a good amount of typing in the codebase for this lambda function, but started extracting them to a more centralised location for reusability - tests
1 parent ff734ff commit cc52c9a

File tree

9 files changed

+653
-49
lines changed

9 files changed

+653
-49
lines changed

packages/cdk/stacks/EpsAssistMeStack.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ export class EpsAssistMeStack extends Stack {
183183
exportName: `${props.stackName}:lambda:SlackBot:ExecutionRole:Arn`
184184
})
185185

186+
new CfnOutput(this, "SlackBotLambdaArn", {
187+
value: functions.slackBotLambda.function.functionArn,
188+
exportName: `${props.stackName}:lambda:SlackBot:Arn`
189+
})
190+
186191
if (isPullRequest) {
187192
new CfnOutput(this, "VERSION_NUMBER", {
188193
value: props.version,
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
types for direct lambda invocation - defines contracts for bypassing slack
3+
4+
centralizes all type definitions for direct invocation flow to avoid scattered
5+
inline type hints across handlers and processors.
6+
"""
7+
8+
from typing import Any, TypedDict, Literal
9+
from datetime import datetime, timezone
10+
11+
12+
class DirectInvocationRequest(TypedDict, total=False):
13+
"""payload contract for direct lambda calls - bypasses slack entirely"""
14+
15+
invocation_type: Literal["direct"]
16+
query: str
17+
session_id: str | None # conversation continuity across calls
18+
19+
20+
class DirectInvocationResponseData(TypedDict):
21+
"""successful ai response payload - matches slack handler output format"""
22+
23+
text: str
24+
session_id: str | None
25+
citations: list[dict[str, str]] # [{title: str, uri: str}, ...]
26+
timestamp: str # iso8601 with Z suffix
27+
28+
29+
class DirectInvocationErrorData(TypedDict):
30+
"""error response payload - consistent structure for all failure modes"""
31+
32+
error: str
33+
timestamp: str # iso8601 with Z suffix
34+
35+
36+
class DirectInvocationResponse(TypedDict):
37+
"""complete lambda response envelope - includes status code + payload"""
38+
39+
statusCode: int
40+
response: DirectInvocationResponseData | DirectInvocationErrorData
41+
42+
43+
class AIProcessorResponse(TypedDict):
44+
"""ai processor output - shared between slack and direct invocation"""
45+
46+
text: str
47+
session_id: str | None
48+
citations: list[dict[str, str]]
49+
# TODO: ensure proper typing for bedrock response when refactoring other types in the future
50+
kb_response: dict[str, Any] # raw bedrock data for slack session handling
51+
52+
53+
# type guards for runtime validation
54+
def is_valid_direct_request(event: dict[str, Any]) -> bool:
55+
"""validate direct invocation payload structure"""
56+
return (
57+
event.get("invocation_type") == "direct"
58+
and isinstance(event.get("query"), str)
59+
and bool(event.get("query", "").strip()) # non-empty after whitespace removal
60+
)
61+
62+
63+
def create_success_response(
64+
text: str, session_id: str | None, citations: list[dict[str, str]]
65+
) -> DirectInvocationResponse:
66+
"""factory for successful direct invocation responses"""
67+
return {
68+
"statusCode": 200,
69+
"response": {
70+
"text": text,
71+
"session_id": session_id,
72+
"citations": citations,
73+
"timestamp": datetime.now(timezone.utc).isoformat(),
74+
},
75+
}
76+
77+
78+
def create_error_response(status_code: int, error_message: str) -> DirectInvocationResponse:
79+
"""factory for error responses - ensures consistent timestamp format"""
80+
return {
81+
"statusCode": status_code,
82+
"response": {
83+
"error": error_message,
84+
"timestamp": datetime.now(timezone.utc).isoformat(),
85+
},
86+
}

packages/slackBotFunction/app/handler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99
from slack_bolt.adapter.aws_lambda import SlackRequestHandler
1010
from aws_lambda_powertools.utilities.typing import LambdaContext
1111

12+
from typing import Any
13+
1214
from app.core.config import get_logger
15+
from app.core.types import (
16+
DirectInvocationResponse,
17+
is_valid_direct_request,
18+
create_success_response,
19+
create_error_response,
20+
)
1321
from app.services.app import get_app
1422
from app.slack.slack_events import process_pull_request_slack_action, process_pull_request_slack_event
1523

@@ -33,6 +41,10 @@ def handler(event: dict, context: LambdaContext) -> dict:
3341
When subsequent actions or events are processed, this is looked up, and if it exists, then the pull request lambda
3442
is triggered with either pull_request_event or pull_request_action
3543
"""
44+
# direct invocation bypasses slack infrastructure entirely
45+
if event.get("invocation_type") == "direct":
46+
return handle_direct_invocation(event, context)
47+
3648
app = get_app(logger=logger)
3749
# handle pull request processing requests
3850
if event.get("pull_request_event"):
@@ -55,3 +67,28 @@ def handler(event: dict, context: LambdaContext) -> dict:
5567
# handle Slack webhook requests
5668
slack_handler = SlackRequestHandler(app=app)
5769
return slack_handler.handle(event=event, context=context)
70+
71+
72+
def handle_direct_invocation(event: dict[str, Any], context: LambdaContext) -> DirectInvocationResponse:
73+
"""direct lambda invocation for ai assistance - bypasses slack entirely"""
74+
try:
75+
# validate request structure using type guard
76+
if not is_valid_direct_request(event):
77+
return create_error_response(400, "Missing required field: query")
78+
79+
query = event["query"]
80+
session_id = event.get("session_id")
81+
82+
# shared logic: same AI processing as slack handlers use
83+
from app.services.ai_processor import process_ai_query
84+
85+
ai_response = process_ai_query(query, session_id)
86+
87+
return create_success_response(
88+
text=ai_response["text"],
89+
session_id=ai_response["session_id"],
90+
citations=ai_response["citations"],
91+
)
92+
except Exception as e:
93+
logger.error(f"Error in direct invocation: {e}")
94+
return create_error_response(500, "Internal server error")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
shared AI processing service - extracted to avoid duplication
3+
4+
both slack handlers and direct invocation use identical logic for query
5+
reformulation and bedrock interaction. single source of truth for AI flows.
6+
"""
7+
8+
from app.services.bedrock import query_bedrock
9+
from app.services.query_reformulator import reformulate_query
10+
from app.core.config import get_logger
11+
from app.core.types import AIProcessorResponse
12+
13+
logger = get_logger()
14+
15+
16+
def process_ai_query(user_query: str, session_id: str | None = None) -> AIProcessorResponse:
17+
"""shared AI processing logic for both slack and direct invocation"""
18+
# reformulate: improves vector search quality in knowledge base
19+
reformulated_query = reformulate_query(user_query)
20+
21+
# session_id enables conversation continuity across multiple queries
22+
kb_response = query_bedrock(reformulated_query, session_id)
23+
24+
return {
25+
"text": kb_response["output"]["text"],
26+
"session_id": kb_response.get("sessionId"),
27+
"citations": kb_response.get("citations", []),
28+
"kb_response": kb_response, # slack needs raw bedrock data for session handling
29+
}

packages/slackBotFunction/app/slack/slack_events.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
get_bot_token,
1717
get_logger,
1818
)
19-
from app.services.bedrock import query_bedrock
19+
2020
from app.services.dynamo import (
2121
delete_state_information,
2222
get_state_information,
2323
store_state_information,
2424
update_state_information,
2525
)
26-
from app.services.query_reformulator import reformulate_query
26+
2727
from app.services.slack import get_friendly_channel_name, post_error_message
2828
from app.utils.handler_utils import (
2929
conversation_key_and_root,
@@ -34,6 +34,9 @@
3434
strip_mentions,
3535
)
3636

37+
from app.services.ai_processor import process_ai_query
38+
39+
3740
logger = get_logger()
3841

3942

@@ -321,16 +324,13 @@ def process_slack_message(event: Dict[str, Any], event_id: str, client: WebClien
321324
client.chat_postMessage(**post_params)
322325
return
323326

324-
# Reformulate query for better RAG retrieval
325-
reformulated_query = reformulate_query(user_query)
326-
327-
# Check if we have an existing Bedrock conversation session
327+
# conversation continuity: reuse bedrock session across slack messages
328328
session_data = get_conversation_session_data(conversation_key)
329329
session_id = session_data.get("session_id") if session_data else None
330330

331-
# Query Bedrock Knowledge Base with conversation context
332-
kb_response = query_bedrock(reformulated_query, session_id)
333-
response_text = kb_response["output"]["text"]
331+
ai_response = process_ai_query(user_query, session_id)
332+
kb_response = ai_response["kb_response"]
333+
response_text = ai_response["text"]
334334

335335
# Post the answer (plain) to get message_ts
336336
post_params = {"channel": channel, "text": response_text}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""shared ai processor - validates query reformulation and bedrock integration"""
2+
3+
import pytest
4+
from unittest.mock import patch
5+
from app.services.ai_processor import process_ai_query
6+
7+
8+
class TestAIProcessor:
9+
10+
@patch("app.services.ai_processor.query_bedrock")
11+
@patch("app.services.ai_processor.reformulate_query")
12+
def test_process_ai_query_without_session(self, mock_reformulate, mock_bedrock):
13+
"""new conversation: no session context passed to bedrock"""
14+
mock_reformulate.return_value = "reformulated: How to authenticate EPS API?"
15+
mock_bedrock.return_value = {
16+
"output": {"text": "To authenticate with EPS API, you need..."},
17+
"sessionId": "new-session-abc123",
18+
"citations": [{"title": "EPS Authentication Guide", "uri": "https://example.com/auth"}],
19+
}
20+
21+
result = process_ai_query("How to authenticate EPS API?")
22+
23+
assert result["text"] == "To authenticate with EPS API, you need..."
24+
assert result["session_id"] == "new-session-abc123"
25+
assert len(result["citations"]) == 1
26+
assert result["citations"][0]["title"] == "EPS Authentication Guide"
27+
assert "kb_response" in result
28+
29+
mock_reformulate.assert_called_once_with("How to authenticate EPS API?")
30+
mock_bedrock.assert_called_once_with("reformulated: How to authenticate EPS API?", None)
31+
32+
@patch("app.services.ai_processor.query_bedrock")
33+
@patch("app.services.ai_processor.reformulate_query")
34+
def test_process_ai_query_with_session(self, mock_reformulate, mock_bedrock):
35+
"""conversation continuity: existing session maintained across queries"""
36+
mock_reformulate.return_value = "reformulated: What about rate limits?"
37+
mock_bedrock.return_value = {
38+
"output": {"text": "EPS API has rate limits of..."},
39+
"sessionId": "existing-session-456",
40+
"citations": [],
41+
}
42+
43+
result = process_ai_query("What about rate limits?", session_id="existing-session-456")
44+
45+
assert result["text"] == "EPS API has rate limits of..."
46+
assert result["session_id"] == "existing-session-456"
47+
assert result["citations"] == []
48+
assert "kb_response" in result
49+
50+
mock_reformulate.assert_called_once_with("What about rate limits?")
51+
mock_bedrock.assert_called_once_with("reformulated: What about rate limits?", "existing-session-456")
52+
53+
@patch("app.services.ai_processor.query_bedrock")
54+
@patch("app.services.ai_processor.reformulate_query")
55+
def test_process_ai_query_reformulate_error(self, mock_reformulate, mock_bedrock):
56+
"""graceful degradation: reformulation failure bubbles up"""
57+
mock_reformulate.side_effect = Exception("Query reformulation failed")
58+
59+
with pytest.raises(Exception) as exc_info:
60+
process_ai_query("How to authenticate EPS API?")
61+
62+
assert "Query reformulation failed" in str(exc_info.value)
63+
mock_bedrock.assert_not_called()
64+
65+
@patch("app.services.ai_processor.query_bedrock")
66+
@patch("app.services.ai_processor.reformulate_query")
67+
def test_process_ai_query_bedrock_error(self, mock_reformulate, mock_bedrock):
68+
"""bedrock service failure: error propagated to caller"""
69+
mock_reformulate.return_value = "reformulated query"
70+
mock_bedrock.side_effect = Exception("Bedrock service error")
71+
72+
with pytest.raises(Exception) as exc_info:
73+
process_ai_query("How to authenticate EPS API?")
74+
75+
assert "Bedrock service error" in str(exc_info.value)
76+
mock_reformulate.assert_called_once()
77+
78+
@patch("app.services.ai_processor.query_bedrock")
79+
@patch("app.services.ai_processor.reformulate_query")
80+
def test_process_ai_query_missing_citations(self, mock_reformulate, mock_bedrock):
81+
"""bedrock response incomplete: citations default to empty list"""
82+
mock_reformulate.return_value = "reformulated query"
83+
mock_bedrock.return_value = {
84+
"output": {"text": "Response without citations"},
85+
"sessionId": "session-123",
86+
# No citations key
87+
}
88+
89+
result = process_ai_query("test query")
90+
91+
assert result["text"] == "Response without citations"
92+
assert result["session_id"] == "session-123"
93+
assert result["citations"] == [] # safe default when bedrock omits citations
94+
95+
@patch("app.services.ai_processor.query_bedrock")
96+
@patch("app.services.ai_processor.reformulate_query")
97+
def test_process_ai_query_missing_session_id(self, mock_reformulate, mock_bedrock):
98+
"""bedrock response incomplete: session_id properly handles None"""
99+
mock_reformulate.return_value = "reformulated query"
100+
mock_bedrock.return_value = {
101+
"output": {"text": "Response without session"},
102+
"citations": [],
103+
# No sessionId key
104+
}
105+
106+
result = process_ai_query("test query")
107+
108+
assert result["text"] == "Response without session"
109+
assert result["session_id"] is None # explicit None when bedrock omits sessionId
110+
assert result["citations"] == []
111+
112+
@patch("app.services.ai_processor.query_bedrock")
113+
@patch("app.services.ai_processor.reformulate_query")
114+
def test_process_ai_query_empty_query(self, mock_reformulate, mock_bedrock):
115+
"""edge case: empty query still processed through full pipeline"""
116+
mock_reformulate.return_value = ""
117+
mock_bedrock.return_value = {
118+
"output": {"text": "Please provide a question"},
119+
"sessionId": "session-empty",
120+
"citations": [],
121+
}
122+
123+
result = process_ai_query("")
124+
125+
assert result["text"] == "Please provide a question"
126+
mock_reformulate.assert_called_once_with("")
127+
mock_bedrock.assert_called_once_with("", None)
128+
129+
@patch("app.services.ai_processor.query_bedrock")
130+
@patch("app.services.ai_processor.reformulate_query")
131+
def test_process_ai_query_includes_raw_response(self, mock_reformulate, mock_bedrock):
132+
"""slack needs raw bedrock data: kb_response preserved for session handling"""
133+
mock_reformulate.return_value = "reformulated query"
134+
raw_response = {
135+
"output": {"text": "Test response"},
136+
"sessionId": "test-123",
137+
"citations": [{"title": "Test", "uri": "test.com"}],
138+
"metadata": {"some": "extra_data"},
139+
}
140+
mock_bedrock.return_value = raw_response
141+
142+
result = process_ai_query("test query")
143+
144+
assert result["kb_response"] == raw_response
145+
assert result["kb_response"]["metadata"]["some"] == "extra_data"

0 commit comments

Comments
 (0)