diff --git a/usaspending_api/common/spark/jobs.py b/usaspending_api/common/spark/jobs.py index 9c48d6734a..fe397c2667 100644 --- a/usaspending_api/common/spark/jobs.py +++ b/usaspending_api/common/spark/jobs.py @@ -5,9 +5,12 @@ from datetime import datetime from typing import TYPE_CHECKING, Generator +import boto3 +from botocore.client import BaseClient from databricks.sdk import WorkspaceClient from databricks.sdk.config import Config as DatabricksConfig from databricks.sdk.service.jobs import BaseJob, RunLifeCycleState +from django.conf import settings from django.core.management import call_command from duckdb.experimental.spark.sql import SparkSession as DuckDBSparkSession @@ -109,13 +112,65 @@ def _wait_for_run_to_start(self, job: BaseJob, job_run_id: int) -> None: class EmrServerlessStrategy(_AbstractStrategy): + _client: BaseClient = None + @property def name(self) -> str: return "EMR_SERVERLESS" + @property + def client(self) -> BaseClient: + if not self._client: + self._client = boto3.client("emr-serverless", settings.USASPENDING_AWS_REGION) + return self._client + + def _get_application_id(self, application_name: str) -> str: + paginator = self.client.get_paginator("list_applications") + matched_applications = [] + for list_applications_response in paginator.paginate(): + temp_applications = list_applications_response.get("applications", []) + matched_applications.extend( + [application for application in temp_applications if application["name"] == application_name] + ) + + match len(matched_applications): + case 1: + application_id = matched_applications[0]["id"] + case 0: + raise ValueError(f"No EMR Serverless application found with name '{application_name}'") + case _: + arns_to_log = [application["arn"] for application in matched_applications] + raise ValueError( + f"More than 1 EMR Serverless application found with name '{application_name}': {arns_to_log}" + ) + + return application_id + def handle_start(self, job_name: str, command_name: str, command_options: list[str], **kwargs) -> dict: - # TODO: This will be implemented as we migrate, but added as a placeholder for now - pass + application_id = kwargs.get("application_id") + application_name = kwargs.get("application_name") + execution_role_arn = kwargs.get("execution_role_arn") + + if not execution_role_arn: + raise ValueError(f"Execution role ARN is required to start an EMR Serverless job") + elif not application_name and not application_id: + raise ValueError(f"Application Name or ID is required to start an EMR Serverless job") + elif application_name and not application_id: + application_id = self._get_application_id(application_name) + + response = self.client.start_job_run( + applicationId=application_id, + executionRoleArn=execution_role_arn, + name=job_name, + mode="BATCH", + jobDriver={ + "sparkSubmit": { + "entryPoint": command_name, + "entryPointArguments": command_options, + } + }, + ) + return response class LocalStrategy(_AbstractStrategy): diff --git a/usaspending_api/common/tests/unit/test_spark_jobs.py b/usaspending_api/common/tests/unit/test_spark_jobs.py index 6a3e769a04..8bc5eb4eb8 100644 --- a/usaspending_api/common/tests/unit/test_spark_jobs.py +++ b/usaspending_api/common/tests/unit/test_spark_jobs.py @@ -51,3 +51,49 @@ def test_databricks_strategy_handle_start(databricks_strategy_client): spark_job = SparkJobs(DatabricksStrategy()) assert spark_job.start(job_name="", command_name="", command_options=[""]) == {"job_id": 1, "run_id": 10} + + +@patch("usaspending_api.common.spark.jobs.EmrServerlessStrategy.client") +def test_emr_serverless_strategy_handle_start(emr_serverless_strategy_client): + mock_application = MagicMock() + mock_application.application_id = 1 + mock_application.name = "application_1" + + mock_application_paginator = MagicMock() + mock_application_paginator.paginate = MagicMock( + return_value=[{"applications": [{"id": mock_application.application_id, "name": mock_application.name}]}] + ) + + emr_serverless_strategy_client.get_paginator = MagicMock(return_value=mock_application_paginator) + emr_serverless_strategy_client.start_job_run = MagicMock() + + strategy = EmrServerlessStrategy() + assert strategy._get_application_id("application_1") == 1 + + emr_serverless_strategy_client.reset_mock() + + spark_job = SparkJobs(strategy) + spark_job.start( + job_name="", + command_name="", + command_options=[""], + application_id="Some ID", + application_name="application_1", + execution_role_arn="Some ARN", + ) + assert emr_serverless_strategy_client.get_paginator.call_count == 0 + assert emr_serverless_strategy_client.start_job_run.call_count == 1 + + emr_serverless_strategy_client.reset_mock() + + spark_job = SparkJobs(strategy) + spark_job.start( + job_name="", + command_name="", + command_options=[""], + application_id=None, + application_name="application_1", + execution_role_arn="Some ARN", + ) + assert emr_serverless_strategy_client.get_paginator.call_count == 1 + assert emr_serverless_strategy_client.start_job_run.call_count == 1