Skip to content

Commit a9764a1

Browse files
authored
Merge pull request #31 from neurostuff/fix/lambda_deploy
[FIX] the logging/result fetching endpoints
2 parents 1a466e6 + 8c2a119 commit a9764a1

File tree

8 files changed

+102
-10
lines changed

8 files changed

+102
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
.runner
33
.pytest_cache
44
_version.py
5+
__pycache__
-171 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
-7.34 KB
Binary file not shown.

compose_runner/aws_lambda/log_poll_handler.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import os
44
import time
5+
import base64
6+
import json
57
from typing import Any, Dict, List
68

79
import boto3
@@ -11,9 +13,38 @@
1113
LOG_GROUP_ENV = "RUNNER_LOG_GROUP"
1214
DEFAULT_LOOKBACK_MS_ENV = "DEFAULT_LOOKBACK_MS"
1315

16+
def _is_http_event(event: Any) -> bool:
17+
return isinstance(event, dict) and "requestContext" in event
18+
19+
20+
def _extract_payload(event: Dict[str, Any]) -> Dict[str, Any]:
21+
if not _is_http_event(event):
22+
return event
23+
body = event.get("body")
24+
if not body:
25+
return {}
26+
if event.get("isBase64Encoded"):
27+
body = base64.b64decode(body).decode("utf-8")
28+
return json.loads(body)
29+
30+
31+
def _http_response(body: Dict[str, Any], status_code: int = 200) -> Dict[str, Any]:
32+
return {
33+
"statusCode": status_code,
34+
"headers": {"Content-Type": "application/json"},
35+
"body": json.dumps(body),
36+
}
37+
1438

1539
def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
16-
job_id = event["job_id"]
40+
raw_event = event
41+
event = _extract_payload(event)
42+
job_id = event.get("job_id")
43+
if not job_id:
44+
message = "Request payload must include 'job_id'."
45+
if _is_http_event(raw_event):
46+
return _http_response({"status": "FAILED", "error": message}, status_code=400)
47+
raise KeyError(message)
1748
next_token = event.get("next_token")
1849
start_time = event.get("start_time")
1950
end_time = event.get("end_time")
@@ -29,7 +60,7 @@ def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
2960

3061
params: Dict[str, Any] = {
3162
"logGroupName": log_group,
32-
"filterPattern": f'{{"job_id": "{job_id}"}}',
63+
"filterPattern": f'"{job_id}"',
3364
"startTime": int(start_time),
3465
}
3566
if end_time is not None:
@@ -43,8 +74,11 @@ def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
4374
for item in response.get("events", [])
4475
]
4576

46-
return {
77+
body = {
4778
"job_id": job_id,
4879
"events": events,
4980
"next_token": response.get("nextToken"),
5081
}
82+
if _is_http_event(raw_event):
83+
return _http_response(body)
84+
return body

compose_runner/aws_lambda/results_handler.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import os
4+
import base64
5+
import json
46
from datetime import datetime, timezone
57
from typing import Any, Dict, List
68

@@ -19,11 +21,41 @@ def _serialize_dt(value: datetime) -> str:
1921
return value.astimezone(timezone.utc).isoformat()
2022

2123

24+
def _is_http_event(event: Any) -> bool:
25+
return isinstance(event, dict) and "requestContext" in event
26+
27+
28+
def _extract_payload(event: Dict[str, Any]) -> Dict[str, Any]:
29+
if not _is_http_event(event):
30+
return event
31+
body = event.get("body")
32+
if not body:
33+
return {}
34+
if event.get("isBase64Encoded"):
35+
body = base64.b64decode(body).decode("utf-8")
36+
return json.loads(body)
37+
38+
39+
def _http_response(body: Dict[str, Any], status_code: int = 200) -> Dict[str, Any]:
40+
return {
41+
"statusCode": status_code,
42+
"headers": {"Content-Type": "application/json"},
43+
"body": json.dumps(body),
44+
}
45+
46+
2247
def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
48+
raw_event = event
49+
event = _extract_payload(event)
2350
bucket = os.environ[RESULTS_BUCKET_ENV]
2451
prefix = os.environ.get(RESULTS_PREFIX_ENV)
2552

26-
job_id = event["job_id"]
53+
job_id = event.get("job_id")
54+
if not job_id:
55+
message = "Request payload must include 'job_id'."
56+
if _is_http_event(raw_event):
57+
return _http_response({"status": "FAILED", "error": message}, status_code=400)
58+
raise KeyError(message)
2759
expires_in = int(event.get("expires_in", DEFAULT_EXPIRES_IN))
2860

2961
key_prefix = f"{prefix.rstrip('/')}/{job_id}" if prefix else job_id
@@ -51,9 +83,12 @@ def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
5183
}
5284
)
5385

54-
return {
86+
body = {
5587
"job_id": job_id,
5688
"artifacts": artifacts,
5789
"bucket": bucket,
5890
"prefix": key_prefix,
5991
}
92+
if _is_http_event(raw_event):
93+
return _http_response(body)
94+
return body

compose_runner/tests/test_lambda_handlers.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ def filter_log_events(self, **kwargs):
8787
assert result["events"][0]["message"] == events_payload[0]["message"]
8888

8989

90+
def test_log_poll_handler_http_missing_job_id(monkeypatch):
91+
monkeypatch.setenv("RUNNER_LOG_GROUP", "/aws/lambda/test")
92+
http_event = _make_http_event({})
93+
response = log_poll_handler.handler(http_event, DummyContext())
94+
body = json.loads(response["body"])
95+
assert response["statusCode"] == 400
96+
assert body["status"] == "FAILED"
97+
assert "job_id" in body["error"]
98+
99+
90100
def test_results_handler(monkeypatch):
91101
objects = [
92102
{"Key": "prefix/id/file1.nii.gz", "Size": 10, "LastModified": results_handler.datetime.now()}
@@ -109,8 +119,20 @@ def generate_presigned_url(self, client_method, Params, ExpiresIn):
109119
monkeypatch.setenv("RESULTS_PREFIX", "prefix")
110120
monkeypatch.setattr(results_handler, "_S3", FakeS3())
111121

112-
event = {"job_id": "id"}
113-
result = results_handler.handler(event, DummyContext())
114-
assert result["job_id"] == "id"
115-
assert result["artifacts"][0]["url"] == "https://signed/url"
116-
assert result["artifacts"][0]["filename"] == "file1.nii.gz"
122+
event = _make_http_event({"job_id": "id"})
123+
response = results_handler.handler(event, DummyContext())
124+
body = json.loads(response["body"])
125+
assert response["statusCode"] == 200
126+
assert body["job_id"] == "id"
127+
assert body["artifacts"][0]["url"] == "https://signed/url"
128+
assert body["artifacts"][0]["filename"] == "file1.nii.gz"
129+
130+
131+
def test_results_handler_missing_job_id(monkeypatch):
132+
monkeypatch.setenv("RESULTS_BUCKET", "bucket")
133+
event = _make_http_event({})
134+
response = results_handler.handler(event, DummyContext())
135+
body = json.loads(response["body"])
136+
assert response["statusCode"] == 400
137+
assert body["status"] == "FAILED"
138+
assert "job_id" in body["error"]

0 commit comments

Comments
 (0)