Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions usaspending_api/common/spark/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from datetime import datetime
from typing import TYPE_CHECKING, Generator

import boto3
from botocore.client import BaseClient
from databricks.sdk import WorkspaceClient
from databricks.sdk.config import Config as DatabricksConfig
from databricks.sdk.service.jobs import BaseJob, RunLifeCycleState
from django.conf import settings
from django.core.management import call_command
from duckdb.experimental.spark.sql import SparkSession as DuckDBSparkSession

Expand Down Expand Up @@ -109,13 +112,65 @@ def _wait_for_run_to_start(self, job: BaseJob, job_run_id: int) -> None:


class EmrServerlessStrategy(_AbstractStrategy):
_client: BaseClient = None

@property
def name(self) -> str:
return "EMR_SERVERLESS"

@property
def client(self) -> BaseClient:
if not self._client:
self._client = boto3.client("emr-serverless", settings.USASPENDING_AWS_REGION)
return self._client

def _get_application_id(self, application_name: str) -> str:
paginator = self.client.get_paginator("list_applications")
matched_applications = []
for list_applications_response in paginator.paginate():
temp_applications = list_applications_response.get("applications", [])
matched_applications.extend(
[application for application in temp_applications if application["name"] == application_name]
)

match len(matched_applications):
case 1:
application_id = matched_applications[0]["id"]
case 0:
raise ValueError(f"No EMR Serverless application found with name '{application_name}'")
case _:
arns_to_log = [application["arn"] for application in matched_applications]
raise ValueError(
f"More than 1 EMR Serverless application found with name '{application_name}': {arns_to_log}"
)

return application_id

def handle_start(self, job_name: str, command_name: str, command_options: list[str], **kwargs) -> dict:
# TODO: This will be implemented as we migrate, but added as a placeholder for now
pass
application_id = kwargs.get("application_id")
application_name = kwargs.get("application_name")
execution_role_arn = kwargs.get("execution_role_arn")

if not execution_role_arn:
raise ValueError(f"Execution role ARN is required to start an EMR Serverless job")
elif not application_name and not application_id:
raise ValueError(f"Application Name or ID is required to start an EMR Serverless job")
elif application_name and not application_id:
application_id = self._get_application_id(application_name)

response = self.client.start_job_run(
applicationId=application_id,
executionRoleArn=execution_role_arn,
name=job_name,
mode="BATCH",
jobDriver={
"sparkSubmit": {
"entryPoint": command_name,
"entryPointArguments": command_options,
}
},
)
return response


class LocalStrategy(_AbstractStrategy):
Expand Down
46 changes: 46 additions & 0 deletions usaspending_api/common/tests/unit/test_spark_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,49 @@ def test_databricks_strategy_handle_start(databricks_strategy_client):

spark_job = SparkJobs(DatabricksStrategy())
assert spark_job.start(job_name="", command_name="", command_options=[""]) == {"job_id": 1, "run_id": 10}


@patch("usaspending_api.common.spark.jobs.EmrServerlessStrategy.client")
def test_emr_serverless_strategy_handle_start(emr_serverless_strategy_client):
mock_application = MagicMock()
mock_application.application_id = 1
mock_application.name = "application_1"

mock_application_paginator = MagicMock()
mock_application_paginator.paginate = MagicMock(
return_value=[{"applications": [{"id": mock_application.application_id, "name": mock_application.name}]}]
)

emr_serverless_strategy_client.get_paginator = MagicMock(return_value=mock_application_paginator)
emr_serverless_strategy_client.start_job_run = MagicMock()

strategy = EmrServerlessStrategy()
assert strategy._get_application_id("application_1") == 1

emr_serverless_strategy_client.reset_mock()

spark_job = SparkJobs(strategy)
spark_job.start(
job_name="",
command_name="",
command_options=[""],
application_id="Some ID",
application_name="application_1",
execution_role_arn="Some ARN",
)
assert emr_serverless_strategy_client.get_paginator.call_count == 0
assert emr_serverless_strategy_client.start_job_run.call_count == 1

emr_serverless_strategy_client.reset_mock()

spark_job = SparkJobs(strategy)
spark_job.start(
job_name="",
command_name="",
command_options=[""],
application_id=None,
application_name="application_1",
execution_role_arn="Some ARN",
)
assert emr_serverless_strategy_client.get_paginator.call_count == 1
assert emr_serverless_strategy_client.start_job_run.call_count == 1