Skip to content

Commit 83ba7ce

Browse files
authored
Merge pull request #32 from neurostuff/ref/use_ecs
[REF] use ecs
2 parents a9764a1 + 6196678 commit 83ba7ce

File tree

11 files changed

+795
-279
lines changed

11 files changed

+795
-279
lines changed

.github/workflows/deploy.yml

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ jobs:
4141
working-directory: infra/cdk
4242
env:
4343
RESULTS_PREFIX: compose-runner/results
44-
RUN_MEMORY_SIZE: 3008
45-
RUN_TIMEOUT_SECONDS: 900
44+
TASK_CPU: 4096
45+
TASK_MEMORY_MIB: 30720
46+
STATE_MACHINE_TIMEOUT_SECONDS: 7200
4647
run: |
4748
source .venv/bin/activate
4849
VERSION=${GITHUB_REF_NAME}
@@ -51,29 +52,57 @@ jobs:
5152
--outputs-file cdk-outputs.json \
5253
-c composeRunnerVersion=${VERSION} \
5354
-c resultsPrefix=${RESULTS_PREFIX} \
54-
-c runMemorySize=${RUN_MEMORY_SIZE} \
55-
-c runTimeoutSeconds=${RUN_TIMEOUT_SECONDS}
55+
-c taskCpu=${TASK_CPU} \
56+
-c taskMemoryMiB=${TASK_MEMORY_MIB} \
57+
-c stateMachineTimeoutSeconds=${STATE_MACHINE_TIMEOUT_SECONDS}
5658
57-
- name: Smoke test run endpoint
59+
- name: Smoke test submission and status endpoints
5860
working-directory: infra/cdk
5961
run: |
60-
RUN_URL=$(jq -r '.ComposeRunnerStack.ComposeRunnerFunctionUrl' cdk-outputs.json)
61-
if [ -z "$RUN_URL" ] || [ "$RUN_URL" = "null" ]; then
62-
echo "Run Function URL not found in outputs"
62+
SUBMIT_URL=$(jq -r '.ComposeRunnerStack.ComposeRunnerSubmitFunctionUrl' cdk-outputs.json)
63+
STATUS_URL=$(jq -r '.ComposeRunnerStack.ComposeRunnerStatusFunctionUrl' cdk-outputs.json)
64+
if [ -z "$SUBMIT_URL" ] || [ "$SUBMIT_URL" = "null" ]; then
65+
echo "Submit Function URL not found in outputs"
6366
exit 1
6467
fi
68+
if [ -z "$STATUS_URL" ] || [ "$STATUS_URL" = "null" ]; then
69+
echo "Status Function URL not found in outputs"
70+
exit 1
71+
fi
72+
6573
body='{"meta_analysis_id": "pFGy6g3LRo9x", "environment": "production", "no_upload": true}'
66-
response=$(curl -s -w "\n%{http_code}" -X POST "$RUN_URL" -H "Content-Type: application/json" -d "$body")
67-
http_code=$(echo "$response" | tail -n1)
68-
json_body=$(echo "$response" | head -n1)
69-
echo "$json_body" > smoke_run.json
70-
echo "Status code: $http_code"
71-
if [ "$http_code" != "200" ]; then
72-
echo "Run endpoint failed: $json_body"
74+
response=$(curl -s -w "\n%{http_code}" -X POST "$SUBMIT_URL" -H "Content-Type: application/json" -d "$body")
75+
submit_code=$(echo "$response" | tail -n1)
76+
submit_json=$(echo "$response" | head -n1)
77+
echo "$submit_json" > smoke_submit.json
78+
echo "Submit status code: $submit_code"
79+
if [ "$submit_code" != "202" ]; then
80+
echo "Submit endpoint failed: $submit_json"
81+
exit 1
82+
fi
83+
job_id=$(jq -r '.job_id' smoke_submit.json)
84+
artifact_prefix=$(jq -r '.artifact_prefix' smoke_submit.json)
85+
if [ -z "$job_id" ] || [ "$job_id" = "null" ]; then
86+
echo "Submit response missing job_id: $submit_json"
87+
exit 1
88+
fi
89+
if [ -z "$artifact_prefix" ] || [ "$artifact_prefix" = "null" ]; then
90+
echo "Submit response missing artifact_prefix: $submit_json"
91+
exit 1
92+
fi
93+
94+
status_body=$(printf '{"job_id":"%s"}' "$job_id")
95+
status_response=$(curl -s -w "\n%{http_code}" -X POST "$STATUS_URL" -H "Content-Type: application/json" -d "$status_body")
96+
status_code=$(echo "$status_response" | tail -n1)
97+
status_json=$(echo "$status_response" | head -n1)
98+
echo "$status_json" > smoke_status.json
99+
echo "Status status code: $status_code"
100+
if [ "$status_code" != "200" ]; then
101+
echo "Status endpoint failed: $status_json"
73102
exit 1
74103
fi
75-
status=$(jq -r '.status' smoke_run.json)
76-
if [ "$status" != "SUCCEEDED" ]; then
77-
echo "Run endpoint returned non-success status: $json_body"
104+
status_value=$(jq -r '.status' smoke_status.json)
105+
if [ "$status_value" = "null" ] || [ -z "$status_value" ]; then
106+
echo "Status response missing status: $status_json"
78107
exit 1
79108
fi

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ RUN hatch dep show requirements > requirements.txt && pip install -r requirement
1414

1515
COPY . .
1616

17-
# install the package (more likely to change, leverage caching!)
18-
RUN pip install .
17+
# install the package with AWS extras so the ECS task has boto3, etc.
18+
RUN pip install '.[aws]'
1919

2020
ENTRYPOINT ["compose-run"]

README.md

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,24 @@
33
Python package to execute meta-analyses created using neurosynth compose and NiMARE
44
as the meta-analysis execution engine.
55

6-
## AWS Lambda Deployment
6+
## AWS Deployment
77

8-
This repository includes an AWS CDK application for provisioning the Lambda-based
9-
execution environment and log polling function.
8+
This repository includes an AWS CDK application that turns compose-runner into a
9+
serverless batch pipeline using Step Functions, AWS Lambda, and ECS Fargate.
10+
The deployed architecture works like this:
11+
12+
- `ComposeRunnerSubmit` (Lambda Function URL) accepts HTTP requests, validates
13+
the meta-analysis payload, and starts a Step Functions execution. The response
14+
is immediate and returns both a durable `job_id` (the execution ARN) and the
15+
`artifact_prefix` used for S3 and log correlation.
16+
- A Standard state machine runs a single Fargate task (`compose_runner.ecs_task`)
17+
and waits for completion. The container downloads inputs, executes the
18+
meta-analysis on up to 4 vCPU / 30 GiB of memory, uploads artifacts to S3, and
19+
writes `metadata.json` into the same prefix.
20+
- `ComposeRunnerStatus` (Lambda Function URL) wraps `DescribeExecution`, merges
21+
metadata from S3, and exposes a simple status endpoint suitable for polling.
22+
- `ComposeRunnerLogPoller` streams the ECS CloudWatch Logs for a given `artifact_prefix`,
23+
while `ComposeRunnerResultsFetcher` returns presigned URLs for stored artifacts.
1024

1125
1. Create a virtual environment and install the CDK dependencies:
1226
```bash
@@ -19,21 +33,24 @@ execution environment and log polling function.
1933
```bash
2034
cdk bootstrap
2135
```
22-
3. Deploy the stack (supplying the compose-runner version you want baked into the Lambda image):
36+
3. Deploy the stack (supplying the compose-runner version you want baked into the images):
2337
```bash
2438
cdk deploy \
2539
-c composeRunnerVersion=$(hatch version) \
2640
-c resultsPrefix=compose-runner/results \
27-
-c runMemorySize=3008 \
28-
-c runTimeoutSeconds=900
41+
-c taskCpu=4096 \
42+
-c taskMemoryMiB=30720
2943
```
30-
The deployment output includes HTTPS endpoints for submitting runs (`ComposeRunnerFunctionUrl`), polling logs (`ComposeRunnerLogPollerFunctionUrl`), and fetching presigned S3 URLs (`ComposeRunnerResultsFunctionUrl`).
31-
Omit `resultsBucketName` to let the stack create a managed bucket, or pass an
32-
existing bucket name via `-c resultsBucketName=<bucket>`.
44+
Pass `-c resultsBucketName=<bucket>` to use an existing S3 bucket, or omit it
45+
to let the stack create and retain a dedicated bucket. Additional knobs:
46+
47+
- `-c stateMachineTimeoutSeconds=7200` to control the max wall clock per run
48+
- `-c submitTimeoutSeconds` / `-c statusTimeoutSeconds` / `-c pollTimeoutSeconds`
49+
to tune Lambda timeouts
50+
- `-c taskEphemeralStorageGiB` if the default 21 GiB scratch volume is insufficient
3351

34-
The deployment builds the Lambda container image from `aws_lambda/Dockerfile`,
35-
creates two functions (`ComposeRunnerFunction` and `ComposeRunnerLogPoller`),
36-
and provisions the S3 bucket used to store generated artifacts (including
37-
`meta_results.pkl`). The log poller function expects clients to call it with a
38-
job ID (the run Lambda invocation request ID) and returns filtered CloudWatch Logs
39-
entries for that job.
52+
The deployment builds both the Lambda image (`aws_lambda/Dockerfile`) and the
53+
Fargate task image (`Dockerfile`), provisions the Step Functions state machine,
54+
and configures a public VPC so each task has outbound internet access.
55+
The CloudFormation outputs list the HTTPS endpoints for submission, status,
56+
logs, and artifact retrieval, alongside the Step Functions ARN.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import json
5+
from dataclasses import dataclass
6+
from typing import Any, Dict, Optional
7+
8+
9+
def is_http_event(event: Any) -> bool:
10+
return isinstance(event, dict) and "requestContext" in event
11+
12+
13+
def _decode_body(event: Dict[str, Any]) -> Optional[str]:
14+
body = event.get("body")
15+
if not body:
16+
return None
17+
if event.get("isBase64Encoded"):
18+
return base64.b64decode(body).decode("utf-8")
19+
return body
20+
21+
22+
def extract_payload(event: Dict[str, Any]) -> Dict[str, Any]:
23+
if not is_http_event(event):
24+
return event
25+
body = _decode_body(event)
26+
if not body:
27+
return {}
28+
return json.loads(body)
29+
30+
31+
def http_response(body: Dict[str, Any], status_code: int = 200) -> Dict[str, Any]:
32+
return {
33+
"statusCode": status_code,
34+
"headers": {"Content-Type": "application/json"},
35+
"body": json.dumps(body),
36+
}
37+
38+
39+
@dataclass(frozen=True)
40+
class LambdaRequest:
41+
raw_event: Any
42+
payload: Dict[str, Any]
43+
is_http: bool
44+
45+
@classmethod
46+
def parse(cls, event: Any) -> "LambdaRequest":
47+
payload = extract_payload(event)
48+
return cls(raw_event=event, payload=payload, is_http=is_http_event(event))
49+
50+
def respond(self, body: Dict[str, Any], status_code: int = 200) -> Dict[str, Any]:
51+
if self.is_http:
52+
return http_response(body, status_code)
53+
return body
54+
55+
def bad_request(self, message: str, status_code: int = 400) -> Dict[str, Any]:
56+
return self.respond({"status": "FAILED", "error": message}, status_code=status_code)
57+
58+
def get(self, key: str, default: Any = None) -> Any:
59+
return self.payload.get(key, default)
60+

compose_runner/aws_lambda/log_poll_handler.py

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,30 @@
22

33
import os
44
import time
5-
import base64
6-
import json
75
from typing import Any, Dict, List
86

97
import boto3
108

9+
from compose_runner.aws_lambda.common import LambdaRequest
10+
1111
_LOGS_CLIENT = boto3.client("logs", region_name=os.environ.get("AWS_REGION", "us-east-1"))
1212

1313
LOG_GROUP_ENV = "RUNNER_LOG_GROUP"
1414
DEFAULT_LOOKBACK_MS_ENV = "DEFAULT_LOOKBACK_MS"
1515

16-
def _is_http_event(event: Any) -> bool:
17-
return isinstance(event, dict) and "requestContext" in event
18-
19-
20-
def _extract_payload(event: Dict[str, Any]) -> Dict[str, Any]:
21-
if not _is_http_event(event):
22-
return event
23-
body = event.get("body")
24-
if not body:
25-
return {}
26-
if event.get("isBase64Encoded"):
27-
body = base64.b64decode(body).decode("utf-8")
28-
return json.loads(body)
29-
30-
31-
def _http_response(body: Dict[str, Any], status_code: int = 200) -> Dict[str, Any]:
32-
return {
33-
"statusCode": status_code,
34-
"headers": {"Content-Type": "application/json"},
35-
"body": json.dumps(body),
36-
}
37-
3816

3917
def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
40-
raw_event = event
41-
event = _extract_payload(event)
42-
job_id = event.get("job_id")
43-
if not job_id:
44-
message = "Request payload must include 'job_id'."
45-
if _is_http_event(raw_event):
46-
return _http_response({"status": "FAILED", "error": message}, status_code=400)
18+
request = LambdaRequest.parse(event)
19+
payload = request.payload
20+
artifact_prefix = payload.get("artifact_prefix")
21+
if not artifact_prefix:
22+
message = "Request payload must include 'artifact_prefix'."
23+
if request.is_http:
24+
return request.bad_request(message, status_code=400)
4725
raise KeyError(message)
48-
next_token = event.get("next_token")
49-
start_time = event.get("start_time")
50-
end_time = event.get("end_time")
26+
next_token = payload.get("next_token")
27+
start_time = payload.get("start_time")
28+
end_time = payload.get("end_time")
5129

5230
log_group = os.environ[LOG_GROUP_ENV]
5331
lookback_ms = int(os.environ.get(DEFAULT_LOOKBACK_MS_ENV, "3600000"))
@@ -60,7 +38,7 @@ def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
6038

6139
params: Dict[str, Any] = {
6240
"logGroupName": log_group,
63-
"filterPattern": f'"{job_id}"',
41+
"filterPattern": f'"{artifact_prefix}"',
6442
"startTime": int(start_time),
6543
}
6644
if end_time is not None:
@@ -75,10 +53,8 @@ def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
7553
]
7654

7755
body = {
78-
"job_id": job_id,
56+
"artifact_prefix": artifact_prefix,
7957
"events": events,
8058
"next_token": response.get("nextToken"),
8159
}
82-
if _is_http_event(raw_event):
83-
return _http_response(body)
84-
return body
60+
return request.respond(body)
Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
import os
4-
import base64
5-
import json
64
from datetime import datetime, timezone
75
from typing import Any, Dict, List
86

97
import boto3
108

9+
from compose_runner.aws_lambda.common import LambdaRequest
10+
1111
_S3 = boto3.client("s3", region_name=os.environ.get("AWS_REGION", "us-east-1"))
1212

1313
RESULTS_BUCKET_ENV = "RESULTS_BUCKET"
@@ -21,44 +21,21 @@ def _serialize_dt(value: datetime) -> str:
2121
return value.astimezone(timezone.utc).isoformat()
2222

2323

24-
def _is_http_event(event: Any) -> bool:
25-
return isinstance(event, dict) and "requestContext" in event
26-
27-
28-
def _extract_payload(event: Dict[str, Any]) -> Dict[str, Any]:
29-
if not _is_http_event(event):
30-
return event
31-
body = event.get("body")
32-
if not body:
33-
return {}
34-
if event.get("isBase64Encoded"):
35-
body = base64.b64decode(body).decode("utf-8")
36-
return json.loads(body)
37-
38-
39-
def _http_response(body: Dict[str, Any], status_code: int = 200) -> Dict[str, Any]:
40-
return {
41-
"statusCode": status_code,
42-
"headers": {"Content-Type": "application/json"},
43-
"body": json.dumps(body),
44-
}
45-
46-
4724
def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
48-
raw_event = event
49-
event = _extract_payload(event)
25+
request = LambdaRequest.parse(event)
26+
payload = request.payload
5027
bucket = os.environ[RESULTS_BUCKET_ENV]
5128
prefix = os.environ.get(RESULTS_PREFIX_ENV)
5229

53-
job_id = event.get("job_id")
54-
if not job_id:
55-
message = "Request payload must include 'job_id'."
56-
if _is_http_event(raw_event):
57-
return _http_response({"status": "FAILED", "error": message}, status_code=400)
30+
artifact_prefix = payload.get("artifact_prefix")
31+
if not artifact_prefix:
32+
message = "Request payload must include 'artifact_prefix'."
33+
if request.is_http:
34+
return request.bad_request(message, status_code=400)
5835
raise KeyError(message)
59-
expires_in = int(event.get("expires_in", DEFAULT_EXPIRES_IN))
36+
expires_in = int(payload.get("expires_in", DEFAULT_EXPIRES_IN))
6037

61-
key_prefix = f"{prefix.rstrip('/')}/{job_id}" if prefix else job_id
38+
key_prefix = f"{prefix.rstrip('/')}/{artifact_prefix}" if prefix else artifact_prefix
6239

6340
response = _S3.list_objects_v2(Bucket=bucket, Prefix=key_prefix)
6441
contents = response.get("Contents", [])
@@ -84,11 +61,9 @@ def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
8461
)
8562

8663
body = {
87-
"job_id": job_id,
64+
"artifact_prefix": artifact_prefix,
8865
"artifacts": artifacts,
8966
"bucket": bucket,
9067
"prefix": key_prefix,
9168
}
92-
if _is_http_event(raw_event):
93-
return _http_response(body)
94-
return body
69+
return request.respond(body)

0 commit comments

Comments
 (0)