|
1 | 1 | """CLI for evaluations Cloud Run Jobs.""" |
2 | 2 |
|
3 | 3 | import argparse |
| 4 | +import json |
| 5 | +import os |
4 | 6 | import subprocess |
5 | 7 | import sys |
6 | 8 |
|
| 9 | +import boto3 |
| 10 | +from botocore.exceptions import BotoCoreError, ClientError |
| 11 | + |
7 | 12 | from evaluations.configs import ModelEval, TierName, get_tier |
8 | 13 | from evaluations.logging import logger |
9 | 14 | from evaluations.settings import settings |
10 | 15 |
|
11 | 16 |
|
| 17 | +def get_aws_secret_value(secret_arn: str, key: str | None = None) -> str: |
| 18 | + """Fetch a secret value from AWS Secrets Manager. |
| 19 | +
|
| 20 | + Args: |
| 21 | + secret_arn: The ARN of the secret. |
| 22 | + key: If provided, parse secret as JSON and extract this key. |
| 23 | +
|
| 24 | + Returns: |
| 25 | + The secret value (or extracted key value). |
| 26 | + """ |
| 27 | + region = settings.AWS_REGION |
| 28 | + client = boto3.client("secretsmanager", region_name=region) |
| 29 | + response = client.get_secret_value(SecretId=secret_arn) |
| 30 | + secret_string = response["SecretString"] |
| 31 | + |
| 32 | + if key is None: |
| 33 | + return secret_string |
| 34 | + |
| 35 | + return json.loads(secret_string)[key] |
| 36 | + |
| 37 | + |
| 38 | +def setup_db_credentials() -> None: |
| 39 | + """Set up database credentials from AWS Secrets Manager or settings. |
| 40 | +
|
| 41 | + If DB_SECRET_ARN is set, fetches the password from AWS Secrets Manager. |
| 42 | + Otherwise, uses PGPASSWORD from settings. |
| 43 | +
|
| 44 | + This supports automatically rotating database credentials when using AWS. |
| 45 | + """ |
| 46 | + if not settings.DB_SECRET_ARN: |
| 47 | + # Use PGPASSWORD as set in the environment |
| 48 | + return |
| 49 | + |
| 50 | + logger.info("Fetching database password from AWS Secrets Manager") |
| 51 | + |
| 52 | + try: |
| 53 | + password = get_aws_secret_value(settings.DB_SECRET_ARN, key="password") |
| 54 | + # override PGPASSWORD in the environment |
| 55 | + os.environ["PGPASSWORD"] = password |
| 56 | + logger.info("Database password loaded from AWS Secrets Manager") |
| 57 | + except (BotoCoreError, ClientError, KeyError, json.JSONDecodeError) as e: |
| 58 | + logger.error("Failed to fetch database password from AWS: %s", e) |
| 59 | + sys.exit(1) |
| 60 | + |
| 61 | + |
12 | 62 | def run_ad_hoc( |
13 | 63 | model: str, |
14 | 64 | tasks: str, |
@@ -99,6 +149,9 @@ def main() -> None: |
99 | 149 | 2. EVAL_TIER set: Run tier evaluation |
100 | 150 | 3. CLI arguments: Manual invocation |
101 | 151 | """ |
| 152 | + # Fetch database credentials from AWS if DB_SECRET_ARN is set |
| 153 | + setup_db_credentials() |
| 154 | + |
102 | 155 | # Check for ad-hoc mode |
103 | 156 | if settings.EVAL_MODE == "ad-hoc": |
104 | 157 | if not settings.AD_HOC_MODEL or not settings.AD_HOC_TASKS: |
|
0 commit comments