Skip to content

Commit 495bd5e

Browse files
authored
Merge pull request #4531 from fedspendingtransparency/ftr/dev-12528-spark-download-zipping
[DEV-12528] Remove Hadoop copy merge step from Spark downloads
2 parents f2e3b82 + b5e23fb commit 495bd5e

File tree

6 files changed

+141
-181
lines changed

6 files changed

+141
-181
lines changed

usaspending_api/common/etl/spark.py

Lines changed: 51 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66
"""
77

88
import logging
9+
import math
910
import time
1011
from collections import namedtuple
1112
from itertools import chain
1213
from typing import List
1314

14-
from py4j.protocol import Py4JError
1515
from pyspark.sql import DataFrame, SparkSession
1616
from pyspark.sql.functions import col, concat, concat_ws, expr, lit, regexp_replace, to_date, transform, when
1717
from pyspark.sql.types import ArrayType, DecimalType, StringType, StructType
1818

1919
from usaspending_api.accounts.models import AppropriationAccountBalances, FederalAccount, TreasuryAppropriationAccount
20+
from usaspending_api.common.helpers.s3_helpers import rename_s3_object, retrieve_s3_bucket_object_list
2021
from usaspending_api.common.helpers.spark_helpers import (
2122
get_broker_jdbc_url,
2223
get_jdbc_connection_properties,
@@ -97,7 +98,6 @@ def extract_db_data_frame(
9798
is_date_partitioning_col: bool = False,
9899
custom_schema: StructType = None,
99100
) -> DataFrame:
100-
101101
logger.info(f"Getting partition bounds using SQL:\n{min_max_sql}")
102102

103103
data_df = None
@@ -427,7 +427,7 @@ def diff(
427427
cols_to_show = (
428428
["diff"]
429429
+ [f"l.{unique_key_col}", f"r.{unique_key_col}"]
430-
+ list(chain(*zip([f"l.{c}" for c in compare_cols], [f"r.{c}" for c in compare_cols])))
430+
+ list(chain(*zip([f"l.{c}" for c in compare_cols], [f"r.{c}" for c in compare_cols], strict=False)))
431431
)
432432
differences = differences.select(*cols_to_show)
433433
if not include_unchanged_rows:
@@ -588,7 +588,6 @@ def write_csv_file(
588588
spark: SparkSession,
589589
df: DataFrame,
590590
parts_dir: str,
591-
num_partitions: int,
592591
max_records_per_file=EXCEL_ROW_LIMIT,
593592
overwrite=True,
594593
logger=None,
@@ -599,7 +598,6 @@ def write_csv_file(
599598
spark: passed-in active SparkSession
600599
df: the DataFrame wrapping the data source to be dumped to CSV.
601600
parts_dir: Path to dir that will contain the outputted parts files from partitions
602-
num_partitions: Indicates the number of partitions to use when writing the Dataframe
603601
overwrite: Whether to replace the file CSV files if they already exist by that name
604602
max_records_per_file: Suggestion to Spark of how many records to put in each written CSV file part,
605603
if it will end up writing multiple files.
@@ -617,12 +615,13 @@ def write_csv_file(
617615
start = time.time()
618616
logger.info(f"Writing source data DataFrame to csv part files for file {parts_dir}...")
619617
df_record_count = df.count()
618+
num_partitions = math.ceil(df_record_count / max_records_per_file) or 1
620619
df.repartition(num_partitions).write.options(
621620
# NOTE: this is a suggestion, to be used by Spark if partitions yield multiple files
622621
maxRecordsPerFile=max_records_per_file,
623622
).csv(
624623
path=parts_dir,
625-
header=False,
624+
header=True,
626625
emptyValue="", # "" creates the output of ,,, for null values to match behavior of previous Postgres job
627626
escape='"', # " is used to escape the 'quote' character setting (which defaults to "). Escaped quote = ""
628627
ignoreLeadingWhiteSpace=False, # must set for CSV write, as it defaults to true
@@ -636,112 +635,6 @@ def write_csv_file(
636635
return df_record_count
637636

638637

639-
def hadoop_copy_merge(
640-
spark: SparkSession,
641-
parts_dir: str,
642-
header: str,
643-
part_merge_group_size: int,
644-
logger=None,
645-
file_format="csv",
646-
) -> List[str]:
647-
"""PySpark impl of Hadoop 2.x copyMerge() (deprecated in Hadoop 3.x)
648-
Merges files from a provided input directory and then redivides them
649-
into multiple files based on merge group size.
650-
Args:
651-
spark: passed-in active SparkSession
652-
parts_dir: Path to the dir that contains the input parts files. The parts dir name
653-
determines the name of the merged files. Parts_dir cannot have a trailing slash.
654-
header: A comma-separated list of field names, to be placed as the first row of every final CSV file.
655-
Individual part files must NOT therefore be created with their own header.
656-
part_merge_group_size: Final CSV data will be subdivided into numbered files. This indicates how many part files
657-
should be combined into a numbered file.
658-
logger: The logger to use. If one note provided (e.g. to log to console or stdout) the underlying JVM-based
659-
Logger will be extracted from the ``spark`` ``SparkSession`` and used as the logger.
660-
file_format: The format of the part files and the format of the final merged file, e.g. "csv"
661-
662-
Returns:
663-
A list of file paths where each element in the list denotes a path to
664-
a merged file that was generated during the copy merge.
665-
"""
666-
overwrite = True
667-
hadoop = spark.sparkContext._jvm.org.apache.hadoop
668-
conf = spark.sparkContext._jsc.hadoopConfiguration()
669-
670-
# Guard against incorrectly formatted argument value
671-
parts_dir = parts_dir.rstrip("/")
672-
673-
parts_dir_path = hadoop.fs.Path(parts_dir)
674-
675-
fs = parts_dir_path.getFileSystem(conf)
676-
677-
if not fs.exists(parts_dir_path):
678-
raise ValueError("Source directory {} does not exist".format(parts_dir))
679-
680-
file = parts_dir
681-
file_path = hadoop.fs.Path(file)
682-
683-
# Don't delete first if disallowing overwrite.
684-
if not overwrite and fs.exists(file_path):
685-
raise Py4JError(
686-
spark._jvm.org.apache.hadoop.fs.FileAlreadyExistsException(f"{str(file_path)} " f"already exists")
687-
)
688-
part_files = []
689-
690-
for f in fs.listStatus(parts_dir_path):
691-
if f.isFile():
692-
# Sometimes part files can be empty, we need to ignore them
693-
if f.getLen() == 0:
694-
continue
695-
file_path = f.getPath()
696-
if file_path.getName().startswith("_"):
697-
logger.debug(f"Skipping non-part file: {file_path.getName()}")
698-
continue
699-
logger.debug(f"Including part file: {file_path.getName()}")
700-
part_files.append(f.getPath())
701-
if not part_files:
702-
logger.warning("Source directory is empty with no part files. Attempting creation of file with CSV header only")
703-
out_stream = None
704-
try:
705-
merged_file_path = f"{parts_dir}.{file_format}"
706-
out_stream = fs.create(hadoop.fs.Path(merged_file_path), overwrite)
707-
out_stream.writeBytes(header + "\n")
708-
finally:
709-
if out_stream is not None:
710-
out_stream.close()
711-
return [merged_file_path]
712-
713-
part_files.sort(key=lambda f: str(f)) # put parts in order by part number for merging
714-
paths_to_merged_files = []
715-
for parts_file_group in _merge_grouper(part_files, part_merge_group_size):
716-
part_suffix = f"_{str(parts_file_group.part).zfill(2)}" if parts_file_group.part else ""
717-
partial_merged_file = f"{parts_dir}.partial{part_suffix}"
718-
partial_merged_file_path = hadoop.fs.Path(partial_merged_file)
719-
merged_file_path = f"{parts_dir}{part_suffix}.{file_format}"
720-
paths_to_merged_files.append(merged_file_path)
721-
# Make path a hadoop path because we are working with a hadoop file system
722-
merged_file_path = hadoop.fs.Path(merged_file_path)
723-
if overwrite and fs.exists(merged_file_path):
724-
fs.delete(merged_file_path, True)
725-
out_stream = None
726-
try:
727-
if fs.exists(partial_merged_file_path):
728-
fs.delete(partial_merged_file_path, True)
729-
out_stream = fs.create(partial_merged_file_path)
730-
out_stream.writeBytes(header + "\n")
731-
_merge_file_parts(fs, out_stream, conf, hadoop, partial_merged_file_path, parts_file_group.file_list)
732-
finally:
733-
if out_stream is not None:
734-
out_stream.close()
735-
try:
736-
fs.rename(partial_merged_file_path, merged_file_path)
737-
except Exception:
738-
if fs.exists(partial_merged_file_path):
739-
fs.delete(partial_merged_file_path, True)
740-
logger.exception("Exception encountered. See logs")
741-
raise
742-
return paths_to_merged_files
743-
744-
745638
def _merge_file_parts(fs, out_stream, conf, hadoop, partial_merged_file_path, part_file_list):
746639
"""Read-in files in alphabetical order and append them one by one to the merged file"""
747640

@@ -767,3 +660,49 @@ def _merge_grouper(items, group_size):
767660
group_generator = (items[i : i + group_size] for i in range(0, len(items), group_size))
768661
for i, group in enumerate(group_generator, start=1):
769662
yield FileMergeGroup(i, group)
663+
664+
665+
def rename_part_files(
666+
bucket_name: str,
667+
destination_file_name: str,
668+
logger: logging.Logger,
669+
temp_download_dir_name: str = "temp_download",
670+
file_format: str = "csv",
671+
) -> list[str]:
672+
"""Renames the part-000.csv files to match the zip filename structure.
673+
674+
Args:
675+
bucket_name: S3 bucket that contains the file to be renamed and will contain the renamed file.
676+
destination_file_name: Timestamped download file name. This is used to find the correct folder within the
677+
bucket.
678+
logger: Logger instance.
679+
temp_download_dir_name: Name of the folder to used to store the renamed CSV files before they are downloaded.
680+
Defaults to "temp_download".
681+
file_format: What file format to save the files in.
682+
Defaults to "csv".
683+
684+
Returns:
685+
A list of the full S3 paths for the CSV files.
686+
"""
687+
list_of_part_files = sorted(
688+
[
689+
file.key
690+
for file in retrieve_s3_bucket_object_list(bucket_name)
691+
if (
692+
file.key.startswith(f"{temp_download_dir_name}/{destination_file_name}/part-")
693+
and file.key.endswith(file_format)
694+
)
695+
]
696+
)
697+
698+
full_file_paths = []
699+
700+
for index, part_file in enumerate(list_of_part_files):
701+
old_key = f"{bucket_name}/{part_file}"
702+
new_key = f"{temp_download_dir_name}/{destination_file_name}_{str(index + 1).zfill(2)}.{file_format}"
703+
logger.info(f"Renaming {old_key} to {bucket_name}/{new_key}")
704+
705+
rename_s3_object(bucket_name=bucket_name, old_key=old_key, new_key=new_key)
706+
full_file_paths.append(f"s3a://{bucket_name}/{new_key}")
707+
708+
return full_file_paths

usaspending_api/common/helpers/download_csv_strategies.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,19 @@
88

99
from django.conf import settings
1010
from pyspark.sql import DataFrame
11+
1112
from usaspending_api.common.csv_helpers import count_rows_in_delimited_file
12-
from usaspending_api.common.helpers.s3_helpers import delete_s3_objects, download_s3_object
13+
from usaspending_api.common.etl.spark import rename_part_files
14+
from usaspending_api.common.helpers.s3_helpers import (
15+
delete_s3_objects,
16+
download_s3_object,
17+
)
1318
from usaspending_api.download.filestreaming.download_generation import (
1419
EXCEL_ROW_LIMIT,
15-
split_and_zip_data_files,
16-
wait_for_process,
1720
execute_psql,
1821
generate_export_query_temp_file,
22+
split_and_zip_data_files,
23+
wait_for_process,
1924
)
2025
from usaspending_api.download.filestreaming.zip_file import append_files_to_zip_file
2126
from usaspending_api.download.lookups import FILE_FORMATS
@@ -139,7 +144,7 @@ def download_to_csv(
139144
# we do not want to force all containers where
140145
# other strategies run to have pyspark installed when the strategy
141146
# doesn't require it.
142-
from usaspending_api.common.etl.spark import hadoop_copy_merge, write_csv_file
147+
from usaspending_api.common.etl.spark import write_csv_file
143148
from usaspending_api.common.helpers.spark_helpers import configure_spark_session, get_active_spark_session
144149

145150
self.spark = None
@@ -172,26 +177,24 @@ def download_to_csv(
172177
self.spark,
173178
df,
174179
parts_dir=s3_destination_path,
175-
num_partitions=1,
176180
max_records_per_file=EXCEL_ROW_LIMIT,
177181
logger=self._logger,
178182
delimiter=delimiter,
179183
)
180184
column_count = len(df.columns)
181-
# When combining these later, will prepend the extracted header to each resultant file.
182-
# The parts therefore must NOT have headers or the headers will show up in the data when combined.
183-
header = ",".join([_.name for _ in df.schema.fields])
184185
self._logger.info("Concatenating partitioned output files ...")
185-
merged_file_paths = hadoop_copy_merge(
186-
spark=self.spark,
187-
parts_dir=s3_destination_path,
188-
header=header,
186+
merged_file_paths = rename_part_files(
187+
bucket_name=s3_bucket_name,
188+
destination_file_name=destination_file_name,
189189
logger=self._logger,
190-
part_merge_group_size=1,
191190
file_format=file_format,
192191
)
193192
final_csv_data_file_locations = self._move_data_csv_s3_to_local(
194-
s3_bucket_name, merged_file_paths, s3_bucket_path, s3_bucket_sub_path, destination_path_dir
193+
s3_bucket_name,
194+
merged_file_paths,
195+
s3_bucket_path,
196+
s3_bucket_sub_path,
197+
destination_path_dir,
195198
)
196199
except Exception:
197200
self._logger.exception("Exception encountered. See logs")
@@ -205,7 +208,12 @@ def download_to_csv(
205208
return CSVDownloadMetadata(final_csv_data_file_locations, record_count, column_count)
206209

207210
def _move_data_csv_s3_to_local(
208-
self, bucket_name, s3_file_paths, s3_bucket_path, s3_bucket_sub_path, destination_path_dir
211+
self,
212+
bucket_name: str,
213+
s3_file_paths: list[str],
214+
s3_bucket_path: str,
215+
s3_bucket_sub_path: str,
216+
destination_path_dir: str,
209217
) -> List[str]:
210218
"""Moves files from s3 data csv location to a location on the local machine.
211219
@@ -216,7 +224,7 @@ def _move_data_csv_s3_to_local(
216224
s3_bucket_path: The bucket path, e.g. s3a:// + bucket name
217225
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name, e.g. temp_directory/files
218226
destination_path_dir: The location to move those files from s3 to, must not include the
219-
file name in the path. This path should be a diretory.
227+
file name in the path. This path should be a directory.
220228
221229
Returns:
222230
A list of the final location on the local machine that the

usaspending_api/common/helpers/s3_helpers.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
import boto3
21
import io
32
import logging
43
import math
54
import time
6-
7-
from boto3.s3.transfer import TransferConfig, S3Transfer
8-
from botocore.exceptions import ClientError
9-
from django.conf import settings
105
from pathlib import Path
116
from typing import Optional
7+
8+
import boto3
9+
from boto3.s3.transfer import S3Transfer, TransferConfig
1210
from botocore.client import BaseClient
11+
from botocore.exceptions import ClientError
12+
from django.conf import settings
1313

1414
from usaspending_api.config import CONFIG
1515

@@ -167,3 +167,21 @@ def delete_s3_objects(
167167
resp = s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": object_list})
168168

169169
return len(resp.get("Deleted", []))
170+
171+
172+
def rename_s3_object(bucket_name: str, old_key: str, new_key: str, region_name: str = settings.USASPENDING_AWS_REGION):
173+
"""Rename an existing S3 object by:
174+
1) Copying the file (old_key) to a new file with the new name (new_key)
175+
2) If the copy was successful, delete the old file (old_key)
176+
Args:
177+
bucket_name: The name of the bucket where the current object is located.
178+
old_key: The current name of the key to be renamed.
179+
new_key: The new name of the key.
180+
region_name: AWS region to use; defaults to the settings provided region.
181+
"""
182+
183+
s3 = _get_boto3("client", "s3", region_name=region_name)
184+
response = s3.copy_object(Bucket=bucket_name, CopySource=old_key, Key=new_key)
185+
186+
if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
187+
s3.delete_object(Bucket=bucket_name, Key=old_key)

0 commit comments

Comments
 (0)