@@ -47,44 +47,41 @@ def get_worker_spark_session(
4747 return spark_builder .getOrCreate ()
4848
4949
50- def get_packages (connection_type : str ) -> list [str ]: # noqa: WPS212
50+ def get_packages (connection_types : set [ str ] ) -> list [str ]: # noqa: WPS212
5151 import pyspark
5252 from onetl .connection import MSSQL , Clickhouse , MySQL , Oracle , Postgres , SparkS3
5353 from onetl .file .format import XML , Excel
5454
55+ spark_version = pyspark .__version__
5556 # excel version is hardcoded due to https://github.com/nightscape/spark-excel/issues/902
5657 file_formats_spark_packages : list [str ] = [
57- * XML .get_packages (spark_version = pyspark . __version__ ),
58+ * XML .get_packages (spark_version = spark_version ),
5859 * Excel .get_packages (spark_version = "3.5.1" ),
5960 ]
6061
61- if connection_type == "postgres" :
62- return Postgres .get_packages ()
63- if connection_type == "oracle" :
64- return Oracle .get_packages ()
65- if connection_type == "clickhouse" :
66- return [
67- "io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2" ,
68- * Clickhouse .get_packages (),
69- ]
70- if connection_type == "mssql" :
71- return MSSQL .get_packages ()
72- if connection_type == "mysql" :
73- return MySQL .get_packages ()
74- if connection_type == "s3" :
75- import pyspark
76-
77- spark_version = pyspark .__version__
78- return SparkS3 .get_packages (spark_version = spark_version ) + file_formats_spark_packages
79-
80- if connection_type in ("hdfs" , "sftp" , "ftp" , "ftps" , "samba" , "webdav" ):
81- return file_formats_spark_packages
82-
83- # If the database type does not require downloading .jar packages
84- return []
85-
86-
87- def get_excluded_packages (db_type : str ) -> list [str ]:
62+ result = []
63+ if connection_types & {"postgres" , "all" }:
64+ result .extend (Postgres .get_packages ())
65+ if connection_types & {"oracle" , "all" }:
66+ result .extend (Oracle .get_packages ())
67+ if connection_types & {"clickhouse" , "all" }:
68+ result .append ("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2" )
69+ result .extend (Clickhouse .get_packages ())
70+ if connection_types & {"mssql" , "all" }:
71+ result .extend (MSSQL .get_packages ())
72+ if connection_types & {"mysql" , "all" }:
73+ result .extend (MySQL .get_packages ())
74+
75+ if connection_types & {"s3" , "all" }:
76+ result .extend (SparkS3 .get_packages (spark_version = spark_version ))
77+
78+ if connection_types & {"s3" , "hdfs" , "sftp" , "ftp" , "ftps" , "samba" , "webdav" , "all" }:
79+ result .extend (file_formats_spark_packages )
80+
81+ return result
82+
83+
84+ def get_excluded_packages () -> list [str ]:
8885 from onetl .connection import SparkS3
8986
9087 return SparkS3 .get_exclude_packages ()
@@ -95,16 +92,11 @@ def get_spark_session_conf(
9592 target : ConnectionDTO ,
9693 resources : dict ,
9794) -> dict :
98- maven_packages : list [str ] = []
99- excluded_packages : list [str ] = []
100-
101- for db_type in source , target :
102- maven_packages .extend (get_packages (connection_type = db_type .type )) # type: ignore
103- excluded_packages .extend (get_excluded_packages (db_type = db_type .type )) # type: ignore
95+ maven_packages : list [str ] = get_packages (connection_types = {source .type , target .type })
96+ excluded_packages : list [str ] = get_excluded_packages ()
10497
10598 memory_mb = math .ceil (resources ["ram_bytes_per_task" ] / 1024 / 1024 )
10699 config = {
107- "spark.jars.packages" : "," .join (maven_packages ),
108100 "spark.sql.pyspark.jvmStacktrace.enabled" : "true" ,
109101 "spark.hadoop.mapreduce.fileoutputcommitter.marksuccessfuljobs" : "false" ,
110102 "spark.executor.cores" : resources ["cpu_cores_per_task" ],
0 commit comments