Skip to content

Commit 6a23778

Browse files
style checks
1 parent ed5ce8b commit 6a23778

File tree

7 files changed

+477
-243
lines changed

7 files changed

+477
-243
lines changed

usaspending_api/common/helpers/spark_helpers.py

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,31 @@
55
could be used as stages or steps of an ETL job (aka "data pipeline")
66
"""
77

8-
import inspect
98
import logging
109
import os
11-
import sys
1210
import re
13-
11+
import sys
1412
from datetime import date, datetime
15-
from typing import Dict, List, Optional, Sequence, Set, Union
13+
from typing import Any, Dict, List, Optional, Sequence, Set, Union
1614

1715
from django.core.management import call_command
1816
from py4j.java_gateway import (
1917
JavaGateway,
2018
)
21-
from py4j.protocol import Py4JJavaError
2219
from pyspark.conf import SparkConf
2320
from pyspark.context import SparkContext
2421
from pyspark.find_spark_home import _find_spark_home
2522
from pyspark.java_gateway import launch_gateway
26-
from pyspark.serializers import read_int, UTF8Deserializer
23+
from pyspark.serializers import UTF8Deserializer, read_int
2724
from pyspark.sql import SparkSession
2825

2926
from usaspending_api.awards.delta_models.awards import AWARDS_COLUMNS
3027
from usaspending_api.awards.delta_models.financial_accounts_by_awards import (
3128
FINANCIAL_ACCOUNTS_BY_AWARDS_COLUMNS,
3229
)
33-
from usaspending_api.common.helpers.aws_helpers import is_aws, get_aws_credentials
30+
from usaspending_api.common.helpers.aws_helpers import get_aws_credentials, is_aws
3431
from usaspending_api.config import CONFIG
35-
from usaspending_api.config.utils import parse_pg_uri, parse_http_url
32+
from usaspending_api.config.utils import parse_http_url, parse_pg_uri
3633
from usaspending_api.transactions.delta_models import (
3734
DETACHED_AWARD_PROCUREMENT_DELTA_COLUMNS,
3835
PUBLISHED_FABS_COLUMNS,
@@ -72,7 +69,13 @@ def is_spark_context_stopped() -> bool:
7269
# Check the Singleton instance populated if there's an active SparkContext
7370
if SparkContext._active_spark_context is not None:
7471
sc = SparkContext._active_spark_context
75-
is_stopped = not (sc._jvm and not sc._jvm.SparkSession.getDefaultSession().get().sparkContext().isStopped())
72+
is_stopped = not (
73+
sc._jvm
74+
and not sc._jvm.SparkSession.getDefaultSession()
75+
.get()
76+
.sparkContext()
77+
.isStopped()
78+
)
7679
return is_stopped
7780

7881

@@ -86,7 +89,10 @@ def stop_spark_context() -> bool:
8689
sc._jvm
8790
and hasattr(sc._jvm, "SparkSession")
8891
and sc._jvm.SparkSession
89-
and not sc._jvm.SparkSession.getDefaultSession().get().sparkContext().isStopped()
92+
and not sc._jvm.SparkSession.getDefaultSession()
93+
.get()
94+
.sparkContext()
95+
.isStopped()
9096
):
9197
try:
9298
sc.stop()
@@ -96,11 +102,11 @@ def stop_spark_context() -> bool:
96102
return stopped_without_error
97103

98104

99-
def configure_spark_session(
105+
def configure_spark_session( # noqa: C901,PLR0912,PLR0913,PLR0915
100106
java_gateway: JavaGateway = None,
101107
spark_context: Union[SparkContext, SparkSession] = None,
102-
master=None,
103-
app_name="Spark App",
108+
master: str | None = None,
109+
app_name: str = "Spark App",
104110
log_level: int = None,
105111
log_spark_config_vals: bool = False,
106112
log_hadoop_config_vals: bool = False,
@@ -151,9 +157,15 @@ def configure_spark_session(
151157
property, otherwise an error will be thrown.
152158
"""
153159
if spark_context and (
154-
not spark_context._jvm or spark_context._jvm.SparkSession.getDefaultSession().get().sparkContext().isStopped()
160+
not spark_context._jvm
161+
or spark_context._jvm.SparkSession.getDefaultSession()
162+
.get()
163+
.sparkContext()
164+
.isStopped()
155165
):
156-
raise ValueError("The provided spark_context arg is a stopped SparkContext. It must be active.")
166+
raise ValueError(
167+
"The provided spark_context arg is a stopped SparkContext. It must be active."
168+
)
157169
if spark_context and java_gateway:
158170
raise Exception(
159171
"Cannot provide BOTH spark_context and java_gateway args. The active spark_context supplies its own gateway"
@@ -259,7 +271,9 @@ def configure_spark_session(
259271
if spark_context:
260272
built_conf = spark.conf
261273
provided_conf_keys = [item[0] for item in conf.getAll()]
262-
non_modifiable_conf = [k for k in provided_conf_keys if not built_conf.isModifiable(k)]
274+
non_modifiable_conf = [
275+
k for k in provided_conf_keys if not built_conf.isModifiable(k)
276+
]
263277
if non_modifiable_conf:
264278
raise ValueError(
265279
"An active SparkContext was given along with NEW spark config values. The following "
@@ -309,8 +323,8 @@ def configure_spark_session(
309323

310324

311325
def read_java_gateway_connection_info(
312-
gateway_conn_info_path,
313-
): # pragma: no cover -- useful development util
326+
gateway_conn_info_path: str,
327+
) -> (int, str): # pragma: no cover -- useful development util
314328
"""Read the port and auth token from a file holding connection info to a running spark-submit process
315329
316330
Args:
@@ -326,8 +340,8 @@ def read_java_gateway_connection_info(
326340

327341

328342
def attach_java_gateway(
329-
gateway_port,
330-
gateway_auth_token,
343+
gateway_port: int,
344+
gateway_auth_token: str,
331345
) -> JavaGateway: # pragma: no cover -- useful development util
332346
"""Create a new JavaGateway that latches onto the port of a running spark-submit process
333347
@@ -374,7 +388,9 @@ def attach_java_gateway(
374388
return gateway
375389

376390

377-
def get_jdbc_connection_properties(fix_strings: bool = True, truncate: bool = False) -> dict:
391+
def get_jdbc_connection_properties(
392+
fix_strings: bool = True, truncate: bool = False
393+
) -> dict:
378394
jdbc_props = {
379395
"driver": "org.postgresql.Driver",
380396
"fetchsize": str(CONFIG.SPARK_PARTITION_ROWS),
@@ -394,30 +410,32 @@ def get_jdbc_url_from_pg_uri(pg_uri: str) -> str:
394410
"""Converts the passed-in Postgres DB connection URI to a JDBC-compliant Postgres DB connection string"""
395411
url_parts, user, password = parse_pg_uri(pg_uri)
396412
if user is None or password is None:
397-
raise ValueError("pg_uri provided must have username and password with host or in query string")
413+
raise ValueError(
414+
"pg_uri provided must have username and password with host or in query string"
415+
)
398416
# JDBC URLs only support postgresql://
399417
pg_uri = f"postgresql://{url_parts.hostname}:{url_parts.port}{url_parts.path}?user={user}&password={password}"
400418

401419
return f"jdbc:{pg_uri}"
402420

403421

404-
def get_usas_jdbc_url():
422+
def get_usas_jdbc_url() -> str:
405423
"""Getting a JDBC-compliant Postgres DB connection string hard-wired to the POSTGRES vars set in CONFIG"""
406424
if not CONFIG.DATABASE_URL:
407425
raise ValueError("DATABASE_URL config val must provided")
408426

409427
return get_jdbc_url_from_pg_uri(CONFIG.DATABASE_URL)
410428

411429

412-
def get_broker_jdbc_url():
430+
def get_broker_jdbc_url() -> str:
413431
"""Getting a JDBC-compliant Broker Postgres DB connection string hard-wired to the POSTGRES vars set in CONFIG"""
414432
if not CONFIG.BROKER_DB:
415433
raise ValueError("BROKER_DB config val must provided")
416434

417435
return get_jdbc_url_from_pg_uri(CONFIG.BROKER_DB)
418436

419437

420-
def get_es_config(): # pragma: no cover -- will be used eventually
438+
def get_es_config() -> dict[str, Any]: # pragma: no cover -- will be used eventually
421439
"""
422440
Get a base template of Elasticsearch configuration settings tailored to the specific environment setup being
423441
used
@@ -454,7 +472,9 @@ def get_es_config(): # pragma: no cover -- will be used eventually
454472
"es.net.ssl": str(ssl).lower(), # default false
455473
"es.net.ssl.cert.allow.self.signed": "true", # default false
456474
"es.batch.size.entries": str(CONFIG.ES_BATCH_ENTRIES), # default 1000
457-
"es.batch.size.bytes": str(CONFIG.ES_MAX_BATCH_BYTES), # default 1024*1024 (1mb)
475+
"es.batch.size.bytes": str(
476+
CONFIG.ES_MAX_BATCH_BYTES
477+
), # default 1024*1024 (1mb)
458478
"es.batch.write.refresh": "false", # default true, to refresh after configured batch size completes
459479
}
460480

@@ -466,48 +486,13 @@ def get_es_config(): # pragma: no cover -- will be used eventually
466486
return config
467487

468488

469-
def get_jvm_logger(spark: SparkSession, logger_name=None):
470-
"""
471-
Get a JVM log4j Logger object instance to log through Java
472-
473-
WARNING about Logging: This is NOT python's `logging` module
474-
This is a python proxy to a java Log4J Logger object
475-
As such, you can't do everything with it you'd do in Python, NOTABLY: passing
476-
keyword args, like `logger.error("msg here", exc_info=exc)`. Instead do e.g.:
477-
`logger.error("msg here", exc)`
478-
Also, errors may not be loggable in the traditional way. See: https://www.py4j.org/py4j_java_protocol.html#
479-
`logger.error("msg here", exc)` should probably just format the stack track from Java:
480-
`logger.error("msg here:\n {str(exc)}")`
481-
"""
482-
if not logger_name:
483-
try:
484-
calling_function_name = inspect.stack()[1][3]
485-
logger_name = calling_function_name
486-
except Exception:
487-
logger_name = "pyspark_job"
488-
logger = spark._jvm.org.apache.log4j.LogManager.getLogger(logger_name)
489-
return logger
490-
491-
492-
def log_java_exception(logger, exc, err_msg=""):
493-
if exc and (isinstance(exc, Py4JJavaError) or hasattr(exc, "java_exception")):
494-
logger.error(f"{err_msg}\n{str(exc.java_exception)}")
495-
elif exc and hasattr(exc, "printStackTrace"):
496-
logger.error(f"{err_msg}\n{str(exc.printStackTrace)}")
497-
else:
498-
try:
499-
logger.error(err_msg, exc)
500-
except Exception:
501-
logger.error(f"{err_msg}\n{str(exc)}")
502-
503-
504489
def configure_s3_credentials(
505490
conf: SparkConf,
506491
access_key: str = None,
507492
secret_key: str = None,
508493
profile: str = None,
509494
temporary_creds: bool = False,
510-
):
495+
) -> None:
511496
"""Set Spark config values allowing authentication to S3 for bucket data
512497
513498
See Also:
@@ -542,11 +527,15 @@ def configure_s3_credentials(
542527
"org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider",
543528
)
544529
conf.set("spark.hadoop.fs.s3a.session.token", aws_creds.token)
545-
conf.set("spark.hadoop.fs.s3a.assumed.role.sts.endpoint", CONFIG.AWS_STS_ENDPOINT)
546-
conf.set("spark.hadoop.fs.s3a.assumed.role.sts.endpoint.region", CONFIG.AWS_REGION)
530+
conf.set(
531+
"spark.hadoop.fs.s3a.assumed.role.sts.endpoint", CONFIG.AWS_STS_ENDPOINT
532+
)
533+
conf.set(
534+
"spark.hadoop.fs.s3a.assumed.role.sts.endpoint.region", CONFIG.AWS_REGION
535+
)
547536

548537

549-
def log_spark_config(spark: SparkSession, config_key_contains=""):
538+
def log_spark_config(spark: SparkSession, config_key_contains: str = "") -> None:
550539
"""Log at log4j INFO the values of the SparkConf object in the current SparkSession"""
551540
[
552541
logger.info(f"{item[0]}={item[1]}")
@@ -555,19 +544,28 @@ def log_spark_config(spark: SparkSession, config_key_contains=""):
555544
]
556545

557546

558-
def log_hadoop_config(spark: SparkSession, config_key_contains=""):
547+
def log_hadoop_config(spark: SparkSession, config_key_contains: str = "") -> None:
559548
"""Print out to the log the current config values for hadoop. Limit to only those whose key contains the string
560549
provided to narrow in on a particular subset of config values.
561550
"""
562551
conf = spark.sparkContext._jsc.hadoopConfiguration()
563552
[
564553
logger.info(f"{k}={v}")
565-
for (k, v) in {str(_).split("=")[0]: str(_).split("=")[1] for _ in conf.iterator()}.items()
554+
for (k, v) in {
555+
str(_).split("=")[0]: str(_).split("=")[1] for _ in conf.iterator()
556+
}.items()
566557
if config_key_contains in k
567558
]
568559

569560

570-
def load_dict_to_delta_table(spark, s3_data_bucket, table_schema, table_name, data, overwrite=False):
561+
def load_dict_to_delta_table( # noqa: PLR0912
562+
spark: SparkSession,
563+
s3_data_bucket: str,
564+
table_schema: str,
565+
table_name: str,
566+
data: list[dict],
567+
overwrite: bool = False,
568+
) -> None:
571569
"""Create a table in Delta and populate it with the contents of a provided dicationary (data). This should
572570
primarily be used for unit testing.
573571
@@ -584,16 +582,23 @@ def load_dict_to_delta_table(spark, s3_data_bucket, table_schema, table_name, da
584582
table_to_col_names_dict = {}
585583
table_to_col_names_dict["transaction_fabs"] = TRANSACTION_FABS_COLUMNS
586584
table_to_col_names_dict["transaction_fpds"] = TRANSACTION_FPDS_COLUMNS
587-
table_to_col_names_dict["transaction_normalized"] = list(TRANSACTION_NORMALIZED_COLUMNS)
585+
table_to_col_names_dict["transaction_normalized"] = list(
586+
TRANSACTION_NORMALIZED_COLUMNS
587+
)
588588
table_to_col_names_dict["awards"] = list(AWARDS_COLUMNS)
589-
table_to_col_names_dict["financial_accounts_by_awards"] = list(FINANCIAL_ACCOUNTS_BY_AWARDS_COLUMNS)
590-
table_to_col_names_dict["detached_award_procurement"] = list(DETACHED_AWARD_PROCUREMENT_DELTA_COLUMNS)
589+
table_to_col_names_dict["financial_accounts_by_awards"] = list(
590+
FINANCIAL_ACCOUNTS_BY_AWARDS_COLUMNS
591+
)
592+
table_to_col_names_dict["detached_award_procurement"] = list(
593+
DETACHED_AWARD_PROCUREMENT_DELTA_COLUMNS
594+
)
591595
table_to_col_names_dict["published_fabs"] = list(PUBLISHED_FABS_COLUMNS)
592596

593597
table_to_col_info_dict = {}
594598
for tbl_name, col_info in zip(
595599
("transaction_fabs", "transaction_fpds"),
596600
(TRANSACTION_FABS_COLUMN_INFO, TRANSACTION_FPDS_COLUMN_INFO),
601+
strict=False,
597602
):
598603
table_to_col_info_dict[tbl_name] = {}
599604
for col in col_info:
@@ -653,7 +658,7 @@ def clean_postgres_sql_for_spark_sql(
653658
postgres_sql_str: str,
654659
global_temp_view_proxies: List[str] = None,
655660
identifier_replacements: Dict[str, str] = None,
656-
):
661+
) -> str:
657662
"""Convert some of the known-to-be-problematic PostgreSQL syntax, which is not compliant with Spark SQL,
658663
to an acceptable and compliant Spark SQL alternative.
659664
@@ -681,7 +686,9 @@ def clean_postgres_sql_for_spark_sql(
681686
# Treat these type casts as string in Spark SQL
682687
# NOTE: If replacing a ::JSON cast, be sure that the string data coming from delta is treated as needed (e.g. as
683688
# JSON or converted to JSON or a dict) on the receiving side, and not just left as a string
684-
spark_sql = re.sub(r"::text|::json", r"::string", spark_sql, flags=re.IGNORECASE | re.MULTILINE)
689+
spark_sql = re.sub(
690+
r"::text|::json", r"::string", spark_sql, flags=re.IGNORECASE | re.MULTILINE
691+
)
685692

686693
if global_temp_view_proxies:
687694
for vw in global_temp_view_proxies:

0 commit comments

Comments
 (0)