55could be used as stages or steps of an ETL job (aka "data pipeline")
66"""
77
8- import inspect
98import logging
109import os
11- import sys
1210import re
13-
11+ import sys
1412from datetime import date , datetime
15- from typing import Dict , List , Optional , Sequence , Set , Union
13+ from typing import Any , Dict , List , Optional , Sequence , Set , Union
1614
1715from django .core .management import call_command
1816from py4j .java_gateway import (
1917 JavaGateway ,
2018)
21- from py4j .protocol import Py4JJavaError
2219from pyspark .conf import SparkConf
2320from pyspark .context import SparkContext
2421from pyspark .find_spark_home import _find_spark_home
2522from pyspark .java_gateway import launch_gateway
26- from pyspark .serializers import read_int , UTF8Deserializer
23+ from pyspark .serializers import UTF8Deserializer , read_int
2724from pyspark .sql import SparkSession
2825
2926from usaspending_api .awards .delta_models .awards import AWARDS_COLUMNS
3027from usaspending_api .awards .delta_models .financial_accounts_by_awards import (
3128 FINANCIAL_ACCOUNTS_BY_AWARDS_COLUMNS ,
3229)
33- from usaspending_api .common .helpers .aws_helpers import is_aws , get_aws_credentials
30+ from usaspending_api .common .helpers .aws_helpers import get_aws_credentials , is_aws
3431from usaspending_api .config import CONFIG
35- from usaspending_api .config .utils import parse_pg_uri , parse_http_url
32+ from usaspending_api .config .utils import parse_http_url , parse_pg_uri
3633from usaspending_api .transactions .delta_models import (
3734 DETACHED_AWARD_PROCUREMENT_DELTA_COLUMNS ,
3835 PUBLISHED_FABS_COLUMNS ,
@@ -72,7 +69,13 @@ def is_spark_context_stopped() -> bool:
7269 # Check the Singleton instance populated if there's an active SparkContext
7370 if SparkContext ._active_spark_context is not None :
7471 sc = SparkContext ._active_spark_context
75- is_stopped = not (sc ._jvm and not sc ._jvm .SparkSession .getDefaultSession ().get ().sparkContext ().isStopped ())
72+ is_stopped = not (
73+ sc ._jvm
74+ and not sc ._jvm .SparkSession .getDefaultSession ()
75+ .get ()
76+ .sparkContext ()
77+ .isStopped ()
78+ )
7679 return is_stopped
7780
7881
@@ -86,7 +89,10 @@ def stop_spark_context() -> bool:
8689 sc ._jvm
8790 and hasattr (sc ._jvm , "SparkSession" )
8891 and sc ._jvm .SparkSession
89- and not sc ._jvm .SparkSession .getDefaultSession ().get ().sparkContext ().isStopped ()
92+ and not sc ._jvm .SparkSession .getDefaultSession ()
93+ .get ()
94+ .sparkContext ()
95+ .isStopped ()
9096 ):
9197 try :
9298 sc .stop ()
@@ -96,11 +102,11 @@ def stop_spark_context() -> bool:
96102 return stopped_without_error
97103
98104
99- def configure_spark_session (
105+ def configure_spark_session ( # noqa: C901,PLR0912,PLR0913,PLR0915
100106 java_gateway : JavaGateway = None ,
101107 spark_context : Union [SparkContext , SparkSession ] = None ,
102- master = None ,
103- app_name = "Spark App" ,
108+ master : str | None = None ,
109+ app_name : str = "Spark App" ,
104110 log_level : int = None ,
105111 log_spark_config_vals : bool = False ,
106112 log_hadoop_config_vals : bool = False ,
@@ -151,9 +157,15 @@ def configure_spark_session(
151157 property, otherwise an error will be thrown.
152158 """
153159 if spark_context and (
154- not spark_context ._jvm or spark_context ._jvm .SparkSession .getDefaultSession ().get ().sparkContext ().isStopped ()
160+ not spark_context ._jvm
161+ or spark_context ._jvm .SparkSession .getDefaultSession ()
162+ .get ()
163+ .sparkContext ()
164+ .isStopped ()
155165 ):
156- raise ValueError ("The provided spark_context arg is a stopped SparkContext. It must be active." )
166+ raise ValueError (
167+ "The provided spark_context arg is a stopped SparkContext. It must be active."
168+ )
157169 if spark_context and java_gateway :
158170 raise Exception (
159171 "Cannot provide BOTH spark_context and java_gateway args. The active spark_context supplies its own gateway"
@@ -259,7 +271,9 @@ def configure_spark_session(
259271 if spark_context :
260272 built_conf = spark .conf
261273 provided_conf_keys = [item [0 ] for item in conf .getAll ()]
262- non_modifiable_conf = [k for k in provided_conf_keys if not built_conf .isModifiable (k )]
274+ non_modifiable_conf = [
275+ k for k in provided_conf_keys if not built_conf .isModifiable (k )
276+ ]
263277 if non_modifiable_conf :
264278 raise ValueError (
265279 "An active SparkContext was given along with NEW spark config values. The following "
@@ -309,8 +323,8 @@ def configure_spark_session(
309323
310324
311325def read_java_gateway_connection_info (
312- gateway_conn_info_path ,
313- ): # pragma: no cover -- useful development util
326+ gateway_conn_info_path : str ,
327+ ) -> ( int , str ) : # pragma: no cover -- useful development util
314328 """Read the port and auth token from a file holding connection info to a running spark-submit process
315329
316330 Args:
@@ -326,8 +340,8 @@ def read_java_gateway_connection_info(
326340
327341
328342def attach_java_gateway (
329- gateway_port ,
330- gateway_auth_token ,
343+ gateway_port : int ,
344+ gateway_auth_token : str ,
331345) -> JavaGateway : # pragma: no cover -- useful development util
332346 """Create a new JavaGateway that latches onto the port of a running spark-submit process
333347
@@ -374,7 +388,9 @@ def attach_java_gateway(
374388 return gateway
375389
376390
377- def get_jdbc_connection_properties (fix_strings : bool = True , truncate : bool = False ) -> dict :
391+ def get_jdbc_connection_properties (
392+ fix_strings : bool = True , truncate : bool = False
393+ ) -> dict :
378394 jdbc_props = {
379395 "driver" : "org.postgresql.Driver" ,
380396 "fetchsize" : str (CONFIG .SPARK_PARTITION_ROWS ),
@@ -394,30 +410,32 @@ def get_jdbc_url_from_pg_uri(pg_uri: str) -> str:
394410 """Converts the passed-in Postgres DB connection URI to a JDBC-compliant Postgres DB connection string"""
395411 url_parts , user , password = parse_pg_uri (pg_uri )
396412 if user is None or password is None :
397- raise ValueError ("pg_uri provided must have username and password with host or in query string" )
413+ raise ValueError (
414+ "pg_uri provided must have username and password with host or in query string"
415+ )
398416 # JDBC URLs only support postgresql://
399417 pg_uri = f"postgresql://{ url_parts .hostname } :{ url_parts .port } { url_parts .path } ?user={ user } &password={ password } "
400418
401419 return f"jdbc:{ pg_uri } "
402420
403421
404- def get_usas_jdbc_url ():
422+ def get_usas_jdbc_url () -> str :
405423 """Getting a JDBC-compliant Postgres DB connection string hard-wired to the POSTGRES vars set in CONFIG"""
406424 if not CONFIG .DATABASE_URL :
407425 raise ValueError ("DATABASE_URL config val must provided" )
408426
409427 return get_jdbc_url_from_pg_uri (CONFIG .DATABASE_URL )
410428
411429
412- def get_broker_jdbc_url ():
430+ def get_broker_jdbc_url () -> str :
413431 """Getting a JDBC-compliant Broker Postgres DB connection string hard-wired to the POSTGRES vars set in CONFIG"""
414432 if not CONFIG .BROKER_DB :
415433 raise ValueError ("BROKER_DB config val must provided" )
416434
417435 return get_jdbc_url_from_pg_uri (CONFIG .BROKER_DB )
418436
419437
420- def get_es_config (): # pragma: no cover -- will be used eventually
438+ def get_es_config () -> dict [ str , Any ] : # pragma: no cover -- will be used eventually
421439 """
422440 Get a base template of Elasticsearch configuration settings tailored to the specific environment setup being
423441 used
@@ -454,7 +472,9 @@ def get_es_config(): # pragma: no cover -- will be used eventually
454472 "es.net.ssl" : str (ssl ).lower (), # default false
455473 "es.net.ssl.cert.allow.self.signed" : "true" , # default false
456474 "es.batch.size.entries" : str (CONFIG .ES_BATCH_ENTRIES ), # default 1000
457- "es.batch.size.bytes" : str (CONFIG .ES_MAX_BATCH_BYTES ), # default 1024*1024 (1mb)
475+ "es.batch.size.bytes" : str (
476+ CONFIG .ES_MAX_BATCH_BYTES
477+ ), # default 1024*1024 (1mb)
458478 "es.batch.write.refresh" : "false" , # default true, to refresh after configured batch size completes
459479 }
460480
@@ -466,48 +486,13 @@ def get_es_config(): # pragma: no cover -- will be used eventually
466486 return config
467487
468488
469- def get_jvm_logger (spark : SparkSession , logger_name = None ):
470- """
471- Get a JVM log4j Logger object instance to log through Java
472-
473- WARNING about Logging: This is NOT python's `logging` module
474- This is a python proxy to a java Log4J Logger object
475- As such, you can't do everything with it you'd do in Python, NOTABLY: passing
476- keyword args, like `logger.error("msg here", exc_info=exc)`. Instead do e.g.:
477- `logger.error("msg here", exc)`
478- Also, errors may not be loggable in the traditional way. See: https://www.py4j.org/py4j_java_protocol.html#
479- `logger.error("msg here", exc)` should probably just format the stack track from Java:
480- `logger.error("msg here:\n {str(exc)}")`
481- """
482- if not logger_name :
483- try :
484- calling_function_name = inspect .stack ()[1 ][3 ]
485- logger_name = calling_function_name
486- except Exception :
487- logger_name = "pyspark_job"
488- logger = spark ._jvm .org .apache .log4j .LogManager .getLogger (logger_name )
489- return logger
490-
491-
492- def log_java_exception (logger , exc , err_msg = "" ):
493- if exc and (isinstance (exc , Py4JJavaError ) or hasattr (exc , "java_exception" )):
494- logger .error (f"{ err_msg } \n { str (exc .java_exception )} " )
495- elif exc and hasattr (exc , "printStackTrace" ):
496- logger .error (f"{ err_msg } \n { str (exc .printStackTrace )} " )
497- else :
498- try :
499- logger .error (err_msg , exc )
500- except Exception :
501- logger .error (f"{ err_msg } \n { str (exc )} " )
502-
503-
504489def configure_s3_credentials (
505490 conf : SparkConf ,
506491 access_key : str = None ,
507492 secret_key : str = None ,
508493 profile : str = None ,
509494 temporary_creds : bool = False ,
510- ):
495+ ) -> None :
511496 """Set Spark config values allowing authentication to S3 for bucket data
512497
513498 See Also:
@@ -542,11 +527,15 @@ def configure_s3_credentials(
542527 "org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider" ,
543528 )
544529 conf .set ("spark.hadoop.fs.s3a.session.token" , aws_creds .token )
545- conf .set ("spark.hadoop.fs.s3a.assumed.role.sts.endpoint" , CONFIG .AWS_STS_ENDPOINT )
546- conf .set ("spark.hadoop.fs.s3a.assumed.role.sts.endpoint.region" , CONFIG .AWS_REGION )
530+ conf .set (
531+ "spark.hadoop.fs.s3a.assumed.role.sts.endpoint" , CONFIG .AWS_STS_ENDPOINT
532+ )
533+ conf .set (
534+ "spark.hadoop.fs.s3a.assumed.role.sts.endpoint.region" , CONFIG .AWS_REGION
535+ )
547536
548537
549- def log_spark_config (spark : SparkSession , config_key_contains = "" ):
538+ def log_spark_config (spark : SparkSession , config_key_contains : str = "" ) -> None :
550539 """Log at log4j INFO the values of the SparkConf object in the current SparkSession"""
551540 [
552541 logger .info (f"{ item [0 ]} ={ item [1 ]} " )
@@ -555,19 +544,28 @@ def log_spark_config(spark: SparkSession, config_key_contains=""):
555544 ]
556545
557546
558- def log_hadoop_config (spark : SparkSession , config_key_contains = "" ):
547+ def log_hadoop_config (spark : SparkSession , config_key_contains : str = "" ) -> None :
559548 """Print out to the log the current config values for hadoop. Limit to only those whose key contains the string
560549 provided to narrow in on a particular subset of config values.
561550 """
562551 conf = spark .sparkContext ._jsc .hadoopConfiguration ()
563552 [
564553 logger .info (f"{ k } ={ v } " )
565- for (k , v ) in {str (_ ).split ("=" )[0 ]: str (_ ).split ("=" )[1 ] for _ in conf .iterator ()}.items ()
554+ for (k , v ) in {
555+ str (_ ).split ("=" )[0 ]: str (_ ).split ("=" )[1 ] for _ in conf .iterator ()
556+ }.items ()
566557 if config_key_contains in k
567558 ]
568559
569560
570- def load_dict_to_delta_table (spark , s3_data_bucket , table_schema , table_name , data , overwrite = False ):
561+ def load_dict_to_delta_table ( # noqa: PLR0912
562+ spark : SparkSession ,
563+ s3_data_bucket : str ,
564+ table_schema : str ,
565+ table_name : str ,
566+ data : list [dict ],
567+ overwrite : bool = False ,
568+ ) -> None :
571569 """Create a table in Delta and populate it with the contents of a provided dicationary (data). This should
572570 primarily be used for unit testing.
573571
@@ -584,16 +582,23 @@ def load_dict_to_delta_table(spark, s3_data_bucket, table_schema, table_name, da
584582 table_to_col_names_dict = {}
585583 table_to_col_names_dict ["transaction_fabs" ] = TRANSACTION_FABS_COLUMNS
586584 table_to_col_names_dict ["transaction_fpds" ] = TRANSACTION_FPDS_COLUMNS
587- table_to_col_names_dict ["transaction_normalized" ] = list (TRANSACTION_NORMALIZED_COLUMNS )
585+ table_to_col_names_dict ["transaction_normalized" ] = list (
586+ TRANSACTION_NORMALIZED_COLUMNS
587+ )
588588 table_to_col_names_dict ["awards" ] = list (AWARDS_COLUMNS )
589- table_to_col_names_dict ["financial_accounts_by_awards" ] = list (FINANCIAL_ACCOUNTS_BY_AWARDS_COLUMNS )
590- table_to_col_names_dict ["detached_award_procurement" ] = list (DETACHED_AWARD_PROCUREMENT_DELTA_COLUMNS )
589+ table_to_col_names_dict ["financial_accounts_by_awards" ] = list (
590+ FINANCIAL_ACCOUNTS_BY_AWARDS_COLUMNS
591+ )
592+ table_to_col_names_dict ["detached_award_procurement" ] = list (
593+ DETACHED_AWARD_PROCUREMENT_DELTA_COLUMNS
594+ )
591595 table_to_col_names_dict ["published_fabs" ] = list (PUBLISHED_FABS_COLUMNS )
592596
593597 table_to_col_info_dict = {}
594598 for tbl_name , col_info in zip (
595599 ("transaction_fabs" , "transaction_fpds" ),
596600 (TRANSACTION_FABS_COLUMN_INFO , TRANSACTION_FPDS_COLUMN_INFO ),
601+ strict = False ,
597602 ):
598603 table_to_col_info_dict [tbl_name ] = {}
599604 for col in col_info :
@@ -653,7 +658,7 @@ def clean_postgres_sql_for_spark_sql(
653658 postgres_sql_str : str ,
654659 global_temp_view_proxies : List [str ] = None ,
655660 identifier_replacements : Dict [str , str ] = None ,
656- ):
661+ ) -> str :
657662 """Convert some of the known-to-be-problematic PostgreSQL syntax, which is not compliant with Spark SQL,
658663 to an acceptable and compliant Spark SQL alternative.
659664
@@ -681,7 +686,9 @@ def clean_postgres_sql_for_spark_sql(
681686 # Treat these type casts as string in Spark SQL
682687 # NOTE: If replacing a ::JSON cast, be sure that the string data coming from delta is treated as needed (e.g. as
683688 # JSON or converted to JSON or a dict) on the receiving side, and not just left as a string
684- spark_sql = re .sub (r"::text|::json" , r"::string" , spark_sql , flags = re .IGNORECASE | re .MULTILINE )
689+ spark_sql = re .sub (
690+ r"::text|::json" , r"::string" , spark_sql , flags = re .IGNORECASE | re .MULTILINE
691+ )
685692
686693 if global_temp_view_proxies :
687694 for vw in global_temp_view_proxies :
0 commit comments