66"""
77
88import logging
9+ import math
910import time
1011from collections import namedtuple
1112from itertools import chain
1213from typing import List
1314
14- from py4j .protocol import Py4JError
1515from pyspark .sql import DataFrame , SparkSession
1616from pyspark .sql .functions import col , concat , concat_ws , expr , lit , regexp_replace , to_date , transform , when
1717from pyspark .sql .types import ArrayType , DecimalType , StringType , StructType
1818
1919from usaspending_api .accounts .models import AppropriationAccountBalances , FederalAccount , TreasuryAppropriationAccount
20+ from usaspending_api .common .helpers .s3_helpers import rename_s3_object , retrieve_s3_bucket_object_list
2021from usaspending_api .common .helpers .spark_helpers import (
2122 get_broker_jdbc_url ,
2223 get_jdbc_connection_properties ,
@@ -97,7 +98,6 @@ def extract_db_data_frame(
9798 is_date_partitioning_col : bool = False ,
9899 custom_schema : StructType = None ,
99100) -> DataFrame :
100-
101101 logger .info (f"Getting partition bounds using SQL:\n { min_max_sql } " )
102102
103103 data_df = None
@@ -427,7 +427,7 @@ def diff(
427427 cols_to_show = (
428428 ["diff" ]
429429 + [f"l.{ unique_key_col } " , f"r.{ unique_key_col } " ]
430- + list (chain (* zip ([f"l.{ c } " for c in compare_cols ], [f"r.{ c } " for c in compare_cols ])))
430+ + list (chain (* zip ([f"l.{ c } " for c in compare_cols ], [f"r.{ c } " for c in compare_cols ], strict = False )))
431431 )
432432 differences = differences .select (* cols_to_show )
433433 if not include_unchanged_rows :
@@ -588,7 +588,6 @@ def write_csv_file(
588588 spark : SparkSession ,
589589 df : DataFrame ,
590590 parts_dir : str ,
591- num_partitions : int ,
592591 max_records_per_file = EXCEL_ROW_LIMIT ,
593592 overwrite = True ,
594593 logger = None ,
@@ -599,7 +598,6 @@ def write_csv_file(
599598 spark: passed-in active SparkSession
600599 df: the DataFrame wrapping the data source to be dumped to CSV.
601600 parts_dir: Path to dir that will contain the outputted parts files from partitions
602- num_partitions: Indicates the number of partitions to use when writing the Dataframe
603601 overwrite: Whether to replace the file CSV files if they already exist by that name
604602 max_records_per_file: Suggestion to Spark of how many records to put in each written CSV file part,
605603 if it will end up writing multiple files.
@@ -617,12 +615,13 @@ def write_csv_file(
617615 start = time .time ()
618616 logger .info (f"Writing source data DataFrame to csv part files for file { parts_dir } ..." )
619617 df_record_count = df .count ()
618+ num_partitions = math .ceil (df_record_count / max_records_per_file ) or 1
620619 df .repartition (num_partitions ).write .options (
621620 # NOTE: this is a suggestion, to be used by Spark if partitions yield multiple files
622621 maxRecordsPerFile = max_records_per_file ,
623622 ).csv (
624623 path = parts_dir ,
625- header = False ,
624+ header = True ,
626625 emptyValue = "" , # "" creates the output of ,,, for null values to match behavior of previous Postgres job
627626 escape = '"' , # " is used to escape the 'quote' character setting (which defaults to "). Escaped quote = ""
628627 ignoreLeadingWhiteSpace = False , # must set for CSV write, as it defaults to true
@@ -636,112 +635,6 @@ def write_csv_file(
636635 return df_record_count
637636
638637
639- def hadoop_copy_merge (
640- spark : SparkSession ,
641- parts_dir : str ,
642- header : str ,
643- part_merge_group_size : int ,
644- logger = None ,
645- file_format = "csv" ,
646- ) -> List [str ]:
647- """PySpark impl of Hadoop 2.x copyMerge() (deprecated in Hadoop 3.x)
648- Merges files from a provided input directory and then redivides them
649- into multiple files based on merge group size.
650- Args:
651- spark: passed-in active SparkSession
652- parts_dir: Path to the dir that contains the input parts files. The parts dir name
653- determines the name of the merged files. Parts_dir cannot have a trailing slash.
654- header: A comma-separated list of field names, to be placed as the first row of every final CSV file.
655- Individual part files must NOT therefore be created with their own header.
656- part_merge_group_size: Final CSV data will be subdivided into numbered files. This indicates how many part files
657- should be combined into a numbered file.
658- logger: The logger to use. If one note provided (e.g. to log to console or stdout) the underlying JVM-based
659- Logger will be extracted from the ``spark`` ``SparkSession`` and used as the logger.
660- file_format: The format of the part files and the format of the final merged file, e.g. "csv"
661-
662- Returns:
663- A list of file paths where each element in the list denotes a path to
664- a merged file that was generated during the copy merge.
665- """
666- overwrite = True
667- hadoop = spark .sparkContext ._jvm .org .apache .hadoop
668- conf = spark .sparkContext ._jsc .hadoopConfiguration ()
669-
670- # Guard against incorrectly formatted argument value
671- parts_dir = parts_dir .rstrip ("/" )
672-
673- parts_dir_path = hadoop .fs .Path (parts_dir )
674-
675- fs = parts_dir_path .getFileSystem (conf )
676-
677- if not fs .exists (parts_dir_path ):
678- raise ValueError ("Source directory {} does not exist" .format (parts_dir ))
679-
680- file = parts_dir
681- file_path = hadoop .fs .Path (file )
682-
683- # Don't delete first if disallowing overwrite.
684- if not overwrite and fs .exists (file_path ):
685- raise Py4JError (
686- spark ._jvm .org .apache .hadoop .fs .FileAlreadyExistsException (f"{ str (file_path )} " f"already exists" )
687- )
688- part_files = []
689-
690- for f in fs .listStatus (parts_dir_path ):
691- if f .isFile ():
692- # Sometimes part files can be empty, we need to ignore them
693- if f .getLen () == 0 :
694- continue
695- file_path = f .getPath ()
696- if file_path .getName ().startswith ("_" ):
697- logger .debug (f"Skipping non-part file: { file_path .getName ()} " )
698- continue
699- logger .debug (f"Including part file: { file_path .getName ()} " )
700- part_files .append (f .getPath ())
701- if not part_files :
702- logger .warning ("Source directory is empty with no part files. Attempting creation of file with CSV header only" )
703- out_stream = None
704- try :
705- merged_file_path = f"{ parts_dir } .{ file_format } "
706- out_stream = fs .create (hadoop .fs .Path (merged_file_path ), overwrite )
707- out_stream .writeBytes (header + "\n " )
708- finally :
709- if out_stream is not None :
710- out_stream .close ()
711- return [merged_file_path ]
712-
713- part_files .sort (key = lambda f : str (f )) # put parts in order by part number for merging
714- paths_to_merged_files = []
715- for parts_file_group in _merge_grouper (part_files , part_merge_group_size ):
716- part_suffix = f"_{ str (parts_file_group .part ).zfill (2 )} " if parts_file_group .part else ""
717- partial_merged_file = f"{ parts_dir } .partial{ part_suffix } "
718- partial_merged_file_path = hadoop .fs .Path (partial_merged_file )
719- merged_file_path = f"{ parts_dir } { part_suffix } .{ file_format } "
720- paths_to_merged_files .append (merged_file_path )
721- # Make path a hadoop path because we are working with a hadoop file system
722- merged_file_path = hadoop .fs .Path (merged_file_path )
723- if overwrite and fs .exists (merged_file_path ):
724- fs .delete (merged_file_path , True )
725- out_stream = None
726- try :
727- if fs .exists (partial_merged_file_path ):
728- fs .delete (partial_merged_file_path , True )
729- out_stream = fs .create (partial_merged_file_path )
730- out_stream .writeBytes (header + "\n " )
731- _merge_file_parts (fs , out_stream , conf , hadoop , partial_merged_file_path , parts_file_group .file_list )
732- finally :
733- if out_stream is not None :
734- out_stream .close ()
735- try :
736- fs .rename (partial_merged_file_path , merged_file_path )
737- except Exception :
738- if fs .exists (partial_merged_file_path ):
739- fs .delete (partial_merged_file_path , True )
740- logger .exception ("Exception encountered. See logs" )
741- raise
742- return paths_to_merged_files
743-
744-
745638def _merge_file_parts (fs , out_stream , conf , hadoop , partial_merged_file_path , part_file_list ):
746639 """Read-in files in alphabetical order and append them one by one to the merged file"""
747640
@@ -767,3 +660,49 @@ def _merge_grouper(items, group_size):
767660 group_generator = (items [i : i + group_size ] for i in range (0 , len (items ), group_size ))
768661 for i , group in enumerate (group_generator , start = 1 ):
769662 yield FileMergeGroup (i , group )
663+
664+
665+ def rename_part_files (
666+ bucket_name : str ,
667+ destination_file_name : str ,
668+ logger : logging .Logger ,
669+ temp_download_dir_name : str = "temp_download" ,
670+ file_format : str = "csv" ,
671+ ) -> list [str ]:
672+ """Renames the part-000.csv files to match the zip filename structure.
673+
674+ Args:
675+ bucket_name: S3 bucket that contains the file to be renamed and will contain the renamed file.
676+ destination_file_name: Timestamped download file name. This is used to find the correct folder within the
677+ bucket.
678+ logger: Logger instance.
679+ temp_download_dir_name: Name of the folder to used to store the renamed CSV files before they are downloaded.
680+ Defaults to "temp_download".
681+ file_format: What file format to save the files in.
682+ Defaults to "csv".
683+
684+ Returns:
685+ A list of the full S3 paths for the CSV files.
686+ """
687+ list_of_part_files = sorted (
688+ [
689+ file .key
690+ for file in retrieve_s3_bucket_object_list (bucket_name )
691+ if (
692+ file .key .startswith (f"{ temp_download_dir_name } /{ destination_file_name } /part-" )
693+ and file .key .endswith (file_format )
694+ )
695+ ]
696+ )
697+
698+ full_file_paths = []
699+
700+ for index , part_file in enumerate (list_of_part_files ):
701+ old_key = f"{ bucket_name } /{ part_file } "
702+ new_key = f"{ temp_download_dir_name } /{ destination_file_name } _{ str (index + 1 ).zfill (2 )} .{ file_format } "
703+ logger .info (f"Renaming { old_key } to { bucket_name } /{ new_key } " )
704+
705+ rename_s3_object (bucket_name = bucket_name , old_key = old_key , new_key = new_key )
706+ full_file_paths .append (f"s3a://{ bucket_name } /{ new_key } " )
707+
708+ return full_file_paths
0 commit comments