Skip to content

Commit a2c9fb9

Browse files
committed
add cost cutoff lambda function
1 parent f6be04c commit a2c9fb9

File tree

2 files changed

+120
-1
lines changed

2 files changed

+120
-1
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from __future__ import annotations
2+
3+
import datetime as _dt
4+
import logging
5+
import os
6+
from decimal import Decimal
7+
from typing import Any, Dict
8+
9+
import boto3
10+
from botocore.exceptions import BotoCoreError, ClientError
11+
12+
logger = logging.getLogger(__name__)
13+
logger.setLevel(logging.INFO)
14+
15+
_CE_CLIENT = boto3.client("ce", region_name=os.environ.get("AWS_REGION", "us-east-1"))
16+
17+
COST_LIMIT_ENV = "COST_LIMIT_USD"
18+
19+
20+
def _month_range(today: _dt.date) -> Dict[str, str]:
21+
start = today.replace(day=1)
22+
# Cost Explorer end date is exclusive; add a day to include today.
23+
end = today + _dt.timedelta(days=1)
24+
return {"Start": start.isoformat(), "End": end.isoformat()}
25+
26+
27+
def _current_month_cost() -> Dict[str, Any]:
28+
period = _month_range(_dt.date.today())
29+
response = _CE_CLIENT.get_cost_and_usage(
30+
TimePeriod=period,
31+
Granularity="MONTHLY",
32+
Metrics=["UnblendedCost"],
33+
)
34+
results = response.get("ResultsByTime", [])
35+
total = results[0]["Total"]["UnblendedCost"] if results else {"Amount": "0", "Unit": "USD"}
36+
amount = float(Decimal(total.get("Amount", "0")))
37+
currency = total.get("Unit", "USD")
38+
return {"amount": amount, "currency": currency, "time_period": period}
39+
40+
41+
def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
42+
limit_raw = os.environ.get(COST_LIMIT_ENV)
43+
if not limit_raw:
44+
raise RuntimeError(f"{COST_LIMIT_ENV} environment variable must be set.")
45+
46+
try:
47+
limit = float(limit_raw)
48+
except ValueError as exc: # noqa: PERF203
49+
raise RuntimeError(f"Invalid {COST_LIMIT_ENV}: {limit_raw}") from exc
50+
51+
try:
52+
cost = _current_month_cost()
53+
except (ClientError, BotoCoreError) as exc:
54+
logger.error("Failed to query Cost Explorer: %s", exc)
55+
return {
56+
"status": "ERROR",
57+
"allowed": False,
58+
"error": "cost_explorer_unavailable",
59+
"limit": limit,
60+
}
61+
62+
amount = cost["amount"]
63+
allowed = amount < limit
64+
return {
65+
"status": "OK",
66+
"allowed": allowed,
67+
"current_spend": amount,
68+
"limit": limit,
69+
"currency": cost.get("currency", "USD"),
70+
"time_period": cost.get("time_period"),
71+
}

infra/cdk/stacks/compose_runner_stack.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: object) -> Non
3636
poll_memory_size = int(self.node.try_get_context("pollMemorySize") or 512)
3737
poll_timeout_seconds = int(self.node.try_get_context("pollTimeoutSeconds") or 30)
3838
poll_lookback_ms = int(self.node.try_get_context("pollLookbackMs") or 3600000)
39+
monthly_spend_limit_usd = float(self.node.try_get_context("monthlySpendLimit") or 100)
3940

4041
task_cpu = int(self.node.try_get_context("taskCpu") or 4096)
4142
task_memory_mib = int(self.node.try_get_context("taskMemoryMiB") or 30720)
@@ -243,6 +244,31 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: object) -> Non
243244
max_attempts=2,
244245
)
245246

247+
cost_check_code = lambda_.DockerImageCode.from_image_asset(
248+
str(project_root),
249+
file="aws_lambda/Dockerfile",
250+
cmd=["compose_runner.aws_lambda.cost_check_handler.handler"],
251+
build_args=build_args,
252+
)
253+
254+
cost_check_function = lambda_.DockerImageFunction(
255+
self,
256+
"ComposeRunnerCostCheck",
257+
code=cost_check_code,
258+
memory_size=256,
259+
timeout=Duration.seconds(15),
260+
environment={
261+
"COST_LIMIT_USD": str(monthly_spend_limit_usd),
262+
},
263+
description="Blocks executions when monthly spend exceeds the configured limit.",
264+
)
265+
cost_check_function.add_to_role_policy(
266+
iam.PolicyStatement(
267+
actions=["ce:GetCostAndUsage"],
268+
resources=["*"],
269+
)
270+
)
271+
246272
run_output = sfn.Pass(
247273
self,
248274
"ComposeRunnerOutput",
@@ -256,7 +282,7 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: object) -> Non
256282
},
257283
)
258284

259-
definition_chain = sfn.Choice(
285+
task_selection = sfn.Choice(
260286
self,
261287
"SelectFargateTask",
262288
).when(
@@ -266,6 +292,28 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: object) -> Non
266292
run_task_standard.next(run_output)
267293
)
268294

295+
cost_limit_exceeded = sfn.Fail(
296+
self,
297+
"CostLimitExceeded",
298+
cause="Monthly spend limit exceeded.",
299+
error="CostLimitExceeded",
300+
)
301+
302+
enforce_cost_limit = sfn.Choice(self, "EnforceMonthlyCostLimit").when(
303+
sfn.Condition.boolean_equals("$.cost_check.Payload.allowed", False),
304+
cost_limit_exceeded,
305+
).otherwise(task_selection)
306+
307+
cost_check_step = tasks.LambdaInvoke(
308+
self,
309+
"CheckMonthlyCost",
310+
lambda_function=cost_check_function,
311+
payload=sfn.TaskInput.from_object({"stateInput.$": "$"}),
312+
result_path="$.cost_check",
313+
)
314+
315+
definition_chain = cost_check_step.next(enforce_cost_limit)
316+
269317
state_machine = sfn.StateMachine(
270318
self,
271319
"ComposeRunnerStateMachine",

0 commit comments

Comments
 (0)