Skip to content

Commit e27ad39

Browse files
committed
[DOP-15023] Pass Run to CREATE_SPARK_SESSION_FUNCTION
1 parent 67a17b0 commit e27ad39

File tree

21 files changed

+321
-317
lines changed

21 files changed

+321
-317
lines changed

docker/Dockerfile.worker

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ COPY ./syncmaster/ /app/syncmaster/
3939

4040
FROM base as test
4141

42-
ENV CREATE_SPARK_SESSION_FUNCTION=tests.spark.get_worker_spark_session.get_worker_spark_session
42+
ENV CREATE_SPARK_SESSION_FUNCTION=tests.spark.get_worker_spark_session
4343

4444
# CI runs tests in the worker container, so we need backend dependencies too
4545
RUN poetry install --no-root --extras "worker backend" --with test --without docs,dev
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Pass current ``Run`` to ``CREATE_SPARK_SESSION_FUNCTION``. This allows using run/transfer/group information for Spark session options,
2+
like ``appName`` or custom ones.

syncmaster/dto/connections.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
22
# SPDX-License-Identifier: Apache-2.0
33
from dataclasses import dataclass
4+
from typing import ClassVar
45

56

67
@dataclass
78
class ConnectionDTO:
8-
pass
9+
type: ClassVar[str]
910

1011

1112
@dataclass
@@ -16,7 +17,7 @@ class PostgresConnectionDTO(ConnectionDTO):
1617
password: str
1718
additional_params: dict
1819
database_name: str
19-
type: str = "postgres"
20+
type: ClassVar[str] = "postgres"
2021

2122

2223
@dataclass
@@ -28,23 +29,23 @@ class OracleConnectionDTO(ConnectionDTO):
2829
additional_params: dict
2930
sid: str | None = None
3031
service_name: str | None = None
31-
type: str = "oracle"
32+
type: ClassVar[str] = "oracle"
3233

3334

3435
@dataclass
3536
class HiveConnectionDTO(ConnectionDTO):
3637
user: str
3738
password: str
3839
cluster: str
39-
type: str = "hive"
40+
type: ClassVar[str] = "hive"
4041

4142

4243
@dataclass
4344
class HDFSConnectionDTO(ConnectionDTO):
4445
user: str
4546
password: str
4647
cluster: str
47-
type: str = "hdfs"
48+
type: ClassVar[str] = "hdfs"
4849

4950

5051
@dataclass
@@ -57,4 +58,4 @@ class S3ConnectionDTO(ConnectionDTO):
5758
additional_params: dict
5859
region: str | None = None
5960
protocol: str = "https"
60-
type: str = "s3"
61+
type: ClassVar[str] = "s3"

syncmaster/dto/transfers.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,66 @@
11
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
22
# SPDX-License-Identifier: Apache-2.0
3+
import json
34
from dataclasses import dataclass
5+
from typing import ClassVar
46

5-
from syncmaster.schemas.v1.transfers.file_format import CSV, JSON, JSONLine
7+
from onetl.file.format import CSV, JSON, JSONLine
68

79

810
@dataclass
911
class TransferDTO:
10-
pass
12+
type: ClassVar[str]
1113

1214

1315
@dataclass
14-
class PostgresTransferDTO(TransferDTO):
16+
class DBTransferDTO(TransferDTO):
1517
table_name: str
16-
type: str = "postgres"
1718

1819

1920
@dataclass
20-
class OracleTransferDTO(TransferDTO):
21-
table_name: str
22-
type: str = "oracle"
21+
class FileTransferDTO(TransferDTO):
22+
directory_path: str
23+
file_format: CSV | JSONLine | JSON
24+
options: dict
25+
df_schema: dict | None = None
26+
27+
def __post_init__(self):
28+
if isinstance(self.file_format, dict):
29+
self.file_format = self._get_format(self.file_format.copy())
30+
if isinstance(self.df_schema, str):
31+
self.df_schema = json.loads(self.df_schema)
32+
33+
def _get_format(self, file_format: dict):
34+
file_type = file_format.pop("type", None)
35+
if file_type == "csv":
36+
return CSV.parse_obj(file_format)
37+
if file_type == "jsonline":
38+
return JSONLine.parse_obj(file_format)
39+
if file_type == "json":
40+
return JSON.parse_obj(file_format)
41+
raise ValueError("Unknown file type")
2342

2443

2544
@dataclass
26-
class HiveTransferDTO(TransferDTO):
27-
table_name: str
28-
type: str = "hive"
45+
class PostgresTransferDTO(DBTransferDTO):
46+
type: ClassVar[str] = "postgres"
2947

3048

3149
@dataclass
32-
class S3TransferDTO(TransferDTO):
33-
directory_path: str
34-
file_format: CSV | JSONLine | JSON
35-
options: dict
36-
df_schema: dict | None = None
37-
type: str = "s3"
50+
class OracleTransferDTO(DBTransferDTO):
51+
type: ClassVar[str] = "oracle"
3852

3953

4054
@dataclass
41-
class HDFSTransferDTO(TransferDTO):
42-
directory_path: str
43-
file_format: CSV | JSONLine | JSON
44-
options: dict
45-
df_schema: dict | None = None
46-
type: str = "hdfs"
55+
class HiveTransferDTO(DBTransferDTO):
56+
type: ClassVar[str] = "hive"
57+
58+
59+
@dataclass
60+
class S3TransferDTO(FileTransferDTO):
61+
type: ClassVar[str] = "s3"
62+
63+
64+
@dataclass
65+
class HDFSTransferDTO(FileTransferDTO):
66+
type: ClassVar[str] = "hdfs"

syncmaster/worker/controller.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
from syncmaster.config import Settings
7-
from syncmaster.db.models import Connection, Transfer
7+
from syncmaster.db.models import Connection, Run
88
from syncmaster.dto.connections import (
99
HDFSConnectionDTO,
1010
HiveConnectionDTO,
@@ -21,11 +21,11 @@
2121
)
2222
from syncmaster.exceptions.connection import ConnectionTypeNotRecognizedError
2323
from syncmaster.worker.handlers.base import Handler
24+
from syncmaster.worker.handlers.db.hive import HiveHandler
25+
from syncmaster.worker.handlers.db.oracle import OracleHandler
26+
from syncmaster.worker.handlers.db.postgres import PostgresHandler
2427
from syncmaster.worker.handlers.file.hdfs import HDFSHandler
2528
from syncmaster.worker.handlers.file.s3 import S3Handler
26-
from syncmaster.worker.handlers.hive import HiveHandler
27-
from syncmaster.worker.handlers.oracle import OracleHandler
28-
from syncmaster.worker.handlers.postgres import PostgresHandler
2929

3030
logger = logging.getLogger(__name__)
3131

@@ -65,47 +65,40 @@ class TransferController:
6565

6666
def __init__(
6767
self,
68-
transfer: Transfer,
68+
run: Run,
6969
source_connection: Connection,
7070
source_auth_data: dict,
7171
target_connection: Connection,
7272
target_auth_data: dict,
7373
settings: Settings,
7474
):
75+
self.run = run
76+
self.settings = settings
7577
self.source_handler = self.get_handler(
7678
connection_data=source_connection.data,
77-
transfer_params=transfer.source_params,
79+
transfer_params=run.transfer.source_params,
7880
connection_auth_data=source_auth_data,
7981
)
8082
self.target_handler = self.get_handler(
8183
connection_data=target_connection.data,
82-
transfer_params=transfer.target_params,
84+
transfer_params=run.transfer.target_params,
8385
connection_auth_data=target_auth_data,
8486
)
85-
spark = settings.CREATE_SPARK_SESSION_FUNCTION(
86-
settings,
87-
target=self.target_handler.connection_dto,
87+
88+
def perform_transfer(self) -> None:
89+
spark = self.settings.CREATE_SPARK_SESSION_FUNCTION(
90+
settings=self.settings,
91+
run=self.run,
8892
source=self.source_handler.connection_dto,
93+
target=self.target_handler.connection_dto,
8994
)
9095

91-
self.source_handler.set_spark(spark)
92-
self.target_handler.set_spark(spark)
93-
logger.info("source connection = %s", self.source_handler)
94-
logger.info("target connection = %s", self.target_handler)
95-
96-
def start_transfer(self) -> None:
97-
self.source_handler.init_connection()
98-
self.source_handler.init_reader()
99-
100-
self.target_handler.init_connection()
101-
self.target_handler.init_writer()
102-
logger.info("Source and target were initialized")
103-
104-
df = self.target_handler.normalize_column_name(self.source_handler.read())
105-
logger.info("Data has been read")
96+
with spark:
97+
self.source_handler.connect(spark)
98+
self.target_handler.connect(spark)
10699

107-
self.target_handler.write(df)
108-
logger.info("Data has been inserted")
100+
df = self.source_handler.read()
101+
self.target_handler.write(df)
109102

110103
def get_handler(
111104
self,
@@ -114,7 +107,8 @@ def get_handler(
114107
transfer_params: dict[str, Any],
115108
) -> Handler:
116109
connection_data.update(connection_auth_data)
117-
handler_type = connection_data["type"]
110+
handler_type = connection_data.pop("type")
111+
transfer_params.pop("type", None)
118112

119113
if connection_handler_proxy.get(handler_type, None) is None:
120114
raise ConnectionTypeNotRecognizedError

syncmaster/worker/handlers/base.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,33 @@
11
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
22
# SPDX-License-Identifier: Apache-2.0
3-
from abc import ABC
43

5-
from onetl.db import DBReader, DBWriter
6-
from pyspark.sql import SparkSession
7-
from pyspark.sql.dataframe import DataFrame
4+
from __future__ import annotations
5+
6+
from abc import ABC, abstractmethod
7+
from typing import TYPE_CHECKING
88

99
from syncmaster.dto.connections import ConnectionDTO
1010
from syncmaster.dto.transfers import TransferDTO
1111

12+
if TYPE_CHECKING:
13+
from pyspark.sql import SparkSession
14+
from pyspark.sql.dataframe import DataFrame
15+
1216

1317
class Handler(ABC):
1418
def __init__(
1519
self,
1620
connection_dto: ConnectionDTO,
1721
transfer_dto: TransferDTO,
18-
spark: SparkSession | None = None,
19-
) -> None:
20-
self.spark = spark
21-
self.reader: DBReader | None = None
22-
self.writer: DBWriter | None = None
22+
):
2323
self.connection_dto = connection_dto
2424
self.transfer_dto = transfer_dto
2525

26-
def init_connection(self): ...
27-
28-
def set_spark(self, spark: SparkSession):
29-
self.spark = spark
30-
31-
def init_reader(self):
32-
if self.connection_dto is None:
33-
raise ValueError("At first you need to initialize connection. Run `init_connection")
34-
35-
def init_writer(self):
36-
if self.connection_dto is None:
37-
raise ValueError("At first you need to initialize connection. Run `init_connection")
38-
39-
def read(self) -> DataFrame:
40-
if self.reader is None:
41-
raise ValueError("Reader is not initialized")
42-
return self.reader.run()
26+
@abstractmethod
27+
def connect(self, spark: SparkSession) -> None: ...
4328

44-
def write(self, df: DataFrame) -> None:
45-
if self.writer is None:
46-
raise ValueError("Writer is not initialized")
47-
return self.writer.run(df=df)
29+
@abstractmethod
30+
def read(self) -> DataFrame: ...
4831

49-
def normalize_column_name(self, df: DataFrame) -> DataFrame: ...
32+
@abstractmethod
33+
def write(self, df: DataFrame) -> None: ...
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from abc import abstractmethod
7+
from typing import TYPE_CHECKING
8+
9+
from onetl.base import BaseDBConnection
10+
from onetl.db import DBReader, DBWriter
11+
12+
from syncmaster.dto.transfers import DBTransferDTO
13+
from syncmaster.worker.handlers.base import Handler
14+
15+
if TYPE_CHECKING:
16+
from pyspark.sql.dataframe import DataFrame
17+
18+
19+
class DBHandler(Handler):
20+
connection: BaseDBConnection
21+
transfer_dto: DBTransferDTO
22+
23+
def read(self) -> DataFrame:
24+
reader = DBReader(
25+
connection=self.connection,
26+
table=self.transfer_dto.table_name,
27+
)
28+
return reader.run()
29+
30+
def write(self, df: DataFrame) -> None:
31+
writer = DBWriter(
32+
connection=self.connection,
33+
table=self.transfer_dto.table_name,
34+
)
35+
return writer.run(df=self.normalize_column_names(df))
36+
37+
@abstractmethod
38+
def normalize_column_names(self, df: DataFrame) -> DataFrame: ...
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from typing import TYPE_CHECKING
7+
8+
from onetl.connection import Hive
9+
10+
from syncmaster.dto.connections import HiveConnectionDTO
11+
from syncmaster.dto.transfers import HiveTransferDTO
12+
from syncmaster.worker.handlers.db.base import DBHandler
13+
14+
if TYPE_CHECKING:
15+
from pyspark.sql import SparkSession
16+
from pyspark.sql.dataframe import DataFrame
17+
18+
19+
class HiveHandler(DBHandler):
20+
connection: Hive
21+
connection_dto: HiveConnectionDTO
22+
transfer_dto: HiveTransferDTO
23+
24+
def connect(self, spark: SparkSession):
25+
self.connection = Hive(
26+
cluster=self.connection_dto.cluster,
27+
spark=spark,
28+
).check()
29+
30+
def read(self) -> DataFrame:
31+
self.connection.spark.catalog.refreshTable(self.transfer_dto.table_name)
32+
return super().read()
33+
34+
def normalize_column_names(self, df: DataFrame) -> DataFrame:
35+
for column_name in df.columns:
36+
df = df.withColumnRenamed(column_name, column_name.lower())
37+
return df

0 commit comments

Comments
 (0)