@@ -47,44 +47,42 @@ def get_worker_spark_session(
4747 return spark_builder .getOrCreate ()
4848
4949
50- def get_packages (connection_type : str ) -> list [str ]: # noqa: WPS212
50+ def get_packages (connection_types : set [ str ] ) -> list [str ]: # noqa: WPS212
5151 import pyspark
5252 from onetl .connection import MSSQL , Clickhouse , MySQL , Oracle , Postgres , SparkS3
5353 from onetl .file .format import XML , Excel
5454
55+ spark_version = pyspark .__version__
5556 # excel version is hardcoded due to https://github.com/nightscape/spark-excel/issues/902
5657 file_formats_spark_packages : list [str ] = [
57- * XML .get_packages (spark_version = pyspark . __version__ ),
58+ * XML .get_packages (spark_version = spark_version ),
5859 * Excel .get_packages (spark_version = "3.5.1" ),
5960 ]
6061
61- if connection_type == "postgres" :
62- return Postgres .get_packages ()
63- if connection_type == "oracle" :
64- return Oracle .get_packages ()
65- if connection_type == "clickhouse" :
66- return [
67- "io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2" ,
68- * Clickhouse .get_packages (),
69- ]
70- if connection_type == "mssql" :
71- return MSSQL .get_packages ()
72- if connection_type == "mysql" :
73- return MySQL .get_packages ()
74- if connection_type == "s3" :
75- import pyspark
76-
77- spark_version = pyspark .__version__
78- return SparkS3 .get_packages (spark_version = spark_version ) + file_formats_spark_packages
79-
80- if connection_type in ("hdfs" , "sftp" , "ftp" , "ftps" , "samba" , "webdav" ):
81- return file_formats_spark_packages
82-
83- # If the database type does not require downloading .jar packages
84- return []
85-
86-
87- def get_excluded_packages (db_type : str ) -> list [str ]:
62+ result = []
63+
64+ if connection_types & {"postgres" , "all" }:
65+ result .extend (Postgres .get_packages ())
66+ if connection_types & {"oracle" , "all" }:
67+ result .extend (Oracle .get_packages ())
68+ if connection_types & {"clickhouse" , "all" }:
69+ result .append ("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2" )
70+ result .extend (Clickhouse .get_packages ())
71+ if connection_types & {"mssql" , "all" }:
72+ result .extend (MSSQL .get_packages ())
73+ if connection_types & {"mysql" , "all" }:
74+ result .extend (MySQL .get_packages ())
75+
76+ if connection_types & {"s3" , "all" }:
77+ result .extend (SparkS3 .get_packages (spark_version = spark_version ))
78+
79+ if connection_types & {"s3" , "hdfs" , "sftp" , "ftp" , "ftps" , "samba" , "webdav" , "all" }:
80+ result .extend (file_formats_spark_packages )
81+
82+ return result
83+
84+
85+ def get_excluded_packages () -> list [str ]:
8886 from onetl .connection import SparkS3
8987
9088 return SparkS3 .get_exclude_packages ()
@@ -95,16 +93,11 @@ def get_spark_session_conf(
9593 target : ConnectionDTO ,
9694 resources : dict ,
9795) -> dict :
98- maven_packages : list [str ] = []
99- excluded_packages : list [str ] = []
100-
101- for db_type in source , target :
102- maven_packages .extend (get_packages (connection_type = db_type .type )) # type: ignore
103- excluded_packages .extend (get_excluded_packages (db_type = db_type .type )) # type: ignore
96+ maven_packages : list [str ] = get_packages (connection_types = {source .type , target .type })
97+ excluded_packages : list [str ] = get_excluded_packages ()
10498
10599 memory_mb = math .ceil (resources ["ram_bytes_per_task" ] / 1024 / 1024 )
106100 config = {
107- "spark.jars.packages" : "," .join (maven_packages ),
108101 "spark.sql.pyspark.jvmStacktrace.enabled" : "true" ,
109102 "spark.hadoop.mapreduce.fileoutputcommitter.marksuccessfuljobs" : "false" ,
110103 "spark.executor.cores" : resources ["cpu_cores_per_task" ],
0 commit comments