Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def account_level(self) -> AccountLevel: ...
@abstractmethod
def submission_type(self) -> SubmissionType: ...

def _build_file_name(self) -> str:
def _build_file_names(self) -> list[str]:
date_range = construct_data_date_range(self.filters.dict())
agency = obtain_filename_prefix_from_agency_id(self.filters.agency)
level = self.account_level.abbreviation
title = self.submission_type.title
timestamp = datetime.strftime(self.start_time, "%Y-%m-%d_H%HM%MS%S")
return f"{date_range}_{agency}_{level}_{title}_{timestamp}"
return [f"{date_range}_{agency}_{level}_{title}_{timestamp}"]
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def spark(self) -> SparkSession | DuckDBSparkSession:
return self._spark

@cached_property
def file_name(self) -> str:
return self._build_file_name()
def file_names(self) -> list[str]:
return self._build_file_names()

@cached_property
def dataframe(self) -> DataFrame | DuckDBSparkDataFrame:
return self._build_dataframe()
def dataframes(self) -> list[DataFrame | DuckDBSparkDataFrame]:
return self._build_dataframes()

@abstractmethod
def _build_file_name(self) -> str: ...
def _build_file_names(self) -> list[str]: ...

@abstractmethod
def _build_dataframe(self) -> DataFrame | DuckDBSparkDataFrame: ...
def _build_dataframes(self) -> list[DataFrame | DuckDBSparkDataFrame]: ...
8 changes: 4 additions & 4 deletions usaspending_api/download/delta_downloads/account_balances.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def download_table(self) -> DataFrame:
f"s3a://{CONFIG.SPARK_S3_BUCKET}/{CONFIG.DELTA_LAKE_S3_PATH}/rpt/account_balances_download"
)

def _build_dataframe(self) -> DataFrame:
return (
def _build_dataframes(self) -> list[DataFrame]:
return [
self.download_table.filter(
sf.col("submission_id").isin(
get_submission_ids_for_periods(
Expand All @@ -49,8 +49,8 @@ def _build_dataframe(self) -> DataFrame:
.filter(self.dynamic_filters)
.groupby(self.group_by_cols)
.agg(*self.agg_cols)
.select(*self.select_cols)
)
.select(*self.select_cols),
]


class FederalAccountDownload(AccountBalancesMixin, AbstractAccountDownload):
Expand Down
44 changes: 39 additions & 5 deletions usaspending_api/download/delta_downloads/award_financial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from datetime import datetime


from pyspark.sql import functions as sf, Column, DataFrame, SparkSession
from usaspending_api.config import CONFIG

Expand All @@ -11,6 +14,7 @@
AbstractAccountDownloadFactory,
)
from usaspending_api.download.delta_downloads.filters.account_filters import AccountDownloadFilters
from usaspending_api.download.download_utils import construct_data_date_range, obtain_filename_prefix_from_agency_id
from usaspending_api.download.v2.download_column_historical_lookups import query_paths


Expand All @@ -21,6 +25,9 @@ class AwardFinancialMixin:

filters: AccountDownloadFilters
dynamic_filters: Column
account_level: AccountLevel
submission_type: SubmissionType
start_time: datetime

@property
def download_table(self) -> DataFrame:
Expand All @@ -39,6 +46,25 @@ def non_zero_filters(self) -> Column:
| (sf.col("transaction_obligated_amount") != 0)
)

@property
def award_categories(self) -> dict[str, Column]:
return {
"Assistance": (sf.isnotnull(sf.col("is_fpds")) & ~sf.col("is_fpds")),
"Contracts": sf.col("is_fpds"),
"Unlinked": sf.isnull(sf.col("is_fpds")),
}

def _build_file_names(self) -> list[str]:
date_range = construct_data_date_range(self.filters.dict())
agency = obtain_filename_prefix_from_agency_id(self.filters.agency)
level = self.account_level.abbreviation
title = self.submission_type.title
timestamp = datetime.strftime(self.start_time, "%Y-%m-%d_H%HM%MS%S")
return [
f"{date_range}_{agency}_{level}_{award_category}_{title}_{timestamp}"
for award_category in self.award_categories
]


class FederalAccountDownload(AwardFinancialMixin, AbstractAccountDownload):

Expand Down Expand Up @@ -124,6 +150,7 @@ def group_by_cols(self) -> list[str]:
"prime_award_summary_recipient_cd_current",
"prime_award_summary_place_of_performance_cd_original",
"prime_award_summary_place_of_performance_cd_current",
"is_fpds",
]

@property
Expand Down Expand Up @@ -158,19 +185,22 @@ def select_cols(self) -> list[Column]:
+ ["last_modified_date"]
)

def _build_dataframe(self) -> DataFrame:
def _build_dataframes(self) -> list[DataFrame]:
# TODO: Should handle the aggregate columns via a new name instead of relying on drops. If the Delta tables are
# referenced by their location then the ability to use the table identifier is lost as it doesn't
# appear to use the metastore for the Delta tables.
return (
combined_download = (
self.download_table.filter(self.dynamic_filters)
.groupBy(self.group_by_cols)
.agg(*[agg_func(col) for col, agg_func in self.agg_cols.items()])
# drop original agg columns from the dataframe to avoid ambiguous column names
.drop(*[sf.col(f"award_financial_download.{col}") for col in self.agg_cols])
.filter(self.non_zero_filters)
.select(self.select_cols)
)
return [
combined_download.filter(award_category_filter).select(self.select_cols)
for award_category_filter in self.award_categories.values()
]


class TreasuryAccountDownload(AwardFinancialMixin, AbstractAccountDownload):
Expand All @@ -183,7 +213,7 @@ def account_level(self) -> AccountLevel:
def submission_type(self) -> SubmissionType:
return SubmissionType.AWARD_FINANCIAL

def _build_dataframe(self) -> DataFrame:
def _build_dataframes(self) -> list[DataFrame]:
select_cols = (
[sf.col("treasury_owning_agency_name").alias("owning_agency_name")]
+ [
Expand All @@ -193,7 +223,11 @@ def _build_dataframe(self) -> DataFrame:
]
+ [sf.date_format("last_modified_date", "yyyy-MM-dd").alias("last_modified_date")]
)
return self.download_table.filter(self.dynamic_filters & self.non_zero_filters).select(select_cols)
combined_download = self.download_table.filter(self.dynamic_filters & self.non_zero_filters)
return [
combined_download.filter(award_category_filter).select(select_cols)
for award_category_filter in self.award_categories.values()
]


class AwardFinancialDownloadFactory(AbstractAccountDownloadFactory):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def download_table(self) -> DataFrame | DuckDBSparkDataFrame:
f"s3a://{CONFIG.SPARK_S3_BUCKET}/{CONFIG.DELTA_LAKE_S3_PATH}/rpt/object_class_program_activity_download"
)

def _build_dataframe(self) -> DataFrame | DuckDBSparkDataFrame:
return (
def _build_dataframes(self) -> list[DataFrame | DuckDBSparkDataFrame]:
return [
self.download_table.filter(
self.sf.col("submission_id").isin(
get_submission_ids_for_periods(
Expand All @@ -67,8 +67,8 @@ def _build_dataframe(self) -> DataFrame | DuckDBSparkDataFrame:
.drop(*[self.sf.col(f"object_class_program_activity_download.{col}") for col in self.agg_cols])
.select(*self.select_cols)
# Sorting by a value that is repeated often will help improve compression during the zipping step
.sort(self.sort_by_cols)
)
.sort(self.sort_by_cols),
]


class FederalAccountDownload(ObjectClassProgramActivityMixin, AbstractAccountDownload):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
StructField("reporting_fiscal_quarter", IntegerType()),
StructField("reporting_fiscal_year", IntegerType()),
StructField("quarter_format_flag", BooleanType()),
StructField("is_fpds", BooleanType(), nullable=True),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nullable=True is the default, but it doesn't hurt to be explicit here since this is different than transaction_search

StructField("merge_hash_key", LongType()),
]
)
Expand Down Expand Up @@ -327,6 +328,7 @@ def award_financial_df(spark: SparkSession):
sa.reporting_fiscal_quarter,
sa.reporting_fiscal_year,
sa.quarter_format_flag,
ts.is_fpds,
)
.withColumn("merge_hash_key", sf.xxhash64("*"))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,21 @@ def process_download(self):
download_request = self.get_download_request()
if self.columns is not None:
for download in download_request.download_list:
download.dataframe = download.dataframe.select(*self.columns)
download.dataframes = [df.select(*self.columns) for df in download.dataframes]

csvs_metadata = [
spark_to_csv_strategy.download_to_csv(
source_sql=None,
destination_path=self.working_dir_path / download.file_name,
destination_file_name=download.file_name,
destination_path=self.working_dir_path,
destination_file_name=file_name,
working_dir_path=self.working_dir_path,
download_zip_path=zip_file_path,
source_df=download.dataframe,
source_df=df,
delimiter=download_request.file_delimiter,
file_format=download_request.file_extension,
)
for download in download_request.download_list
for file_name, df in zip(download.file_names, download.dataframes)
]
for csv_metadata in csvs_metadata:
files_to_cleanup.extend(csv_metadata.filepaths)
Expand Down
Loading