Skip to content

Commit 86ce4e7

Browse files
authored
Merge pull request #33 from dwhswenson/fix-resulttype-injection
Inject `result_type` on publish to SNS, not on read
2 parents 0f882be + c699223 commit 86ce4e7

File tree

12 files changed

+227
-62
lines changed

12 files changed

+227
-62
lines changed

docs/explanation/data-flow.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ Let's say the lambda runs a few tasks, which either pass (status `OK`) or fail (
2626

2727
## Phase 2: What Gets Published to SNS
2828

29-
We publish one SNS message per status type, so in this case two messages: one for `OK` and one for `ERROR`. Each message includes the relevant portion of the `_perform_task` return value (JSON-encoded) as the `Message` payload. The `result_type` is included as a `MessageAttribute` to allow for filtering by the SQS subscriptions.
29+
We publish one SNS message per status type, so in this case two messages: one for `OK` and one for `ERROR`. Each message includes the relevant portion of the `_perform_task` return value, with a top-level `result_type` added before publication. The same `result_type` is also included as a `MessageAttribute` to allow for filtering by the SQS subscriptions.
3030

3131
For `OK`, the publish call payload shape is:
3232

3333
```json
3434
{
3535
"TopicArn": "arn:aws:sns:us-east-1:123456789012:lambdacron-results.fifo",
36-
"Message": "{\"tasks\": [{\"taskid\": 1, \"name\": \"Foo\"}, {\"taskid\": 2, \"name\": \"Bar\"}]}",
36+
"Message": "{\"tasks\": [{\"taskid\": 1, \"name\": \"Foo\"}, {\"taskid\": 2, \"name\": \"Bar\"}], \"result_type\": \"OK\"}",
3737
"Subject": "Notification for OK",
3838
"MessageAttributes": {
3939
"result_type": {
@@ -47,7 +47,7 @@ For `OK`, the publish call payload shape is:
4747

4848
For `ERROR`, the shape is identical except:
4949

50-
- `Message` is `"{\"tasks\": [{\"taskid\": 3, \"name\": \"Baz\"}]}"`
50+
- `Message` is `"{\"tasks\": [{\"taskid\": 3, \"name\": \"Baz\"}], \"result_type\": \"ERROR\"}"`
5151
- `Subject` is `"Notification for ERROR"`
5252
- `MessageAttributes.result_type.StringValue` is `"ERROR"`
5353

@@ -63,7 +63,7 @@ The SQS event record (as seen by notifier Lambda) looks like:
6363
{
6464
"messageId": "msg-ok-1",
6565
"eventSource": "aws:sqs",
66-
"body": "{\"tasks\": [{\"taskid\": 1, \"name\": \"Foo\"}, {\"taskid\": 2, \"name\": \"Bar\"}]}",
66+
"body": "{\"tasks\": [{\"taskid\": 1, \"name\": \"Foo\"}, {\"taskid\": 2, \"name\": \"Bar\"}], \"result_type\": \"OK\"}",
6767
"messageAttributes": {
6868
"result_type": {
6969
"stringValue": "OK",
@@ -77,8 +77,7 @@ The SQS event record (as seen by notifier Lambda) looks like:
7777

7878
## Phase 4: Notifier Parse Behavior with This Shape
7979

80-
The notifier parser converts the JSON string back into an object, and, if `result_type` is missing from the payload, injects it from the SQS `messageAttributes`. In this case, the payload doesn't include `result_type`, so it is injected.
81-
80+
The notifier parser converts the JSON string back into an object and validates that the payload `result_type` matches the SQS `messageAttributes.result_type` value when that attribute is present.
8281
The result, which is fed to the notifier's templates, is:
8382

8483
```json

docs/how-to/set-up-ses.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ Set up Amazon SES once per AWS account and region where your notifier Lambda run
1212

1313
* Choose the AWS account/region where email will be sent (SES setup is regional).
1414
* Choose a sender identity:
15+
1516
* Email identity for one sender address.
1617
* Domain identity if you want to send from multiple addresses in one domain.
18+
1719
* Make sure you can edit DNS records if you choose a domain identity.
1820

1921
SES requires verified identities for senders and (in sandbox) recipients. That means you must verify the email address or domain you want to send from, and if in sandbox, also verify any recipient addresses. Sandbox mode limits how much you can send, and you'll probably want to request production access if you want to send to more than one recipient.

docs/how-to/write-lambda-and-templates.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Based on `examples/basic/lambda/lambda_module.py`:
3333

3434
* Key (`example`) is the `result_type`.
3535
* Value (`{"message": "Hello World"}`) is the payload rendered by templates.
36-
* The payload automatically has the key `result_type` injected by the notification handler, so you can access it in templates as `{{ result_type }}`.
36+
* LambdaCron adds `result_type` to the published message body when it sends the payload to SNS, so you can access it in templates as `{{ result_type }}`.
3737

3838
## 2. Create templates that use the payload fields
3939

@@ -90,7 +90,8 @@ At runtime:
9090

9191
Important detail:
9292

93-
* `result_type` is injected into the render payload by the notification handler when available in message attributes.
93+
* `result_type` is injected into the message body by the publisher before it reaches the notifier.
94+
* The SNS message attribute still carries the same `result_type`. It is used for filter policies and validation.
9495
* That is why templates like `{{ result_type }}` work even when your `_perform_task` payload does not explicitly include a `result_type` field.
9596

9697
## 4. Checklist Before Deploying

src/lambdacron/lambda_task.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,46 @@ def load_sns_message_group_id(env_var: str = "SNS_MESSAGE_GROUP_ID") -> str:
158158
return os.environ.get(env_var, "lambdacron")
159159

160160

161+
def build_result_message_payload(*, result_type: str, message: Any) -> dict[str, Any]:
162+
"""
163+
Build the JSON object published to SNS for a single result type.
164+
165+
Parameters
166+
----------
167+
result_type : str
168+
Result type key from the task output mapping.
169+
message : Any
170+
JSON-serializable payload associated with the result type.
171+
172+
Returns
173+
-------
174+
dict[str, Any]
175+
Payload object with a top-level ``result_type`` field.
176+
177+
Raises
178+
------
179+
ValueError
180+
If the payload is not a JSON object or if it contains a conflicting
181+
``result_type`` value.
182+
"""
183+
if not isinstance(message, Mapping):
184+
raise ValueError(
185+
f"Result payload for type '{result_type}' must be a JSON object"
186+
)
187+
188+
payload = dict(message)
189+
existing_result_type = payload.get("result_type")
190+
if existing_result_type is None:
191+
payload["result_type"] = result_type
192+
else:
193+
if existing_result_type != result_type:
194+
raise ValueError(
195+
f"Result payload for type '{result_type}' has conflicting "
196+
f"result_type '{existing_result_type}'"
197+
)
198+
return payload
199+
200+
161201
def dispatch_sns_messages(
162202
*,
163203
result: Mapping[str, Any],
@@ -180,9 +220,10 @@ def dispatch_sns_messages(
180220
Logger used to emit structured publish logs.
181221
"""
182222
for result_type, message in result.items():
223+
payload = build_result_message_payload(result_type=result_type, message=message)
183224
sns_client.publish(
184225
TopicArn=sns_topic_arn,
185-
Message=json.dumps(message),
226+
Message=json.dumps(payload),
186227
Subject=f"Notification for {result_type}",
187228
MessageAttributes={
188229
"result_type": {

src/lambdacron/notifications/base.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ class RenderedTemplateNotificationHandler(ABC):
9494
Providers keyed by template name for rendering.
9595
expected_queue_arn : str, optional
9696
Queue ARN to validate incoming SQS records.
97-
include_result_type : bool, optional
98-
Whether to include the SNS message attribute ``result_type`` in the payload.
9997
logger : logging.Logger, optional
10098
Logger used for structured logging.
10199
jinja_env : jinja2.Environment, optional
@@ -107,13 +105,11 @@ def __init__(
107105
template_providers: Mapping[str, TemplateProvider],
108106
*,
109107
expected_queue_arn: Optional[str] = None,
110-
include_result_type: bool = True,
111108
logger: Optional[logging.Logger] = None,
112109
jinja_env: Optional[Environment] = None,
113110
) -> None:
114111
self.template_providers = dict(template_providers)
115112
self.expected_queue_arn = expected_queue_arn
116-
self.include_result_type = include_result_type
117113
self.logger = logger or logging.getLogger(self.__class__.__name__)
118114
self.jinja_env = jinja_env or Environment(undefined=StrictUndefined)
119115

@@ -213,10 +209,18 @@ def _parse_result(self, record: Mapping[str, Any]) -> Mapping[str, Any]:
213209
raise ValueError("SNS message must be valid JSON") from exc
214210
if not isinstance(payload, dict):
215211
raise ValueError("Result payload must be a JSON object")
216-
if self.include_result_type:
217-
result_type = self._extract_result_type(record)
218-
if result_type and "result_type" not in payload:
219-
payload["result_type"] = result_type
212+
payload_result_type = payload.get("result_type")
213+
if not isinstance(payload_result_type, str) or not payload_result_type:
214+
raise ValueError(
215+
"Result payload must include a non-empty string result_type"
216+
)
217+
218+
attribute_result_type = self._extract_result_type(record)
219+
if attribute_result_type and attribute_result_type != payload_result_type:
220+
raise ValueError(
221+
"Result type mismatch between payload and message attributes "
222+
f"(payload {payload_result_type}, attribute {attribute_result_type})"
223+
)
220224
return payload
221225

222226
def _render_template(self, template: str, result: Mapping[str, Any]) -> str:

src/lambdacron/notifications/print_handler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,11 @@ def __init__(
2525
*,
2626
template_provider: TemplateProvider,
2727
expected_queue_arn: str | None = None,
28-
include_result_type: bool = True,
2928
logger: Any | None = None,
3029
) -> None:
3130
super().__init__(
3231
template_providers={"body": template_provider},
3332
expected_queue_arn=expected_queue_arn,
34-
include_result_type=include_result_type,
3533
logger=logger,
3634
)
3735

src/lambdacron/render.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from jinja2 import TemplateError
88

9+
from lambdacron.lambda_task import build_result_message_payload
910
from lambdacron.notifications.base import (
1011
FileTemplateProvider,
1112
RenderedTemplateNotificationHandler,
@@ -46,8 +47,7 @@ def build_parser() -> argparse.ArgumentParser:
4647
class RenderNotificationHandler(RenderedTemplateNotificationHandler):
4748
def __init__(self, *, template_path: Path, stream: TextIO | None = None) -> None:
4849
super().__init__(
49-
template_providers={"body": FileTemplateProvider(template_path)},
50-
include_result_type=True,
50+
template_providers={"body": FileTemplateProvider(template_path)}
5151
)
5252
self.stream = stream or sys.stdout
5353

@@ -95,11 +95,15 @@ def extract_result_payload(payload_json: str, *, result_type: str) -> str:
9595
if not isinstance(payload, dict):
9696
raise ValueError("Task output must be a JSON object keyed by result type")
9797
selected = payload.get(result_type)
98-
if not isinstance(selected, dict):
98+
if not isinstance(selected, Mapping):
9999
raise ValueError(
100-
f"Result payload for type '{result_type}' must be a JSON object"
100+
f"Result payload for type '{result_type}' must be a JSON object, "
101+
f"got {type(selected).__name__}"
101102
)
102-
return json.dumps(selected)
103+
payload_for_publish = build_result_message_payload(
104+
result_type=result_type, message=selected
105+
)
106+
return json.dumps(payload_for_publish)
103107

104108

105109
def main(argv: list[str] | None = None) -> int:

tests/notifications/test_base.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ def test_file_template_provider_raises_for_missing_file(tmp_path):
7171
def test_notification_handler_parses_sqs_json_body(monkeypatch):
7272
monkeypatch.setenv("TEMPLATE", "Status {{ status }}")
7373
handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()})
74-
event = build_sqs_event(json.dumps({"status": "ok"}))
74+
event = build_sqs_event(json.dumps({"status": "ok", "result_type": "success"}))
7575

7676
response = handler.lambda_handler(event, context=None)
7777

7878
assert handler.calls == [
7979
{
80-
"result": {"status": "ok"},
80+
"result": {"status": "ok", "result_type": "success"},
8181
"rendered": {"body": "Status ok"},
8282
"record": event["Records"][0],
8383
}
@@ -88,7 +88,9 @@ def test_notification_handler_parses_sqs_json_body(monkeypatch):
8888
def test_notification_handler_parses_sns_envelope(monkeypatch):
8989
monkeypatch.setenv("TEMPLATE", "Result {{ status }}")
9090
handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()})
91-
sns_body = json.dumps({"Message": json.dumps({"status": "good"})})
91+
sns_body = json.dumps(
92+
{"Message": json.dumps({"status": "good", "result_type": "success"})}
93+
)
9294
event = build_sqs_event(sns_body)
9395

9496
response = handler.lambda_handler(event, context=None)
@@ -141,8 +143,14 @@ def test_notification_handler_logs_invocation(monkeypatch, caplog):
141143
)
142144
event = {
143145
"Records": [
144-
{"body": json.dumps({"name": "Ada"}), "eventSource": "aws:sqs"},
145-
{"body": json.dumps({"name": "Grace"}), "eventSource": "aws:sqs"},
146+
{
147+
"body": json.dumps({"name": "Ada", "result_type": "success"}),
148+
"eventSource": "aws:sqs",
149+
},
150+
{
151+
"body": json.dumps({"name": "Grace", "result_type": "success"}),
152+
"eventSource": "aws:sqs",
153+
},
146154
]
147155
}
148156

@@ -204,30 +212,41 @@ def test_parse_result_rejects_non_object_payload(monkeypatch):
204212
assert response == {"batchItemFailures": [{"itemIdentifier": "msg-123"}]}
205213

206214

207-
@pytest.mark.parametrize("include_result_type", [True, False])
208-
@pytest.mark.parametrize("payload_has_result_type", [True, False])
209-
def test_notification_handler_result_type_injection(
210-
monkeypatch, include_result_type, payload_has_result_type
211-
):
215+
def test_notification_handler_payload_result_type_passes_through(monkeypatch):
216+
monkeypatch.setenv("TEMPLATE", "Result {{ result_type }}")
217+
handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()})
218+
payload = {"status": "ok", "result_type": "payload"}
219+
event = build_sqs_event(
220+
json.dumps(payload),
221+
message_attributes={"result_type": {"stringValue": "payload"}},
222+
)
223+
224+
handler.lambda_handler(event, context=None)
225+
226+
assert handler.calls[0]["result"] == payload
227+
228+
229+
def test_notification_handler_requires_payload_result_type(monkeypatch):
212230
monkeypatch.setenv("TEMPLATE", "Result {{ result_type | default('none') }}")
213-
handler = CapturingHandler(
214-
template_providers={"body": EnvVarTemplateProvider()},
215-
include_result_type=include_result_type,
231+
handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()})
232+
event = build_sqs_event(
233+
json.dumps({"status": "ok"}),
234+
message_attributes={"result_type": {"stringValue": "attribute"}},
216235
)
217-
payload = {"status": "ok"}
218-
if payload_has_result_type:
219-
payload["result_type"] = "payload"
236+
237+
response = handler.lambda_handler(event, context=None)
238+
239+
assert response == {"batchItemFailures": [{"itemIdentifier": "msg-123"}]}
240+
241+
242+
def test_notification_handler_rejects_result_type_mismatch(monkeypatch):
243+
monkeypatch.setenv("TEMPLATE", "Result {{ result_type }}")
244+
handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()})
220245
event = build_sqs_event(
221-
json.dumps(payload),
246+
json.dumps({"status": "ok", "result_type": "payload"}),
222247
message_attributes={"result_type": {"stringValue": "attribute"}},
223248
)
224249

225-
handler.lambda_handler(event, context=None)
250+
response = handler.lambda_handler(event, context=None)
226251

227-
result = handler.calls[0]["result"]
228-
if payload_has_result_type:
229-
assert result["result_type"] == "payload"
230-
elif include_result_type:
231-
assert result["result_type"] == "attribute"
232-
else:
233-
assert "result_type" not in result
252+
assert response == {"batchItemFailures": [{"itemIdentifier": "msg-123"}]}

tests/notifications/test_email_handler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_email_handler_sends_rendered_templates(monkeypatch):
4040
recipients=["alice@example.com", "bob@example.com"],
4141
ses_client=ses_client,
4242
)
43-
event = build_sqs_event({"name": "Ada"})
43+
event = build_sqs_event({"name": "Ada", "result_type": "success"})
4444

4545
handler.lambda_handler(event, context=None)
4646

@@ -71,7 +71,7 @@ def test_email_handler_includes_optional_fields(monkeypatch):
7171
config_set="alerts",
7272
reply_to=["reply@example.com"],
7373
)
74-
event = build_sqs_event({"name": "Grace"})
74+
event = build_sqs_event({"name": "Grace", "result_type": "success"})
7575

7676
handler.lambda_handler(event, context=None)
7777

@@ -102,7 +102,9 @@ def send_email(self, **kwargs):
102102
recipients=["ops@example.com"],
103103
ses_client=ErrorSesClient(),
104104
)
105-
event = build_sqs_event({"name": "Ada"}, message_id="msg-err")
105+
event = build_sqs_event(
106+
{"name": "Ada", "result_type": "success"}, message_id="msg-err"
107+
)
106108

107109
response = handler.lambda_handler(event, context=None)
108110

tests/notifications/test_print_handler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ def test_print_handler_prints_rendered_template(monkeypatch, capsys):
77
monkeypatch.setenv("TEMPLATE", "Hello {{ name }}")
88
handler = PrintNotificationHandler(template_provider=EnvVarTemplateProvider())
99
event = {
10-
"Records": [{"body": json.dumps({"name": "Ada"}), "eventSource": "aws:sqs"}]
10+
"Records": [
11+
{
12+
"body": json.dumps({"name": "Ada", "result_type": "success"}),
13+
"eventSource": "aws:sqs",
14+
}
15+
]
1116
}
1217

1318
handler.lambda_handler(event, context=None)

0 commit comments

Comments
 (0)