Skip to content

Commit 970b8f6

Browse files
authored
Merge pull request #4591 from fedspendingtransparency/ftr/dev-13727-file-c-duckdb
[DEV-13727]File C Custom Account DuckDB downloads
2 parents 7577cfd + 8201e35 commit 970b8f6

File tree

3 files changed

+148
-102
lines changed

3 files changed

+148
-102
lines changed

usaspending_api/common/etl/spark.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,7 @@ def load_es_index(
390390
)
391391

392392

393-
def merge_delta_table(
394-
spark: SparkSession, source_df: DataFrame, delta_table_name: str, merge_column: str
395-
) -> None:
393+
def merge_delta_table(spark: SparkSession, source_df: DataFrame, delta_table_name: str, merge_column: str) -> None:
396394
source_df.create_or_replace_temporary_view("temp_table")
397395

398396
spark.sql(
@@ -409,7 +407,7 @@ def diff(
409407
left: DataFrame,
410408
right: DataFrame,
411409
unique_key_col: str = "id",
412-
compare_cols: list[str] | None = None,
410+
compare_cols: list[str] = None,
413411
include_unchanged_rows: bool = False,
414412
) -> DataFrame:
415413
"""Compares two Spark DataFrames that share a schema and returns row-level differences in a DataFrame
@@ -563,8 +561,7 @@ def convert_array_cols_to_string(
563561
2. Escape any quotes inside the array element with backslash.
564562
- A case that involves all of this will yield CSV field value like this when viewed in a text editor,
565563
assuming Spark CSV options are: quote='"', escape='"' (the default is for it to match quote)
566-
...,"{""{\""simple\"": \""elem1\"", \""other\"": \""elem1\""}"",
567-
""{\""simple\"": \""elem2\"", \""other\"": \""elem2\""}""}",...
564+
...,"{""{\""simple\"": \""elem1\"", \""other\"": \""elem1\""}"",...
568565
"""
569566
arr_open_bracket = "["
570567
arr_close_bracket = "]"
@@ -599,7 +596,8 @@ def convert_array_cols_to_string(
599596
# Special handling in case of data that already has either a quote " or backslash \
600597
# inside an array element
601598
# First replace any single backslash character \ with TWO \\ (an escaped backslash)
602-
# Then replace quote " character with \" (escaped quote, inside a quoted array elem)
599+
# Then replace any quote " character with \"
600+
# (escaped quote, inside a quoted array elem)
603601
# NOTE: these regexp_replace get sent down to a Java replaceAll, which will require
604602
# FOUR backslashes to represent ONE
605603
regexp_replace(
@@ -649,7 +647,7 @@ def _generate_global_view_sql_strings(tables: list[str], jdbc_url: str) -> list[
649647

650648
def create_ref_temp_views( # noqa: PLR0912
651649
spark: SparkSession | DuckDBSparkSession, create_broker_views: bool = False
652-
) -> None:
650+
) -> None: # noqa: PLR0912
653651
"""Create global temporary Spark reference views that sit atop remote PostgreSQL RDS tables
654652
Setting create_broker_views to True will create views for all tables list in _BROKER_REF_TABLES
655653
Note: They will all be listed under global_temp.{table_name}
@@ -792,9 +790,8 @@ def write_csv_file( # noqa: PLR0913
792790
spark: SparkSession,
793791
df: DataFrame,
794792
parts_dir: str,
795-
max_records_per_file: int = EXCEL_ROW_LIMIT,
796793
overwrite: bool = True,
797-
logger: logging.Logger | None = None,
794+
logger: logging.Logger = None,
798795
delimiter: str = ",",
799796
) -> int:
800797
"""Write DataFrame data to CSV file parts.
@@ -804,8 +801,6 @@ def write_csv_file( # noqa: PLR0913
804801
parts_dir: Path to dir that will contain the outputted parts files from partitions
805802
num_partitions: Indicates the number of partitions to use when writing the Dataframe
806803
overwrite: Whether to replace the file CSV files if they already exist by that name
807-
max_records_per_file: Suggestion to Spark of how many records to put in each written CSV file part,
808-
if it will end up writing multiple files.
809804
logger: The logger to use. If one note provided (e.g. to log to console or stdout) the underlying JVM-based
810805
Logger will be extracted from the ``spark`` ``SparkSession`` and used as the logger.
811806
delimiter: Charactor used to separate columns in the CSV
@@ -822,10 +817,10 @@ def write_csv_file( # noqa: PLR0913
822817
f"Writing source data DataFrame to csv part files for file {parts_dir}..."
823818
)
824819
df_record_count = df.count()
825-
num_partitions = math.ceil(df_record_count / max_records_per_file) or 1
820+
num_partitions = math.ceil(df_record_count / EXCEL_ROW_LIMIT) or 1
826821
df.repartition(num_partitions).write.options(
827822
# NOTE: this is a suggestion, to be used by Spark if partitions yield multiple files
828-
maxRecordsPerFile=max_records_per_file,
823+
maxRecordsPerFile=EXCEL_ROW_LIMIT,
829824
).csv(
830825
path=parts_dir,
831826
header=True,
@@ -848,7 +843,6 @@ def write_csv_file_duckdb(
848843
df: DuckDBDataFrame,
849844
download_file_name: str,
850845
temp_csv_directory_path: str = CSV_LOCAL_PATH,
851-
max_records_per_file: int = EXCEL_ROW_LIMIT,
852846
logger: logging.Logger | None = None,
853847
delimiter: str = ",",
854848
) -> tuple[int, list[str] | list]:
@@ -858,8 +852,6 @@ def write_csv_file_duckdb(
858852
download_file_name: Name of the download being generated.
859853
temp_csv_directory_path: Directory that will contain the individual CSV files before zipping.
860854
Defaults to CSV_LOCAL_PATH
861-
max_records_per_file: Max number of records to put in each written CSV file.
862-
Defaults to EXCEL_ROW_LIMIT
863855
logger: Logging instance to use.
864856
Defaults to None
865857
delimiter: Charactor used to separate columns in the CSV
@@ -870,7 +862,7 @@ def write_csv_file_duckdb(
870862
"""
871863
start = time.time()
872864
_pandas_df = df.toPandas()
873-
_pandas_df["file_number"] = (_pandas_df.index // max_records_per_file) + 1
865+
_pandas_df["file_number"] = (_pandas_df.index // EXCEL_ROW_LIMIT) + 1
874866
df_record_count = len(_pandas_df)
875867
rel = duckdb.from_df(_pandas_df)
876868

@@ -894,15 +886,13 @@ def write_csv_file_duckdb(
894886
f"{temp_csv_directory_path}{download_file_name}/{d}"
895887
for d in os.listdir(f"{temp_csv_directory_path}{download_file_name}")
896888
]
897-
for dir in _partition_dirs:
898-
_old_csv_path = f"{dir}/{os.listdir(dir)[0]}"
899-
_new_csv_path = (
900-
f"{temp_csv_directory_path}{download_file_name}"
901-
f"/{download_file_name}_{dir.split('=')[1].zfill(2)}.csv"
902-
)
889+
for _dir in _partition_dirs:
890+
_file_number = _dir.split("=")[1].zfill(2)
891+
_old_csv_path = f"{_dir}/{os.listdir(_dir)[0]}"
892+
_new_csv_path = f"{temp_csv_directory_path}{download_file_name}/{download_file_name}_{_file_number}.csv"
903893
shutil.move(_old_csv_path, _new_csv_path)
904894
full_file_paths.append(_new_csv_path)
905-
os.rmdir(dir)
895+
os.rmdir(_dir)
906896

907897
logger.info(
908898
f"{temp_csv_directory_path}{download_file_name} contains {df_record_count:,} rows of data"

usaspending_api/common/helpers/download_csv_strategies.py

Lines changed: 90 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import List, Optional
88

99
from django.conf import settings
10+
from duckdb.experimental.spark.sql import DataFrame as DuckDBDataFrame
1011
from duckdb.experimental.spark.sql import SparkSession as DuckDBSparkSession
1112
from pyspark.sql import DataFrame
1213

@@ -17,7 +18,6 @@
1718
download_s3_object,
1819
)
1920
from usaspending_api.download.filestreaming.download_generation import (
20-
EXCEL_ROW_LIMIT,
2121
execute_psql,
2222
generate_export_query_temp_file,
2323
split_and_zip_data_files,
@@ -80,20 +80,25 @@ def __init__(self, logger: logging.Logger, *args, **kwargs):
8080

8181
def download_to_csv(
8282
self,
83-
source_sql,
84-
destination_path,
85-
destination_file_name,
86-
working_dir_path,
87-
download_zip_path,
88-
source_df=None,
89-
):
83+
source_sql: str,
84+
destination_path: str,
85+
destination_file_name: str,
86+
working_dir_path: str,
87+
download_zip_path: str,
88+
source_df: DataFrame | None = None,
89+
) -> CSVDownloadMetadata:
9090
start_time = time.perf_counter()
9191
self._logger.info(f"Downloading data to {destination_path}")
92-
temp_data_file_name = destination_path.parent / (destination_path.name + "_temp")
92+
row_count = 0
93+
temp_data_file_name = destination_path.parent / (
94+
destination_path.name + "_temp"
95+
)
9396
options = FILE_FORMATS[self.file_format]["options"]
9497
export_query = r"\COPY ({}) TO STDOUT {}".format(source_sql, options)
9598
try:
96-
temp_file, temp_file_path = generate_export_query_temp_file(export_query, None, working_dir_path)
99+
temp_file, temp_file_path = generate_export_query_temp_file(
100+
export_query, None, working_dir_path
101+
)
97102
# Create a separate process to run the PSQL command; wait
98103
psql_process = multiprocessing.Process(
99104
target=execute_psql, args=(temp_file_path, temp_data_file_name, None)
@@ -104,12 +109,20 @@ def download_to_csv(
104109
delim = FILE_FORMATS[self.file_format]["delimiter"]
105110

106111
# Log how many rows we have
107-
self._logger.info(f"Counting rows in delimited text file {temp_data_file_name}")
112+
self._logger.info(
113+
f"Counting rows in delimited text file {temp_data_file_name}"
114+
)
108115
try:
109-
row_count = count_rows_in_delimited_file(filename=temp_data_file_name, has_header=True, delimiter=delim)
110-
self._logger.info(f"{destination_path} contains {row_count:,} rows of data")
116+
row_count = count_rows_in_delimited_file(
117+
filename=temp_data_file_name, has_header=True, delimiter=delim
118+
)
119+
self._logger.info(
120+
f"{destination_path} contains {row_count:,} rows of data"
121+
)
111122
except Exception:
112-
self._logger.exception("Unable to obtain delimited text file line count")
123+
self._logger.exception(
124+
"Unable to obtain delimited text file line count"
125+
)
113126

114127
start_time = time.perf_counter()
115128
zip_process = multiprocessing.Process(
@@ -136,32 +149,36 @@ def __init__(self, logger: logging.Logger, *args, **kwargs):
136149
super().__init__(*args, **kwargs)
137150
self._logger = logger
138151

139-
def download_to_csv(
152+
def download_to_csv( # noqa: PLR0913
140153
self,
141-
source_sql,
142-
destination_path,
143-
destination_file_name,
144-
working_dir_path,
145-
download_zip_path,
146-
source_df=None,
147-
delimiter=",",
148-
file_format="csv",
149-
):
150-
# These imports are here for a reason.
151-
# some strategies do not require spark
152-
# we do not want to force all containers where
153-
# other strategies run to have pyspark installed when the strategy
154-
# doesn't require it.
154+
source_sql: str,
155+
destination_path: str,
156+
destination_file_name: str,
157+
working_dir_path: str,
158+
download_zip_path: str,
159+
source_df: DataFrame | None = None,
160+
delimiter: str = ",",
161+
file_format: str = "csv",
162+
) -> CSVDownloadMetadata:
163+
# Some strategies do not require spark we do not want to force all containers where
164+
# other strategies run to have pyspark installed when the strategy doesn't require it.
155165
from usaspending_api.common.etl.spark import write_csv_file
156-
from usaspending_api.common.helpers.spark_helpers import configure_spark_session, get_active_spark_session
166+
from usaspending_api.common.helpers.spark_helpers import (
167+
configure_spark_session,
168+
get_active_spark_session,
169+
)
157170

158171
self.spark = None
159-
destination_path_dir = str(destination_path).replace(f"/{destination_file_name}", "")
172+
destination_path_dir = str(destination_path).replace(
173+
f"/{destination_file_name}", ""
174+
)
160175
# The place to write intermediate data files to in s3
161176
s3_bucket_name = settings.BULK_DOWNLOAD_S3_BUCKET_NAME
162177
s3_bucket_path = f"s3a://{s3_bucket_name}"
163178
s3_bucket_sub_path = "temp_download"
164-
s3_destination_path = f"{s3_bucket_path}/{s3_bucket_sub_path}/{destination_file_name}"
179+
s3_destination_path = (
180+
f"{s3_bucket_path}/{s3_bucket_sub_path}/{destination_file_name}"
181+
)
165182
try:
166183
extra_conf = {
167184
# Config for Delta Lake tables and SQL. Need these to keep Dela table metadata in the metastore
@@ -176,7 +193,9 @@ def download_to_csv(
176193
self.spark_created_by_command = False
177194
if not self.spark:
178195
self.spark_created_by_command = True
179-
self.spark = configure_spark_session(**extra_conf, spark_context=self.spark)
196+
self.spark = configure_spark_session(
197+
**extra_conf, spark_context=self.spark
198+
)
180199
if source_df is not None:
181200
df = source_df
182201
else:
@@ -185,7 +204,6 @@ def download_to_csv(
185204
self.spark,
186205
df,
187206
parts_dir=s3_destination_path,
188-
max_records_per_file=EXCEL_ROW_LIMIT,
189207
logger=self._logger,
190208
delimiter=delimiter,
191209
)
@@ -208,12 +226,19 @@ def download_to_csv(
208226
self._logger.exception("Exception encountered. See logs")
209227
raise
210228
finally:
211-
delete_s3_objects(s3_bucket_name, key_prefix=f"{s3_bucket_sub_path}/{destination_file_name}")
229+
delete_s3_objects(
230+
s3_bucket_name,
231+
key_prefix=f"{s3_bucket_sub_path}/{destination_file_name}",
232+
)
212233
if self.spark_created_by_command:
213234
self.spark.stop()
214235
append_files_to_zip_file(final_csv_data_file_locations, download_zip_path)
215-
self._logger.info(f"Generated the following data csv files {final_csv_data_file_locations}")
216-
return CSVDownloadMetadata(final_csv_data_file_locations, record_count, column_count)
236+
self._logger.info(
237+
f"Generated the following data csv files {final_csv_data_file_locations}"
238+
)
239+
return CSVDownloadMetadata(
240+
final_csv_data_file_locations, record_count, column_count
241+
)
217242

218243
def _move_data_csv_s3_to_local(
219244
self,
@@ -230,7 +255,7 @@ def _move_data_csv_s3_to_local(
230255
s3_file_paths: A list of file paths to move from s3, name should
231256
include s3a:// and bucket name
232257
s3_bucket_path: The bucket path, e.g. s3a:// + bucket name
233-
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name, e.g. temp_directory/files
258+
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name
234259
destination_path_dir: The location to move those files from s3 to, must not include the
235260
file name in the path. This path should be a directory.
236261
@@ -251,27 +276,31 @@ def _move_data_csv_s3_to_local(
251276
final_path,
252277
)
253278
local_csv_file_paths.append(final_path)
254-
self._logger.info(f"Copied data files from S3 to local machine in {(time.time() - start_time):3f}s")
279+
self._logger.info(
280+
f"Copied data files from S3 to local machine in {(time.time() - start_time):3f}s"
281+
)
255282
return local_csv_file_paths
256283

257284

258285
class DuckDBToCSVStrategy(AbstractToCSVStrategy):
259-
def __init__(self, logger: logging.Logger, spark: DuckDBSparkSession, *args, **kwargs):
286+
def __init__(
287+
self, logger: logging.Logger, spark: DuckDBSparkSession, *args, **kwargs
288+
):
260289
super().__init__(*args, **kwargs)
261290
self._logger = logger
262291
self.spark = spark
263292

264-
def download_to_csv(
293+
def download_to_csv( # noqa: PLR0913
265294
self,
266295
source_sql: str | None,
267296
destination_path: str,
268297
destination_file_name: str,
269298
working_dir_path: str,
270299
download_zip_path: str,
271-
source_df=None,
272-
delimiter=",",
273-
file_format="csv",
274-
):
300+
source_df: DuckDBDataFrame | None = None,
301+
delimiter: str = ",",
302+
file_format: str = "csv",
303+
) -> CSVDownloadMetadata:
275304
from usaspending_api.common.etl.spark import write_csv_file_duckdb
276305

277306
try:
@@ -282,7 +311,6 @@ def download_to_csv(
282311
record_count, final_csv_data_file_locations = write_csv_file_duckdb(
283312
df=df,
284313
download_file_name=destination_file_name,
285-
max_records_per_file=EXCEL_ROW_LIMIT,
286314
logger=self._logger,
287315
delimiter=delimiter,
288316
)
@@ -291,11 +319,20 @@ def download_to_csv(
291319
self._logger.exception("Exception encountered. See logs")
292320
raise
293321
append_files_to_zip_file(final_csv_data_file_locations, download_zip_path)
294-
self._logger.info(f"Generated the following data csv files {final_csv_data_file_locations}")
295-
return CSVDownloadMetadata(final_csv_data_file_locations, record_count, column_count)
322+
self._logger.info(
323+
f"Generated the following data csv files {final_csv_data_file_locations}"
324+
)
325+
return CSVDownloadMetadata(
326+
final_csv_data_file_locations, record_count, column_count
327+
)
296328

297329
def _move_data_csv_s3_to_local(
298-
self, bucket_name, s3_file_paths, s3_bucket_path, s3_bucket_sub_path, destination_path_dir
330+
self,
331+
bucket_name: str,
332+
s3_file_paths: list[str] | set[str] | tuple[str],
333+
s3_bucket_path: str,
334+
s3_bucket_sub_path: str,
335+
destination_path_dir: str,
299336
) -> List[str]:
300337
"""Moves files from s3 data csv location to a location on the local machine.
301338
@@ -304,7 +341,7 @@ def _move_data_csv_s3_to_local(
304341
s3_file_paths: A list of file paths to move from s3, name should
305342
include s3a:// and bucket name
306343
s3_bucket_path: The bucket path, e.g. s3a:// + bucket name
307-
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name, e.g. temp_directory/files
344+
s3_bucket_sub_path: The path to the s3 files in the bucket, exluding s3a:// + bucket name
308345
destination_path_dir: The location to move those files from s3 to, must not include the
309346
file name in the path. This path should be a diretory.
310347
@@ -325,5 +362,7 @@ def _move_data_csv_s3_to_local(
325362
final_path,
326363
)
327364
local_csv_file_paths.append(final_path)
328-
self._logger.info(f"Copied data files from S3 to local machine in {(time.time() - start_time):3f}s")
365+
self._logger.info(
366+
f"Copied data files from S3 to local machine in {(time.time() - start_time):3f}s"
367+
)
329368
return local_csv_file_paths

0 commit comments

Comments
 (0)