Skip to content

Commit 41bc401

Browse files
committed
feat(bedrock_agent): create bedrock agents functions data class
1 parent 668db82 commit 41bc401

File tree

4 files changed

+238
-0
lines changed

4 files changed

+238
-0
lines changed

aws_lambda_powertools/utilities/data_classes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .appsync_resolver_events_event import AppSyncResolverEventsEvent
1010
from .aws_config_rule_event import AWSConfigRuleEvent
1111
from .bedrock_agent_event import BedrockAgentEvent
12+
from .bedrock_agent_function_event import BedrockAgentFunctionEvent
1213
from .cloud_watch_alarm_event import (
1314
CloudWatchAlarmConfiguration,
1415
CloudWatchAlarmData,
@@ -59,6 +60,7 @@
5960
"AppSyncResolverEventsEvent",
6061
"ALBEvent",
6162
"BedrockAgentEvent",
63+
"BedrockAgentFunctionEvent",
6264
"CloudWatchAlarmData",
6365
"CloudWatchAlarmEvent",
6466
"CloudWatchAlarmMetric",
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
6+
7+
8+
class BedrockAgentInfo(DictWrapper):
9+
@property
10+
def name(self) -> str:
11+
return self["name"]
12+
13+
@property
14+
def id(self) -> str: # noqa: A003
15+
return self["id"]
16+
17+
@property
18+
def alias(self) -> str:
19+
return self["alias"]
20+
21+
@property
22+
def version(self) -> str:
23+
return self["version"]
24+
25+
26+
class BedrockAgentFunctionParameter(DictWrapper):
27+
@property
28+
def name(self) -> str:
29+
return self["name"]
30+
31+
@property
32+
def type(self) -> str: # noqa: A003
33+
return self["type"]
34+
35+
@property
36+
def value(self) -> str:
37+
return self["value"]
38+
39+
40+
class BedrockAgentFunctionEvent(DictWrapper):
41+
"""
42+
Bedrock Agent Function input event
43+
44+
Documentation:
45+
https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html
46+
"""
47+
48+
@classmethod
49+
def validate_required_fields(cls, data: dict[str, Any]) -> None:
50+
required_fields = {
51+
"messageVersion": str,
52+
"agent": dict,
53+
"inputText": str,
54+
"sessionId": str,
55+
"actionGroup": str,
56+
"function": str,
57+
}
58+
59+
for field, field_type in required_fields.items():
60+
if field not in data:
61+
raise ValueError(f"Missing required field: {field}")
62+
if not isinstance(data[field], field_type):
63+
raise TypeError(f"Field {field} must be of type {field_type}")
64+
65+
# Validate agent structure
66+
required_agent_fields = {"name", "id", "alias", "version"}
67+
if not all(field in data["agent"] for field in required_agent_fields):
68+
raise ValueError("Agent object missing required fields")
69+
70+
def __init__(self, data: dict[str, Any]) -> None:
71+
super().__init__(data)
72+
self.validate_required_fields(data)
73+
74+
@property
75+
def message_version(self) -> str:
76+
return self["messageVersion"]
77+
78+
@property
79+
def input_text(self) -> str:
80+
return self["inputText"]
81+
82+
@property
83+
def session_id(self) -> str:
84+
return self["sessionId"]
85+
86+
@property
87+
def action_group(self) -> str:
88+
return self["actionGroup"]
89+
90+
@property
91+
def function(self) -> str:
92+
return self["function"]
93+
94+
@property
95+
def parameters(self) -> list[BedrockAgentFunctionParameter]:
96+
parameters = self.get("parameters") or []
97+
return [BedrockAgentFunctionParameter(x) for x in parameters]
98+
99+
@property
100+
def agent(self) -> BedrockAgentInfo:
101+
return BedrockAgentInfo(self["agent"])
102+
103+
@property
104+
def session_attributes(self) -> dict[str, str]:
105+
return self.get("sessionAttributes", {}) or {}
106+
107+
@property
108+
def prompt_session_attributes(self) -> dict[str, str]:
109+
return self.get("promptSessionAttributes", {}) or {}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"messageVersion": "1.0",
3+
"agent": {
4+
"alias": "PROD",
5+
"name": "hr-assistant-function-def",
6+
"version": "1",
7+
"id": "1234abcd"
8+
},
9+
"sessionId": "123456789123458",
10+
"sessionAttributes": {
11+
"employeeId": "EMP123",
12+
"department": "Engineering"
13+
},
14+
"promptSessionAttributes": {
15+
"lastInteraction": "2024-02-01T15:30:00Z",
16+
"requestType": "vacation"
17+
},
18+
"inputText": "I want to request vacation from March 15 to March 20",
19+
"actionGroup": "VacationsActionGroup",
20+
"function": "submitVacationRequest",
21+
"parameters": [
22+
{
23+
"name": "startDate",
24+
"type": "string",
25+
"value": "2024-03-15"
26+
},
27+
{
28+
"name": "endDate",
29+
"type": "string",
30+
"value": "2024-03-20"
31+
}
32+
]
33+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
6+
from tests.functional.utils import load_event
7+
8+
9+
def test_bedrock_agent_function_event():
10+
raw_event = load_event("bedrockAgentFunctionEvent.json")
11+
parsed_event = BedrockAgentFunctionEvent(raw_event)
12+
13+
# Test basic event properties
14+
assert parsed_event.message_version == raw_event["messageVersion"]
15+
assert parsed_event.session_id == raw_event["sessionId"]
16+
assert parsed_event.input_text == raw_event["inputText"]
17+
assert parsed_event.action_group == raw_event["actionGroup"]
18+
assert parsed_event.function == raw_event["function"]
19+
20+
# Test agent information
21+
agent = parsed_event.agent
22+
raw_agent = raw_event["agent"]
23+
assert agent.alias == raw_agent["alias"]
24+
assert agent.name == raw_agent["name"]
25+
assert agent.version == raw_agent["version"]
26+
assert agent.id == raw_agent["id"]
27+
28+
# Test session attributes
29+
assert parsed_event.session_attributes == raw_event["sessionAttributes"]
30+
assert parsed_event.prompt_session_attributes == raw_event["promptSessionAttributes"]
31+
32+
# Test parameters
33+
parameters = parsed_event.parameters
34+
raw_parameters = raw_event["parameters"]
35+
assert len(parameters) == len(raw_parameters)
36+
37+
for param, raw_param in zip(parameters, raw_parameters):
38+
assert param.name == raw_param["name"]
39+
assert param.type == raw_param["type"]
40+
assert param.value == raw_param["value"]
41+
42+
43+
def test_bedrock_agent_function_event_minimal():
44+
"""Test with minimal required fields"""
45+
minimal_event = {
46+
"messageVersion": "1.0",
47+
"agent": {
48+
"alias": "PROD",
49+
"name": "hr-assistant-function-def",
50+
"version": "1",
51+
"id": "1234abcd-56ef-78gh-90ij-klmn12345678",
52+
},
53+
"sessionId": "87654321-abcd-efgh-ijkl-mnop12345678",
54+
"inputText": "I want to request vacation",
55+
"actionGroup": "VacationsActionGroup",
56+
"function": "submitVacationRequest",
57+
}
58+
59+
parsed_event = BedrockAgentFunctionEvent(minimal_event)
60+
61+
assert parsed_event.session_attributes == {}
62+
assert parsed_event.prompt_session_attributes == {}
63+
assert parsed_event.parameters == []
64+
65+
66+
def test_bedrock_agent_function_event_validation():
67+
"""Test validation of required fields"""
68+
# Test missing required field
69+
with pytest.raises(ValueError, match="Missing required field: messageVersion"):
70+
BedrockAgentFunctionEvent({})
71+
72+
# Test invalid field type
73+
invalid_event = {
74+
"messageVersion": 1, # should be string
75+
"agent": {"alias": "PROD", "name": "hr-assistant", "version": "1", "id": "1234"},
76+
"inputText": "",
77+
"sessionId": "",
78+
"actionGroup": "",
79+
"function": "",
80+
}
81+
with pytest.raises(TypeError, match="Field messageVersion must be of type <class 'str'>"):
82+
BedrockAgentFunctionEvent(invalid_event)
83+
84+
# Test missing agent fields
85+
invalid_agent_event = {
86+
"messageVersion": "1.0",
87+
"agent": {"name": "test"}, # missing required agent fields
88+
"inputText": "",
89+
"sessionId": "",
90+
"actionGroup": "",
91+
"function": "",
92+
}
93+
with pytest.raises(ValueError, match="Agent object missing required fields"):
94+
BedrockAgentFunctionEvent(invalid_agent_event)

0 commit comments

Comments
 (0)