diff --git a/usaspending_api/common/experimental_api_flags.py b/usaspending_api/common/experimental_api_flags.py index 2aabc8b7f0..a9940dc777 100644 --- a/usaspending_api/common/experimental_api_flags.py +++ b/usaspending_api/common/experimental_api_flags.py @@ -20,6 +20,7 @@ EXPERIMENTAL_API_HEADER = "HTTP_X_EXPERIMENTAL_API" ELASTICSEARCH_HEADER_VALUE = "elasticsearch" +DOWNLOAD_HEADER_VALUE = "download" def is_experimental_elasticsearch_api(request: Request) -> bool: @@ -29,6 +30,13 @@ def is_experimental_elasticsearch_api(request: Request) -> bool: return request.META.get(EXPERIMENTAL_API_HEADER) == ELASTICSEARCH_HEADER_VALUE +def is_experimental_download_api(request: Request) -> bool: + """ + Returns True or False depending on if the expected_header_value matches what is sent with the request + """ + return request.META.get(EXPERIMENTAL_API_HEADER) == DOWNLOAD_HEADER_VALUE + + def mirror_request_to_elasticsearch(request: Union[HttpRequest, Request]): """Duplicate request and send-again against this server, with the ES header attached to mirror non-elasticsearch load against elasticsearch for load testing diff --git a/usaspending_api/common/spark/jobs.py b/usaspending_api/common/spark/jobs.py index 0090a7e2d8..a6f6ffe1df 100644 --- a/usaspending_api/common/spark/jobs.py +++ b/usaspending_api/common/spark/jobs.py @@ -4,11 +4,17 @@ from typing import TYPE_CHECKING, Generator from databricks.sdk import WorkspaceClient +from django.conf import settings from django.core.management import call_command -from usaspending_api.common.helpers.spark_helpers import configure_spark_session, get_active_spark_session from usaspending_api.common.spark.configs import LOCAL_EXTENDED_EXTRA_CONF +if settings.IS_LOCAL: + # Importing this only for local development. For local strategy it is expected that a Spark session is available, + # however, for Databricks and EMR strategies a Spark session is not available on the API itself. In these cases the + # job is submitted to Databricks or EMR where the Spark session exists. + from usaspending_api.common.helpers.spark_helpers import configure_spark_session, get_active_spark_session + if TYPE_CHECKING: from pyspark.sql import SparkSession diff --git a/usaspending_api/download/tests/integration/test_download_accounts.py b/usaspending_api/download/tests/integration/test_download_accounts.py index 71f39ca4f8..1cad52dcb7 100644 --- a/usaspending_api/download/tests/integration/test_download_accounts.py +++ b/usaspending_api/download/tests/integration/test_download_accounts.py @@ -6,6 +6,7 @@ from unittest.mock import Mock from django.conf import settings +from django.core.management import call_command from model_bakery import baker from rest_framework import status @@ -497,3 +498,35 @@ def test_empty_array_filter_fail(client, download_test_data): assert ( "Field 'filters|def_codes' value '[]' is below min '1' items" in resp.json()["detail"] ), "Incorrect error message" + + +@pytest.mark.django_db(databases=[settings.DOWNLOAD_DB_ALIAS, settings.DEFAULT_DB_ALIAS]) +def test_file_c_spark_download(client, download_test_data, spark, s3_unittest_data_bucket, hive_unittest_metastore_db): + download_generation.retrieve_db_string = Mock(return_value=get_database_dsn_string()) + + call_command( + "create_delta_table", + f"--spark-s3-bucket={s3_unittest_data_bucket}", + f"--destination-table=account_download", + ) + + resp = client.post( + "/api/v2/download/accounts/", + content_type="application/json", + data=json.dumps( + { + "account_level": "federal_account", + "filters": { + "budget_function": "all", + "agency": "all", + "submission_types": ["award_financial"], + "fy": "2021", + "period": 12, + }, + "file_format": "csv", + } + ), + headers={"X-Experimental-API": "download"}, + ) + + assert resp.status_code == status.HTTP_200_OK diff --git a/usaspending_api/download/v2/base_download_viewset.py b/usaspending_api/download/v2/base_download_viewset.py index 534e178fdd..f560084607 100644 --- a/usaspending_api/download/v2/base_download_viewset.py +++ b/usaspending_api/download/v2/base_download_viewset.py @@ -13,7 +13,9 @@ from usaspending_api.broker.lookups import EXTERNAL_DATA_TYPE_DICT from usaspending_api.broker.models import ExternalDataLoadDate from usaspending_api.common.api_versioning import API_TRANSFORM_FUNCTIONS, api_transformations +from usaspending_api.common.experimental_api_flags import is_experimental_download_api from usaspending_api.common.helpers.dict_helpers import order_nested_object +from usaspending_api.common.spark.jobs import DatabricksStrategy, LocalStrategy, SparkJobs from usaspending_api.common.sqs.sqs_handler import get_sqs_queue from usaspending_api.download.download_utils import create_unique_filename, log_new_download_job from usaspending_api.download.filestreaming import download_generation @@ -68,12 +70,19 @@ def post( ) log_new_download_job(request, download_job) - self.process_request(download_job) + self.process_request(download_job, request, json_request) return self.get_download_response(file_name=final_output_zip_name) - def process_request(self, download_job: DownloadJob): - if settings.IS_LOCAL and settings.RUN_LOCAL_DOWNLOAD_IN_PROCESS: + def process_request(self, download_job: DownloadJob, request: Request, json_request: dict): + if ( + is_experimental_download_api(request) + and json_request["request_type"] == "account" + and "award_financial" in json_request["download_types"] + ): + # goes to spark for File C account download + self.process_account_download_in_spark(download_job=download_job) + elif settings.IS_LOCAL and settings.RUN_LOCAL_DOWNLOAD_IN_PROCESS: # Eagerly execute the download in this running process download_generation.generate_download(download_job) else: @@ -85,6 +94,25 @@ def process_request(self, download_job: DownloadJob): queue = get_sqs_queue(queue_name=settings.BULK_DOWNLOAD_SQS_QUEUE_NAME) queue.send_message(MessageBody=str(download_job.download_job_id)) + def process_account_download_in_spark(self, download_job: DownloadJob): + """ + Process File C downloads through spark instead of sqs for better performance + """ + if settings.IS_LOCAL: + spark_jobs = SparkJobs(LocalStrategy()) + spark_jobs.start( + job_name="api_download-accounts", + command_name="generate_spark_download", + command_options=[f"--download-job-id={download_job.download_job_id}", f"--skip-local-cleanup"], + ) + else: + spark_jobs = SparkJobs(DatabricksStrategy()) + spark_jobs.start( + job_name="api_download-accounts", + command_name="generate_spark_download", + command_options=[f"--download-job-id={download_job.download_job_id}"], + ) + def get_download_response(self, file_name: str): """ Generate download response which encompasses various elements to provide accurate status for state of a