Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 15 additions & 25 deletions usaspending_api/common/etl/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,7 @@ def load_es_index(
)


def merge_delta_table(
spark: SparkSession, source_df: DataFrame, delta_table_name: str, merge_column: str
) -> None:
def merge_delta_table(spark: SparkSession, source_df: DataFrame, delta_table_name: str, merge_column: str) -> None:
source_df.create_or_replace_temporary_view("temp_table")

spark.sql(
Expand All @@ -409,7 +407,7 @@ def diff(
left: DataFrame,
right: DataFrame,
unique_key_col: str = "id",
compare_cols: list[str] | None = None,
compare_cols: list[str] = None,
include_unchanged_rows: bool = False,
) -> DataFrame:
"""Compares two Spark DataFrames that share a schema and returns row-level differences in a DataFrame
Expand Down Expand Up @@ -563,8 +561,7 @@ def convert_array_cols_to_string(
2. Escape any quotes inside the array element with backslash.
- A case that involves all of this will yield CSV field value like this when viewed in a text editor,
assuming Spark CSV options are: quote='"', escape='"' (the default is for it to match quote)
...,"{""{\""simple\"": \""elem1\"", \""other\"": \""elem1\""}"",
""{\""simple\"": \""elem2\"", \""other\"": \""elem2\""}""}",...
...,"{""{\""simple\"": \""elem1\"", \""other\"": \""elem1\""}"",...
"""
arr_open_bracket = "["
arr_close_bracket = "]"
Expand Down Expand Up @@ -599,7 +596,8 @@ def convert_array_cols_to_string(
# Special handling in case of data that already has either a quote " or backslash \
# inside an array element
# First replace any single backslash character \ with TWO \\ (an escaped backslash)
# Then replace quote " character with \" (escaped quote, inside a quoted array elem)
# Then replace any quote " character with \"
# (escaped quote, inside a quoted array elem)
# NOTE: these regexp_replace get sent down to a Java replaceAll, which will require
# FOUR backslashes to represent ONE
regexp_replace(
Expand Down Expand Up @@ -649,7 +647,7 @@ def _generate_global_view_sql_strings(tables: list[str], jdbc_url: str) -> list[

def create_ref_temp_views( # noqa: PLR0912
spark: SparkSession | DuckDBSparkSession, create_broker_views: bool = False
) -> None:
) -> None: # noqa: PLR0912
"""Create global temporary Spark reference views that sit atop remote PostgreSQL RDS tables
Setting create_broker_views to True will create views for all tables list in _BROKER_REF_TABLES
Note: They will all be listed under global_temp.{table_name}
Expand Down Expand Up @@ -792,9 +790,8 @@ def write_csv_file( # noqa: PLR0913
spark: SparkSession,
df: DataFrame,
parts_dir: str,
max_records_per_file: int = EXCEL_ROW_LIMIT,
overwrite: bool = True,
logger: logging.Logger | None = None,
logger: logging.Logger = None,
delimiter: str = ",",
) -> int:
"""Write DataFrame data to CSV file parts.
Expand All @@ -804,8 +801,6 @@ def write_csv_file( # noqa: PLR0913
parts_dir: Path to dir that will contain the outputted parts files from partitions
num_partitions: Indicates the number of partitions to use when writing the Dataframe
overwrite: Whether to replace the file CSV files if they already exist by that name
max_records_per_file: Suggestion to Spark of how many records to put in each written CSV file part,
if it will end up writing multiple files.
logger: The logger to use. If one note provided (e.g. to log to console or stdout) the underlying JVM-based
Logger will be extracted from the ``spark`` ``SparkSession`` and used as the logger.
delimiter: Charactor used to separate columns in the CSV
Expand All @@ -822,10 +817,10 @@ def write_csv_file( # noqa: PLR0913
f"Writing source data DataFrame to csv part files for file {parts_dir}..."
)
df_record_count = df.count()
num_partitions = math.ceil(df_record_count / max_records_per_file) or 1
num_partitions = math.ceil(df_record_count / EXCEL_ROW_LIMIT) or 1
df.repartition(num_partitions).write.options(
# NOTE: this is a suggestion, to be used by Spark if partitions yield multiple files
maxRecordsPerFile=max_records_per_file,
maxRecordsPerFile=EXCEL_ROW_LIMIT,
).csv(
path=parts_dir,
header=True,
Expand All @@ -848,7 +843,6 @@ def write_csv_file_duckdb(
df: DuckDBDataFrame,
download_file_name: str,
temp_csv_directory_path: str = CSV_LOCAL_PATH,
max_records_per_file: int = EXCEL_ROW_LIMIT,
logger: logging.Logger | None = None,
delimiter: str = ",",
) -> tuple[int, list[str] | list]:
Expand All @@ -858,8 +852,6 @@ def write_csv_file_duckdb(
download_file_name: Name of the download being generated.
temp_csv_directory_path: Directory that will contain the individual CSV files before zipping.
Defaults to CSV_LOCAL_PATH
max_records_per_file: Max number of records to put in each written CSV file.
Defaults to EXCEL_ROW_LIMIT
logger: Logging instance to use.
Defaults to None
delimiter: Charactor used to separate columns in the CSV
Expand All @@ -870,7 +862,7 @@ def write_csv_file_duckdb(
"""
start = time.time()
_pandas_df = df.toPandas()
_pandas_df["file_number"] = (_pandas_df.index // max_records_per_file) + 1
_pandas_df["file_number"] = (_pandas_df.index // EXCEL_ROW_LIMIT) + 1
df_record_count = len(_pandas_df)
rel = duckdb.from_df(_pandas_df)

Expand All @@ -894,15 +886,13 @@ def write_csv_file_duckdb(
f"{temp_csv_directory_path}{download_file_name}/{d}"
for d in os.listdir(f"{temp_csv_directory_path}{download_file_name}")
]
for dir in _partition_dirs:
_old_csv_path = f"{dir}/{os.listdir(dir)[0]}"
_new_csv_path = (
f"{temp_csv_directory_path}{download_file_name}"
f"/{download_file_name}_{dir.split('=')[1].zfill(2)}.csv"
)
for _dir in _partition_dirs:
_file_number = _dir.split("=")[1].zfill(2)
_old_csv_path = f"{_dir}/{os.listdir(_dir)[0]}"
_new_csv_path = f"{temp_csv_directory_path}{download_file_name}/{download_file_name}_{_file_number}.csv"
shutil.move(_old_csv_path, _new_csv_path)
full_file_paths.append(_new_csv_path)
os.rmdir(dir)
os.rmdir(_dir)

logger.info(
f"{temp_csv_directory_path}{download_file_name} contains {df_record_count:,} rows of data"
Expand Down
140 changes: 89 additions & 51 deletions usaspending_api/common/helpers/download_csv_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Optional

from django.conf import settings
from duckdb.experimental.spark.sql import DataFrame as DuckDBDataFrame
from duckdb.experimental.spark.sql import SparkSession as DuckDBSparkSession
from pyspark.sql import DataFrame

Expand All @@ -17,7 +18,6 @@
download_s3_object,
)
from usaspending_api.download.filestreaming.download_generation import (
EXCEL_ROW_LIMIT,
execute_psql,
generate_export_query_temp_file,
split_and_zip_data_files,
Expand Down Expand Up @@ -80,20 +80,24 @@ def __init__(self, logger: logging.Logger, *args, **kwargs):

def download_to_csv(
self,
source_sql,
destination_path,
destination_file_name,
working_dir_path,
download_zip_path,
source_df=None,
):
source_sql: str,
destination_path: str,
destination_file_name: str,
working_dir_path: str,
download_zip_path: str,
source_df: DataFrame | None = None,
) -> CSVDownloadMetadata:
start_time = time.perf_counter()
self._logger.info(f"Downloading data to {destination_path}")
temp_data_file_name = destination_path.parent / (destination_path.name + "_temp")
temp_data_file_name = destination_path.parent / (
destination_path.name + "_temp"
)
options = FILE_FORMATS[self.file_format]["options"]
export_query = r"\COPY ({}) TO STDOUT {}".format(source_sql, options)
try:
temp_file, temp_file_path = generate_export_query_temp_file(export_query, None, working_dir_path)
temp_file, temp_file_path = generate_export_query_temp_file(
export_query, None, working_dir_path
)
# Create a separate process to run the PSQL command; wait
psql_process = multiprocessing.Process(
target=execute_psql, args=(temp_file_path, temp_data_file_name, None)
Expand All @@ -104,12 +108,20 @@ def download_to_csv(
delim = FILE_FORMATS[self.file_format]["delimiter"]

# Log how many rows we have
self._logger.info(f"Counting rows in delimited text file {temp_data_file_name}")
self._logger.info(
f"Counting rows in delimited text file {temp_data_file_name}"
)
try:
row_count = count_rows_in_delimited_file(filename=temp_data_file_name, has_header=True, delimiter=delim)
self._logger.info(f"{destination_path} contains {row_count:,} rows of data")
row_count = count_rows_in_delimited_file(
filename=temp_data_file_name, has_header=True, delimiter=delim
)
self._logger.info(
f"{destination_path} contains {row_count:,} rows of data"
)
except Exception:
self._logger.exception("Unable to obtain delimited text file line count")
self._logger.exception(
"Unable to obtain delimited text file line count"
)

start_time = time.perf_counter()
zip_process = multiprocessing.Process(
Expand All @@ -136,32 +148,36 @@ def __init__(self, logger: logging.Logger, *args, **kwargs):
super().__init__(*args, **kwargs)
self._logger = logger

def download_to_csv(
def download_to_csv( # noqa: PLR0913
self,
source_sql,
destination_path,
destination_file_name,
working_dir_path,
download_zip_path,
source_df=None,
delimiter=",",
file_format="csv",
):
# These imports are here for a reason.
# some strategies do not require spark
# we do not want to force all containers where
# other strategies run to have pyspark installed when the strategy
# doesn't require it.
source_sql: str,
destination_path: str,
destination_file_name: str,
working_dir_path: str,
download_zip_path: str,
source_df: DataFrame | None = None,
delimiter: str = ",",
file_format: str = "csv",
) -> CSVDownloadMetadata:
# Some strategies do not require spark we do not want to force all containers where
# other strategies run to have pyspark installed when the strategy doesn't require it.
from usaspending_api.common.etl.spark import write_csv_file
from usaspending_api.common.helpers.spark_helpers import configure_spark_session, get_active_spark_session
from usaspending_api.common.helpers.spark_helpers import (
configure_spark_session,
get_active_spark_session,
)

self.spark = None
destination_path_dir = str(destination_path).replace(f"/{destination_file_name}", "")
destination_path_dir = str(destination_path).replace(
f"/{destination_file_name}", ""
)
# The place to write intermediate data files to in s3
s3_bucket_name = settings.BULK_DOWNLOAD_S3_BUCKET_NAME
s3_bucket_path = f"s3a://{s3_bucket_name}"
s3_bucket_sub_path = "temp_download"
s3_destination_path = f"{s3_bucket_path}/{s3_bucket_sub_path}/{destination_file_name}"
s3_destination_path = (
f"{s3_bucket_path}/{s3_bucket_sub_path}/{destination_file_name}"
)
try:
extra_conf = {
# Config for Delta Lake tables and SQL. Need these to keep Dela table metadata in the metastore
Expand All @@ -176,7 +192,9 @@ def download_to_csv(
self.spark_created_by_command = False
if not self.spark:
self.spark_created_by_command = True
self.spark = configure_spark_session(**extra_conf, spark_context=self.spark)
self.spark = configure_spark_session(
**extra_conf, spark_context=self.spark
)
if source_df is not None:
df = source_df
else:
Expand All @@ -185,7 +203,6 @@ def download_to_csv(
self.spark,
df,
parts_dir=s3_destination_path,
max_records_per_file=EXCEL_ROW_LIMIT,
logger=self._logger,
delimiter=delimiter,
)
Expand All @@ -208,12 +225,19 @@ def download_to_csv(
self._logger.exception("Exception encountered. See logs")
raise
finally:
delete_s3_objects(s3_bucket_name, key_prefix=f"{s3_bucket_sub_path}/{destination_file_name}")
delete_s3_objects(
s3_bucket_name,
key_prefix=f"{s3_bucket_sub_path}/{destination_file_name}",
)
if self.spark_created_by_command:
self.spark.stop()
append_files_to_zip_file(final_csv_data_file_locations, download_zip_path)
self._logger.info(f"Generated the following data csv files {final_csv_data_file_locations}")
return CSVDownloadMetadata(final_csv_data_file_locations, record_count, column_count)
self._logger.info(
f"Generated the following data csv files {final_csv_data_file_locations}"
)
return CSVDownloadMetadata(
final_csv_data_file_locations, record_count, column_count
)

def _move_data_csv_s3_to_local(
self,
Expand All @@ -230,7 +254,7 @@ def _move_data_csv_s3_to_local(
s3_file_paths: A list of file paths to move from s3, name should
include s3a:// and bucket name
s3_bucket_path: The bucket path, e.g. s3a:// + bucket name
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name, e.g. temp_directory/files
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name
destination_path_dir: The location to move those files from s3 to, must not include the
file name in the path. This path should be a directory.

Expand All @@ -251,27 +275,31 @@ def _move_data_csv_s3_to_local(
final_path,
)
local_csv_file_paths.append(final_path)
self._logger.info(f"Copied data files from S3 to local machine in {(time.time() - start_time):3f}s")
self._logger.info(
f"Copied data files from S3 to local machine in {(time.time() - start_time):3f}s"
)
return local_csv_file_paths


class DuckDBToCSVStrategy(AbstractToCSVStrategy):
def __init__(self, logger: logging.Logger, spark: DuckDBSparkSession, *args, **kwargs):
def __init__(
self, logger: logging.Logger, spark: DuckDBSparkSession, *args, **kwargs
):
super().__init__(*args, **kwargs)
self._logger = logger
self.spark = spark

def download_to_csv(
def download_to_csv( # noqa: PLR0913
self,
source_sql: str | None,
destination_path: str,
destination_file_name: str,
working_dir_path: str,
download_zip_path: str,
source_df=None,
delimiter=",",
file_format="csv",
):
source_df: DuckDBDataFrame | None = None,
delimiter: str = ",",
file_format: str = "csv",
) -> CSVDownloadMetadata:
from usaspending_api.common.etl.spark import write_csv_file_duckdb

try:
Expand All @@ -282,7 +310,6 @@ def download_to_csv(
record_count, final_csv_data_file_locations = write_csv_file_duckdb(
df=df,
download_file_name=destination_file_name,
max_records_per_file=EXCEL_ROW_LIMIT,
logger=self._logger,
delimiter=delimiter,
)
Expand All @@ -291,11 +318,20 @@ def download_to_csv(
self._logger.exception("Exception encountered. See logs")
raise
append_files_to_zip_file(final_csv_data_file_locations, download_zip_path)
self._logger.info(f"Generated the following data csv files {final_csv_data_file_locations}")
return CSVDownloadMetadata(final_csv_data_file_locations, record_count, column_count)
self._logger.info(
f"Generated the following data csv files {final_csv_data_file_locations}"
)
return CSVDownloadMetadata(
final_csv_data_file_locations, record_count, column_count
)

def _move_data_csv_s3_to_local(
self, bucket_name, s3_file_paths, s3_bucket_path, s3_bucket_sub_path, destination_path_dir
self,
bucket_name: str,
s3_file_paths: list[str] | set[str] | tuple[str],
s3_bucket_path: str,
s3_bucket_sub_path: str,
destination_path_dir: str,
) -> List[str]:
"""Moves files from s3 data csv location to a location on the local machine.

Expand All @@ -304,7 +340,7 @@ def _move_data_csv_s3_to_local(
s3_file_paths: A list of file paths to move from s3, name should
include s3a:// and bucket name
s3_bucket_path: The bucket path, e.g. s3a:// + bucket name
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name, e.g. temp_directory/files
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name
destination_path_dir: The location to move those files from s3 to, must not include the
file name in the path. This path should be a diretory.

Expand All @@ -325,5 +361,7 @@ def _move_data_csv_s3_to_local(
final_path,
)
local_csv_file_paths.append(final_path)
self._logger.info(f"Copied data files from S3 to local machine in {(time.time() - start_time):3f}s")
self._logger.info(
f"Copied data files from S3 to local machine in {(time.time() - start_time):3f}s"
)
return local_csv_file_paths
Loading